3

次の例で観察した動作に困惑しています。

import tensorflow as tf

@tf.function
def f(a):
    c = a * 2
    b = tf.reduce_sum(c ** 2 + 2 * c)
    return b, c

def fplain(a):
    c = a * 2
    b = tf.reduce_sum(c ** 2 + 2 * c)
    return b, c


a = tf.Variable([[0., 1.], [1., 0.]])

with tf.GradientTape() as tape:
    b, c = f(a)
    
print('tf.function gradient: ', tape.gradient([b], [c]))

# outputs: tf.function gradient:  [None]

with tf.GradientTape() as tape:
    b, c = fplain(a)
    
print('plain gradient: ', tape.gradient([b], [c]))

# outputs: plain gradient:  [<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
# array([[2., 6.],
#        [6., 2.]], dtype=float32)>]

より低い動作は、私が期待するものです。@tf.function のケースを理解するにはどうすればよいですか?

事前にどうもありがとうございました!

(この問題は、すべての計算が関数内にあるため、tf.function を使用する場合の勾配の欠落とは異なることに注意してください。)

4

1 に答える 1