Балаганский Никита Николаевич commited on
Commit
9dc62b8
1 Parent(s): e7e873a
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -7,7 +7,6 @@ import torch
7
 
8
  import transformers
9
  import tokenizers
10
- from torch import autocast
11
 
12
  from sampling import CAIFSampler, TopKWithTemperatureSampler
13
  from generator import Generator
@@ -57,6 +56,10 @@ def inference(lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool =
57
  generator.set_caif_sampler(caif_sampler)
58
  ordinary_sampler = TopKWithTemperatureSampler()
59
  generator.set_ordinary_sampler(ordinary_sampler)
 
 
 
 
60
  with autocast(fp16):
61
  sequences, tokens = generator.sample_sequences(
62
  num_samples=1,
 
7
 
8
  import transformers
9
  import tokenizers
 
10
 
11
  from sampling import CAIFSampler, TopKWithTemperatureSampler
12
  from generator import Generator
 
56
  generator.set_caif_sampler(caif_sampler)
57
  ordinary_sampler = TopKWithTemperatureSampler()
58
  generator.set_ordinary_sampler(ordinary_sampler)
59
+ if device == "cpu":
60
+ autocast = torch.cpu.amp.autocast
61
+ else:
62
+ autocast = torch.cuda.amp.autocast
63
  with autocast(fp16):
64
  sequences, tokens = generator.sample_sequences(
65
  num_samples=1,