Spaces:
Sleeping
Sleeping
File size: 3,999 Bytes
4f12561 a913549 4f12561 070d008 52abc92 070d008 4f12561 a913549 070d008 7f9ed76 4f12561 a913549 4f12561 a913549 070d008 a913549 4f12561 a913549 ea6d602 4f12561 a913549 77542e4 83e3068 16908cf 20127c2 4f12561 070d008 56b1815 4f12561 ea6d602 070d008 4f12561 070d008 197324c 4f12561 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import time
import torch
from transformers import BertForSequenceClassification, AdamW
from torch.utils.data import DataLoader, TensorDataset
from transformers import BertTokenizer
import gradio as gr
import pandas as pd
import os
import spaces
from spaces.zero.gradio import HTMLError
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model.to(device)
optimizer = AdamW(model.parameters(), lr=1e-5)
global_data = None
def load_data(file):
global global_data
df = pd.read_csv(file)
inputs = tokenizer(df['text'].tolist(), padding=True, truncation=True, return_tensors="pt") # Mã hóa văn bản
labels = torch.tensor(df['label'].tolist()).long() # Đảm bảo tên cột là 'label'
global_data = TensorDataset(inputs['input_ids'], inputs['attention_mask'], labels)
print(global_data)
def get_dataloader(start, end, batch_size=8):
global global_data
subset = torch.utils.data.Subset(global_data, range(start, end))
return DataLoader(subset, batch_size=batch_size)
@spaces.GPU(duration=20)
def train_batch(dataloader):
model.train()
start_time = time.time()
for step, batch in enumerate(dataloader):
input_ids, attention_mask, labels = batch
input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
elapsed_time = time.time() - start_time
if elapsed_time > 10:
print('Save checkpoint')
if not os.path.exists('./checkpoint'):
os.makedirs('./checkpoint')
torch.save(model.state_dict(), "./checkpoint/model.pt")
return False, "Checkpoint saved. Training paused."
return True, "Batch training completed."
def train_step(file=None, start_idx=0):
if file:
load_data(file)
print(global_data)
start_idx = int(start_idx)
# Load lại checkpoint nếu tồn tại
if os.path.exists("./checkpoint/model.pt"):
print("Loading checkpoint...")
model.load_state_dict(torch.load("./checkpoint/model.pt"))
else:
print("Checkpoint not found, starting fresh...")
if not os.path.exists('./checkpoint'):
os.makedirs('./checkpoint')
torch.save(model.state_dict(), "./checkpoint/model.pt")
batch_size = 8
total_samples = len(global_data)
counting = 0
while start_idx < total_samples:
print("Step:", counting)
print("Percent:", (start_idx) / total_samples * 100, "%")
counting += 1
end_idx = min(start_idx + (batch_size * 10), total_samples) # 10 batches per loop
dataloader = get_dataloader(start_idx, end_idx, batch_size)
try:
success, message = train_batch(dataloader)
if not success:
return start_idx, "./checkpoint/model.pt" # Trả về start_idx nếu lỗi xảy ra
except HTMLError as e:
print(e)
if not os.path.exists('./checkpoint'):
os.makedirs('./checkpoint')
print('Save checkpoint')
torch.save(model.state_dict(), "./checkpoint/model.pt")
return start_idx, "./checkpoint/model.pt" # Trả về start_idx để lưu lại vị trí
start_idx = end_idx
if not os.path.exists('./checkpoint'):
os.makedirs('./checkpoint')
torch.save(model.state_dict(), "./checkpoint/model.pt")
return start_idx, "./checkpoint/model.pt"
if __name__ == "__main__":
iface = gr.Interface(
fn=train_step,
inputs=[gr.File(label="Upload CSV"), gr.Textbox()],
outputs=["text", gr.File()]
)
iface.launch()
|