map-room / mappingservice /ms /model_loader.py
Calin Rada
init
f006f31 unverified
import logging
import time
import torch
from optimum.bettertransformer import BetterTransformer
from transformers import (
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoTokenizer,
pipeline,
)
from mappingservice.config import Settings
class ModelLoader:
def __init__(self, settings: Settings, model_names):
self.settings = settings
self.model_names = model_names
self.model_cache = {}
self.token = self.settings.huggingface_access_token.get_secret_value()
self.classification = "token-classification"
def load_model(self, model_name):
start_time = time.time()
if self.classification == "token-classification":
model = AutoModelForTokenClassification.from_pretrained(model_name, token=self.token) # noqa: E501
else:
model = AutoModelForSequenceClassification.from_pretrained(model_name, token=self.token) # noqa: E501
tokenizer = AutoTokenizer.from_pretrained(model_name, token=self.token)
model = BetterTransformer.transform(model)
end_time = time.time()
logging.info(f"Model {model_name} loaded in {end_time - start_time:.2f} seconds.") # noqa: E501
return model, tokenizer
def get_model(self, model_name, language):
mc = self.model_cache.get(model_name) or {}
if language in mc:
logging.info(f"Using cached model for language: {language}")
return self.model_cache[model_name][language]
self.classification = self.model_names.get(model_name).get("classification")
model_name = self.model_names.get(model_name).get(language)
if not model_name:
logging.warning(f"Unsupported language: {language}")
return None
model, tokenizer = self.load_model(model_name)
pipeline_dict = {
'task': self.classification,
'model': model,
'tokenizer': tokenizer,
'token': self.token,
'device': 0 if torch.cuda.is_available() else -1,
}
if self.classification == "token-classification":
pipeline_dict.update({'framework': 'pt'})
if self.classification == "text-classification":
pipeline_dict.update({'top_k': 1})
model_pipeline = pipeline(**pipeline_dict)
self.model_cache[model_name] = {}
self.model_cache[model_name][language] = model_pipeline
return model_pipeline