juanfra218 commited on
Commit
c06880d
1 Parent(s): 1f24039

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +5 -3
README.md CHANGED
@@ -59,19 +59,21 @@ training_args = Seq2SeqTrainingArguments(
59
 
60
  ```
61
  import torch
62
- from transformers import AutoTokenizer, T5ForConditionalGeneration
 
 
63
 
64
  # Load the tokenizer and model
65
  model_path = 'juanfra218/text2sql'
66
- tokenizer = AutoTokenizer.from_pretrained(model_path)
67
  model = T5ForConditionalGeneration.from_pretrained(model_path)
 
68
 
69
  # Function to generate SQL queries
70
  def generate_sql(prompt, schema):
71
  input_text = "translate English to SQL: " + prompt + " " + schema
72
  inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True, padding="max_length")
73
 
74
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
75
  inputs = {key: value.to(device) for key, value in inputs.items()}
76
 
77
  max_output_length = 1024
 
59
 
60
  ```
61
  import torch
62
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
63
+
64
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
 
66
  # Load the tokenizer and model
67
  model_path = 'juanfra218/text2sql'
68
+ tokenizer = T5Tokenizer.from_pretrained(model_path)
69
  model = T5ForConditionalGeneration.from_pretrained(model_path)
70
+ model.to(device)
71
 
72
  # Function to generate SQL queries
73
  def generate_sql(prompt, schema):
74
  input_text = "translate English to SQL: " + prompt + " " + schema
75
  inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True, padding="max_length")
76
 
 
77
  inputs = {key: value.to(device) for key, value in inputs.items()}
78
 
79
  max_output_length = 1024