5

次の操作を倍精度で行う必要があります。

ここに画像の説明を入力

数値は、値がメモリに格納される方法を表します。これをAVXで実装したい。最初にの列を 8 要素までパディングして[QK]から、行列ベクトルの乗算を実行し、その後[x][QK]内積を実行するのが最善でしょうか?

編集: わかりましたので、以下に示すように、ベクトルが埋め込まれたFLOAT 32 ビットバージョンを実装することにしました。

// Perform matrix vector multiplication of QK*x            
// Load first four columns QK into 4 ymm registers    
ymm0 = _mm256_load_ps((float *)(QK));     
ymm1 = _mm256_load_ps((float *)(QK+8));       
ymm2 = _mm256_load_ps((float *)(QK+16));    
ymm3 = _mm256_load_ps((float *)(QK+24));  

// Load first four values of x
ymm4 = _mm256_broadcast_ss((float *)(x));
ymm5 = _mm256_broadcast_ss((float *)(x+1));
ymm6 = _mm256_broadcast_ss((float *)(x+2));
ymm7 = _mm256_broadcast_ss((float *)(x+3));

// Multiply in place - frees up ymm4,ymm5,ymm6,ymm7
ymm0 = _mm256_mul_ps(ymm0, ymm4);
ymm1 = _mm256_mul_ps(ymm1, ymm5);
ymm2 = _mm256_mul_ps(ymm2, ymm6);
ymm3 = _mm256_mul_ps(ymm3, ymm7);

// Add together, this frees up ymm1,ymm2,and ymm3
ymm0 = _mm256_add_ps(ymm0, ymm1);
ymm2 = _mm256_add_ps(ymm2, ymm3);
ymm0 = _mm256_add_ps(ymm0, ymm2);

// Load next last columns of QK
ymm1 = _mm256_load_ps((float *)(QK+32));     
ymm2 = _mm256_load_ps((float *)(QK+40)); 

// Load last two values of x
ymm6 = _mm256_broadcast_ss((float *)(x+4));
ymm7 = _mm256_broadcast_ss((float *)(x+5));

// Multiply in place
ymm1 = _mm256_mul_ps(ymm1, ymm6);
ymm2 = _mm256_mul_ps(ymm2, ymm7);

// Add together, this frees up every register except for ymm0
ymm0 = _mm256_add_ps(ymm0, ymm1);
ymm0 = _mm256_add_ps(ymm0, ymm2);  

// Answer stored in ymm0 and ymm1   
// Calculate dot product of y*(QK*x)
// Load x
ymm1 = _mm256_load_ps((float *)(y));   

// Do dotproduct by using horizontal multiply followed by horizontal add
// Multiply in place
ymm0 = _mm256_mul_ps(ymm0, ymm1);  

// Do horizontal sum
__m128 xmm1 = _mm256_extractf128_ps(ymm0, 1);
__m128 xmm2 = _mm256_extractf128_ps(ymm0, 0);
xmm2 = _mm_add_ps(xmm1, xmm2);
xmm1 = _mm_movehl_ps(xmm1, xmm2);
xmm2 = _mm_add_ps(xmm1, xmm2);
xmm1 = _mm_shuffle_ps(xmm2, xmm2, 1);
xmm2 = _mm_add_ss(xmm1, xmm2);
ans[0] = _mm_cvtss_f32(xmm2);  

現状では、以下よりも約 3 倍高速に実行されます。

ans[0] = (QK[0]*x[0]+QK[8]*x[1]+QK[16]*x[2]+QK[24]*x[3]+QK[32]*x[4]+QK[40]*x[5])*y[0]+
         (QK[1]*x[0]+QK[9]*x[1]+QK[17]*x[2]+QK[25]*x[3]+QK[33]*x[4]+QK[41]*x[5])*y[1]+
         (QK[2]*x[0]+QK[10]*x[1]+QK[18]*x[2]+QK[26]*x[3]+QK[34]*x[4]+QK[42]*x[5])*y[2]+
         (QK[3]*x[0]+QK[11]*x[1]+QK[19]*x[2]+QK[27]*x[3]+QK[35]*x[4]+QK[43]*x[5])*y[3]+
         (QK[4]*x[0]+QK[12]*x[1]+QK[20]*x[2]+QK[28]*x[3]+QK[36]*x[4]+QK[44]*x[5])*y[4]+
         (QK[5]*x[0]+QK[13]*x[1]+QK[21]*x[2]+QK[29]*x[3]+QK[37]*x[4]+QK[45]*x[5])*y[5];

5 億回の反復では、標準の C バージョンは約 9 秒で実行され、単精度の AVX バージョンは約 3.5 秒で実行されます。最後に水平合計をコメントアウトすると、約 0.5 秒で実行されます。水平合計は本当にパフォーマンスを殺します...

4

1 に答える 1