__global__
static void find_groups(int *locs, int *sorted, int num)
{
int bid = blockIdx.y * gridDim.x + blockIdx.x;
int tid = bid * blockDim.x + threadIdx.x;
if (tid < num) {
int curr = sorted[tid];
if (tid == 0 || curr != sorted[tid - 1]) locs[curr] = tid;
}
}
int main()
{
int h_P0[N] = {0, 0, 1, 2, 1, 1, 0, 2, 0, 0};
int h_P1[N] = {0, 1, 1, 2, 1, 2, 0, 2, 1, 0};
thrust::host_vector<int> th_P0(h_P0, h_P0 + N);
thrust::host_vector<int> th_P1(h_P1, h_P1 + N);
thrust::device_vector<int> td_P0 = th_P0;
thrust::device_vector<int> td_P1 = th_P1;
thrust::device_vector<int> td_S0(N);
thrust::device_vector<int> td_S1(N);
thrust::sequence(td_S0.begin(), td_S0.end());
thrust::sequence(td_S1.begin(), td_S1.end());
thrust::stable_sort_by_key(td_P0.begin(), td_P0.end(), td_S0.begin());
thrust::stable_sort_by_key(td_P1.begin(), td_P1.end(), td_S1.begin());
thrust::device_vector<int> td_l0(3, -1); // Changed here
thrust::device_vector<int> td_l1(3, -1); // And here
int threads = 256;
int blocks_x = (N + 256) / 256;
int blocks_y = (blocks_x + 65535) / 65535;
dim3 blocks(blocks_x, blocks_y);
int *d_l0 = thrust::raw_pointer_cast(td_l0.data());
int *d_l1 = thrust::raw_pointer_cast(td_l1.data());
int *d_P0 = thrust::raw_pointer_cast(td_P0.data());
int *d_P1 = thrust::raw_pointer_cast(td_P1.data());
find_groups<<<blocks, threads>>>(d_l0, d_P0, N);
find_groups<<<blocks, threads>>>(d_l1, d_P1, N);
return 0;
}
アルゴリズムは簡単な手順で説明できます。
- P0 をキーで並べ替える
- P1 をキーで並べ替える
- キーには 2 番目のテーブルが含まれるようになりました
P0 と P1 を find_groups カーネルに渡します。グループが 3 つしかないことがわかっているので、グループ番号が n-1 から n に変わるスレッドは、グローバル メモリに書き込みます。スレッド 0 は常に 0 を書き込みます。これは、すべてのベクトルの最初のグループの始まりであるためです。
それらを印刷してみました。これは私が得るものです。すべてのインデックスが 0 であることに注意してください。
Sorted
t t+1
0 0
1 6
6 9
8 1
9 2
2 4
4 8
5 3
3 5
7 7
Ranges
Groups t t + 1
S [0-4] [0-2]
I [5-7] [3-6]
R [8-9] [7-9]
完全なコード (印刷用のコードを含む) にアクセスする必要がある場合は、このリンクにアクセスしてください。
これで十分かどうかはわかりません。しかし、ここで何かを見逃した場合はお知らせください。
編集
クラスが欠落している場所を処理するようにコードを変更しました。関連するベクトルを -1 で初期化します。したがって、-1 の開始点に遭遇した場合、そのクラスはその反復で表示されないことを意味します。