SQL_GEN_Context / app.py
alibidaran's picture
Update app.py
e53e706 verified
raw
history blame contribute delete
No virus
2.35 kB
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer
import gradio as gr
from gradio.themes.base import Base
from gradio.themes.utils import colors, fonts, sizes
from typing import Iterable
class SQLGEN(Base):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.stone,
secondary_hue: colors.Color | str = colors.green,
neutral_hue: colors.Color | str = colors.gray,
spacing_size: sizes.Size | str = sizes.spacing_md,
radius_size: sizes.Size | str = sizes.radius_md,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font
| str
| Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"),
"ui-sans-serif",
"sans-serif",
),
font_mono: fonts.Font
| str
| Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"),
"ui-monospace",
"monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
spacing_size=spacing_size,
radius_size=radius_size,
text_size=text_size,
font=font,
font_mono=font_mono,
)
model_id = "alibidaran/Gemma2_SQLGEN"
#bnb_config for GPU usage
#bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.bfloat16
#)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto')
tokenizer.padding_side = 'right'
def generate_sql(query,context):
prompt = query
context=context
text=f"<s>##Question: {prompt} \n ##Context: {context} \n ##Answer:"
inputs=tokenizer(text,return_tensors='pt').to('cuda')
with torch.no_grad():
outputs=model.generate(**inputs,max_new_tokens=100,do_sample=True,top_p=0.99,top_k=10,temperature=0.5)
output_text=outputs[:, inputs.input_ids.shape[1]:]
output_text=tokenizer.decode(output_text[0], skip_special_tokens=True)
return output_text
interface=gr.Interface(generate_sql,['text','text'],gr.Code(),title='SQLGEN', theme=SQLGEN())
if __name__=='__main__':
interface.launch()