OpenMPを使用して並列化した素数を計算するコードがあります。
#pragma omp parallel for private(i,j) reduction(+:pcount) schedule(dynamic)
for (i = sqrt_limit+1; i < limit; i++)
{
check = 1;
for (j = 2; j <= sqrt_limit; j++)
{
if ( !(j&1) && (i&(j-1)) == 0 )
{
check = 0;
break;
}
if ( j&1 && i%j == 0 )
{
check = 0;
break;
}
}
if (check)
pcount++;
}
私はそれをGPUに移植しようとしていますが、上記のOpenMPの例で行ったようにカウントを減らしたいと思います。以下は私のコードですが、間違った結果を出すことは別として、遅いです:
__global__ void sieve ( int *flags, int *o_flags, long int sqrootN, long int N)
{
long int gid = blockIdx.x*blockDim.x+threadIdx.x, tid = threadIdx.x, j;
__shared__ int s_flags[NTHREADS];
if (gid > sqrootN && gid < N)
s_flags[tid] = flags[gid];
else
return;
__syncthreads();
s_flags[tid] = 1;
for (j = 2; j <= sqrootN; j++)
{
if ( gid%j == 0 )
{
s_flags[tid] = 0;
break;
}
}
//reduce
for(unsigned int s=1; s < blockDim.x; s*=2)
{
if( tid % (2*s) == 0 )
{
s_flags[tid] += s_flags[tid + s];
}
__syncthreads();
}
//write results of this block to the global memory
if (tid == 0)
o_flags[blockIdx.x] = s_flags[0];
}
まず、このカーネルを高速化するにはどうすればよいですか。ボトルネックはforループだと思いますが、どのように置き換えるかはわかりません。そして次に、私のカウントは正しくありません。'%'演算子を変更しましたが、いくつかの利点に気づきました。
配列ではflags
、2からsqroot(N)までの素数をマークしました。このカーネルでは、sqroot(N)からNまでの素数を計算していますが、{sqroot(N)、N}の各数値を確認する必要があります。 {2、sqroot(N)}の素数で割り切れる。o_flags
配列には、各ブロックの部分和が格納されます。
編集:提案に従って、コードを変更しました(syncthreadsに関するコメントについて理解が深まりました)。私はflags配列は必要なく、私の場合はグローバルインデックスだけが機能することに気づきました。この時点で私が懸念しているのは、forループに起因する可能性のあるコードの遅さ(正確さ以上)です。また、特定のデータサイズ(100000)の後、カーネルは後続のデータサイズに対して誤った結果を生成していました。データサイズが100000未満の場合でも、GPU削減の結果は正しくありません(NVidiaフォーラムのメンバーは、データサイズが2の累乗ではないことが原因である可能性があると指摘しました)。したがって、まだ3つの(関連している可能性がある)質問があります-
どうすればこのカーネルを高速化できますか?各tidをループする必要がある場合は、共有メモリを使用することをお勧めしますか?
特定のデータサイズに対してのみ正しい結果が生成されるのはなぜですか?
どうすれば削減を変更できますか?
__global__ void sieve ( int *o_flags, long int sqrootN, long int N ) { unsigned int gid = blockIdx.x*blockDim.x+threadIdx.x, tid = threadIdx.x; volatile __shared__ int s_flags[NTHREADS]; s_flags[tid] = 1; for (unsigned int j=2; j<=sqrootN; j++) { if ( gid % j == 0 ) s_flags[tid] = 0; } __syncthreads(); //reduce reduce(s_flags, tid, o_flags); }