Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, TrainingArguments, Trainer, BertLMHeadModel, BertForSequenceClassification | |
from datasets import Dataset | |
import pandas as pd | |
import csv | |
from transformers import TrainingArguments, Trainer | |
import tensorflow as tf | |
# Check TensorFlow GPU availability | |
print("GPUs Available: ", tf.config.list_physical_devices('GPU')) | |
import os | |
# Setting the environment variable for MPS | |
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0' | |
def get_device(): | |
"""Automatically chooses the best device.""" | |
if torch.cuda.is_available(): | |
return torch.device('cuda') | |
elif torch.backends.mps.is_available(): | |
return torch.device('mps') | |
else: | |
return torch.device('cpu') | |
def load_data_and_config(data_path): | |
"""Loads training data from CSV.""" | |
data = [] | |
with open(data_path, newline='', encoding='utf-8') as csvfile: | |
reader = csv.DictReader(csvfile, delimiter=';') | |
for row in reader: | |
data.append({'text': row['description']}) | |
return data | |
def train_model(model, tokenizer, data, device): | |
"""Trains the model using the Hugging Face Trainer API.""" | |
inputs = [tokenizer(d['text'], max_length=512, truncation=True, padding='max_length', return_tensors="pt") for d in data] | |
dataset = Dataset.from_dict({ | |
'input_ids': [x['input_ids'].squeeze() for x in inputs], | |
'labels': [x['input_ids'].squeeze() for x in inputs] | |
}) | |
training_args = TrainingArguments( | |
output_dir='./results', | |
num_train_epochs=3, | |
per_device_train_batch_size=8, | |
gradient_accumulation_steps=4, | |
fp16=True, # Enable mixed precision | |
warmup_steps=500, | |
weight_decay=0.01, | |
logging_dir='./logs', | |
logging_steps=10, | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=dataset, | |
tokenizer=tokenizer | |
) | |
trainer.train() | |
# Perform any remaining steps such as logging, saving, etc. | |
trainer.save_model() | |
def main(api_name, base_url): | |
device = get_device() # Get the appropriate device | |
data = load_data_and_config("train2.csv") | |
tokenizer = AutoTokenizer.from_pretrained("google/codegemma-2b") | |
# Load the configuration for a specific model | |
config = AutoConfig.from_pretrained('google/codegemma-2b') | |
# Update the activation function | |
config.hidden_act = 'gelu_pytorch_tanh' # Set to use approximate GeLU | |
model = AutoModelForCausalLM.from_pretrained('google/codegemma-2b', is_decoder=True) | |
#model = BertLMHeadModel.from_pretrained('google/codegemma-2b', is_decoder=True) | |
# Example assuming you have a prepared dataset for classification | |
#model = BertForSequenceClassification.from_pretrained('thenlper/gte-small', num_labels=2, is_decoder=True) # binary classification | |
model.to(device) # Move model to the appropriate device | |
train_model(model, tokenizer, data, device) | |
model.save_pretrained("./fine_tuned_model") | |
tokenizer.save_pretrained("./fine_tuned_model") | |
prompt = "I need to retrieve the latest block on chain using a python script" | |
api_query = generate_api_query(model, tokenizer, prompt, "latest block on chain", api_name, base_url) | |
print(f"Generated code: {api_query}") | |
def generate_api_query(model, tokenizer, prompt, desired_output, api_name, base_url): | |
# Prepare input prompt for the model, ensure tensors are compatible with PyTorch | |
input_ids = tokenizer.encode(f"{prompt} Write an API query to {api_name} to get {desired_output}", return_tensors="pt") | |
# Ensure input_ids are on the same device as the model | |
input_ids = input_ids.to(model.device) | |
# Generate query using model with temperature for randomness | |
output = model.generate(input_ids, max_length=128, truncation=True, padding='max_length', temperature=0.1, do_sample=True) | |
# Decode the generated query tokens | |
query = tokenizer.decode(output[0], skip_special_tokens=True) | |
return f"{base_url}/{query}" | |
if __name__ == "__main__": | |
api_name = "Koios" | |
base_url = "https://api.koios.rest/v1" | |
main(api_name, base_url) | |