J
は、いくつかのパラメーターに関する関数のヤコビアンであると仮定f
します。(PyTorch またはおそらく Jax で) 2 つの入力 (x1
および) を取り、メモリ内の行列全体をインスタンス化せずx2
に計算する関数を持つ効率的な方法はありますか?J(x1)*J(x2).transpose()
J
私は何かに出くわしましjvp(f, input, v=vjp(f, input))
たが、それをよく理解していません。私が欲しいものかどうかはわかりません。
J
は、いくつかのパラメーターに関する関数のヤコビアンであると仮定f
します。(PyTorch またはおそらく Jax で) 2 つの入力 (x1
および) を取り、メモリ内の行列全体をインスタンス化せずx2
に計算する関数を持つ効率的な方法はありますか?J(x1)*J(x2).transpose()
J
私は何かに出くわしましjvp(f, input, v=vjp(f, input))
たが、それをよく理解していません。私が欲しいものかどうかはわかりません。