2

カーネル テンプレートで CUDA カーネル関数ポインターを受け入れる CUDA ランタイム API 関数を使用したいと考えています。

テンプレートなしで次のことができます。

__global__ myKernel()
{
  ...
}

void myFunc(const char* kernel_ptr)
{
  ...
  // use API functions like
  cudaFuncGetAttributes(&attrib, kernel_ptr);
  ...
}

int main()
{
  myFunc(myKernel);
}

ただし、カーネルがテンプレートの場合、上記は機能しません。

もう一つの例:

#include "cuda_runtime.h"
#include "device_launch_parameters.h"
#include <stdio.h>

template<typename T>
__global__ void addKernel(T *c, const T *a, const T *b)
{
    int i = threadIdx.x;
    c[i] = a[i] + b[i];
}

int main()
{
    cudaFuncAttributes attrib;
    cudaError_t err;

    //OK:
    err = cudaFuncGetAttributes(&attrib, addKernel<float>); // works fine
    printf("result: %s, reg1: %d\n", cudaGetErrorString(err), attrib.numRegs);

    //NOT OK:
    //try to get function ptr to pass as an argument:
    const char* ptr = addKernel<float>; // compile error
    err = cudaFuncGetAttributes(&attrib, ptr);
    printf("result: %s, reg2: %d\n", cudaGetErrorString(err), attrib.numRegs);
}

上記はコンパイル エラーになります。

エラー:関数テンプレート「addKernel」のインスタンスが必要なタイプと一致しません

編集: これまでに見つけた唯一の回避策は、myFunc (最初のコード例を参照) 内のものをマクロに入れることです。これは醜いですが、ポインター引数の受け渡しを必要とせず、正常に動作します:

#define MY_FUNC(kernel) \
  { \
     ...\
     cudaFuncGetAttributes( &attrib, kernel ); \
     ...\
  }

使用法:

MY_FUNC( myKernel<float> )
4

3 に答える 3

2

「別の例:」に含まれるコードを参照する

これを変える:

const char* ptr = addKernel<float>; // compile error

これに:

void (*ptr)(float *, const float *, const float *) = addKernel<float>;

そして、それが正しくコンパイルされて実行されると信じています。

あなたがやろうとしていることの全体的な範囲でそれが役立つかどうかはわかりません。

コメントの質問に答える編集:

関数からポインターを「抽出」したら、それを別の型にキャストできます。それを試してみてください。たとえば、次のコードも機能します。

void (*ptr)(float *, const float *, const float *) = addKernel<float>;
const char *ptr1 = (char *)ptr;
err = cudaFuncGetAttributes(&attrib, ptr1);

したがって、質問に答えるために、関数ポインターを取得したら、必要に応じて関数ポインターをキャストできます。const char*

ちなみに、回答として投稿したコードは、gcc 4.1.2 および gcc 4.4.6 でコンパイル エラーをスローします。

$ nvcc -arch=sm_20 -O3 -o t201 t201.cu
t201.cu: In function âint main()â:
t201.cu:25: error: address of overloaded function with no contextual type information
t201.cu:29: error: address of overloaded function with no contextual type information
$

&また、これらの 2 行でを削除すると、エラーが発生します。

$ nvcc -arch=sm_20 -O3 -o t201 t201.cu
t201.cu: In function âint main()â:
t201.cu:25: error: insufficient contextual information to determine type
t201.cu:29: error: insufficient contextual information to determine type
$

そのため、ポイント A からポイント B に到達するために必要な手順に関して、これの一部はコンパイラに依存する可能性があります。

于 2013-07-12T19:05:18.743 に答える