--- license: mit datasets: - b-mc2/sql-create-context - gretelai/synthetic_text_to_sql language: - en base_model: google-t5/t5-base metrics: - exact_match model-index: - name: juanfra218/text2sql results: - task: type: text-to-sql metrics: - name: exact_match type: exact_match value: 0.4326836917562724 - name: bleu type: bleu value: 0.6687 tags: - sql library_name: transformers --- # Fine-Tuned Google T5 Model for Text to SQL Translation A fine-tuned version of the Google T5 model, trained for the task of translating natural language queries into SQL statements. ## Model Details - **Architecture**: Google T5 Base (Text-to-Text Transfer Transformer) - **Task**: Text to SQL Translation - **Fine-Tuning Datasets**: - [sql-create-context Dataset](https://huggingface.co/datasets/b-mc2/sql-create-context) - [Synthetic-Text-To-SQL Dataset](https://huggingface.co/datasets/gretelai/synthetic-text-to-sql) ## Training Parameters ``` training_args = Seq2SeqTrainingArguments( output_dir="./results", evaluation_strategy="epoch", learning_rate=2e-5, per_device_train_batch_size=8, per_device_eval_batch_size=8, weight_decay=0.01, save_total_limit=3, num_train_epochs=3, predict_with_generate=True, fp16=True, push_to_hub=False, ) ``` ## Usage ``` import torch from transformers import T5Tokenizer, T5ForConditionalGeneration device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the tokenizer and model model_path = 'juanfra218/text2sql' tokenizer = T5Tokenizer.from_pretrained(model_path) model = T5ForConditionalGeneration.from_pretrained(model_path) model.to(device) # Function to generate SQL queries def generate_sql(prompt, schema): input_text = "translate English to SQL: " + prompt + " " + schema inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True, padding="max_length") inputs = {key: value.to(device) for key, value in inputs.items()} max_output_length = 1024 outputs = model.generate(**inputs, max_length=max_output_length) return tokenizer.decode(outputs[0], skip_special_tokens=True) # Interactive loop print("Enter 'quit' to exit.") while True: prompt = input("Insert prompt: ") schema = input("Insert schema: ") if prompt.lower() == 'quit': break sql_query = generate_sql(prompt, schema) print(f"Generated SQL query: {sql_query}") print() ``` ## Files - `optimizer.pt`: State of the optimizer. - `training_args.bin`: Training arguments and hyperparameters. - `tokenizer.json`: Tokenizer vocabulary and settings. - `spiece.model`: SentencePiece model file. - `special_tokens_map.json`: Special tokens mapping. - `tokenizer_config.json`: Tokenizer configuration settings. - `model.safetensors`: Trained model weights. - `generation_config.json`: Configuration for text generation. - `config.json`: Model architecture configuration. - `test_results.csv`: Results on the testing set, contains: prompt, context, true_answer, predicted_answer, exact_match