1

Numeric.AD から次の最小限の例をコンパイルしようとしています。

import Numeric.AD 
timeAndGrad f l = grad f l
main = putStrLn "hi"

そして、私はこのエラーに遭遇します:

test.hs:3:24:
    Couldn't match expected type ‘f (Numeric.AD.Internal.Reverse.Reverse
                                       s a)
                                  -> Numeric.AD.Internal.Reverse.Reverse s a’
                with actual type ‘t’
      because type variable ‘s’ would escape its scope
    This (rigid, skolem) type variable is bound by
      a type expected by the context:
        Data.Reflection.Reifies s Numeric.AD.Internal.Reverse.Tape =>
        f (Numeric.AD.Internal.Reverse.Reverse s a)
        -> Numeric.AD.Internal.Reverse.Reverse s a
      at test.hs:3:19-26
    Relevant bindings include
      l :: f a (bound at test.hs:3:15)
      f :: t (bound at test.hs:3:13)
      timeAndGrad :: t -> f a -> f a (bound at test.hs:3:1)
    In the first argument of ‘grad’, namely ‘f’
    In the expression: grad f l

なぜこれが起こっているのかについての手がかりはありますか?前の例を見ると、これは「平坦化」gradのタイプであることがわかります。

grad :: (Traversable f, Num a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> f a

しかし、実際にはコードでこのようなことをする必要があります。実際、これはコンパイルできない最も最小限の例です。私がやりたいより複雑なことは、次のようなものです:

example :: SomeType
example f x args = (do stuff with the gradient and gradient "function")
    where gradient = grad f x
          gradientFn = grad f
          (other where clauses involving gradient and gradient "function")

これは、コンパイルされる型シグネチャを使用した、もう少し複雑なバージョンです。

{-# LANGUAGE RankNTypes #-}

import Numeric.AD 
import Numeric.AD.Internal.Reverse

-- compiles but I can't figure out how to use it in code
grad2 :: (Show a, Num a, Floating a) => (forall s.[Reverse s a] -> Reverse s a) -> [a] -> [a]
grad2 f l = grad f l

-- compiles with the right type, but the resulting gradient is all 0s...
grad2' :: (Show a, Num a, Floating a) => ([a] -> a) -> [a] -> [a]
grad2' f l = grad f' l
       where f' = Lift . f . extractAll
       -- i've tried using the Reverse constructor with Reverse 0 _, Reverse 1 _, and Reverse 2 _, but those don't yield the correct gradient. Not sure how the modes work

extractAll :: [Reverse t a] -> [a]
extractAll xs = map extract xs
           where extract (Lift x) = x -- non-exhaustive pattern match

dist :: (Show a, Num a, Floating a) => [a] -> a
dist [x, y] = sqrt(x^2 + y^2)

-- incorrect output: [0.0, 0.0]
main = putStrLn $ show $ grad2' dist [1,2]

ただし、最初のバージョンのgrad2をコードで使用する方法がわかりません。 の処理方法がわからないためですReverse s a。2 番目のバージョン はgrad2'、内部コンストラクターを使用しLiftて を作成するため、適切な型を持っていますが、出力勾配がすべて 0 であるため、Reverse s a内部 (具体的にはパラメーター ) がどのように機能するかを理解していないに違いありません。s他のコンストラクターReverse(ここには示されていません) を使用すると、間違ったグラデーションが生成されます。

または、人々がコードを使用したライブラリ/コードの例はありadますか? 私のユースケースは非常に一般的なものだと思います。

4

1 に答える 1