Spaces:
Build error
Build error
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from utils.logger import setup_logger | |
from utils.model_loader import ModelLoader | |
from api.shopify_client import ShopifyClient | |
logger = setup_logger(__name__) | |
class SQLGenerator: | |
def __init__(self): | |
try: | |
self.model_name = "premai-io/prem-1B-SQL" | |
self.tokenizer = ModelLoader.load_model_with_retry( | |
self.model_name, | |
AutoTokenizer | |
) | |
self.model = ModelLoader.load_model_with_retry( | |
self.model_name, | |
AutoModelForCausalLM | |
) | |
self.shopify_client = ShopifyClient() | |
except Exception as e: | |
logger.error(f"Failed to initialize SQLGenerator: {str(e)}") | |
raise | |
def generate_query(self, natural_language_query): | |
try: | |
schema_info = """ | |
CREATE TABLE products ( | |
id DECIMAL(8,2) PRIMARY KEY, | |
title VARCHAR(255), | |
body_html VARCHAR(255), | |
vendor VARCHAR(255), | |
product_type VARCHAR(255), | |
created_at VARCHAR(255), | |
handle VARCHAR(255), | |
updated_at DATE, | |
published_at VARCHAR(255), | |
template_suffix VARCHAR(255), | |
published_scope VARCHAR(255), | |
tags VARCHAR(255), | |
status VARCHAR(255), | |
admin_graphql_api_id DECIMAL(8,2), | |
variants VARCHAR(255), | |
options VARCHAR(255), | |
images VARCHAR(255), | |
image VARCHAR(255) | |
); | |
""" | |
prompt = f"""### Task: Generate a SQL query to answer the following question. | |
### Database Schema: {schema_info} | |
### Question: {natural_language_query} | |
### SQL Query:""" | |
inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False) | |
outputs = self.model.generate( | |
inputs["input_ids"], | |
max_length=256, | |
do_sample=False, | |
num_return_sequences=1, | |
eos_token_id=self.tokenizer.eos_token_id, | |
pad_token_id=self.tokenizer.pad_token_id | |
) | |
return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
except Exception as e: | |
logger.error(f"Query generation error: {str(e)}") | |
return "Failed to generate SQL query due to an error." | |
def fetch_shopify_data(self, endpoint): | |
return self.shopify_client.fetch_data(endpoint) | |