teticio commited on
Commit
9529307
1 Parent(s): 3d84bdd

use base model if on cpu

Browse files
Files changed (1) hide show
  1. app.py +6 -8
app.py CHANGED
@@ -15,12 +15,11 @@ from transformers.generation_logits_process import TypicalLogitsWarper
15
 
16
  nltk.download('punkt')
17
 
18
- cuda = torch.cuda.is_available()
19
-
20
- tokenizer = RobertaTokenizer.from_pretrained("roberta-large")
21
- model = RobertaForMaskedLM.from_pretrained("roberta-large")
22
- if cuda:
23
- model = model.cuda()
24
 
25
  max_len = 20
26
  top_k = 100
@@ -99,8 +98,7 @@ def parallel_sequential_generation(seed_text: str,
99
  masked_tokens = np.where(
100
  inp['input_ids'][0].numpy() == tokenizer.mask_token_id)[0]
101
  seed_len = masked_tokens[0]
102
- if cuda:
103
- inp = inp.to('cuda')
104
 
105
  for ii in range(max_iter):
106
  kk = np.random.randint(0, max_len)
 
15
 
16
  nltk.download('punkt')
17
 
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ pretrained = "roberta-large" if device == "cuda" else "roberta-base"
20
+ tokenizer = RobertaTokenizer.from_pretrained(pretrained)
21
+ model = RobertaForMaskedLM.from_pretrained(pretrained)
22
+ model = model.to(device)
 
23
 
24
  max_len = 20
25
  top_k = 100
 
98
  masked_tokens = np.where(
99
  inp['input_ids'][0].numpy() == tokenizer.mask_token_id)[0]
100
  seed_len = masked_tokens[0]
101
+ inp = inp.to(device)
 
102
 
103
  for ii in range(max_iter):
104
  kk = np.random.randint(0, max_len)