1

tanhTensorFlow のデフォルト以外の伝達関数をいくつか試してみたいと思いますBasicRNNCell

元の実装は次のようになります。

class BasicRNNCell(RNNCell):
(...)
def __call__(self, inputs, state, scope=None):
    """Most basic RNN: output = new_state = tanh(W * input + U * state + B)."""
    with vs.variable_scope(scope or type(self).__name__):  # "BasicRNNCell"
      output = tanh(linear([inputs, state], self._num_units, True))
    return output, output

...そして私はそれを次のように変更しました:

class MyRNNCell(BasicRNNCell):
(...)
def __call__(self, inputs, state, scope=None):
    """Most basic RNN: output = new_state = tanh(W * input + U * state + B)."""
    with tf.variable_scope(scope or type(self).__name__):  # "BasicRNNCell"
      output = my_transfer_function(linear([inputs, state], self._num_units, True))
    return output, output

vs.variable_scopeへの変更tf.variable_scopeは成功しましたlinearが、 > rnn_cell.py <の実装であり、tfそれ自体では使用できません。

どうすればこれを機能させることができますか?

linear完全に再実装する必要がありますか? (私はすでにコードをチェックしましたが、そこでも依存関係の問題に遭遇すると思います...)

4

1 に答える 1