1

各トークンのバート ベクトルに関心があります。bert ベクトルとは、berts 出力層の特定のトークンの単語ベクトルを意味します。したがって、どのトークンがどのバート ベクトルを生成するかを調べたいと思います。いくつかのコードを書きましたが、それが正しいかどうか、またはテストする方法がわかりません。

したがって、コードでは文を bert で処理します。位置 ID のリストを作成し、モデルに渡します。その後、同じ位置 ID を使用して、トークンを出力レイヤーにマップします。次に、入力文の各ベクトルの文字オフセットを計算するコードがあります。

これは、position_ids を使用して生成する正しい方法ですか?

from transformers import BertModel, BertConfig, BertTokenizer
import torch

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

def sentence_to_vector(input_sentence):
    tokens_encoded = tokenizer.encode(input_sentence, add_special_tokens=True)
    input_ids = torch.tensor(tokens_encoded).unsqueeze(0)  # Batch size 1

    seq_length = input_ids.size(1)

    # code to construct position_ids from here: 
    # https://github.com/huggingface/transformers/blob/8da280ebbeca5ebd7561fd05af78c65df9161f92/pytorch_pretrained_bert/modeling.py#L188:L189
    position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)  
    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    outputs = model(input_ids, position_ids=position_ids)

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    # from the BertModel documentation (example at the bottom):
    # The last hidden-state is the first element of the output tuple
    # https://huggingface.co/transformers/model_doc/bert.html#transformers.BertModel

    #ttv = {}  # token to vector
    #for i in position_ids[0]:
    #    ttv[tokens[i]] = outputs[0][0][position_ids[0][i]]

    data = []
    last_offset = 0
    for i in range(0, len(position_ids[0])):
        token = tokens[position_ids[0][i]]
        vector = outputs[0][0][position_ids[0][i]]
        pos_begin = None
        pos_end = None
        if not token == "[CLS]" and not token == "[SEP]":
            pos_begin = input_sentence.find(token, last_offset)
            pos_end = pos_begin + len(token)
            last_offset = pos_end
        data.append({
            "token": token,
            "pos_begin": pos_begin,
            "pos_end": pos_end,
            "vector": vector
        })
    return data

input_sentence = "do the chicken dance!"
data = sentence_to_vector(input_sentence)

for token in data:
    print(token["token"] + "\t" + str(token["pos_begin"]) + "\t" + str(token["pos_end"]) + "\t" + str(token["vector"][0:3]) + "..." )
4

0 に答える 0