from transformers import AutoModelForCausalLM, AutoTokenizer from utils.logger import setup_logger from utils.model_loader import ModelLoader from api.shopify_client import ShopifyClient logger = setup_logger(__name__) class SQLGenerator: def __init__(self): try: self.model_name = "premai-io/prem-1B-SQL" self.tokenizer = ModelLoader.load_model_with_retry( self.model_name, AutoTokenizer ) self.model = ModelLoader.load_model_with_retry( self.model_name, AutoModelForCausalLM ) self.shopify_client = ShopifyClient() except Exception as e: logger.error(f"Failed to initialize SQLGenerator: {str(e)}") raise def generate_query(self, natural_language_query): try: schema_info = """ CREATE TABLE products ( id DECIMAL(8,2) PRIMARY KEY, title VARCHAR(255), body_html VARCHAR(255), vendor VARCHAR(255), product_type VARCHAR(255), created_at VARCHAR(255), handle VARCHAR(255), updated_at DATE, published_at VARCHAR(255), template_suffix VARCHAR(255), published_scope VARCHAR(255), tags VARCHAR(255), status VARCHAR(255), admin_graphql_api_id DECIMAL(8,2), variants VARCHAR(255), options VARCHAR(255), images VARCHAR(255), image VARCHAR(255) ); """ prompt = f"""### Task: Generate a SQL query to answer the following question. ### Database Schema: {schema_info} ### Question: {natural_language_query} ### SQL Query:""" inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False) outputs = self.model.generate( inputs["input_ids"], max_length=256, do_sample=False, num_return_sequences=1, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id ) return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip() except Exception as e: logger.error(f"Query generation error: {str(e)}") return "Failed to generate SQL query due to an error." def fetch_shopify_data(self, endpoint): return self.shopify_client.fetch_data(endpoint)