--- license: mit datasets: - b-mc2/sql-create-context - gretelai/synthetic_text_to_sql language: - en base_model: google-t5/t5-base metrics: - exact_match model-index: - name: juanfra218/text2sql results: - task: type: text-to-sql metrics: - name: exact_match type: exact_match value: 0.4322 --- # Fine-Tuned Google T5 Model for Text to SQL Translation A fine-tuned version of the Google T5 model, trained for the task of translating natural language queries into SQL statements. ## Model Details - **Architecture**: Google T5 Base (Text-to-Text Transfer Transformer) - **Task**: Text to SQL Translation - **Fine-Tuning Datasets**: - [sql-create-context Dataset](https://huggingface.co/datasets/b-mc2/sql-create-context) - [Synthetic-Text-To-SQL Dataset](https://huggingface.co/datasets/gretelai/synthetic-text-to-sql) ## Ongoing Work Currently working to implement PICARD (Parsing Incrementally for Constrained Auto-Regressive Decoding from Language Models) to improve the results of this model. More details can be found in the original [PICARD paper](https://arxiv.org/abs/2109.05093). ## Results Results are currently being evaluated and will be posted here soon. ## Usage ``` import torch from transformers import AutoTokenizer, T5ForConditionalGeneration # Load the tokenizer and model model_path = 'text2sql_model_path' tokenizer = AutoTokenizer.from_pretrained(model_path) model = T5ForConditionalGeneration.from_pretrained(model_path) # Function to generate SQL queries def generate_sql(prompt, schema): input_text = "translate English to SQL: " + prompt + " " + schema inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True, padding="max_length") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") inputs = {key: value.to(device) for key, value in inputs.items()} max_output_length = 1024 outputs = model.generate(**inputs, max_length=max_output_length) return tokenizer.decode(outputs[0], skip_special_tokens=True) # Interactive loop print("Enter 'quit' to exit.") while True: prompt = input("Insert prompt: ") schema = input("Insert schema: ") if prompt.lower() == 'quit': break sql_query = generate_sql(prompt, schema) print(f"Generated SQL query: {sql_query}") print() ``` ## Files - `optimizer.pt`: State of the optimizer. - `training_args.bin`: Training arguments and hyperparameters. - `tokenizer.json`: Tokenizer vocabulary and settings. - `spiece.model`: SentencePiece model file. - `special_tokens_map.json`: Special tokens mapping. - `tokenizer_config.json`: Tokenizer configuration settings. - `model.safetensors`: Trained model weights. - `generation_config.json`: Configuration for text generation. - `config.json`: Model architecture configuration.