yrshi commited on
Commit
031d775
1 Parent(s): f37ea64

fixed the device problem

Browse files
Files changed (1) hide show
  1. app.py +1 -0
app.py CHANGED
@@ -203,6 +203,7 @@ class InferenceRunner:
203
  @torch.no_grad()
204
  def predict(self, rxn_dict, temperature=1):
205
  graphs, prompt_tokens = self.tokenize(rxn_dict)
 
206
  result_dict = rxn_dict
207
  samples = {'graphs': graphs, 'prompt_tokens': prompt_tokens}
208
  prediction = self.model.blip2opt.generate(
 
203
  @torch.no_grad()
204
  def predict(self, rxn_dict, temperature=1):
205
  graphs, prompt_tokens = self.tokenize(rxn_dict)
206
+ self.model.blip2opt = self.model.blip2opt.to('cuda')
207
  result_dict = rxn_dict
208
  samples = {'graphs': graphs, 'prompt_tokens': prompt_tokens}
209
  prediction = self.model.blip2opt.generate(