train_for_fun / app.py
ShynBui's picture
Update app.py
4f12561 verified
raw
history blame
3.02 kB
import time
import torch
from transformers import BertForSequenceClassification, AdamW
from torch.utils.data import DataLoader, TensorDataset
from transformers import BertTokenizer
import gradio as gr
import pandas as pd
import os
import spaces
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model.to(device)
optimizer = AdamW(model.parameters(), lr=1e-5)
global_data = None
def load_data(file):
global global_data
df = pd.read_csv(file)
inputs = tokenizer(df['text'].tolist(), padding=True, truncation=True, return_tensors="pt") # Mã hóa văn bản
labels = torch.tensor(df['lable'].tolist()).long() #
global_data = TensorDataset(inputs['input_ids'], inputs['attention_mask'], labels)
print(global_data)
def get_dataloader(start, end, batch_size=8):
global global_data
subset = torch.utils.data.Subset(global_data, range(start, end))
return DataLoader(subset, batch_size=batch_size)
@spaces.GPU(duration=120)
def train_batch(dataloader):
model.train()
start_time = time.time()
for step, batch in enumerate(dataloader):
input_ids, attention_mask, labels = batch
input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
elapsed_time = time.time() - start_time
if elapsed_time > 10: # Dừng trước 60 giây để lưu checkpoint
print("save checkpoint")
torch.save(model.state_dict(), "./checkpoint/model.pt")
return False, "Checkpoint saved. Training paused."
return True, "Batch training completed."
def train_step(file=None):
if file:
load_data(file)
start_idx = 0
batch_size = 8
total_samples = len(global_data)
while start_idx < total_samples:
print(start_idx)
end_idx = min(start_idx + (batch_size * 10), total_samples) # Chia nhỏ dữ liệu để xử lý nhanh
dataloader = get_dataloader(start_idx, end_idx, batch_size)
start_time = time.time()
success, message = train_batch(dataloader)
elapsed_time = time.time() - start_time
if elapsed_time >= 10: # Kết thúc trước khi hết 60 giây để lưu checkpoint
torch.save(model.state_dict(), "./checkpoint/model.pt")
return f"{message}. Training paused after {elapsed_time:.2f}s."
start_idx = end_idx
torch.save(model.state_dict(), "./checkpoint/model.pt")
return "Training completed and model saved."
if __name__ == "__main__":
iface = gr.Interface(
fn=train_step,
inputs=gr.File(label="Upload CSV"),
outputs="text"
)
iface.launch()