ShynBui commited on
Commit
a913549
1 Parent(s): 4f12561

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -14
app.py CHANGED
@@ -7,6 +7,7 @@ import gradio as gr
7
  import pandas as pd
8
  import os
9
  import spaces
 
10
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  print(device)
@@ -23,7 +24,7 @@ def load_data(file):
23
  global global_data
24
  df = pd.read_csv(file)
25
  inputs = tokenizer(df['text'].tolist(), padding=True, truncation=True, return_tensors="pt") # Mã hóa văn bản
26
- labels = torch.tensor(df['lable'].tolist()).long() #
27
  global_data = TensorDataset(inputs['input_ids'], inputs['attention_mask'], labels)
28
 
29
  print(global_data)
@@ -33,7 +34,7 @@ def get_dataloader(start, end, batch_size=8):
33
  subset = torch.utils.data.Subset(global_data, range(start, end))
34
  return DataLoader(subset, batch_size=batch_size)
35
 
36
- @spaces.GPU(duration=120)
37
  def train_batch(dataloader):
38
  model.train()
39
  start_time = time.time()
@@ -49,41 +50,45 @@ def train_batch(dataloader):
49
  optimizer.step()
50
 
51
  elapsed_time = time.time() - start_time
52
- if elapsed_time > 10: # Dừng trước 60 giây để lưu checkpoint
53
- print("save checkpoint")
54
  torch.save(model.state_dict(), "./checkpoint/model.pt")
55
  return False, "Checkpoint saved. Training paused."
56
 
57
  return True, "Batch training completed."
58
 
59
-
60
  def train_step(file=None):
61
  if file:
62
  load_data(file)
 
63
 
64
  start_idx = 0
65
  batch_size = 8
66
  total_samples = len(global_data)
67
 
 
68
  while start_idx < total_samples:
69
- print(start_idx)
70
- end_idx = min(start_idx + (batch_size * 10), total_samples) # Chia nhỏ dữ liệu để xử lý nhanh
 
 
71
  dataloader = get_dataloader(start_idx, end_idx, batch_size)
72
 
73
- start_time = time.time()
74
- success, message = train_batch(dataloader)
75
- elapsed_time = time.time() - start_time
 
76
 
77
- if elapsed_time >= 10: # Kết thúc trước khi hết 60 giây để lưu checkpoint
78
- torch.save(model.state_dict(), "./checkpoint/model.pt")
79
- return f"{message}. Training paused after {elapsed_time:.2f}s."
 
80
 
81
  start_idx = end_idx
 
82
 
83
  torch.save(model.state_dict(), "./checkpoint/model.pt")
84
  return "Training completed and model saved."
85
 
86
-
87
  if __name__ == "__main__":
88
  iface = gr.Interface(
89
  fn=train_step,
 
7
  import pandas as pd
8
  import os
9
  import spaces
10
+ from spaces.zero.gradio import HTMLError
11
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  print(device)
 
24
  global global_data
25
  df = pd.read_csv(file)
26
  inputs = tokenizer(df['text'].tolist(), padding=True, truncation=True, return_tensors="pt") # Mã hóa văn bản
27
+ labels = torch.tensor(df['label'].tolist()).long() # Đảm bảo tên cột là 'label'
28
  global_data = TensorDataset(inputs['input_ids'], inputs['attention_mask'], labels)
29
 
30
  print(global_data)
 
34
  subset = torch.utils.data.Subset(global_data, range(start, end))
35
  return DataLoader(subset, batch_size=batch_size)
36
 
37
+ @spaces.GPU(duration=5)
38
  def train_batch(dataloader):
39
  model.train()
40
  start_time = time.time()
 
50
  optimizer.step()
51
 
52
  elapsed_time = time.time() - start_time
53
+ if elapsed_time > 50: # Dừng trước 59 giây để đảm bảo không vượt hạn ngạch
 
54
  torch.save(model.state_dict(), "./checkpoint/model.pt")
55
  return False, "Checkpoint saved. Training paused."
56
 
57
  return True, "Batch training completed."
58
 
 
59
  def train_step(file=None):
60
  if file:
61
  load_data(file)
62
+ print(global_data)
63
 
64
  start_idx = 0
65
  batch_size = 8
66
  total_samples = len(global_data)
67
 
68
+ counting = 0
69
  while start_idx < total_samples:
70
+ print("Step:", counting)
71
+ print("Percent:", total_samples/start_idx * 100, "%")
72
+ counting += 1
73
+ end_idx = min(start_idx + (batch_size * 10), total_samples) # 10 batches per loop
74
  dataloader = get_dataloader(start_idx, end_idx, batch_size)
75
 
76
+ try:
77
+ success, message = train_batch(dataloader)
78
+ if not success:
79
+ return message
80
 
81
+ except HTMLError as e:
82
+ print("Exceeded GPU quota, retrying in 10 seconds...")
83
+ time.sleep(10)
84
+ continue
85
 
86
  start_idx = end_idx
87
+ time.sleep(2) # Nghỉ 2 giây giữa các phiên huấn luyện
88
 
89
  torch.save(model.state_dict(), "./checkpoint/model.pt")
90
  return "Training completed and model saved."
91
 
 
92
  if __name__ == "__main__":
93
  iface = gr.Interface(
94
  fn=train_step,