LinkBERT / handler.py
dejanseo's picture
Update handler.py
65c7bf6 verified
raw
history blame
1.81 kB
from transformers import AutoModelForTokenClassification, AutoTokenizer
import torch
from typing import Dict, List, Any
class EndpointHandler:
def __init__(self, path: str = "dejanseo/LinkBERT"):
# Initialize tokenizer and model with the specified path
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForTokenClassification.from_pretrained(path)
self.model.eval() # Set model to evaluation mode
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
# Extract input text from the request
inputs = data.get("inputs", "")
# Tokenize the inputs
inputs_tensor = self.tokenizer(inputs, return_tensors="pt", add_special_tokens=True)
input_ids = inputs_tensor["input_ids"]
# Run the model
with torch.no_grad():
outputs = self.model(input_ids)
predictions = torch.argmax(outputs.logits, dim=-1)
# Process the predictions to generate readable output
tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])[1:-1] # Exclude CLS and SEP tokens
predictions = predictions[0][1:-1].tolist()
# Reconstruct the text with annotations for token classification
result = []
for token, pred in zip(tokens, predictions):
if pred == 1: # Assuming '1' is the label for the class of interest
result.append(f"<u>{token}</u>")
else:
result.append(token)
reconstructed_text = " ".join(result).replace(" ##", "")
# Return the processed text in a structured format
return [{"text": reconstructed_text}]
# Note: You'll need to replace 'path' with the actual path or identifier of your model when initializing the EndpointHandler.