0
import numpy as np
from transformers import GPTNeoForCausalLM, GPT2Tokenizer 
import coremltools as ct
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

sentence_fragment = "The Oceans are"

class NEO(torch.nn.Module):
    def __init__(self, model):
        super(NEO, self).__init__()
        self.next_token_predictor = model
    
    def forward(self, x):
        sentence = x
        predictions, _ = self.next_token_predictor(sentence)
        token = torch.argmax(predictions[-1, :], dim=0, keepdim=True)
        sentence = torch.cat((sentence, token), 0)
        return sentence

token_predictor = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M", torchscript=True).eval()

context = torch.tensor(tokenizer.encode(sentence_fragment))
random_tokens = torch.randint(10000, (5,))
traced_token_predictor = torch.jit.trace(token_predictor, random_tokens)

model = NEO(model=traced_token_predictor)
scripted_model = torch.jit.script(model)

# Custom model

sentence_fragment = "The Oceans are"

for i in range(10):
    context = torch.tensor(tokenizer.encode(sentence_fragment))
    torch_out = scripted_model(context)
    sentence_fragment = tokenizer.decode(torch_out)
print("Custom model: {}".format(sentence_fragment))

# Stock model

model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M", torchscript=True).eval()

sentence_fragment = "The Oceans are"

input_ids = tokenizer(sentence_fragment, return_tensors="pt").input_ids
gen_tokens = model.generate(input_ids, do_sample=True, max_length=20)
gen_text = tokenizer.batch_decode(gen_tokens)[0]
print("Stock model: "+gen_text)

実行 1

出力:


Custom model: The Oceans are the most important source of water for the entire world
Stock model: The Oceans are on the rise. The American Southwest is thriving, but the southern United States still

実行 2

出力:


Custom model: The Oceans are the most important source of water for the entire world. 
Stock model: The Oceans are the land of man

This is a short video of the Australian government

カスタム モデルは常に同じ出力を返します。ただし、do_sampling = True株式では、model.generate呼び出しごとに異なる結果が返されます。do_sampling がトランスフォーマーでどのように機能するかを理解するのに多くの時間を費やしたので、皆さんの助けが必要です。

呼び出しごとに異なる結果が得られるようにカスタム モデルをコーディングする方法は?

ありがとう!

4

1 に答える 1