Spaces:
Build error
Build error
import torch | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
from utils.logger import setup_logger | |
from utils.model_loader import ModelLoader | |
logger = setup_logger(__name__) | |
class IntentClassifier: | |
def __init__(self): | |
self.model_name = "distilbert-base-uncased-finetuned-sst-2-english" | |
try: | |
self.model = ModelLoader.load_model_with_retry( | |
self.model_name, | |
AutoModelForSequenceClassification, | |
num_labels=2 | |
) | |
self.tokenizer = ModelLoader.load_model_with_retry( | |
self.model_name, | |
AutoTokenizer | |
) | |
self.intents = {0: "database_query", 1: "product_description"} | |
except Exception as e: | |
logger.error(f"Failed to initialize IntentClassifier: {str(e)}") | |
raise | |
def classify(self, query): | |
try: | |
inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True) | |
outputs = self.model(**inputs) | |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
predicted_class = torch.argmax(probabilities).item() | |
return self.intents[predicted_class], probabilities[0][predicted_class].item() | |
except Exception as e: | |
logger.error(f"Classification error: {str(e)}") | |
return "error", 0.0 | |