Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -34,7 +34,7 @@ def get_dataloader(start, end, batch_size=8):
|
|
34 |
subset = torch.utils.data.Subset(global_data, range(start, end))
|
35 |
return DataLoader(subset, batch_size=batch_size)
|
36 |
|
37 |
-
@spaces.GPU(duration=
|
38 |
def train_batch(dataloader):
|
39 |
model.train()
|
40 |
start_time = time.time()
|
@@ -88,8 +88,10 @@ def train_step(file=None, start_idx=0):
|
|
88 |
return start_idx # Trả về start_idx nếu lỗi xảy ra
|
89 |
|
90 |
except HTMLError as e:
|
91 |
-
print("Exceeded GPU quota
|
92 |
-
|
|
|
|
|
93 |
return start_idx # Trả về start_idx để lưu lại vị trí
|
94 |
|
95 |
start_idx = end_idx
|
|
|
34 |
subset = torch.utils.data.Subset(global_data, range(start, end))
|
35 |
return DataLoader(subset, batch_size=batch_size)
|
36 |
|
37 |
+
@spaces.GPU(duration=120)
|
38 |
def train_batch(dataloader):
|
39 |
model.train()
|
40 |
start_time = time.time()
|
|
|
88 |
return start_idx # Trả về start_idx nếu lỗi xảy ra
|
89 |
|
90 |
except HTMLError as e:
|
91 |
+
print("Exceeded GPU quota.")
|
92 |
+
if not os.path.exists('./checkpoint'):
|
93 |
+
os.makedirs('./checkpoint')
|
94 |
+
torch.save(model.state_dict(), "./checkpoint/model.pt")
|
95 |
return start_idx # Trả về start_idx để lưu lại vị trí
|
96 |
|
97 |
start_idx = end_idx
|