|
--- |
|
license: mit |
|
datasets: |
|
- b-mc2/sql-create-context |
|
- gretelai/synthetic_text_to_sql |
|
language: |
|
- en |
|
base_model: google-t5/t5-base |
|
metrics: |
|
- exact_match |
|
model-index: |
|
- name: juanfra218/text2sql |
|
results: |
|
- task: |
|
type: text-to-sql |
|
metrics: |
|
- name: exact_match |
|
type: exact_match |
|
value: 0.4322 |
|
--- |
|
|
|
# Fine-Tuned Google T5 Model for Text to SQL Translation |
|
|
|
A fine-tuned version of the Google T5 model, trained for the task of translating natural language queries into SQL statements. |
|
|
|
## Model Details |
|
|
|
- **Architecture**: Google T5 Base (Text-to-Text Transfer Transformer) |
|
- **Task**: Text to SQL Translation |
|
- **Fine-Tuning Datasets**: |
|
- [sql-create-context Dataset](https://huggingface.co/datasets/b-mc2/sql-create-context) |
|
- [Synthetic-Text-To-SQL Dataset](https://huggingface.co/datasets/gretelai/synthetic-text-to-sql) |
|
|
|
## Ongoing Work |
|
|
|
Currently working to implement PICARD (Parsing Incrementally for Constrained Auto-Regressive Decoding from Language Models) to improve the results of this model. More details can be found in the original [PICARD paper](https://arxiv.org/abs/2109.05093). |
|
|
|
## Results |
|
|
|
Results are currently being evaluated and will be posted here soon. |
|
|
|
## Usage |
|
|
|
``` |
|
import torch |
|
from transformers import AutoTokenizer, T5ForConditionalGeneration |
|
|
|
# Load the tokenizer and model |
|
model_path = 'text2sql_model_path' |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model = T5ForConditionalGeneration.from_pretrained(model_path) |
|
|
|
# Function to generate SQL queries |
|
def generate_sql(prompt, schema): |
|
input_text = "translate English to SQL: " + prompt + " " + schema |
|
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True, padding="max_length") |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
inputs = {key: value.to(device) for key, value in inputs.items()} |
|
|
|
max_output_length = 1024 |
|
outputs = model.generate(**inputs, max_length=max_output_length) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
# Interactive loop |
|
print("Enter 'quit' to exit.") |
|
while True: |
|
prompt = input("Insert prompt: ") |
|
schema = input("Insert schema: ") |
|
if prompt.lower() == 'quit': |
|
break |
|
|
|
sql_query = generate_sql(prompt, schema) |
|
print(f"Generated SQL query: {sql_query}") |
|
print() |
|
``` |
|
|
|
## Files |
|
|
|
- `optimizer.pt`: State of the optimizer. |
|
- `training_args.bin`: Training arguments and hyperparameters. |
|
- `tokenizer.json`: Tokenizer vocabulary and settings. |
|
- `spiece.model`: SentencePiece model file. |
|
- `special_tokens_map.json`: Special tokens mapping. |
|
- `tokenizer_config.json`: Tokenizer configuration settings. |
|
- `model.safetensors`: Trained model weights. |
|
- `generation_config.json`: Configuration for text generation. |
|
- `config.json`: Model architecture configuration. |