import time import torch from transformers import BertForSequenceClassification, AdamW from torch.utils.data import DataLoader, TensorDataset from transformers import BertTokenizer import gradio as gr import pandas as pd import os import spaces device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) model = BertForSequenceClassification.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model.to(device) optimizer = AdamW(model.parameters(), lr=1e-5) global_data = None def load_data(file): global global_data df = pd.read_csv(file) inputs = tokenizer(df['text'].tolist(), padding=True, truncation=True, return_tensors="pt") # Mã hóa văn bản labels = torch.tensor(df['lable'].tolist()).long() # global_data = TensorDataset(inputs['input_ids'], inputs['attention_mask'], labels) print(global_data) def get_dataloader(start, end, batch_size=8): global global_data subset = torch.utils.data.Subset(global_data, range(start, end)) return DataLoader(subset, batch_size=batch_size) @spaces.GPU(duration=120) def train_batch(dataloader): model.train() start_time = time.time() for step, batch in enumerate(dataloader): input_ids, attention_mask, labels = batch input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device) optimizer.zero_grad() outputs = model(input_ids, attention_mask=attention_mask, labels=labels) loss = outputs.loss loss.backward() optimizer.step() elapsed_time = time.time() - start_time if elapsed_time > 10: # Dừng trước 60 giây để lưu checkpoint print("save checkpoint") torch.save(model.state_dict(), "./checkpoint/model.pt") return False, "Checkpoint saved. Training paused." return True, "Batch training completed." def train_step(file=None): if file: load_data(file) start_idx = 0 batch_size = 8 total_samples = len(global_data) while start_idx < total_samples: print(start_idx) end_idx = min(start_idx + (batch_size * 10), total_samples) # Chia nhỏ dữ liệu để xử lý nhanh dataloader = get_dataloader(start_idx, end_idx, batch_size) start_time = time.time() success, message = train_batch(dataloader) elapsed_time = time.time() - start_time if elapsed_time >= 10: # Kết thúc trước khi hết 60 giây để lưu checkpoint torch.save(model.state_dict(), "./checkpoint/model.pt") return f"{message}. Training paused after {elapsed_time:.2f}s." start_idx = end_idx torch.save(model.state_dict(), "./checkpoint/model.pt") return "Training completed and model saved." if __name__ == "__main__": iface = gr.Interface( fn=train_step, inputs=gr.File(label="Upload CSV"), outputs="text" ) iface.launch()