Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -20,115 +20,6 @@ optimizer = AdamW(model.parameters(), lr=1e-5)
|
|
20 |
|
21 |
global_data = None
|
22 |
|
23 |
-
def load_data(file):
|
24 |
-
global global_data
|
25 |
-
df = pd.read_csv(file)
|
26 |
-
inputs = tokenizer(df['text'].tolist(), padding=True, truncation=True, return_tensors="pt") # Mã hóa văn bản
|
27 |
-
labels = torch.tensor(df['label'].tolist()).long() # Đảm bảo tên cột là 'label'
|
28 |
-
global_data = TensorDataset(inputs['input_ids'], inputs['attention_mask'], labels)
|
29 |
-
|
30 |
-
print(global_data)
|
31 |
-
|
32 |
-
def get_dataloader(start, end, batch_size=8):
|
33 |
-
global global_data
|
34 |
-
subset = torch.utils.data.Subset(global_data, range(start, end))
|
35 |
-
return DataLoader(subset, batch_size=batch_size)
|
36 |
-
|
37 |
-
@spaces.GPU(duration=20)
|
38 |
-
def train_batch(dataloader):
|
39 |
-
model.train()
|
40 |
-
start_time = time.time()
|
41 |
-
|
42 |
-
for step, batch in enumerate(dataloader):
|
43 |
-
input_ids, attention_mask, labels = batch
|
44 |
-
input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
|
45 |
-
|
46 |
-
optimizer.zero_grad()
|
47 |
-
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
|
48 |
-
loss = outputs.loss
|
49 |
-
loss.backward()
|
50 |
-
optimizer.step()
|
51 |
-
|
52 |
-
elapsed_time = time.time() - start_time
|
53 |
-
if elapsed_time > 10:
|
54 |
-
print('Save checkpoint')
|
55 |
-
if not os.path.exists('./checkpoint'):
|
56 |
-
os.makedirs('./checkpoint')
|
57 |
-
torch.save(model.state_dict(), "./checkpoint/model.pt")
|
58 |
-
|
59 |
-
return False, "Checkpoint saved. Training paused."
|
60 |
-
|
61 |
-
return True, "Batch training completed."
|
62 |
-
|
63 |
-
Hugging Face's logo
|
64 |
-
Hugging Face
|
65 |
-
Search models, datasets, users...
|
66 |
-
Models
|
67 |
-
Datasets
|
68 |
-
Spaces
|
69 |
-
Posts
|
70 |
-
Docs
|
71 |
-
Solutions
|
72 |
-
Pricing
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
Hugging Face is way more fun with friends and colleagues! 🤗 Join an organization
|
77 |
-
Spaces:
|
78 |
-
|
79 |
-
ShynBui
|
80 |
-
/
|
81 |
-
train_for_fun
|
82 |
-
|
83 |
-
private
|
84 |
-
|
85 |
-
Logs
|
86 |
-
App
|
87 |
-
Files
|
88 |
-
Community
|
89 |
-
Settings
|
90 |
-
train_for_fun
|
91 |
-
/
|
92 |
-
app.py
|
93 |
-
|
94 |
-
ShynBui's picture
|
95 |
-
ShynBui
|
96 |
-
Update app.py
|
97 |
-
07a2715
|
98 |
-
verified
|
99 |
-
15 minutes ago
|
100 |
-
raw
|
101 |
-
|
102 |
-
Copy download link
|
103 |
-
history
|
104 |
-
blame
|
105 |
-
edit
|
106 |
-
delete
|
107 |
-
No virus
|
108 |
-
|
109 |
-
3.25 kB
|
110 |
-
import time
|
111 |
-
import torch
|
112 |
-
from transformers import BertForSequenceClassification, AdamW
|
113 |
-
from torch.utils.data import DataLoader, TensorDataset
|
114 |
-
from transformers import BertTokenizer
|
115 |
-
import gradio as gr
|
116 |
-
import pandas as pd
|
117 |
-
import os
|
118 |
-
import spaces
|
119 |
-
from spaces.zero.gradio import HTMLError
|
120 |
-
|
121 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
122 |
-
print(device)
|
123 |
-
|
124 |
-
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
|
125 |
-
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
126 |
-
model.to(device)
|
127 |
-
|
128 |
-
optimizer = AdamW(model.parameters(), lr=1e-5)
|
129 |
-
|
130 |
-
global_data = None
|
131 |
-
|
132 |
def load_data(file):
|
133 |
global global_data
|
134 |
df = pd.read_csv(file)
|
|
|
20 |
|
21 |
global_data = None
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def load_data(file):
|
24 |
global global_data
|
25 |
df = pd.read_csv(file)
|