File size: 2,827 Bytes
eb7e2af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import time
import torch
from transformers import BertForSequenceClassification, AdamW
from torch.utils.data import DataLoader, TensorDataset
from transformers import BertTokenizer
import gradio as gr
import pandas as pd
import os
import spaces

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model.to(device)

optimizer = AdamW(model.parameters(), lr=1e-5)

global_data = None

def load_data(file):
    global global_data
    df = pd.read_csv(file)
    inputs = tokenizer(df['text'].tolist(), padding=True, truncation=True, return_tensors="pt")  # Mã hóa văn bản
    labels = torch.tensor(df['lable'].tolist()).long()  #
    global_data = TensorDataset(inputs['input_ids'], inputs['attention_mask'], labels)

    print(global_data)

def get_dataloader(start, end, batch_size=8):
    global global_data
    subset = torch.utils.data.Subset(global_data, range(start, end))
    return DataLoader(subset, batch_size=batch_size)

@spaces.GPU
def train_batch(dataloader):
    model.train()
    start_time = time.time()

    for step, batch in enumerate(dataloader):
        input_ids, attention_mask, labels = batch
        input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        elapsed_time = time.time() - start_time
        if elapsed_time > 55:  # Dừng trước 60 giây để lưu checkpoint
            torch.save(model.state_dict(), "./checkpoint/model.pt")
            return False, "Checkpoint saved. Training paused."

    return True, "Batch training completed."


def train_step(file=None):
    if file:
        load_data(file)
    print(global_data)

    start_idx = 0
    batch_size = 8
    total_samples = len(global_data)

    counting = 0
    while start_idx < total_samples:
        print("Step:", counting)
        counting = counting + 1
        end_idx = min(start_idx + (batch_size * 10), total_samples)  # 10 batches per loop
        dataloader = get_dataloader(start_idx, end_idx, batch_size)

        success, message = train_batch(dataloader)
        if not success:
            return message

        start_idx = end_idx
        time.sleep(5)  # Nghỉ 5 giây giữa các phiên huấn luyện

    torch.save(model.state_dict(), "./checkpoint/model.pt")
    return "Training completed and model saved."


if __name__ == "__main__":
    iface = gr.Interface(
        fn=train_step,
        inputs=gr.File(label="Upload CSV"),
        outputs="text"
    )
    iface.launch()