|
import torch, re, shutil, tempfile, os |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
from torch.nn import Softmax |
|
import huggingface_hub |
|
from PIL import Image |
|
from torchvision import transforms, models |
|
from torch import nn |
|
from collections import Counter |
|
from typing import List, Dict |
|
import concurrent.futures |
|
|
|
class BaseModel: |
|
def inference(self, *, image: Image = None, prompt: str = None): |
|
pass |
|
|
|
class ImageRaterModel(BaseModel): |
|
""" |
|
A class representing an image rating model. |
|
|
|
This class encapsulates a deep learning model for rating images into predefined categories. |
|
It provides methods for loading the model, preprocessing images, and making predictions. |
|
|
|
Attributes: |
|
repo_id (str): The identifier of the Hugging Face repository containing the model. |
|
model_id (str): The identifier of the specific model to be loaded. |
|
image_transform (torchvision.transforms.Compose): A sequence of image transformations to be applied to input images. |
|
num_classes (int): The number of rating classes/categories. |
|
class_names (List[str]): A list of human-readable names corresponding to each rating class. |
|
device (torch.device): The device (CPU or GPU) on which the model will be loaded and inference will be performed. |
|
|
|
Methods: |
|
__init__: Initializes the image rating model. |
|
get_architecture: Returns the architecture name of the loaded model. Currently supports resnet18 and resnet50 |
|
preprocess_image_object: Preprocesses an input image for model inference. |
|
inference: Performs inference on a single input image and returns the predicted rating class. |
|
load_model: Loads the deep learning model from the Hugging Face repository. |
|
""" |
|
def __init__(self, repo_id: str, model_id: str, image_transform: transforms = |
|
transforms.Compose([transforms.Resize((256, 256)), |
|
transforms.CenterCrop((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]), |
|
num_classes: int = 5, class_names: List[str] = ["PG", "PG13", "R", "X", "XXX"], |
|
device: torch.device = torch.device('cpu'))-> nn.Module: |
|
|
|
self.repo_id = repo_id |
|
self.model_id = model_id |
|
self.num_classes = num_classes |
|
self.transform = image_transform |
|
self.device = device |
|
self.model = self.load_model() |
|
self.model.to(device) |
|
self.class_names = ["PG", "PG13", "R", "X", "XXX"] |
|
|
|
def get_architecture(self) -> str: |
|
""" |
|
returns the arictecture of the loaded model as string |
|
""" |
|
if 'resnet18' in self.model_id.lower(): |
|
return 'resnet18' |
|
elif 'resnet50' in self.model_id.lower(): |
|
return 'resnet50' |
|
else: |
|
raise ValueError("Unsupported architecture. Please specifiy 'resnet18' or 'resnet50'") |
|
|
|
def preprocess_image_object(self, imageObject: Image) -> torch.Tensor: |
|
""" |
|
Does the same preprocessing as the validation dataset for model training |
|
NOTE: THIS IS FOR RESNET18_100EPOCHS_MAXV2 |
|
""" |
|
if imageObject.mode == 'RGBA': |
|
imageObject = imageObject.convert("RGB") |
|
|
|
image = self.transform(imageObject).unsqueeze(0) |
|
return image |
|
|
|
def inference(self, *, image: Image = None, prompt: str = None) -> str: |
|
""" |
|
Similar to the batch_inference but for a single image object |
|
""" |
|
|
|
if image is None: |
|
raise ValueError("Image must be defined") |
|
|
|
|
|
self.model.eval() |
|
image = self.preprocess_image_object(image) |
|
image = image.to(self.device) |
|
|
|
with torch.no_grad(): |
|
output = self.model(image) |
|
_, prediction = torch.max(output, 1) |
|
predicted_class = self.class_names[prediction.item()] |
|
|
|
return predicted_class |
|
|
|
def load_model(self) -> nn.Module: |
|
""" |
|
Loads model specific architecture |
|
""" |
|
dl_file = huggingface_hub.hf_hub_download( |
|
repo_id = self.repo_id, |
|
filename = 'best_model_params.pt', |
|
subfolder = f'models/{self.model_id}' |
|
) |
|
|
|
tempDir = tempfile.TemporaryDirectory() |
|
temp_dir_path = tempDir.name |
|
|
|
path_to_weights = os.path.join(temp_dir_path, "best_model_params.pt") |
|
shutil.copy(dl_file, path_to_weights) |
|
|
|
|
|
if 'resnet18' in self.model_id.lower(): |
|
model = models.resnet18(weights = 'IMAGENET1K_V1') |
|
elif 'resnet50' in self.model_id.lower(): |
|
model = models.resnet50(weights = 'IMAGENET1K_V1') |
|
else: |
|
raise ValueError("Unsupported architecture. Please specifiy 'resnet18' or 'resnet50'") |
|
|
|
num_ftrs = model.fc.in_features |
|
model.fc = nn.Linear(num_ftrs, self.num_classes) |
|
|
|
model.load_state_dict(torch.load(path_to_weights, map_location = self.device)) |
|
|
|
return model |
|
|
|
class PromptTransformerRaterModel(BaseModel): |
|
""" |
|
A class representing a transformer-based model for rating prompts into PG, PG13, R, X, and XXX categories |
|
|
|
This class encapsulates a transformer-based model for rating prompts or text inputs into predefined categories. |
|
It provides methods for loading the model, preprocessing text inputs, and making predictions. |
|
|
|
Attributes: |
|
repo_id (str): The identifier of the Hugging Face repository containing the model. |
|
model_id (str): The identifier of the specific model to be loaded. |
|
device (torch.device): The device (CPU or GPU) on which the model will be loaded and inference will be performed. |
|
|
|
Methods: |
|
__init__: Initializes the transformer-based rating model. |
|
load_model: Downloads and loads the pre-trained transformer model from the Hugging Face repository. |
|
clean_text: Cleans input text data by removing extraneous characters and spaces. |
|
inference: Performs inference on input text data using the transformer model and returns the predicted rating. |
|
""" |
|
def __init__(self, repo_id: str, model_id: str, model_directory: str|None = None, |
|
device: torch.device = torch.device('cpu')): |
|
|
|
self.repo_id = repo_id |
|
self.model_id = model_id |
|
if model_directory is None: |
|
tempDir = tempfile.TemporaryDirectory() |
|
self.model_directory = tempDir.name |
|
else: |
|
self.model_directory = model_directory |
|
|
|
self.load_model() |
|
|
|
self.model = AutoModelForSequenceClassification.from_pretrained( |
|
self.model_directory |
|
) |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
self.model_directory |
|
) |
|
|
|
self.device = device |
|
self.model.to(device) |
|
self.softmax = Softmax(dim=1) |
|
|
|
def load_model(self) -> None: |
|
""" |
|
Downloads the files for the transformer model |
|
- may end up neglecting this and creating custom |
|
repos on HF for prompt models so we don't need to save |
|
files locally |
|
""" |
|
|
|
for file in ['config.json', 'model.safetensors', 'tokenizer_config.json','special_tokens_map.json', 'vocab.txt', 'vocab.json', 'merges.txt', 'tokenizer.json',]: |
|
try: |
|
dl_file = huggingface_hub.hf_hub_download( |
|
repo_id = self.repo_id, |
|
filename = file, |
|
subfolder = f'models/{self.model_id}' |
|
) |
|
|
|
shutil.copy(dl_file, os.path.join(self.model_directory,file)) |
|
except Exception as e: |
|
|
|
continue |
|
|
|
return None |
|
|
|
@staticmethod |
|
def clean_text(text: str) -> str: |
|
""" |
|
This method cleans prompt data, removing extraneous punctuation meant to denote blending, loras, or models without removing names or tags. |
|
We also get rid of extraneous spaces or line breaks to reduce tokens and maintain as much semantic logic as possible |
|
""" |
|
text = str(text) |
|
|
|
cleaned_text = re.sub(r'[():<>[\]]', ' ', text) |
|
cleaned_text = cleaned_text.replace('\n', ' ') |
|
|
|
cleaned_text = re.sub(r'\s+', ' ', cleaned_text) |
|
cleaned_text = re.sub(r'\s*,\s*', ', ', cleaned_text) |
|
|
|
return cleaned_text.strip() |
|
|
|
def inference(self, *, image: Image = None, prompt: str = None) -> str: |
|
""" |
|
Does inference on prompt data using the transformer model |
|
""" |
|
if prompt is None: |
|
raise ValueError("Prompt must be defined") |
|
|
|
text = self.clean_text(prompt) |
|
tokens = self.tokenizer(text, max_length = 512, truncation = True, padding = 'max_length', return_tensors = 'pt') |
|
|
|
with torch.no_grad(): |
|
|
|
for key in tokens: |
|
tokens[key] = tokens[key].to(self.device) |
|
|
|
outputs = self.model(**tokens) |
|
logits = outputs.logits |
|
probs = self.softmax(logits) |
|
_, pred = torch.max(probs,1) |
|
|
|
pred = pred.item() |
|
|
|
return self.model.config.id2label[pred] |
|
|
|
class MovieRaterModel(BaseModel): |
|
""" |
|
A class representing a movie rating model that combines multiple sub-models. |
|
|
|
This class combines multiple sub-models, including image-based and text-based rating models, to provide a comprehensive rating system for movies. |
|
It allows for the integration of various rating models into a single interface and provides methods for making predictions based on input prompts and images. |
|
|
|
Attributes: |
|
repo_id (str): The identifier of the Hugging Face repository containing the sub-models. |
|
models (List[str]): A list of identifiers for the sub-models to be loaded. |
|
device (torch.device): The device (CPU or GPU) on which the sub-models will be loaded and inference will be performed. |
|
mixtureDict (Dict[str|nn.Module]): A dictionary containing the loaded sub-models. |
|
|
|
Methods: |
|
__init__: Initializes the movie rating model and loads the sub-models. |
|
load_model: Loads the sub-models specified in the models list and populates the mixtureDict. |
|
inference_voting: Performs voting-based inference to determine the most common prediction among the sub-models. |
|
inference: Makes predictions for movie ratings based on input prompts and images using the loaded sub-models. |
|
""" |
|
def __init__(self, repo_id: str, mixtureDict: dict = {}, |
|
models: List[str] = ['baseresNet18', 'baseresNet50', 'bestresNet50', 'promptMovieBert','promptMovieRoberta'], |
|
device: torch.device = torch.device('cpu')): |
|
|
|
self.repo_id = repo_id |
|
self.models = models |
|
self.device = device |
|
self.mixtureDict = mixtureDict |
|
|
|
self.mixtureDict = self.load_model() |
|
|
|
def load_model(self) -> Dict[str,nn.Module]: |
|
""" |
|
Use established classes to load their models and populate the mixtureDict |
|
""" |
|
|
|
for model in self.models: |
|
if 'resnet' in model.lower(): |
|
self.mixtureDict[model] = ImageRaterModel(self.repo_id, model, device = self.device) |
|
elif 'prompt' in model.lower(): |
|
self.mixtureDict[model] = PromptTransformerRaterModel(self.repo_id, model, device = self.device) |
|
|
|
return self.mixtureDict |
|
|
|
@staticmethod |
|
def inference_voting(mylist: List[int]) -> int: |
|
""" |
|
A function used to determine the most common pred among the N-odd models |
|
in cases of tie, returns the most conservative answer |
|
""" |
|
counter = Counter(mylist) |
|
most_common = counter.most_common() |
|
most_common_element = sorted(Counter(mylist).most_common(), key = lambda x: (x[1], x[0]))[-1][0] |
|
|
|
return most_common_element |
|
|
|
@staticmethod |
|
def inference_worker(model, *,image: Image = None, prompt: str = None) -> int: |
|
""" |
|
Worker function to perform inference using a single model |
|
""" |
|
if isinstance(model, ImageRaterModel): |
|
return model.inference(image = image, prompt = prompt) |
|
elif isinstance(model, PromptTransformerRaterModel): |
|
return model.inference(image = image, prompt = prompt) |
|
|
|
def inference(self, *,image: Image = None, prompt: str = None) -> str: |
|
""" |
|
Uses class specific inference for individual preds and then |
|
calls inference_voting to return the most common pred |
|
""" |
|
|
|
if image is None or prompt is None: |
|
raise ValueError("Image AND Prompt must be defined") |
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
|
|
futures = [executor.submit(self.inference_worker, model, image = image, prompt = prompt) for model in self.mixtureDict.values()] |
|
|
|
|
|
results = [future.result() for future in concurrent.futures.as_completed(futures)] |
|
|
|
preds = results |
|
|
|
label2id = {} |
|
id2label = {} |
|
|
|
for name, model in self.mixtureDict.items(): |
|
if 'prompt' in name.lower() and label2id == {}: |
|
label2id = model.model.config.label2id |
|
id2label = model.model.config.id2label |
|
break |
|
|
|
return id2label[self.inference_voting([label2id[i] for i in preds])] |
|
|
|
|