Tonic commited on
Commit
b884f75
1 Parent(s): c3ced67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -10
app.py CHANGED
@@ -9,16 +9,18 @@ You can also use efog 🌬️🌁🌫️SqlCoder by cloning this space. 🧬🔬
9
  Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's🛠️community 👻[![Let's build the future of AI together! 🚀🤖](https://discordapp.com/api/guilds/1109943800132010065/widget.png)](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Polytonic](https://github.com/tonic-ai) & contribute to 🌟 [Poly](https://github.com/tonic-ai/poly) 🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
10
  """
11
 
 
 
12
  def load_tokenizer_model(model_name):
13
- tokenizer = AutoTokenizer.from_pretrained(model_name)
14
- model = AutoModelForCausalLM.from_pretrained(
 
15
  model_name,
16
  trust_remote_code=True,
17
  torch_dtype=torch.float16,
18
  device_map="auto",
19
  use_cache=True,
20
  )
21
- return tokenizer, model
22
 
23
  def generate_prompt(question, prompt_file="prompt.md", metadata_file="metadata.sql"):
24
  with open(prompt_file, "r") as f:
@@ -33,14 +35,15 @@ def generate_prompt(question, prompt_file="prompt.md", metadata_file="metadata.s
33
  return prompt
34
 
35
  @spaces.GPU
36
- def run_inference(question, model, tokenizer):
37
- model.to('cuda')
 
38
  prompt = generate_prompt(question)
39
- eos_token_id = tokenizer.eos_token_id
40
  pipe = pipeline(
41
  "text-generation",
42
- model=model,
43
- tokenizer=tokenizer,
44
  max_new_tokens=300,
45
  do_sample=False,
46
  num_beams=5,
@@ -62,14 +65,14 @@ def run_inference(question, model, tokenizer):
62
 
63
  def main():
64
  model_name = "defog/sqlcoder2"
65
- tokenizer, model = load_tokenizer_model(model_name)
66
 
67
  with gr.Blocks() as demo:
68
  gr.Markdown(title)
69
  question = gr.Textbox(label="Enter your question")
70
  submit = gr.Button("Generate SQL Query")
71
  output = gr.Textbox(label="🌬️🌁🌫️SqlCoder-2")
72
- submit.click(fn=lambda x: run_inference(x, model, tokenizer), inputs=question, outputs=output)
73
 
74
  demo.launch()
75
 
 
9
  Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's🛠️community 👻[![Let's build the future of AI together! 🚀🤖](https://discordapp.com/api/guilds/1109943800132010065/widget.png)](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Polytonic](https://github.com/tonic-ai) & contribute to 🌟 [Poly](https://github.com/tonic-ai/poly) 🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
10
  """
11
 
12
+ global_tokenizer, global_model = None, None
13
+
14
  def load_tokenizer_model(model_name):
15
+ global global_tokenizer, global_model
16
+ global_tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ global_model = AutoModelForCausalLM.from_pretrained(
18
  model_name,
19
  trust_remote_code=True,
20
  torch_dtype=torch.float16,
21
  device_map="auto",
22
  use_cache=True,
23
  )
 
24
 
25
  def generate_prompt(question, prompt_file="prompt.md", metadata_file="metadata.sql"):
26
  with open(prompt_file, "r") as f:
 
35
  return prompt
36
 
37
  @spaces.GPU
38
+ def run_inference(question):
39
+ global global_tokenizer, global_model
40
+ global_model.to('cuda')
41
  prompt = generate_prompt(question)
42
+ eos_token_id = global_tokenizer.eos_token_id
43
  pipe = pipeline(
44
  "text-generation",
45
+ model=global_model,
46
+ tokenizer=global_tokenizer,
47
  max_new_tokens=300,
48
  do_sample=False,
49
  num_beams=5,
 
65
 
66
  def main():
67
  model_name = "defog/sqlcoder2"
68
+ load_tokenizer_model(model_name)
69
 
70
  with gr.Blocks() as demo:
71
  gr.Markdown(title)
72
  question = gr.Textbox(label="Enter your question")
73
  submit = gr.Button("Generate SQL Query")
74
  output = gr.Textbox(label="🌬️🌁🌫️SqlCoder-2")
75
+ submit.click(fn=run_inference, inputs=question, outputs=output)
76
 
77
  demo.launch()
78