Tonic commited on
Commit
b6bd3b8
1 Parent(s): a513939

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -17
app.py CHANGED
@@ -9,14 +9,11 @@ You can also use efog 🌬️🌁🌫️SqlCoder by cloning this space. 🧬🔬
9
  Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's🛠️community 👻[![Let's build the future of AI together! 🚀🤖](https://discordapp.com/api/guilds/1109943800132010065/widget.png)](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Polytonic](https://github.com/tonic-ai) & contribute to 🌟 [Poly](https://github.com/tonic-ai/poly) 🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
10
  """
11
 
12
- @spaces.GPU
13
- class SQLQueryGenerator:
14
- def __init__(self, model_name, prompt_file="prompt.md", metadata_file="metadata.sql"):
15
- self.tokenizer, self.model = self.get_tokenizer_model(model_name)
16
- self.prompt_file = prompt_file
17
- self.metadata_file = metadata_file
18
-
19
- def get_tokenizer_model(self, model_name):
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
  model = AutoModelForCausalLM.from_pretrained(
22
  model_name,
@@ -27,6 +24,12 @@ class SQLQueryGenerator:
27
  )
28
  return tokenizer, model
29
 
 
 
 
 
 
 
30
  def generate_prompt(self, question):
31
  with open(self.prompt_file, "r") as f:
32
  prompt = f.read()
@@ -39,14 +42,15 @@ class SQLQueryGenerator:
39
  )
40
  return prompt
41
 
 
42
  def run_inference(self, question):
43
- self.model.to('cuda')
44
  prompt = self.generate_prompt(question)
45
- eos_token_id = self.tokenizer.eos_token_id
46
  pipe = pipeline(
47
  "text-generation",
48
- model=self.model,
49
- tokenizer=self.tokenizer,
50
  max_new_tokens=300,
51
  do_sample=False,
52
  num_beams=5,
@@ -66,19 +70,17 @@ class SQLQueryGenerator:
66
  )
67
  return generated_query
68
 
69
- def generate_sql(question, sql_query_generator):
70
- return sql_query_generator.run_inference(question)
71
-
72
  def main():
73
  model_name = "defog/sqlcoder2"
74
- sql_query_generator = SQLQueryGenerator(model_name)
 
75
 
76
  with gr.Blocks() as demo:
77
  gr.Markdown(title)
78
  question = gr.Textbox(label="Enter your question")
79
  submit = gr.Button("Generate SQL Query")
80
  output = gr.Textbox(label="🌬️🌁🌫️SqlCoder-2")
81
- submit.click(fn=generate_sql, inputs=[question, gr.State(sql_query_generator)], outputs=output)
82
 
83
  demo.launch()
84
 
 
9
  Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's🛠️community 👻[![Let's build the future of AI together! 🚀🤖](https://discordapp.com/api/guilds/1109943800132010065/widget.png)](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Polytonic](https://github.com/tonic-ai) & contribute to 🌟 [Poly](https://github.com/tonic-ai/poly) 🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
10
  """
11
 
12
+ class TokenizerModel:
13
+ def __init__(self, model_name):
14
+ self.tokenizer, self.model = self.load_model(model_name)
15
+
16
+ def load_model(self, model_name):
 
 
 
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
  model = AutoModelForCausalLM.from_pretrained(
19
  model_name,
 
24
  )
25
  return tokenizer, model
26
 
27
+ class SQLQueryGenerator:
28
+ def __init__(self, tokenizer_model, prompt_file="prompt.md", metadata_file="metadata.sql"):
29
+ self.tokenizer_model = tokenizer_model
30
+ self.prompt_file = prompt_file
31
+ self.metadata_file = metadata_file
32
+
33
  def generate_prompt(self, question):
34
  with open(self.prompt_file, "r") as f:
35
  prompt = f.read()
 
42
  )
43
  return prompt
44
 
45
+ @spaces.GPU
46
  def run_inference(self, question):
47
+ self.tokenizer_model.model.to('cuda')
48
  prompt = self.generate_prompt(question)
49
+ eos_token_id = self.tokenizer_model.tokenizer.eos_token_id
50
  pipe = pipeline(
51
  "text-generation",
52
+ model=self.tokenizer_model.model,
53
+ tokenizer=self.tokenizer_model.tokenizer,
54
  max_new_tokens=300,
55
  do_sample=False,
56
  num_beams=5,
 
70
  )
71
  return generated_query
72
 
 
 
 
73
  def main():
74
  model_name = "defog/sqlcoder2"
75
+ tokenizer_model = TokenizerModel(model_name)
76
+ sql_query_generator = SQLQueryGenerator(tokenizer_model)
77
 
78
  with gr.Blocks() as demo:
79
  gr.Markdown(title)
80
  question = gr.Textbox(label="Enter your question")
81
  submit = gr.Button("Generate SQL Query")
82
  output = gr.Textbox(label="🌬️🌁🌫️SqlCoder-2")
83
+ submit.click(fn=sql_query_generator.run_inference, inputs=question, outputs=output)
84
 
85
  demo.launch()
86