burtenshaw HF staff commited on
Commit
3fbd38b
1 Parent(s): 4ad32d0

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. main.py +1 -1
  2. src/args.py +3 -0
main.py CHANGED
@@ -64,7 +64,7 @@ class ORPO(object):
64
  test = self.data[test_split].filter(self.filter_dataset)
65
  self.test = test.map(self.preprocess_dataset, batched=True, num_proc=self.args.num_proc, remove_columns=self.data[test_split].column_names)
66
 
67
- train = self.data[train_split].filter(self.filter_dataset).select(range(self.args.max_samples))
68
  print(f"\n\n>>> {len(train)} / {len(self.data[train_split])} rows left after filtering by prompt length.")
69
  self.train = train.map(self.preprocess_dataset, batched=True, num_proc=self.args.num_proc, remove_columns=self.data[train_split].column_names)
70
 
 
64
  test = self.data[test_split].filter(self.filter_dataset)
65
  self.test = test.map(self.preprocess_dataset, batched=True, num_proc=self.args.num_proc, remove_columns=self.data[test_split].column_names)
66
 
67
+ train = self.data[train_split].filter(self.filter_dataset)[self.args.max_samples]
68
  print(f"\n\n>>> {len(train)} / {len(self.data[train_split])} rows left after filtering by prompt length.")
69
  self.train = train.map(self.preprocess_dataset, batched=True, num_proc=self.args.num_proc, remove_columns=self.data[train_split].column_names)
70
 
src/args.py CHANGED
@@ -4,6 +4,9 @@ def default_args(parser):
4
  parser.add_argument("--data_name", default='HuggingfaceH4/UltraFeedback', type=str)
5
  parser.add_argument("--model_name", default="gpt2", type=str)
6
 
 
 
 
7
  # Training Arguments
8
  parser.add_argument("--torch_compile", default=False, type=bool)
9
  parser.add_argument("--flash_attention_2", action='store_true')
 
4
  parser.add_argument("--data_name", default='HuggingfaceH4/UltraFeedback', type=str)
5
  parser.add_argument("--model_name", default="gpt2", type=str)
6
 
7
+ # Data Arguments
8
+ parser.add_argument("--max_samples", default=None, type=int)
9
+
10
  # Training Arguments
11
  parser.add_argument("--torch_compile", default=False, type=bool)
12
  parser.add_argument("--flash_attention_2", action='store_true')