Joshua Lochner commited on
Commit
086ca93
1 Parent(s): 583f4cf

Move `load_datasets` to train script

Browse files
Files changed (1) hide show
  1. src/train.py +20 -1
src/train.py CHANGED
@@ -1,4 +1,5 @@
1
- from preprocess import load_datasets, DatasetArguments
 
2
  from predict import ClassifierArguments, SEGMENT_MATCH_RE, CATEGORIES
3
  from shared import CustomTokens, GeneralArguments, OutputArguments
4
  from model import ModelArguments
@@ -42,6 +43,24 @@ logging.basicConfig(
42
  )
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  @dataclass
46
  class DataTrainingArguments:
47
  """
 
1
+ from datasets import load_dataset
2
+ from preprocess import DatasetArguments
3
  from predict import ClassifierArguments, SEGMENT_MATCH_RE, CATEGORIES
4
  from shared import CustomTokens, GeneralArguments, OutputArguments
5
  from model import ModelArguments
 
43
  )
44
 
45
 
46
+ def load_datasets(dataset_args):
47
+
48
+ print('Reading datasets')
49
+ data_files = {}
50
+
51
+ if dataset_args.train_file is not None:
52
+ data_files['train'] = os.path.join(
53
+ dataset_args.data_dir, dataset_args.train_file)
54
+ if dataset_args.validation_file is not None:
55
+ data_files['validation'] = os.path.join(
56
+ dataset_args.data_dir, dataset_args.validation_file)
57
+ if dataset_args.test_file is not None:
58
+ data_files['test'] = os.path.join(
59
+ dataset_args.data_dir, dataset_args.test_file)
60
+
61
+ return load_dataset('json', data_files=data_files, cache_dir=dataset_args.dataset_cache_dir)
62
+
63
+
64
  @dataclass
65
  class DataTrainingArguments:
66
  """