Spaces:
Sleeping
Sleeping
trminhnam20082002
commited on
Commit
•
2b0ea98
1
Parent(s):
9922ab1
fix cuda
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ from st_utils import (
|
|
7 |
)
|
8 |
from huggingface_hub import hf_hub_download
|
9 |
import os
|
|
|
10 |
|
11 |
# list_files(os.getcwd())
|
12 |
|
@@ -19,6 +20,8 @@ Simply select one of the sample Python functions from the dropdown menu below, a
|
|
19 |
"""
|
20 |
)
|
21 |
|
|
|
|
|
22 |
# Download the model from the Hugging Face Hub if it doesn't exist
|
23 |
download_model()
|
24 |
|
|
|
7 |
)
|
8 |
from huggingface_hub import hf_hub_download
|
9 |
import os
|
10 |
+
import torch
|
11 |
|
12 |
# list_files(os.getcwd())
|
13 |
|
|
|
20 |
"""
|
21 |
)
|
22 |
|
23 |
+
st.write(f"Has CUDA: {torch.cuda.is_available()}")
|
24 |
+
|
25 |
# Download the model from the Hugging Face Hub if it doesn't exist
|
26 |
download_model()
|
27 |
|
model.py
CHANGED
@@ -104,7 +104,10 @@ class Seq2Seq(nn.Module):
|
|
104 |
else:
|
105 |
# Predict
|
106 |
preds = []
|
107 |
-
|
|
|
|
|
|
|
108 |
for i in range(source_ids.shape[0]):
|
109 |
context = encoder_output[:, i : i + 1]
|
110 |
context_mask = source_mask[i : i + 1, :]
|
@@ -154,7 +157,10 @@ class Seq2Seq(nn.Module):
|
|
154 |
class Beam(object):
|
155 |
def __init__(self, size, sos, eos):
|
156 |
self.size = size
|
157 |
-
|
|
|
|
|
|
|
158 |
# The score for each translation on the beam.
|
159 |
self.scores = self.tt.FloatTensor(size).zero_()
|
160 |
# The backpointers at each time-step.
|
|
|
104 |
else:
|
105 |
# Predict
|
106 |
preds = []
|
107 |
+
try:
|
108 |
+
zero = torch.cuda.LongTensor(1).fill_(0)
|
109 |
+
except Exception as e:
|
110 |
+
zero = torch.LongTensor(1).fill_(0)
|
111 |
for i in range(source_ids.shape[0]):
|
112 |
context = encoder_output[:, i : i + 1]
|
113 |
context_mask = source_mask[i : i + 1, :]
|
|
|
157 |
class Beam(object):
|
158 |
def __init__(self, size, sos, eos):
|
159 |
self.size = size
|
160 |
+
if torch.cuda.is_available():
|
161 |
+
self.tt = torch.cuda
|
162 |
+
else:
|
163 |
+
self.tt = torch
|
164 |
# The score for each translation on the beam.
|
165 |
self.scores = self.tt.FloatTensor(size).zero_()
|
166 |
# The backpointers at each time-step.
|