ShynBui commited on
Commit
83e3068
1 Parent(s): abe15ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -34,7 +34,7 @@ def get_dataloader(start, end, batch_size=8):
34
  subset = torch.utils.data.Subset(global_data, range(start, end))
35
  return DataLoader(subset, batch_size=batch_size)
36
 
37
- @spaces.GPU(duration=20)
38
  def train_batch(dataloader):
39
  model.train()
40
  start_time = time.time()
@@ -88,8 +88,10 @@ def train_step(file=None, start_idx=0):
88
  return start_idx # Trả về start_idx nếu lỗi xảy ra
89
 
90
  except HTMLError as e:
91
- print("Exceeded GPU quota, retrying in 10 seconds...")
92
- time.sleep(10)
 
 
93
  return start_idx # Trả về start_idx để lưu lại vị trí
94
 
95
  start_idx = end_idx
 
34
  subset = torch.utils.data.Subset(global_data, range(start, end))
35
  return DataLoader(subset, batch_size=batch_size)
36
 
37
+ @spaces.GPU(duration=120)
38
  def train_batch(dataloader):
39
  model.train()
40
  start_time = time.time()
 
88
  return start_idx # Trả về start_idx nếu lỗi xảy ra
89
 
90
  except HTMLError as e:
91
+ print("Exceeded GPU quota.")
92
+ if not os.path.exists('./checkpoint'):
93
+ os.makedirs('./checkpoint')
94
+ torch.save(model.state_dict(), "./checkpoint/model.pt")
95
  return start_idx # Trả về start_idx để lưu lại vị trí
96
 
97
  start_idx = end_idx