Spaces:
Runtime error
Runtime error
stakelovelace
commited on
Commit
•
acc7015
1
Parent(s):
2094fe7
commit from tesla
Browse files
app.py
CHANGED
@@ -42,8 +42,9 @@ def train_model(model, tokenizer, data, device):
|
|
42 |
training_args = TrainingArguments(
|
43 |
output_dir='./results',
|
44 |
num_train_epochs=3,
|
45 |
-
per_device_train_batch_size=
|
46 |
-
gradient_accumulation_steps=
|
|
|
47 |
warmup_steps=500,
|
48 |
weight_decay=0.01,
|
49 |
logging_dir='./logs',
|
@@ -89,7 +90,7 @@ def generate_api_query(model, tokenizer, prompt, desired_output, api_name, base_
|
|
89 |
input_ids = input_ids.to(model.device)
|
90 |
|
91 |
# Generate query using model with temperature for randomness
|
92 |
-
output = model.generate(input_ids, max_length=
|
93 |
|
94 |
# Decode the generated query tokens
|
95 |
query = tokenizer.decode(output[0], skip_special_tokens=True)
|
|
|
42 |
training_args = TrainingArguments(
|
43 |
output_dir='./results',
|
44 |
num_train_epochs=3,
|
45 |
+
per_device_train_batch_size=8,
|
46 |
+
gradient_accumulation_steps=4,
|
47 |
+
fp16=True, # Enable mixed precision
|
48 |
warmup_steps=500,
|
49 |
weight_decay=0.01,
|
50 |
logging_dir='./logs',
|
|
|
90 |
input_ids = input_ids.to(model.device)
|
91 |
|
92 |
# Generate query using model with temperature for randomness
|
93 |
+
output = model.generate(input_ids, max_length=128, truncation=True, padding='max_length', temperature=0.1, do_sample=True)
|
94 |
|
95 |
# Decode the generated query tokens
|
96 |
query = tokenizer.decode(output[0], skip_special_tokens=True)
|