Update README.md
Browse files
README.md
CHANGED
@@ -31,10 +31,10 @@ model = GPT2LMHeadModel.from_pretrained("skytnt/gpt2-japanese-lyric-small")
|
|
31 |
|
32 |
|
33 |
def gen_lyric(prompt_text: str):
|
34 |
-
prompt_text =
|
35 |
prompt_tokens = tokenizer.tokenize(prompt_text)
|
36 |
prompt_token_ids = tokenizer.convert_tokens_to_ids(prompt_tokens)
|
37 |
-
prompt_tensor = torch.LongTensor(prompt_token_ids)
|
38 |
prompt_tensor = prompt_tensor.view(1, -1)
|
39 |
# model forward
|
40 |
output_sequences = model.generate(
|
@@ -55,8 +55,7 @@ def gen_lyric(prompt_text: str):
|
|
55 |
generated_sequence = output_sequences.tolist()[0]
|
56 |
generated_tokens = tokenizer.convert_ids_to_tokens(generated_sequence)
|
57 |
generated_text = tokenizer.convert_tokens_to_string(generated_tokens)
|
58 |
-
generated_text = "\n".join([s.strip() for s in generated_text.split('\\n')]).replace(' ', '\u3000').replace(
|
59 |
-
'</s>', '\n\n---end---')
|
60 |
return generated_text
|
61 |
|
62 |
|
|
|
31 |
|
32 |
|
33 |
def gen_lyric(prompt_text: str):
|
34 |
+
prompt_text = "<s>" + prompt_text
|
35 |
prompt_tokens = tokenizer.tokenize(prompt_text)
|
36 |
prompt_token_ids = tokenizer.convert_tokens_to_ids(prompt_tokens)
|
37 |
+
prompt_tensor = torch.LongTensor(prompt_token_ids).to(device)
|
38 |
prompt_tensor = prompt_tensor.view(1, -1)
|
39 |
# model forward
|
40 |
output_sequences = model.generate(
|
|
|
55 |
generated_sequence = output_sequences.tolist()[0]
|
56 |
generated_tokens = tokenizer.convert_ids_to_tokens(generated_sequence)
|
57 |
generated_text = tokenizer.convert_tokens_to_string(generated_tokens)
|
58 |
+
generated_text = "\n".join([s.strip() for s in generated_text.split('\\n')]).replace(' ', '\u3000').replace('<s>', '').replace('</s>', '\n\n---end---')
|
|
|
59 |
return generated_text
|
60 |
|
61 |
|