0

LERメトリックを計算するための次のコードがあります:

def ctc_lambda_func(args):
    y_pred, labels, input_length, label_length = args
    return K.ctc_batch_cost(labels, y_pred, input_length, label_length)


def decode(inputs):
    y_pred, seq_len, y_true = inputs

    seq_len = tf.cast(seq_len[:, 0], tf.int32)
    y_pred = tf.transpose(y_pred, perm=[1, 0, 2])

    decoded = tf.nn.ctc_beam_search_decoder(inputs=y_pred, 
                                            sequence_length=seq_len, 
                                            beam_width=1,
                                            top_paths=1,
                                            )[0][0]

    y_true_sparse = tf.sparse.from_dense(tf.cast(y_true,dtype=tf.int64))
    diff =  tf.reduce_mean(tf.edit_distance(decoded, y_true_sparse))
    return diff

def add_ctc_loss(m):
    labels = Input(name='the_labels', shape=(None,), dtype='float32')
    input_length = Input(name='input_length', shape=(1,), dtype='int64')
    label_length = Input(name='label_length', shape=(1,), dtype='int64')

    output_length = Lambda(m.output_length)(input_length)


    decoded = Lambda(function=decode, name='decoded', output_shape=(1,))(
                    [m.output, input_length, labels])
    loss_out = Lambda(function=ctc_lambda_func, name='ctc', output_shape=(1,))(
                    [m.output, labels, output_length, label_length])

    model = Model(inputs=[m.input, labels, input_length, label_length], outputs=[loss_out,decoded])

    model.compile(loss={"ctc": lambda y_true, y_pred: y_pred,
                        "decoded": lambda y_true, y_pred: y_pred
                        },
                optimizer="adam",
                )

CTC損失関数を使用して、勾配とLERを「精度」メトリックの形式として更新したいと考えています。CTC 損失が機能し、正常に更新されている間、LER (decoded_loss) は常に 0.0000e+00 のままです。何が間違っているのかわかりませんが、これを修正しようとしてオンラインで例を見て丸一日を失いましたが、問題は同じままです。デコード関数内で値を出力すると、値が適切に生成されていることがわかりますが、進行状況バーは更新されません。トレーニングがエポックを通過するにつれて、LER がどのように変化するかを確認したいと思います。

Epoch 1/150
 36/683 [>.............................] - ETA: 59s - loss: 116.2132 - ctc_loss: 116.2132 - decoded_loss: 0.0000e+00
4

0 に答える 0