私は素晴らしいSSE命令を調べており、それらを使用する関数と「標準」コード(つまり非SSE)を使用する同じ関数との違いを測定するためにいくつかの簡単なコードを使い始めました。コードを (-O3 フラグを指定して) コンパイルしたとき、関数の SSE バージョンを使用するバージョンは、SSE 命令を使用していないプログラムのバージョンよりも実際には (ごくわずかに) "遅い" ことに気付きました。私の推測は次のとおりです。
- コンパイラはコードの最適化において優れた仕事をします
- SSE 関数はより高速に実行できますが、float をレジスターにロードするコストがかかるため、SSE 命令を使用するメリットが相殺されます。
- testSSE() 関数は、SSE を使用するプログラムのバージョンと使用しないプログラムの違いを実際に示すほど複雑ではありません。
これについて彼/彼女の考えを誰か教えてもらえますか? どうもありがとう -
EDIT:コードを修正しました(2つのコードリストの下を参照)。より短い修正バージョンを使用しても、SSE バージョンでは 2 インチ 48、非 SSE バージョンでは 1 インチ 36 であり、この場合、コンパイラは私よりも優れた仕事をするという事実を確認しています!
編集:バグのある古いコード(以下の修正版を参照)
// compiled with c++ tmp.cpp -msse4 -o testSSE -O3
#include <iostream>
#include <cmath>
#include <stdio.h>
#include <pmmintrin.h>
inline void testSSE(float *node1, float *node2, float *node3, float *node4, float *result)
{
__m128 tmp0, tmp1, tmp2, tmp3;
__m128 l, r;
l = _mm_load_ps(node1); //_mm_store_ps(result, l); fprintf(stderr, "1 %f %f %f %f\n", result[0], result[1], result[2], result[3]);
r = _mm_load_ps(node1 + 4); //_mm_store_ps(result, r); fprintf(stderr, "2 %f %f %f %f\n", result[0], result[1], result[2], result[3]);
tmp0 = _mm_hadd_ps(l, r); //_mm_store_ps(result, tmp0); fprintf(stderr, "3 %f %f %f %f\n", result[0], result[1], result[2], result[3]);
l = _mm_load_ps(node2); //_mm_store_ps(result, l); fprintf(stderr, "4 %f %f %f %f\n", result[0], result[1], result[2], result[3]);
r = _mm_load_ps(node2 + 4); //_mm_store_ps(result, r); fprintf(stderr, "5 %f %f %f %f\n", result[0], result[1], result[2], result[3]);
tmp1 = _mm_hadd_ps(l, r); //_mm_store_ps(result, tmp0); fprintf(stderr, "6 %f %f %f %f\n", result[0], result[1], result[2], result[3]);
l = _mm_load_ps(node3);
r = _mm_load_ps(node3 + 4);
tmp2 = _mm_hadd_ps(l, r);
l = _mm_load_ps(node4); //_mm_store_ps(result, l); fprintf(stderr, "10 %f %f %f %f\n", result[0], result[1], result[2], result[3]);
r = _mm_load_ps(node4 + 4); //_mm_store_ps(result, r); fprintf(stderr, "11 %f %f %f %f\n", result[0], result[1], result[2], result[3]);
tmp3 = _mm_hadd_ps(l, r); //_mm_store_ps(result, tmp0); fprintf(stderr, "12 %f %f %f %f\n", result[0], result[1], result[2], result[3]);
l = _mm_hadd_ps(tmp0, tmp1);
r = _mm_hadd_ps(tmp2, tmp3);
__m128 pDest = _mm_hadd_ps(l, r);
_mm_store_ps(result, pDest); // fprintf(stderr, "FINAL %f %f %f %f\n", result[0], result[1], result[2], result[3]);
}
void test(float *node1, float *node2, float *node3, float *node4, float *result)
{
float tmp0[4], tmp1[4], tmp2[4], tmp3[4];
tmp0[0] = node1[0] + node1[1];
tmp0[1] = node1[2] + node1[3];
tmp0[2] = node1[4] + node1[5];
tmp0[3] = node1[6] + node1[7];
tmp1[0] = node2[0] + node2[1];
tmp1[1] = node2[2] + node2[3];
tmp1[2] = node2[4] + node2[5];
tmp1[3] = node2[6] + node2[7];
tmp2[0] = node3[0] + node3[1];
tmp2[1] = node3[2] + node3[3];
tmp2[2] = node3[4] + node3[5];
tmp2[3] = node3[6] + node3[7];
tmp3[0] = node4[0] + node4[1];
tmp3[1] = node4[2] + node4[3];
tmp3[2] = node4[4] + node4[5];
tmp3[3] = node4[6] + node4[7];
float l[4], r[4];
l[0] = tmp0[0] + tmp0[1];
l[1] = tmp0[2] + tmp0[3];
l[2] = tmp1[0] + tmp1[1];
l[3] = tmp1[2] + tmp1[3];
r[0] = tmp2[0] + tmp2[1];
r[1] = tmp2[2] + tmp2[3];
r[2] = tmp3[0] + tmp3[1];
r[3] = tmp3[2] + tmp3[3];
result[0] = l[0] + l[1];
result[1] = l[2] + l[3];
result[2] = r[0] + r[1];
result[3] = r[2] + r[3];
}
int main(int argc, char **argv)
{
int nnodes = 4;
double t = clock();
for (int k = 0; k < 10000000; ++k) {
float *data = new float [nnodes * 8];
for (int i = 0; i < nnodes * 8; ++i) { data[i] = (i / 8) + 1; /* fprintf(stderr, "data %02d %f\n", i, data[i]); */ }
float result[4];
int off = sizeof(float) * 8;
testSSE(data, data + 8, data + 16, data + 24, result);
delete [] data;
}
fprintf(stderr, "%02f (sec)\n", (clock() - t) / (float)CLOCKS_PER_SEC);
return 0;
}
編集: 新しい (修正された) コード
#include <iostream>
#include <cmath>
#include <stdio.h>
#include <pmmintrin.h>
inline void testSSE(float *node1, float *node2, float *node3, float *node4, float *result)
{
__m128 tmp0, tmp1, tmp2, tmp3;
tmp0 = _mm_load_ps(node1);
tmp1 = _mm_load_ps(node2);
tmp2 = _mm_hadd_ps(tmp0, tmp1);
tmp0 = _mm_load_ps(node3);
tmp1 = _mm_load_ps(node4);
tmp3 = _mm_hadd_ps(tmp0, tmp1);
tmp0 = _mm_hadd_ps(tmp2, tmp3);
_mm_store_ps(result, tmp0);
}
void test(float *node1, float *node2, float *node3, float *node4, float *result)
{
float tmp0[4], tmp1[4], tmp2[4], tmp3[4];
tmp0[0] = node1[0] + node1[1];
tmp0[1] = node1[2] + node1[3];
tmp0[2] = node1[4] + node1[5];
tmp0[3] = node1[6] + node1[7];
tmp1[0] = node2[0] + node2[1];
tmp1[1] = node2[2] + node2[3];
tmp1[2] = node2[4] + node2[5];
tmp1[3] = node2[6] + node2[7];
tmp2[0] = node3[0] + node3[1];
tmp2[1] = node3[2] + node3[3];
tmp2[2] = node3[4] + node3[5];
tmp2[3] = node3[6] + node3[7];
tmp3[0] = node4[0] + node4[1];
tmp3[1] = node4[2] + node4[3];
tmp3[2] = node4[4] + node4[5];
tmp3[3] = node4[6] + node4[7];
float l[4], r[4];
l[0] = tmp0[0] + tmp0[1];
l[1] = tmp0[2] + tmp0[3];
l[2] = tmp1[0] + tmp1[1];
l[3] = tmp1[2] + tmp1[3];
r[0] = tmp2[0] + tmp2[1];
r[1] = tmp2[2] + tmp2[3];
r[2] = tmp3[0] + tmp3[1];
r[3] = tmp3[2] + tmp3[3];
result[0] = l[0] + l[1];
result[1] = l[2] + l[3];
result[2] = r[0] + r[1];
result[3] = r[2] + r[3];
}
int main(int argc, char **argv)
{
int nnodes = 4;
float *data = new float [nnodes * 8];
for (int i = 0; i < nnodes * 8; ++i) { data[i] = (i / 8) + 1; /* fprintf(stderr, "data %02d %f\n", i, data[i]); */ }
double t = clock();
for (int k = 0; k < 1e+9; ++k) {
float result[4];
int off = sizeof(float) * 8;
test(data, data + 8, data + 16, data + 24, result);
}
fprintf(stderr, "%02f (sec)\n", (clock() - t) / (float)CLOCKS_PER_SEC);
delete [] data;
return 0;
}