Spaces:
Runtime error
Runtime error
fixed the device problem
Browse files
app.py
CHANGED
@@ -203,6 +203,7 @@ class InferenceRunner:
|
|
203 |
@torch.no_grad()
|
204 |
def predict(self, rxn_dict, temperature=1):
|
205 |
graphs, prompt_tokens = self.tokenize(rxn_dict)
|
|
|
206 |
result_dict = rxn_dict
|
207 |
samples = {'graphs': graphs, 'prompt_tokens': prompt_tokens}
|
208 |
prediction = self.model.blip2opt.generate(
|
|
|
203 |
@torch.no_grad()
|
204 |
def predict(self, rxn_dict, temperature=1):
|
205 |
graphs, prompt_tokens = self.tokenize(rxn_dict)
|
206 |
+
self.model.blip2opt = self.model.blip2opt.to('cuda')
|
207 |
result_dict = rxn_dict
|
208 |
samples = {'graphs': graphs, 'prompt_tokens': prompt_tokens}
|
209 |
prediction = self.model.blip2opt.generate(
|