私は最近音声認識について学んでおり、接頭辞ビーム検索[1,1,_]
の考え方は、 andなどの同じ接頭辞を持つパスをマージすることであることを学びました[_,1,_]
(ご覧のとおり、_
空白マークを示します)。
この理解に基づいて、次のような擬似コードを使用して簡略化できる私のバージョンを実装しました。
def prefix_beam_search(y, beam_size, blank):
seq_len, n_class = y.shape
logY = np.log(y)
beam = [([], 0)]
for t in range(seq_len):
buff = []
for prefix, p in beam:
for i in range(n_class):
new_prefix = list(prefix) + [i]
new_p = p + logY[t][i]
buff.append((new_prefix, new_p))
# merge the paths with same prefix'
new_beam = defaultdict(lambda: ninf)
for prefix, p in buff:
# 'norm_prefix' can simplify the path, [1,1,_,2] ==> [1,2]
# However, the ending 'blank' is retained, [1,1,_] ==> [1,_]
prefix = norm_prefix(prefix, blank)
new_beam[prefix] = logsumexp(new_beam[prefix], p)
# choose the best paths
new_beam = sorted(new_beam.items(), key=lambda x: x[1], reverse=True)
beam = new_beam[: beam_size]
return beam
しかし、私がオンラインで見つけたほとんどのバージョン (紙によると) は次のようなものです。
def _prefix_beam_decode(y, beam_size, blank):
T, V = y.shape
log_y = np.log(y)
beam = [(tuple(), (0, ninf))]
for t in range(T):
new_beam = defaultdict(lambda: (ninf, ninf))
for prefix, (p_b, p_nb) in beam:
for i in range(V):
p = log_y[t, i]
if i == blank:
new_p_b, new_p_nb = new_beam[prefix]
new_p_b = logsumexp(new_p_b, p_b + p, p_nb + p)
new_beam[prefix] = (new_p_b, new_p_nb)
continue
end_t = prefix[-1] if prefix else None
new_prefix = prefix + (i,)
new_p_b, new_p_nb = new_beam[new_prefix]
if i != end_t:
new_p_nb = logsumexp(new_p_nb, p_b + p, p_nb + p)
else:
new_p_nb = logsumexp(new_p_nb, p_b + p)
new_beam[new_prefix] = (new_p_b, new_p_nb)
if i == end_t:
new_p_b, new_p_nb = new_beam[prefix]
new_p_nb = logsumexp(new_p_nb, p_nb + p)
new_beam[prefix] = (new_p_b, new_p_nb)
beam = sorted(new_beam.items(), key=lambda x: logsumexp(*x[1]), reverse=True)
beam = beam[:beam_size]
return beam
2 つの結果は異なり、私のバージョンはより長い文字列を返す傾向があります。そして、私は主な2つの側面をよく理解していません:
- 私のバージョンの詳細で、考慮されていないものはありますか?
new_prefix = prefix + (i,)
共通バージョンは、前の末尾が指定された「s」と同じであるかどうかに関係なく、新しいプレフィックスを生成します。たとえば、古いプレフィックスは[a,a,b]
and で、新しい文字 s が追加されると、両方とも[a,a,b]
保存[a,a,b,b]
されます。これだとしたら何の目的でしょうか?そして、それは二重カウントを引き起こしますか?
回答をお待ちしております。よろしくお願いします。