juanfra218 commited on
Commit
afcf4b8
1 Parent(s): 4241584

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +36 -0
README.md CHANGED
@@ -27,6 +27,42 @@ Currently working to implement PICARD (Parsing Incrementally for Constrained Aut
27
 
28
  Results are currently being evaluated and will be posted here soon.
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  ## Files
31
 
32
  - `optimizer.pt`: State of the optimizer.
 
27
 
28
  Results are currently being evaluated and will be posted here soon.
29
 
30
+ ## Usage
31
+
32
+ ```
33
+ import torch
34
+ from transformers import AutoTokenizer, T5ForConditionalGeneration
35
+
36
+ # Load the tokenizer and model
37
+ model_path = 'text2sql_model_path'
38
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
39
+ model = T5ForConditionalGeneration.from_pretrained(model_path)
40
+
41
+ # Function to generate SQL queries
42
+ def generate_sql(prompt, schema):
43
+ input_text = "translate English to SQL: " + prompt + " " + schema
44
+ inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True, padding="max_length")
45
+
46
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
+ inputs = {key: value.to(device) for key, value in inputs.items()}
48
+
49
+ max_output_length = 1024
50
+ outputs = model.generate(**inputs, max_length=max_output_length)
51
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
52
+
53
+ # Interactive loop
54
+ print("Enter 'quit' to exit.")
55
+ while True:
56
+ prompt = input("Insert prompt: ")
57
+ schema = input("Insert schema: ")
58
+ if prompt.lower() == 'quit':
59
+ break
60
+
61
+ sql_query = generate_sql(prompt, schema)
62
+ print(f"Generated SQL query: {sql_query}")
63
+ print()
64
+ ```
65
+
66
  ## Files
67
 
68
  - `optimizer.pt`: State of the optimizer.