Spaces:
Running
Running
Joshua Lochner
commited on
Commit
•
320a2ba
1
Parent(s):
3af0cd0
Change to multiclass classifier
Browse files- src/train.py +32 -22
src/train.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from preprocess import load_datasets, DatasetArguments
|
2 |
-
from predict import ClassifierArguments,
|
3 |
from shared import CustomTokens, device, GeneralArguments, OutputArguments
|
4 |
-
from model import ModelArguments
|
5 |
import transformers
|
6 |
import logging
|
7 |
import os
|
@@ -14,15 +14,17 @@ from transformers import (
|
|
14 |
DataCollatorForSeq2Seq,
|
15 |
HfArgumentParser,
|
16 |
Seq2SeqTrainer,
|
17 |
-
Seq2SeqTrainingArguments
|
|
|
|
|
18 |
)
|
|
|
19 |
from transformers.trainer_utils import get_last_checkpoint
|
20 |
from transformers.utils import check_min_version
|
21 |
from transformers.utils.versions import require_version
|
22 |
from sklearn.linear_model import LogisticRegression
|
23 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
24 |
from utils import re_findall
|
25 |
-
import re
|
26 |
|
27 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
28 |
check_min_version('4.13.0.dev0')
|
@@ -256,8 +258,9 @@ def main():
|
|
256 |
|
257 |
ngram_range=(1, 2), # best so far
|
258 |
# max_features=8000 # remove for higher accuracy?
|
259 |
-
|
260 |
-
max_features=10000
|
|
|
261 |
)
|
262 |
|
263 |
train_test_data = {
|
@@ -276,17 +279,17 @@ def main():
|
|
276 |
dataset = raw_datasets[ds_type]
|
277 |
|
278 |
for row in dataset:
|
279 |
-
|
280 |
-
matches
|
281 |
-
|
282 |
-
|
283 |
|
284 |
-
|
285 |
-
|
286 |
|
287 |
-
|
288 |
-
train_test_data[ds_type]['X'].append(
|
289 |
-
train_test_data[ds_type]['y'].append(
|
290 |
|
291 |
print('Fitting')
|
292 |
_X_train = vectorizer.fit_transform(train_test_data['train']['X'])
|
@@ -296,10 +299,10 @@ def main():
|
|
296 |
y_test = train_test_data['test']['y']
|
297 |
|
298 |
# 2. Create classifier
|
299 |
-
classifier = LogisticRegression(max_iter=
|
300 |
|
301 |
# 3. Fit data
|
302 |
-
print('
|
303 |
classifier.fit(_X_train, y_train)
|
304 |
|
305 |
# 4. Measure accuracy
|
@@ -336,9 +339,15 @@ def main():
|
|
336 |
)
|
337 |
|
338 |
# Load pretrained model and tokenizer
|
339 |
-
|
340 |
-
|
341 |
model.to(device())
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
model.resize_token_embeddings(len(tokenizer))
|
343 |
|
344 |
if model.config.decoder_start_token_id is None:
|
@@ -479,9 +488,10 @@ def main():
|
|
479 |
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
480 |
trainer.save_model() # Saves the tokenizer too for easy upload
|
481 |
except KeyboardInterrupt:
|
482 |
-
|
483 |
-
|
484 |
-
|
|
|
485 |
raise
|
486 |
|
487 |
metrics = train_result.metrics
|
|
|
1 |
from preprocess import load_datasets, DatasetArguments
|
2 |
+
from predict import ClassifierArguments, SEGMENT_MATCH_RE, CATEGORIES
|
3 |
from shared import CustomTokens, device, GeneralArguments, OutputArguments
|
4 |
+
from model import ModelArguments
|
5 |
import transformers
|
6 |
import logging
|
7 |
import os
|
|
|
14 |
DataCollatorForSeq2Seq,
|
15 |
HfArgumentParser,
|
16 |
Seq2SeqTrainer,
|
17 |
+
Seq2SeqTrainingArguments,
|
18 |
+
AutoTokenizer,
|
19 |
+
AutoModelForSeq2SeqLM
|
20 |
)
|
21 |
+
|
22 |
from transformers.trainer_utils import get_last_checkpoint
|
23 |
from transformers.utils import check_min_version
|
24 |
from transformers.utils.versions import require_version
|
25 |
from sklearn.linear_model import LogisticRegression
|
26 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
27 |
from utils import re_findall
|
|
|
28 |
|
29 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
30 |
check_min_version('4.13.0.dev0')
|
|
|
258 |
|
259 |
ngram_range=(1, 2), # best so far
|
260 |
# max_features=8000 # remove for higher accuracy?
|
261 |
+
max_features=20000
|
262 |
+
# max_features=10000
|
263 |
+
# max_features=1000
|
264 |
)
|
265 |
|
266 |
train_test_data = {
|
|
|
279 |
dataset = raw_datasets[ds_type]
|
280 |
|
281 |
for row in dataset:
|
282 |
+
matches = re_findall(SEGMENT_MATCH_RE, row['extracted'])
|
283 |
+
if matches:
|
284 |
+
for match in matches:
|
285 |
+
train_test_data[ds_type]['X'].append(match['text'])
|
286 |
|
287 |
+
class_index = CATEGORIES.index(match['category'])
|
288 |
+
train_test_data[ds_type]['y'].append(class_index)
|
289 |
|
290 |
+
else:
|
291 |
+
train_test_data[ds_type]['X'].append(row['text'])
|
292 |
+
train_test_data[ds_type]['y'].append(0)
|
293 |
|
294 |
print('Fitting')
|
295 |
_X_train = vectorizer.fit_transform(train_test_data['train']['X'])
|
|
|
299 |
y_test = train_test_data['test']['y']
|
300 |
|
301 |
# 2. Create classifier
|
302 |
+
classifier = LogisticRegression(max_iter=2000, class_weight='balanced')
|
303 |
|
304 |
# 3. Fit data
|
305 |
+
print('Fit classifier')
|
306 |
classifier.fit(_X_train, y_train)
|
307 |
|
308 |
# 4. Measure accuracy
|
|
|
339 |
)
|
340 |
|
341 |
# Load pretrained model and tokenizer
|
342 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
343 |
+
model_args.model_name_or_path)
|
344 |
model.to(device())
|
345 |
+
|
346 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
347 |
+
model_args.model_name_or_path)
|
348 |
+
|
349 |
+
# Ensure model and tokenizer contain the custom tokens
|
350 |
+
CustomTokens.add_custom_tokens(tokenizer)
|
351 |
model.resize_token_embeddings(len(tokenizer))
|
352 |
|
353 |
if model.config.decoder_start_token_id is None:
|
|
|
488 |
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
489 |
trainer.save_model() # Saves the tokenizer too for easy upload
|
490 |
except KeyboardInterrupt:
|
491 |
+
# TODO add option to save model on interrupt?
|
492 |
+
# print('Saving model')
|
493 |
+
# trainer.save_model(os.path.join(
|
494 |
+
# training_args.output_dir, 'checkpoint-latest')) # TODO use dir
|
495 |
raise
|
496 |
|
497 |
metrics = train_result.metrics
|