nileshhanotia commited on
Commit
bf23bc0
1 Parent(s): d0e5d92

Create intent_classifier.py

Browse files
Files changed (1) hide show
  1. intent_classifier.py +35 -0
intent_classifier.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
+ from utils.logger import setup_logger
4
+ from utils.model_loader import ModelLoader
5
+
6
+ logger = setup_logger(__name__)
7
+
8
+ class IntentClassifier:
9
+ def __init__(self):
10
+ self.model_name = "distilbert-base-uncased-finetuned-sst-2-english"
11
+ try:
12
+ self.model = ModelLoader.load_model_with_retry(
13
+ self.model_name,
14
+ AutoModelForSequenceClassification,
15
+ num_labels=2
16
+ )
17
+ self.tokenizer = ModelLoader.load_model_with_retry(
18
+ self.model_name,
19
+ AutoTokenizer
20
+ )
21
+ self.intents = {0: "database_query", 1: "product_description"}
22
+ except Exception as e:
23
+ logger.error(f"Failed to initialize IntentClassifier: {str(e)}")
24
+ raise
25
+
26
+ def classify(self, query):
27
+ try:
28
+ inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True)
29
+ outputs = self.model(**inputs)
30
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
31
+ predicted_class = torch.argmax(probabilities).item()
32
+ return self.intents[predicted_class], probabilities[0][predicted_class].item()
33
+ except Exception as e:
34
+ logger.error(f"Classification error: {str(e)}")
35
+ return "error", 0.0