各トークンのバート ベクトルに関心があります。bert ベクトルとは、berts 出力層の特定のトークンの単語ベクトルを意味します。したがって、どのトークンがどのバート ベクトルを生成するかを調べたいと思います。いくつかのコードを書きましたが、それが正しいかどうか、またはテストする方法がわかりません。
したがって、コードでは文を bert で処理します。位置 ID のリストを作成し、モデルに渡します。その後、同じ位置 ID を使用して、トークンを出力レイヤーにマップします。次に、入力文の各ベクトルの文字オフセットを計算するコードがあります。
これは、position_ids を使用して生成する正しい方法ですか?
from transformers import BertModel, BertConfig, BertTokenizer
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
def sentence_to_vector(input_sentence):
tokens_encoded = tokenizer.encode(input_sentence, add_special_tokens=True)
input_ids = torch.tensor(tokens_encoded).unsqueeze(0) # Batch size 1
seq_length = input_ids.size(1)
# code to construct position_ids from here:
# https://github.com/huggingface/transformers/blob/8da280ebbeca5ebd7561fd05af78c65df9161f92/pytorch_pretrained_bert/modeling.py#L188:L189
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
outputs = model(input_ids, position_ids=position_ids)
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
# from the BertModel documentation (example at the bottom):
# The last hidden-state is the first element of the output tuple
# https://huggingface.co/transformers/model_doc/bert.html#transformers.BertModel
#ttv = {} # token to vector
#for i in position_ids[0]:
# ttv[tokens[i]] = outputs[0][0][position_ids[0][i]]
data = []
last_offset = 0
for i in range(0, len(position_ids[0])):
token = tokens[position_ids[0][i]]
vector = outputs[0][0][position_ids[0][i]]
pos_begin = None
pos_end = None
if not token == "[CLS]" and not token == "[SEP]":
pos_begin = input_sentence.find(token, last_offset)
pos_end = pos_begin + len(token)
last_offset = pos_end
data.append({
"token": token,
"pos_begin": pos_begin,
"pos_end": pos_end,
"vector": vector
})
return data
input_sentence = "do the chicken dance!"
data = sentence_to_vector(input_sentence)
for token in data:
print(token["token"] + "\t" + str(token["pos_begin"]) + "\t" + str(token["pos_end"]) + "\t" + str(token["vector"][0:3]) + "..." )