ShynBui commited on
Commit
abe15ad
1 Parent(s): 070d008

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -109
app.py CHANGED
@@ -20,115 +20,6 @@ optimizer = AdamW(model.parameters(), lr=1e-5)
20
 
21
  global_data = None
22
 
23
- def load_data(file):
24
- global global_data
25
- df = pd.read_csv(file)
26
- inputs = tokenizer(df['text'].tolist(), padding=True, truncation=True, return_tensors="pt") # Mã hóa văn bản
27
- labels = torch.tensor(df['label'].tolist()).long() # Đảm bảo tên cột là 'label'
28
- global_data = TensorDataset(inputs['input_ids'], inputs['attention_mask'], labels)
29
-
30
- print(global_data)
31
-
32
- def get_dataloader(start, end, batch_size=8):
33
- global global_data
34
- subset = torch.utils.data.Subset(global_data, range(start, end))
35
- return DataLoader(subset, batch_size=batch_size)
36
-
37
- @spaces.GPU(duration=20)
38
- def train_batch(dataloader):
39
- model.train()
40
- start_time = time.time()
41
-
42
- for step, batch in enumerate(dataloader):
43
- input_ids, attention_mask, labels = batch
44
- input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
45
-
46
- optimizer.zero_grad()
47
- outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
48
- loss = outputs.loss
49
- loss.backward()
50
- optimizer.step()
51
-
52
- elapsed_time = time.time() - start_time
53
- if elapsed_time > 10:
54
- print('Save checkpoint')
55
- if not os.path.exists('./checkpoint'):
56
- os.makedirs('./checkpoint')
57
- torch.save(model.state_dict(), "./checkpoint/model.pt")
58
-
59
- return False, "Checkpoint saved. Training paused."
60
-
61
- return True, "Batch training completed."
62
-
63
- Hugging Face's logo
64
- Hugging Face
65
- Search models, datasets, users...
66
- Models
67
- Datasets
68
- Spaces
69
- Posts
70
- Docs
71
- Solutions
72
- Pricing
73
-
74
-
75
-
76
- Hugging Face is way more fun with friends and colleagues! 🤗 Join an organization
77
- Spaces:
78
-
79
- ShynBui
80
- /
81
- train_for_fun
82
-
83
- private
84
-
85
- Logs
86
- App
87
- Files
88
- Community
89
- Settings
90
- train_for_fun
91
- /
92
- app.py
93
-
94
- ShynBui's picture
95
- ShynBui
96
- Update app.py
97
- 07a2715
98
- verified
99
- 15 minutes ago
100
- raw
101
-
102
- Copy download link
103
- history
104
- blame
105
- edit
106
- delete
107
- No virus
108
-
109
- 3.25 kB
110
- import time
111
- import torch
112
- from transformers import BertForSequenceClassification, AdamW
113
- from torch.utils.data import DataLoader, TensorDataset
114
- from transformers import BertTokenizer
115
- import gradio as gr
116
- import pandas as pd
117
- import os
118
- import spaces
119
- from spaces.zero.gradio import HTMLError
120
-
121
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
122
- print(device)
123
-
124
- model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
125
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
126
- model.to(device)
127
-
128
- optimizer = AdamW(model.parameters(), lr=1e-5)
129
-
130
- global_data = None
131
-
132
  def load_data(file):
133
  global global_data
134
  df = pd.read_csv(file)
 
20
 
21
  global_data = None
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def load_data(file):
24
  global global_data
25
  df = pd.read_csv(file)