1

Jは、いくつかのパラメーターに関する関数のヤコビアンであると仮定fします。(PyTorch またはおそらく Jax で) 2 つの入力 (x1および) を取り、メモリ内の行列全体をインスタンス化せずx2に計算する関数を持つ効率的な方法はありますか?J(x1)*J(x2).transpose() J

私は何かに出くわしましjvp(f, input, v=vjp(f, input))たが、それをよく理解していません。私が欲しいものかどうかはわかりません。

4

1 に答える 1