from transformers import AutoModelForTokenClassification, AutoTokenizer import torch from typing import Dict, List, Any class EndpointHandler: def __init__(self, path: str = "dejanseo/LinkBERT"): # Initialize tokenizer and model with the specified path self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForTokenClassification.from_pretrained(path) self.model.eval() # Set model to evaluation mode def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: # Extract input text from the request inputs = data.get("inputs", "") # Tokenize the inputs inputs_tensor = self.tokenizer(inputs, return_tensors="pt", add_special_tokens=True) input_ids = inputs_tensor["input_ids"] # Run the model with torch.no_grad(): outputs = self.model(input_ids) predictions = torch.argmax(outputs.logits, dim=-1) # Process the predictions to generate readable output tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])[1:-1] # Exclude CLS and SEP tokens predictions = predictions[0][1:-1].tolist() # Reconstruct the text with annotations for token classification result = [] for token, pred in zip(tokens, predictions): if pred == 1: # Assuming '1' is the label for the class of interest result.append(f"{token}") else: result.append(token) reconstructed_text = " ".join(result).replace(" ##", "") # Return the processed text in a structured format return [{"text": reconstructed_text}] # Note: You'll need to replace 'path' with the actual path or identifier of your model when initializing the EndpointHandler.