次の例で観察した動作に困惑しています。
import tensorflow as tf
@tf.function
def f(a):
c = a * 2
b = tf.reduce_sum(c ** 2 + 2 * c)
return b, c
def fplain(a):
c = a * 2
b = tf.reduce_sum(c ** 2 + 2 * c)
return b, c
a = tf.Variable([[0., 1.], [1., 0.]])
with tf.GradientTape() as tape:
b, c = f(a)
print('tf.function gradient: ', tape.gradient([b], [c]))
# outputs: tf.function gradient: [None]
with tf.GradientTape() as tape:
b, c = fplain(a)
print('plain gradient: ', tape.gradient([b], [c]))
# outputs: plain gradient: [<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
# array([[2., 6.],
# [6., 2.]], dtype=float32)>]
より低い動作は、私が期待するものです。@tf.function のケースを理解するにはどうすればよいですか?
事前にどうもありがとうございました!
(この問題は、すべての計算が関数内にあるため、tf.function を使用する場合の勾配の欠落とは異なることに注意してください。)