1

TensorFlow で CTC 実装がどのように機能するかを理解しようとしています。CTC 機能をテストするためだけに簡単な例を書きましたが、何らかの理由でinf、いくつかのターゲット/入力値を取得しています。

コード:

import tensorflow as tf
import numpy as np

# https://github.com/philipperemy/tensorflow-ctc-speech-recognition/blob/master/utils.py
def sparse_tuple_from(sequences, dtype=np.int32):
    """Create a sparse representention of x.
    Args:
        sequences: a list of lists of type dtype where each element is a sequence
    Returns:
        A tuple with (indices, values, shape)
    """
    indices = []
    values = []

    for n, seq in enumerate(sequences):
        indices.extend(zip([n] * len(seq), range(len(seq))))
        values.extend(seq)

    indices = np.asarray(indices, dtype=np.int64)
    values = np.asarray(values, dtype=dtype)
    shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64)

    return indices, values, shape

batch_size = 1
seq_length = 2
n_labels = 2

seq_len = tf.placeholder(tf.int32, [None])
targets = tf.sparse_placeholder(tf.int32)
logits = tf.constant(np.random.random((batch_size, seq_length, n_labels+1)),dtype=tf.float32) # +1 for the blank label
loss = tf.reduce_mean(tf.nn.ctc_loss(targets, logits, seq_len, time_major = False))


with tf.Session() as sess:
    for it in range(10):
        rand_target = np.random.randint(n_labels, size=(seq_length))
        sample_target = sparse_tuple_from([rand_target])

        logitsval = sess.run(logits)
        lossval = sess.run(loss, feed_dict={seq_len: [seq_length], targets: sample_target})
        print('******* Iter: %d *******'%it)
        print('logits:', logitsval)
        print('rand_target:', rand_target)
        print('rand_sparse_target:', sample_target)
        print('loss:', lossval)
        print()

サンプル出力:

******* Iter: 0 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [0 1]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([0, 1], dtype=int32), array([1, 2]))
loss: 2.61521

******* Iter: 1 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [1 1]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([1, 1], dtype=int32), array([1, 2]))
loss: inf

******* Iter: 2 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [0 1]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([0, 1], dtype=int32), array([1, 2]))
loss: 2.61521

******* Iter: 3 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [1 0]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([1, 0], dtype=int32), array([1, 2]))
loss: 1.59766

******* Iter: 4 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [0 0]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([0, 0], dtype=int32), array([1, 2]))
loss: inf

******* Iter: 5 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [0 1]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([0, 1], dtype=int32), array([1, 2]))
loss: 2.61521

******* Iter: 6 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [1 0]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([1, 0], dtype=int32), array([1, 2]))
loss: 1.59766

******* Iter: 7 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [1 1]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([1, 1], dtype=int32), array([1, 2]))
loss: inf

******* Iter: 8 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [0 1]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([0, 1], dtype=int32), array([1, 2]))
loss: 2.61521

******* Iter: 9 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [0 0]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([0, 0], dtype=int32), array([1, 2]))
loss: inf

私はそこに何が欠けているのですか!?

4

1 に答える 1

2

入力テキスト (rand_target) をよく見てください。inf 損失値と相関するいくつかの単純なパターンが見られると確信しています ;-)

何が起こっているかの簡単な説明: CTC は、各文字を繰り返すことによってテキストをエンコードし、文字の間に非文字マーカー (「CTC ブランク ラベル」と呼ばれる) を挿入することもできます。このエンコーディング (またはデコーディング) を元に戻すということは、単純に繰り返し文字を破棄し、すべての空白を破棄することを意味します。いくつかの例を挙げます (「...」はテキスト、「...」はエンコーディング、「-」は空白のラベルに対応します):

  • "to" -> 'tttooo'、または 'to' または 't-oo'、または 'to' など...
  • 「too」 -> 「to-o」、または「tttoo---oo」、または「---too--」、ただし「too」ではありません (デコードされた「too」がどのように見えるかを考えてください)

これで、一部のサンプルが失敗した理由を確認するのに十分なことがわかりました。

  • 入力テキストの長さは 2 です
  • エンコーディングの長さは 2 です
  • 入力文字が繰り返される場合 (例: '11'、または python リスト: [1, 1])、これをエンコードする唯一の方法は、間に空白を配置することです ('11' と '1 のデコードが多いと考えてください)。 -1')。しかし、エンコーディングの長さは 3 になります。
  • そのため、文字が繰り返される長さ 2 のテキストを長さ 2 のエンコーディングにエンコードする方法はありません。したがって、TF 損失の実装は inf を返します。

エンコーディングをステート マシンとして想像することもできます。下の図を参照してください。テキスト「11」は、開始状態 (左端の 2 つの状態) で始まり、最終状態 (右端の 2 つの状態) で終わるすべての可能なパスで表すことができます。ご覧のとおり、最短経路は「1-1」です。

ここに画像の説明を入力

結論として、入力テキストで繰り返される文字ごとに、少なくとも 1 つの追加の空白を挿入する必要があります。この記事は、CTC を理解するのに役立つかもしれません: https://towardsdatascience.com/3797e43a86c

于 2018-09-30T16:01:13.923 に答える