Commit
•
3fbd38b
1
Parent(s):
4ad32d0
Upload folder using huggingface_hub
Browse files- main.py +1 -1
- src/args.py +3 -0
main.py
CHANGED
@@ -64,7 +64,7 @@ class ORPO(object):
|
|
64 |
test = self.data[test_split].filter(self.filter_dataset)
|
65 |
self.test = test.map(self.preprocess_dataset, batched=True, num_proc=self.args.num_proc, remove_columns=self.data[test_split].column_names)
|
66 |
|
67 |
-
train = self.data[train_split].filter(self.filter_dataset)
|
68 |
print(f"\n\n>>> {len(train)} / {len(self.data[train_split])} rows left after filtering by prompt length.")
|
69 |
self.train = train.map(self.preprocess_dataset, batched=True, num_proc=self.args.num_proc, remove_columns=self.data[train_split].column_names)
|
70 |
|
|
|
64 |
test = self.data[test_split].filter(self.filter_dataset)
|
65 |
self.test = test.map(self.preprocess_dataset, batched=True, num_proc=self.args.num_proc, remove_columns=self.data[test_split].column_names)
|
66 |
|
67 |
+
train = self.data[train_split].filter(self.filter_dataset)[self.args.max_samples]
|
68 |
print(f"\n\n>>> {len(train)} / {len(self.data[train_split])} rows left after filtering by prompt length.")
|
69 |
self.train = train.map(self.preprocess_dataset, batched=True, num_proc=self.args.num_proc, remove_columns=self.data[train_split].column_names)
|
70 |
|
src/args.py
CHANGED
@@ -4,6 +4,9 @@ def default_args(parser):
|
|
4 |
parser.add_argument("--data_name", default='HuggingfaceH4/UltraFeedback', type=str)
|
5 |
parser.add_argument("--model_name", default="gpt2", type=str)
|
6 |
|
|
|
|
|
|
|
7 |
# Training Arguments
|
8 |
parser.add_argument("--torch_compile", default=False, type=bool)
|
9 |
parser.add_argument("--flash_attention_2", action='store_true')
|
|
|
4 |
parser.add_argument("--data_name", default='HuggingfaceH4/UltraFeedback', type=str)
|
5 |
parser.add_argument("--model_name", default="gpt2", type=str)
|
6 |
|
7 |
+
# Data Arguments
|
8 |
+
parser.add_argument("--max_samples", default=None, type=int)
|
9 |
+
|
10 |
# Training Arguments
|
11 |
parser.add_argument("--torch_compile", default=False, type=bool)
|
12 |
parser.add_argument("--flash_attention_2", action='store_true')
|