import time import torch from transformers import BertForSequenceClassification, AdamW from torch.utils.data import DataLoader, TensorDataset from transformers import BertTokenizer import gradio as gr import pandas as pd import os import spaces from spaces.zero.gradio import HTMLError device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) model = BertForSequenceClassification.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model.to(device) optimizer = AdamW(model.parameters(), lr=1e-5) global_data = None def load_data(file): global global_data df = pd.read_csv(file) inputs = tokenizer(df['text'].tolist(), padding=True, truncation=True, return_tensors="pt") # Mã hóa văn bản labels = torch.tensor(df['label'].tolist()).long() # Đảm bảo tên cột là 'label' global_data = TensorDataset(inputs['input_ids'], inputs['attention_mask'], labels) print(global_data) def get_dataloader(start, end, batch_size=8): global global_data subset = torch.utils.data.Subset(global_data, range(start, end)) return DataLoader(subset, batch_size=batch_size) @spaces.GPU(duration=20) def train_batch(dataloader): model.train() start_time = time.time() for step, batch in enumerate(dataloader): input_ids, attention_mask, labels = batch input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device) optimizer.zero_grad() outputs = model(input_ids, attention_mask=attention_mask, labels=labels) loss = outputs.loss loss.backward() optimizer.step() elapsed_time = time.time() - start_time if elapsed_time > 10: print('Save checkpoint') if not os.path.exists('./checkpoint'): os.makedirs('./checkpoint') torch.save(model.state_dict(), "./checkpoint/model.pt") return False, "Checkpoint saved. Training paused." return True, "Batch training completed." def train_step(file=None, start_idx=0): if file: load_data(file) print(global_data) start_idx = int(start_idx) # Load lại checkpoint nếu tồn tại if os.path.exists("./checkpoint/model.pt"): print("Loading checkpoint...") model.load_state_dict(torch.load("./checkpoint/model.pt")) else: print("Checkpoint not found, starting fresh...") if not os.path.exists('./checkpoint'): os.makedirs('./checkpoint') torch.save(model.state_dict(), "./checkpoint/model.pt") batch_size = 8 total_samples = len(global_data) counting = 0 while start_idx < total_samples: print("Step:", counting) print("Percent:", (start_idx) / total_samples * 100, "%") counting += 1 end_idx = min(start_idx + (batch_size * 10), total_samples) # 10 batches per loop dataloader = get_dataloader(start_idx, end_idx, batch_size) try: success, message = train_batch(dataloader) if not success: return start_idx, "./checkpoint/model.pt" # Trả về start_idx nếu lỗi xảy ra except HTMLError as e: print(e) if not os.path.exists('./checkpoint'): os.makedirs('./checkpoint') print('Save checkpoint') torch.save(model.state_dict(), "./checkpoint/model.pt") return start_idx, "./checkpoint/model.pt" # Trả về start_idx để lưu lại vị trí start_idx = end_idx if not os.path.exists('./checkpoint'): os.makedirs('./checkpoint') torch.save(model.state_dict(), "./checkpoint/model.pt") return start_idx, "./checkpoint/model.pt" if __name__ == "__main__": iface = gr.Interface( fn=train_step, inputs=[gr.File(label="Upload CSV"), gr.Textbox()], outputs=["text", gr.File()] ) iface.launch()