Spaces:
Paused
Paused
import logging | |
import time | |
import torch | |
from optimum.bettertransformer import BetterTransformer | |
from transformers import ( | |
AutoModelForSequenceClassification, | |
AutoModelForTokenClassification, | |
AutoTokenizer, | |
pipeline, | |
) | |
from mappingservice.config import Settings | |
class ModelLoader: | |
def __init__(self, settings: Settings, model_names): | |
self.settings = settings | |
self.model_names = model_names | |
self.model_cache = {} | |
self.token = self.settings.huggingface_access_token.get_secret_value() | |
self.classification = "token-classification" | |
def load_model(self, model_name): | |
start_time = time.time() | |
if self.classification == "token-classification": | |
model = AutoModelForTokenClassification.from_pretrained(model_name, token=self.token) # noqa: E501 | |
else: | |
model = AutoModelForSequenceClassification.from_pretrained(model_name, token=self.token) # noqa: E501 | |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=self.token) | |
model = BetterTransformer.transform(model) | |
end_time = time.time() | |
logging.info(f"Model {model_name} loaded in {end_time - start_time:.2f} seconds.") # noqa: E501 | |
return model, tokenizer | |
def get_model(self, model_name, language): | |
mc = self.model_cache.get(model_name) or {} | |
if language in mc: | |
logging.info(f"Using cached model for language: {language}") | |
return self.model_cache[model_name][language] | |
self.classification = self.model_names.get(model_name).get("classification") | |
model_name = self.model_names.get(model_name).get(language) | |
if not model_name: | |
logging.warning(f"Unsupported language: {language}") | |
return None | |
model, tokenizer = self.load_model(model_name) | |
pipeline_dict = { | |
'task': self.classification, | |
'model': model, | |
'tokenizer': tokenizer, | |
'token': self.token, | |
'device': 0 if torch.cuda.is_available() else -1, | |
} | |
if self.classification == "token-classification": | |
pipeline_dict.update({'framework': 'pt'}) | |
if self.classification == "text-classification": | |
pipeline_dict.update({'top_k': 1}) | |
model_pipeline = pipeline(**pipeline_dict) | |
self.model_cache[model_name] = {} | |
self.model_cache[model_name][language] = model_pipeline | |
return model_pipeline | |