私は自分のプロジェクトで VJP を頻繁に使用しています。ヤコビアン計算の対象となる関数を実行し、呼び出し可能な vjp 関数とともに primals_out を返します。たとえば、JAX ドキュメントのカスタム VJP 定義は次のようになります。
from jax import custom_vjp
@custom_vjp
def f(x, y):
return jnp.sin(x) * y
def f_fwd(x, y):
# Returns primal output and residuals to be used in backward pass by f_bwd.
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
def f_bwd(res, g):
cos_x, sin_x, y = res # Gets residuals computed in f_fwd
return (cos_x * g * y, sin_x * g)
f.defvjp(f_fwd, f_bwd)
この例では、VJP を使用する場合に forward 関数の評価が必要であることがわかります。これは、カスタム定義の VJP の代わりに通常の VJP を使用する場合にも当てはまります。ただし、関数の評価コストが高く、その関数をコードのどこかで既に実行しているため、VJP にその関数をもう一度評価させたくありません。
では、VJP を計算するときに関数が評価されないことを示す方法はありますか?