Spaces:
Build error
Build error
File size: 2,670 Bytes
675f5c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils.logger import setup_logger
from utils.model_loader import ModelLoader
from api.shopify_client import ShopifyClient
logger = setup_logger(__name__)
class SQLGenerator:
def __init__(self):
try:
self.model_name = "premai-io/prem-1B-SQL"
self.tokenizer = ModelLoader.load_model_with_retry(
self.model_name,
AutoTokenizer
)
self.model = ModelLoader.load_model_with_retry(
self.model_name,
AutoModelForCausalLM
)
self.shopify_client = ShopifyClient()
except Exception as e:
logger.error(f"Failed to initialize SQLGenerator: {str(e)}")
raise
def generate_query(self, natural_language_query):
try:
schema_info = """
CREATE TABLE products (
id DECIMAL(8,2) PRIMARY KEY,
title VARCHAR(255),
body_html VARCHAR(255),
vendor VARCHAR(255),
product_type VARCHAR(255),
created_at VARCHAR(255),
handle VARCHAR(255),
updated_at DATE,
published_at VARCHAR(255),
template_suffix VARCHAR(255),
published_scope VARCHAR(255),
tags VARCHAR(255),
status VARCHAR(255),
admin_graphql_api_id DECIMAL(8,2),
variants VARCHAR(255),
options VARCHAR(255),
images VARCHAR(255),
image VARCHAR(255)
);
"""
prompt = f"""### Task: Generate a SQL query to answer the following question.
### Database Schema: {schema_info}
### Question: {natural_language_query}
### SQL Query:"""
inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
outputs = self.model.generate(
inputs["input_ids"],
max_length=256,
do_sample=False,
num_return_sequences=1,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
except Exception as e:
logger.error(f"Query generation error: {str(e)}")
return "Failed to generate SQL query due to an error."
def fetch_shopify_data(self, endpoint):
return self.shopify_client.fetch_data(endpoint)
|