from transformers import BertForTokenClassification, BertTokenizer, AutoConfig import torch from typing import Dict, List, Any class EndpointHandler: def __init__(self, path: str = "dejanseo/LinkBERT"): # Load the configuration from the saved model self.config = AutoConfig.from_pretrained(path) # Make sure to specify the correct model name for bert-large-cased # Adjust num_labels according to your model's configuration self.model = BertForTokenClassification.from_pretrained( path, config=self.config ) self.model.eval() # Set model to evaluation mode # Load the tokenizer for bert-large-cased self.tokenizer = BertTokenizer.from_pretrained("bert-large-cased") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: # Extract input text from the request inputs = data.get("inputs", "") # Tokenize the inputs inputs_tensor = self.tokenizer(inputs, return_tensors="pt", add_special_tokens=True) input_ids = inputs_tensor["input_ids"] # Run the model with torch.no_grad(): outputs = self.model(input_ids) predictions = torch.argmax(outputs.logits, dim=-1) # Process the predictions to generate readable output tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])[1:-1] # Exclude CLS and SEP tokens predictions = predictions[0][1:-1].tolist() # Reconstruct the text with annotations for token classification result = [] for token, pred in zip(tokens, predictions): if pred == 1: # Adjust this based on your classification needs result.append(f"{token}") else: result.append(token) reconstructed_text = " ".join(result).replace(" ##", "") # Return the processed text in a structured format return [{"text": reconstructed_text}] # Note: Ensure the path "dejanseo/LinkBERT" is correctly pointing to your model's location # If the model is locally saved, adjust the path accordingly