次の操作を倍精度で行う必要があります。
数値は、値がメモリに格納される方法を表します。これをAVXで実装したい。最初にの列を 8 要素までパディングして[QK]
から、行列ベクトルの乗算を実行し、その後[x]
に[QK]
内積を実行するのが最善でしょうか?
編集: わかりましたので、以下に示すように、ベクトルが埋め込まれたFLOAT 32 ビットバージョンを実装することにしました。
// Perform matrix vector multiplication of QK*x
// Load first four columns QK into 4 ymm registers
ymm0 = _mm256_load_ps((float *)(QK));
ymm1 = _mm256_load_ps((float *)(QK+8));
ymm2 = _mm256_load_ps((float *)(QK+16));
ymm3 = _mm256_load_ps((float *)(QK+24));
// Load first four values of x
ymm4 = _mm256_broadcast_ss((float *)(x));
ymm5 = _mm256_broadcast_ss((float *)(x+1));
ymm6 = _mm256_broadcast_ss((float *)(x+2));
ymm7 = _mm256_broadcast_ss((float *)(x+3));
// Multiply in place - frees up ymm4,ymm5,ymm6,ymm7
ymm0 = _mm256_mul_ps(ymm0, ymm4);
ymm1 = _mm256_mul_ps(ymm1, ymm5);
ymm2 = _mm256_mul_ps(ymm2, ymm6);
ymm3 = _mm256_mul_ps(ymm3, ymm7);
// Add together, this frees up ymm1,ymm2,and ymm3
ymm0 = _mm256_add_ps(ymm0, ymm1);
ymm2 = _mm256_add_ps(ymm2, ymm3);
ymm0 = _mm256_add_ps(ymm0, ymm2);
// Load next last columns of QK
ymm1 = _mm256_load_ps((float *)(QK+32));
ymm2 = _mm256_load_ps((float *)(QK+40));
// Load last two values of x
ymm6 = _mm256_broadcast_ss((float *)(x+4));
ymm7 = _mm256_broadcast_ss((float *)(x+5));
// Multiply in place
ymm1 = _mm256_mul_ps(ymm1, ymm6);
ymm2 = _mm256_mul_ps(ymm2, ymm7);
// Add together, this frees up every register except for ymm0
ymm0 = _mm256_add_ps(ymm0, ymm1);
ymm0 = _mm256_add_ps(ymm0, ymm2);
// Answer stored in ymm0 and ymm1
// Calculate dot product of y*(QK*x)
// Load x
ymm1 = _mm256_load_ps((float *)(y));
// Do dotproduct by using horizontal multiply followed by horizontal add
// Multiply in place
ymm0 = _mm256_mul_ps(ymm0, ymm1);
// Do horizontal sum
__m128 xmm1 = _mm256_extractf128_ps(ymm0, 1);
__m128 xmm2 = _mm256_extractf128_ps(ymm0, 0);
xmm2 = _mm_add_ps(xmm1, xmm2);
xmm1 = _mm_movehl_ps(xmm1, xmm2);
xmm2 = _mm_add_ps(xmm1, xmm2);
xmm1 = _mm_shuffle_ps(xmm2, xmm2, 1);
xmm2 = _mm_add_ss(xmm1, xmm2);
ans[0] = _mm_cvtss_f32(xmm2);
現状では、以下よりも約 3 倍高速に実行されます。
ans[0] = (QK[0]*x[0]+QK[8]*x[1]+QK[16]*x[2]+QK[24]*x[3]+QK[32]*x[4]+QK[40]*x[5])*y[0]+
(QK[1]*x[0]+QK[9]*x[1]+QK[17]*x[2]+QK[25]*x[3]+QK[33]*x[4]+QK[41]*x[5])*y[1]+
(QK[2]*x[0]+QK[10]*x[1]+QK[18]*x[2]+QK[26]*x[3]+QK[34]*x[4]+QK[42]*x[5])*y[2]+
(QK[3]*x[0]+QK[11]*x[1]+QK[19]*x[2]+QK[27]*x[3]+QK[35]*x[4]+QK[43]*x[5])*y[3]+
(QK[4]*x[0]+QK[12]*x[1]+QK[20]*x[2]+QK[28]*x[3]+QK[36]*x[4]+QK[44]*x[5])*y[4]+
(QK[5]*x[0]+QK[13]*x[1]+QK[21]*x[2]+QK[29]*x[3]+QK[37]*x[4]+QK[45]*x[5])*y[5];
5 億回の反復では、標準の C バージョンは約 9 秒で実行され、単精度の AVX バージョンは約 3.5 秒で実行されます。最後に水平合計をコメントアウトすると、約 0.5 秒で実行されます。水平合計は本当にパフォーマンスを殺します...