Spaces:
Running
Running
Joshua Lochner
commited on
Commit
•
086ca93
1
Parent(s):
583f4cf
Move `load_datasets` to train script
Browse files- src/train.py +20 -1
src/train.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
from
|
|
|
2 |
from predict import ClassifierArguments, SEGMENT_MATCH_RE, CATEGORIES
|
3 |
from shared import CustomTokens, GeneralArguments, OutputArguments
|
4 |
from model import ModelArguments
|
@@ -42,6 +43,24 @@ logging.basicConfig(
|
|
42 |
)
|
43 |
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
@dataclass
|
46 |
class DataTrainingArguments:
|
47 |
"""
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
from preprocess import DatasetArguments
|
3 |
from predict import ClassifierArguments, SEGMENT_MATCH_RE, CATEGORIES
|
4 |
from shared import CustomTokens, GeneralArguments, OutputArguments
|
5 |
from model import ModelArguments
|
|
|
43 |
)
|
44 |
|
45 |
|
46 |
+
def load_datasets(dataset_args):
|
47 |
+
|
48 |
+
print('Reading datasets')
|
49 |
+
data_files = {}
|
50 |
+
|
51 |
+
if dataset_args.train_file is not None:
|
52 |
+
data_files['train'] = os.path.join(
|
53 |
+
dataset_args.data_dir, dataset_args.train_file)
|
54 |
+
if dataset_args.validation_file is not None:
|
55 |
+
data_files['validation'] = os.path.join(
|
56 |
+
dataset_args.data_dir, dataset_args.validation_file)
|
57 |
+
if dataset_args.test_file is not None:
|
58 |
+
data_files['test'] = os.path.join(
|
59 |
+
dataset_args.data_dir, dataset_args.test_file)
|
60 |
+
|
61 |
+
return load_dataset('json', data_files=data_files, cache_dir=dataset_args.dataset_cache_dir)
|
62 |
+
|
63 |
+
|
64 |
@dataclass
|
65 |
class DataTrainingArguments:
|
66 |
"""
|