import logging import time import torch from optimum.bettertransformer import BetterTransformer from transformers import ( AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoTokenizer, pipeline, ) from mappingservice.config import Settings class ModelLoader: def __init__(self, settings: Settings, model_names): self.settings = settings self.model_names = model_names self.model_cache = {} self.token = self.settings.huggingface_access_token.get_secret_value() self.classification = "token-classification" def load_model(self, model_name): start_time = time.time() if self.classification == "token-classification": model = AutoModelForTokenClassification.from_pretrained(model_name, token=self.token) # noqa: E501 else: model = AutoModelForSequenceClassification.from_pretrained(model_name, token=self.token) # noqa: E501 tokenizer = AutoTokenizer.from_pretrained(model_name, token=self.token) model = BetterTransformer.transform(model) end_time = time.time() logging.info(f"Model {model_name} loaded in {end_time - start_time:.2f} seconds.") # noqa: E501 return model, tokenizer def get_model(self, model_name, language): mc = self.model_cache.get(model_name) or {} if language in mc: logging.info(f"Using cached model for language: {language}") return self.model_cache[model_name][language] self.classification = self.model_names.get(model_name).get("classification") model_name = self.model_names.get(model_name).get(language) if not model_name: logging.warning(f"Unsupported language: {language}") return None model, tokenizer = self.load_model(model_name) pipeline_dict = { 'task': self.classification, 'model': model, 'tokenizer': tokenizer, 'token': self.token, 'device': 0 if torch.cuda.is_available() else -1, } if self.classification == "token-classification": pipeline_dict.update({'framework': 'pt'}) if self.classification == "text-classification": pipeline_dict.update({'top_k': 1}) model_pipeline = pipeline(**pipeline_dict) self.model_cache[model_name] = {} self.model_cache[model_name][language] = model_pipeline return model_pipeline