2

ここにある F# reduce 関数の C# バージョン (C# スタイル) をコーディングしようとしています。

https://github.com/quantalea/AleaGPUTutorial/tree/master/src/fsharp/examples/generic_reduce

私の質問により具体的には、この関数を例にとります:

let multiReduce (opExpr:Expr<'T -> 'T -> 'T>) numWarps =
    let warpStride = WARP_SIZE + WARP_SIZE / 2 + 1
    let sharedSize = numwarps * warpStride

    <@ fun tid (x:'T) ->
        // stuff
    @>

私は主に F# を使用しているため、C# でこのような関数をコーディングする方法がよくわかりません。C# 版の場合、multiReduce 関数はクラス メンバーになります。したがって、F# コードをより直接的に変換したい場合は、MultiReduce メンバーから Func を返します。

もう 1 つのオプションは、multiReduce 関数を "フラット化" して、C# メンバー バージョンに 2 つの追加パラメーターを持たせることです。そう...

public T MultiReduce(Func<T,T,T> op, int numWarps, int tid, T x)
{
    // stuff
}

しかし、F# バージョンの引用符で囲まれた式はデバイス関数であるため、これがすべての場合に AleaGPU コーディングで機能するとは思いません。関数の実際の呼び出しから特定の変数の割り当てを分離できるようにするには、ネストされた関数構造が必要です。

もう 1 つの方法は、MultiReduce クラスを作成し、opExpr と numWarps をフィールドとして持ち、引用符内の関数をクラス メンバーにすることです。

では、これらのような高階関数は一般的に AleaGPU-C# でどのように実装されるのでしょうか? Func<..> をどこにでも返すのは良くないと思います。これは C# コーディングではあまり行われていないからです。AleaGPU は、これで問題ない特殊なケースですか?

基本的な AleaGPU C# の実装は次のようになります。

internal class TransformModule<T> : ILGPUModule
{
    private readonly Func<T, T> op;

    public TransformModule(GPUModuleTarget target, Func<T, T> opFunc)
        : base(target)
    {
        op = opFunc;
    }

    [Kernel]
    public void Kernel(int n, deviceptr<T> x, deviceptr<T> y)
    {
        var start = blockIdx.x * blockDim.x + threadIdx.x;
        var stride = gridDim.x * blockDim.x;
        for (var i = start; i < n; i += stride)
            y[i] = op(x[i]);
    }

    public void Apply(int n, deviceptr<T> x, deviceptr<T> y)
    {
        const int blockSize = 256;
        var numSm = this.GPUWorker.Device.Attributes.MULTIPROCESSOR_COUNT;
        var gridSize = Math.Min(16 * numSm, Common.divup(n, blockSize));
        var lp = new LaunchParam(gridSize, blockSize);
        GPULaunch(Kernel, lp, n, x, y);
    }

    public T[] Apply(T[] x)
    {
        using (var dx = GPUWorker.Malloc(x))
        using (var dy = GPUWorker.Malloc<T>(x.Length))
        {
            Apply(x.Length, dx.Ptr, dy.Ptr);
            return dy.Gather();
        }
    }
}
4

1 に答える 1