Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -439,7 +439,7 @@ class UniMCPredict:
|
|
439 |
batch = [self.data_model.train_data.encode(
|
440 |
sample) for sample in batch_data]
|
441 |
batch = self.data_model.collate_fn(batch)
|
442 |
-
batch = {k: v.cuda() for k, v in batch.items()}
|
443 |
_, _, logits = self.model.model(**batch)
|
444 |
soft_logits = torch.nn.functional.softmax(logits, dim=-1)
|
445 |
logits = torch.argmax(soft_logits, dim=-1).detach().cpu().numpy()
|
|
|
439 |
batch = [self.data_model.train_data.encode(
|
440 |
sample) for sample in batch_data]
|
441 |
batch = self.data_model.collate_fn(batch)
|
442 |
+
# batch = {k: v.cuda() for k, v in batch.items()}
|
443 |
_, _, logits = self.model.model(**batch)
|
444 |
soft_logits = torch.nn.functional.softmax(logits, dim=-1)
|
445 |
logits = torch.argmax(soft_logits, dim=-1).detach().cpu().numpy()
|