skytnt commited on
Commit
3d1a488
1 Parent(s): 4e1316c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +5 -1
README.md CHANGED
@@ -24,9 +24,13 @@ The model is used to generate Japanese lyrics.
24
  import torch
25
  from transformers import T5Tokenizer, GPT2LMHeadModel
26
 
 
 
 
 
27
  tokenizer = T5Tokenizer.from_pretrained("skytnt/gpt2-japanese-lyric-medium")
28
  model = GPT2LMHeadModel.from_pretrained("skytnt/gpt2-japanese-lyric-medium")
29
-
30
 
31
  def gen_lyric(title: str, prompt_text: str):
32
  if len(title)!= 0 or len(prompt_text)!= 0:
 
24
  import torch
25
  from transformers import T5Tokenizer, GPT2LMHeadModel
26
 
27
+ device = torch.device("cpu")
28
+ if torch.cuda.is_available():
29
+ device = torch.device("cuda")
30
+
31
  tokenizer = T5Tokenizer.from_pretrained("skytnt/gpt2-japanese-lyric-medium")
32
  model = GPT2LMHeadModel.from_pretrained("skytnt/gpt2-japanese-lyric-medium")
33
+ model = model.to(device)
34
 
35
  def gen_lyric(title: str, prompt_text: str):
36
  if len(title)!= 0 or len(prompt_text)!= 0: