File size: 1,412 Bytes
bf23bc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from utils.logger import setup_logger
from utils.model_loader import ModelLoader

logger = setup_logger(__name__)

class IntentClassifier:
    def __init__(self):
        self.model_name = "distilbert-base-uncased-finetuned-sst-2-english"
        try:
            self.model = ModelLoader.load_model_with_retry(
                self.model_name,
                AutoModelForSequenceClassification,
                num_labels=2
            )
            self.tokenizer = ModelLoader.load_model_with_retry(
                self.model_name,
                AutoTokenizer
            )
            self.intents = {0: "database_query", 1: "product_description"}
        except Exception as e:
            logger.error(f"Failed to initialize IntentClassifier: {str(e)}")
            raise

    def classify(self, query):
        try:
            inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True)
            outputs = self.model(**inputs)
            probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
            predicted_class = torch.argmax(probabilities).item()
            return self.intents[predicted_class], probabilities[0][predicted_class].item()
        except Exception as e:
            logger.error(f"Classification error: {str(e)}")
            return "error", 0.0