train_for_fun / app.py
ShynBui's picture
Update app.py
197324c verified
import time
import torch
from transformers import BertForSequenceClassification, AdamW
from torch.utils.data import DataLoader, TensorDataset
from transformers import BertTokenizer
import gradio as gr
import pandas as pd
import os
import spaces
from spaces.zero.gradio import HTMLError
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model.to(device)
optimizer = AdamW(model.parameters(), lr=1e-5)
global_data = None
def load_data(file):
global global_data
df = pd.read_csv(file)
inputs = tokenizer(df['text'].tolist(), padding=True, truncation=True, return_tensors="pt") # Mã hóa văn bản
labels = torch.tensor(df['label'].tolist()).long() # Đảm bảo tên cột là 'label'
global_data = TensorDataset(inputs['input_ids'], inputs['attention_mask'], labels)
print(global_data)
def get_dataloader(start, end, batch_size=8):
global global_data
subset = torch.utils.data.Subset(global_data, range(start, end))
return DataLoader(subset, batch_size=batch_size)
@spaces.GPU(duration=20)
def train_batch(dataloader):
model.train()
start_time = time.time()
for step, batch in enumerate(dataloader):
input_ids, attention_mask, labels = batch
input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
elapsed_time = time.time() - start_time
if elapsed_time > 10:
print('Save checkpoint')
if not os.path.exists('./checkpoint'):
os.makedirs('./checkpoint')
torch.save(model.state_dict(), "./checkpoint/model.pt")
return False, "Checkpoint saved. Training paused."
return True, "Batch training completed."
def train_step(file=None, start_idx=0):
if file:
load_data(file)
print(global_data)
start_idx = int(start_idx)
# Load lại checkpoint nếu tồn tại
if os.path.exists("./checkpoint/model.pt"):
print("Loading checkpoint...")
model.load_state_dict(torch.load("./checkpoint/model.pt"))
else:
print("Checkpoint not found, starting fresh...")
if not os.path.exists('./checkpoint'):
os.makedirs('./checkpoint')
torch.save(model.state_dict(), "./checkpoint/model.pt")
batch_size = 8
total_samples = len(global_data)
counting = 0
while start_idx < total_samples:
print("Step:", counting)
print("Percent:", (start_idx) / total_samples * 100, "%")
counting += 1
end_idx = min(start_idx + (batch_size * 10), total_samples) # 10 batches per loop
dataloader = get_dataloader(start_idx, end_idx, batch_size)
try:
success, message = train_batch(dataloader)
if not success:
return start_idx, "./checkpoint/model.pt" # Trả về start_idx nếu lỗi xảy ra
except HTMLError as e:
print(e)
if not os.path.exists('./checkpoint'):
os.makedirs('./checkpoint')
print('Save checkpoint')
torch.save(model.state_dict(), "./checkpoint/model.pt")
return start_idx, "./checkpoint/model.pt" # Trả về start_idx để lưu lại vị trí
start_idx = end_idx
if not os.path.exists('./checkpoint'):
os.makedirs('./checkpoint')
torch.save(model.state_dict(), "./checkpoint/model.pt")
return start_idx, "./checkpoint/model.pt"
if __name__ == "__main__":
iface = gr.Interface(
fn=train_step,
inputs=[gr.File(label="Upload CSV"), gr.Textbox()],
outputs=["text", gr.File()]
)
iface.launch()