stakelovelace commited on
Commit
acc7015
1 Parent(s): 2094fe7

commit from tesla

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -42,8 +42,9 @@ def train_model(model, tokenizer, data, device):
42
  training_args = TrainingArguments(
43
  output_dir='./results',
44
  num_train_epochs=3,
45
- per_device_train_batch_size=1,
46
- gradient_accumulation_steps=1,
 
47
  warmup_steps=500,
48
  weight_decay=0.01,
49
  logging_dir='./logs',
@@ -89,7 +90,7 @@ def generate_api_query(model, tokenizer, prompt, desired_output, api_name, base_
89
  input_ids = input_ids.to(model.device)
90
 
91
  # Generate query using model with temperature for randomness
92
- output = model.generate(input_ids, max_length=256, temperature=0.1, do_sample=True)
93
 
94
  # Decode the generated query tokens
95
  query = tokenizer.decode(output[0], skip_special_tokens=True)
 
42
  training_args = TrainingArguments(
43
  output_dir='./results',
44
  num_train_epochs=3,
45
+ per_device_train_batch_size=8,
46
+ gradient_accumulation_steps=4,
47
+ fp16=True, # Enable mixed precision
48
  warmup_steps=500,
49
  weight_decay=0.01,
50
  logging_dir='./logs',
 
90
  input_ids = input_ids.to(model.device)
91
 
92
  # Generate query using model with temperature for randomness
93
+ output = model.generate(input_ids, max_length=128, truncation=True, padding='max_length', temperature=0.1, do_sample=True)
94
 
95
  # Decode the generated query tokens
96
  query = tokenizer.decode(output[0], skip_special_tokens=True)