複数語の整数演算を行う 1 つの方法は、double-double 演算を使用することです。いくつかの double-double 乗算コードから始めましょう
#include <math.h>
typedef struct {
double hi;
double lo;
} doubledouble;
static doubledouble quick_two_sum(double a, double b) {
double s = a + b;
double e = b - (s - a);
return (doubledouble){s, e};
}
static doubledouble two_prod(double a, double b) {
double p = a*b;
double e = fma(a, b, -p);
return (doubledouble){p, e};
}
doubledouble df64_mul(doubledouble a, doubledouble b) {
doubledouble p = two_prod(a.hi, b.hi);
p.lo += a.hi*b.lo;
p.lo += a.lo*b.hi;
return quick_two_sum(p.hi, p.lo);
}
この関数two_prod
は、整数 53bx53b -> 106b を 2 つの命令で実行できます。この関数df64_mul
は、整数 106bx106b -> 106b を実行できます。
これを整数 128bx128b -> 整数ハードウェアの 128b と比較してみましょう。
__int128 mul128(__int128 a, __int128 b) {
return a*b;
}
のアセンブリmul128
imul rsi, rdx
mov rax, rdi
imul rcx, rdi
mul rdx
add rcx, rsi
add rdx, rcx
のアセンブリdf64_mul
(でコンパイルgcc -O3 -S i128.c -masm=intel -mfma -ffp-contract=off
)
vmulsd xmm4, xmm0, xmm2
vmulsd xmm3, xmm0, xmm3
vmulsd xmm1, xmm2, xmm1
vfmsub132sd xmm0, xmm4, xmm2
vaddsd xmm3, xmm3, xmm0
vaddsd xmm1, xmm3, xmm1
vaddsd xmm0, xmm1, xmm4
vsubsd xmm4, xmm0, xmm4
vsubsd xmm1, xmm1, xmm4
mul128
は、3 つのスカラー乗算と 2 つのスカラー加算/減算を実行しますdf64_mul
が、3 つの SIMD 乗算、1 つの SIMD FMA、および 5 つの SIMD 加算/減算を実行します。私はこれらの方法をプロファイリングしていませんが、AVX レジスターごとに 4-double を使用した場合df64_mul
よりも優れたパフォーマンスを発揮できることは、私には不合理ではないように思えます (およびに変更)。mul128
sd
pd
xmm
ymm
問題は整数領域に戻ることだと言いたくなる。しかし、なぜこれが必要なのですか?浮動小数点ドメインですべてを行うことができます。いくつかの例を見てみましょう。float
よりも単体テストの方が簡単だと思いますdouble
。
doublefloat two_prod(float a, float b) {
float p = a*b;
float e = fma(a, b, -p);
return (doublefloat){p, e};
}
//3202129*4807935=15395628093615
x = two_prod(3202129,4807935)
int64_t hi = p, lo = e, s = hi+lo
//p = 1.53956280e+13, e = 1.02575000e+05
//hi = 15395627991040, lo = 102575, s = 15395628093615
//1450779*1501672=2178594202488
y = two_prod(1450779, 1501672)
int64_t hi = p, lo = e, s = hi+lo
//p = 2.17859424e+12, e = -4.00720000e+04
//hi = 2178594242560 lo = -40072, s = 2178594202488
そのため、最終的に異なる範囲になり、2 番目のケースでは誤差 ( e
) は負でもありますが、合計は正しいままです。2 つの doublefloat 値x
を加算しy
て (double-double 加算の方法がわかったら、最後のコードを参照してください)、 get を取得することもでき15395628093615+2178594202488
ます。結果を正規化する必要はありません。
しかし、足し算は double-double 演算の主な問題を引き起こします。つまり、加算/減算は遅いです。たとえば、128b+128b -> 128bは少なくとも 11 回の浮動小数点加算add
が必要ですが、整数の場合は 2 回 (と)しか必要ありませんadc
。
したがって、アルゴリズムが乗算に重点を置いていて、加算が少ない場合は、double-double を使用して複数ワードの整数演算を行うことが勝つ可能性があります。
補足として、C 言語は、整数が浮動小数点ハードウェアを介して完全に実装される実装を可能にするのに十分柔軟です。 int
24 ビット (単一の浮動小数点から) の場合long
もあれば、54 ビットの場合もあります。(double 浮動小数点から)、long long
106 ビット (double-double から) になる可能性があります。C では 2 の補数さえ必要としないため、浮動小数点で通常行われるように、整数は負の数に符号付きマグニチュードを使用できます。
sqrt
これは、誰かがそれを試してみたい場合に備えて、double-double 乗算と加算を使用した作業中の C コードです (私は除算やその他の演算を実装していませんが、これを行う方法を示す論文があります)。これが整数用に最適化できるかどうかを見るのは興味深いでしょう。
//if compiling with -mfma you must also use -ffp-contract=off
//float-float is easier to debug. If you want double-double replace
//all float words with double and fmaf with fma
#include <stdio.h>
#include <math.h>
#include <inttypes.h>
#include <x86intrin.h>
#include <stdlib.h>
//#include <float.h>
typedef struct {
float hi;
float lo;
} doublefloat;
typedef union {
float f;
int i;
struct {
unsigned mantisa : 23;
unsigned exponent: 8;
unsigned sign: 1;
};
} float_cast;
void print_float(float_cast a) {
printf("%.8e, 0x%x, mantisa 0x%x, exponent 0x%x, expondent-127 %d, sign %u\n", a.f, a.i, a.mantisa, a.exponent, a.exponent-127, a.sign);
}
void print_doublefloat(doublefloat a) {
float_cast hi = {a.hi};
float_cast lo = {a.lo};
printf("hi: "); print_float(hi);
printf("lo: "); print_float(lo);
}
doublefloat quick_two_sum(float a, float b) {
float s = a + b;
float e = b - (s - a);
return (doublefloat){s, e};
// 3 add
}
doublefloat two_sum(float a, float b) {
float s = a + b;
float v = s - a;
float e = (a - (s - v)) + (b - v);
return (doublefloat){s, e};
// 6 add
}
doublefloat df64_add(doublefloat a, doublefloat b) {
doublefloat s, t;
s = two_sum(a.hi, b.hi);
t = two_sum(a.lo, b.lo);
s.lo += t.hi;
s = quick_two_sum(s.hi, s.lo);
s.lo += t.lo;
s = quick_two_sum(s.hi, s.lo);
return s;
// 2*two_sum, 2 add, 2*quick_two_sum = 2*6 + 2 + 2*3 = 20 add
}
doublefloat split(float a) {
//#define SPLITTER (1<<27) + 1
#define SPLITTER (1<<12) + 1
float t = (SPLITTER)*a;
float hi = t - (t - a);
float lo = a - hi;
return (doublefloat){hi, lo};
// 1 mul, 3 add
}
doublefloat split_sse(float a) {
__m128 k = _mm_set1_ps(4097.0f);
__m128 a4 = _mm_set1_ps(a);
__m128 t = _mm_mul_ps(k,a4);
__m128 hi4 = _mm_sub_ps(t,_mm_sub_ps(t, a4));
__m128 lo4 = _mm_sub_ps(a4, hi4);
float tmp[4];
_mm_storeu_ps(tmp, hi4);
float hi = tmp[0];
_mm_storeu_ps(tmp, lo4);
float lo = tmp[0];
return (doublefloat){hi,lo};
}
float mult_sub(float a, float b, float c) {
doublefloat as = split(a), bs = split(b);
//print_doublefloat(as);
//print_doublefloat(bs);
return ((as.hi*bs.hi - c) + as.hi*bs.lo + as.lo*bs.hi) + as.lo*bs.lo;
// 4 mul, 4 add, 2 split = 6 mul, 10 add
}
doublefloat two_prod(float a, float b) {
float p = a*b;
float e = mult_sub(a, b, p);
return (doublefloat){p, e};
// 1 mul, one mult_sub
// 7 mul, 10 add
}
float mult_sub2(float a, float b, float c) {
doublefloat as = split(a);
return ((as.hi*as.hi -c ) + 2*as.hi*as.lo) + as.lo*as.lo;
}
doublefloat two_sqr(float a) {
float p = a*a;
float e = mult_sub2(a, a, p);
return (doublefloat){p, e};
}
doublefloat df64_mul(doublefloat a, doublefloat b) {
doublefloat p = two_prod(a.hi, b.hi);
p.lo += a.hi*b.lo;
p.lo += a.lo*b.hi;
return quick_two_sum(p.hi, p.lo);
//two_prod, 2 add, 2mul, 1 quick_two_sum = 9 mul, 15 add
//or 1 mul, 1 fma, 2add 2mul, 1 quick_two_sum = 3 mul, 1 fma, 5 add
}
doublefloat df64_sqr(doublefloat a) {
doublefloat p = two_sqr(a.hi);
p.lo += 2*a.hi*a.lo;
return quick_two_sum(p.hi, p.lo);
}
int float2int(float a) {
int M = 0xc00000; //1100 0000 0000 0000 0000 0000
a += M;
float_cast x;
x.f = a;
return x.i - 0x4b400000;
}
doublefloat add22(doublefloat a, doublefloat b) {
float r = a.hi + b.hi;
float s = fabsf(a.hi) > fabsf(b.hi) ?
(((a.hi - r) + b.hi) + b.lo ) + a.lo :
(((b.hi - r) + a.hi) + a.lo ) + b.lo;
return two_sum(r, s);
//11 add
}
int main(void) {
//print_float((float_cast){1.0f});
//print_float((float_cast){-2.0f});
//print_float((float_cast){0.0f});
//print_float((float_cast){3.14159f});
//print_float((float_cast){1.5f});
//print_float((float_cast){3.0f});
//print_float((float_cast){7.0f});
//print_float((float_cast){15.0f});
//print_float((float_cast){31.0f});
//uint64_t t = 0xffffff;
//print_float((float_cast){1.0f*t});
//printf("%" PRId64 " %" PRIx64 "\n", t*t,t*t);
/*
float_cast t1;
t1.mantisa = 0x7fffff;
t1.exponent = 0xfe;
t1.sign = 0;
print_float(t1);
*/
//doublefloat z = two_prod(1.0f*t, 1.0f*t);
//print_doublefloat(z);
//double z2 = (double)z.hi + (double)z.lo;
//printf("%.16e\n", z2);
doublefloat s = {0};
int64_t si = 0;
for(int i=0; i<100000; i++) {
int ai = rand()%0x800, bi = rand()%0x800000;
float a = ai, b = bi;
doublefloat z = two_prod(a,b);
int64_t zi = (int64_t)ai*bi;
//print_doublefloat(z);
//s = df64_add(s,z);
s = add22(s,z);
si += zi;
print_doublefloat(z);
printf("%d %d ", ai,bi);
int64_t h = z.hi;
int64_t l = z.lo;
int64_t t = h+l;
//if(t != zi) printf("%" PRId64 " %" PRId64 "\n", h, l);
printf("%" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 "\n", zi, h, l, h+l);
h = s.hi;
l = s.lo;
t = h + l;
//if(si != t) printf("%" PRId64 " %" PRId64 "\n", h, l);
if(si > (1LL<<48)) {
printf("overflow after %d iterations\n", i); break;
}
}
print_doublefloat(s);
printf("%" PRId64 "\n", si);
int64_t x = s.hi;
int64_t y = s.lo;
int64_t z = x+y;
//int hi = float2int(s.hi);
printf("%" PRId64 " %" PRId64 " %" PRId64 "\n", z,x,y);
}