File size: 3,999 Bytes
4f12561
 
 
 
 
 
 
 
 
a913549
4f12561
 
 
 
 
 
070d008
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52abc92
070d008
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f12561
 
a913549
070d008
 
 
 
 
7f9ed76
 
 
 
 
4f12561
 
 
 
a913549
4f12561
a913549
070d008
a913549
 
4f12561
 
a913549
 
 
ea6d602
4f12561
a913549
77542e4
83e3068
 
16908cf
 
20127c2
4f12561
 
070d008
56b1815
 
4f12561
ea6d602
070d008
4f12561
 
 
 
070d008
197324c
4f12561
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import time
import torch
from transformers import BertForSequenceClassification, AdamW
from torch.utils.data import DataLoader, TensorDataset
from transformers import BertTokenizer
import gradio as gr
import pandas as pd
import os
import spaces
from spaces.zero.gradio import HTMLError

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model.to(device)

optimizer = AdamW(model.parameters(), lr=1e-5)

global_data = None

def load_data(file):
    global global_data
    df = pd.read_csv(file)
    inputs = tokenizer(df['text'].tolist(), padding=True, truncation=True, return_tensors="pt")  # Mã hóa văn bản
    labels = torch.tensor(df['label'].tolist()).long()  # Đảm bảo tên cột là 'label'
    global_data = TensorDataset(inputs['input_ids'], inputs['attention_mask'], labels)

    print(global_data)

def get_dataloader(start, end, batch_size=8):
    global global_data
    subset = torch.utils.data.Subset(global_data, range(start, end))
    return DataLoader(subset, batch_size=batch_size)

@spaces.GPU(duration=20)
def train_batch(dataloader):
    model.train()
    start_time = time.time()

    for step, batch in enumerate(dataloader):
        input_ids, attention_mask, labels = batch
        input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        elapsed_time = time.time() - start_time
        if elapsed_time > 10: 
            print('Save checkpoint')
            if not os.path.exists('./checkpoint'):
                os.makedirs('./checkpoint')
            torch.save(model.state_dict(), "./checkpoint/model.pt")
            
            return False, "Checkpoint saved. Training paused."

    return True, "Batch training completed."


def train_step(file=None, start_idx=0):
    if file:
        load_data(file)
    print(global_data)
    start_idx = int(start_idx)
    # Load lại checkpoint nếu tồn tại
    if os.path.exists("./checkpoint/model.pt"):
        print("Loading checkpoint...")
        model.load_state_dict(torch.load("./checkpoint/model.pt"))
    else:
        print("Checkpoint not found, starting fresh...")
        if not os.path.exists('./checkpoint'):
            os.makedirs('./checkpoint')
        torch.save(model.state_dict(), "./checkpoint/model.pt")

    batch_size = 8
    total_samples = len(global_data)

    counting = 0
    while start_idx < total_samples:
        print("Step:", counting)
        print("Percent:", (start_idx) / total_samples * 100, "%")
        counting += 1
        end_idx = min(start_idx + (batch_size * 10), total_samples)  # 10 batches per loop
        dataloader = get_dataloader(start_idx, end_idx, batch_size)

        try:
            success, message = train_batch(dataloader)
            if not success:
                return start_idx, "./checkpoint/model.pt"  # Trả về start_idx nếu lỗi xảy ra

        except HTMLError as e:
            print(e)
            if not os.path.exists('./checkpoint'):
                os.makedirs('./checkpoint')
            print('Save checkpoint')
            torch.save(model.state_dict(), "./checkpoint/model.pt")
            return start_idx, "./checkpoint/model.pt"  # Trả về start_idx để lưu lại vị trí

        start_idx = end_idx

    if not os.path.exists('./checkpoint'):
        os.makedirs('./checkpoint')
    torch.save(model.state_dict(), "./checkpoint/model.pt")
    return start_idx, "./checkpoint/model.pt"


if __name__ == "__main__":
    iface = gr.Interface(
        fn=train_step,
        inputs=[gr.File(label="Upload CSV"), gr.Textbox()],
        outputs=["text", gr.File()]
    )
    iface.launch()