File size: 13,845 Bytes
4a6aa21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 |
import torch, re, shutil, tempfile, os
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from torch.nn import Softmax
import huggingface_hub
from PIL import Image
from torchvision import transforms, models
from torch import nn
from collections import Counter
from typing import List, Dict
import concurrent.futures
class BaseModel:
def inference(self, *, image: Image = None, prompt: str = None):
pass
class ImageRaterModel(BaseModel):
"""
A class representing an image rating model.
This class encapsulates a deep learning model for rating images into predefined categories.
It provides methods for loading the model, preprocessing images, and making predictions.
Attributes:
repo_id (str): The identifier of the Hugging Face repository containing the model.
model_id (str): The identifier of the specific model to be loaded.
image_transform (torchvision.transforms.Compose): A sequence of image transformations to be applied to input images.
num_classes (int): The number of rating classes/categories.
class_names (List[str]): A list of human-readable names corresponding to each rating class.
device (torch.device): The device (CPU or GPU) on which the model will be loaded and inference will be performed.
Methods:
__init__: Initializes the image rating model.
get_architecture: Returns the architecture name of the loaded model. Currently supports resnet18 and resnet50
preprocess_image_object: Preprocesses an input image for model inference.
inference: Performs inference on a single input image and returns the predicted rating class.
load_model: Loads the deep learning model from the Hugging Face repository.
"""
def __init__(self, repo_id: str, model_id: str, image_transform: transforms =
transforms.Compose([transforms.Resize((256, 256)),
transforms.CenterCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]),
num_classes: int = 5, class_names: List[str] = ["PG", "PG13", "R", "X", "XXX"],
device: torch.device = torch.device('cpu'))-> nn.Module:
self.repo_id = repo_id
self.model_id = model_id
self.num_classes = num_classes
self.transform = image_transform
self.device = device
self.model = self.load_model()
self.model.to(device)
self.class_names = ["PG", "PG13", "R", "X", "XXX"]
def get_architecture(self) -> str:
"""
returns the arictecture of the loaded model as string
"""
if 'resnet18' in self.model_id.lower():
return 'resnet18'
elif 'resnet50' in self.model_id.lower():
return 'resnet50'
else:
raise ValueError("Unsupported architecture. Please specifiy 'resnet18' or 'resnet50'")
def preprocess_image_object(self, imageObject: Image) -> torch.Tensor:
"""
Does the same preprocessing as the validation dataset for model training
NOTE: THIS IS FOR RESNET18_100EPOCHS_MAXV2
"""
if imageObject.mode == 'RGBA':
imageObject = imageObject.convert("RGB")
image = self.transform(imageObject).unsqueeze(0)
return image
def inference(self, *, image: Image = None, prompt: str = None) -> str:
"""
Similar to the batch_inference but for a single image object
"""
if image is None:
raise ValueError("Image must be defined")
self.model.eval() # Set model to evaluation mode
image = self.preprocess_image_object(image)
image = image.to(self.device)
with torch.no_grad(): # No need to compute gradients during inference
output = self.model(image)
_, prediction = torch.max(output, 1)
predicted_class = self.class_names[prediction.item()]
return predicted_class
def load_model(self) -> nn.Module: ##Keep load model
"""
Loads model specific architecture
"""
dl_file = huggingface_hub.hf_hub_download(
repo_id = self.repo_id,
filename = 'best_model_params.pt',
subfolder = f'models/{self.model_id}'
)
tempDir = tempfile.TemporaryDirectory()
temp_dir_path = tempDir.name
path_to_weights = os.path.join(temp_dir_path, "best_model_params.pt")
shutil.copy(dl_file, path_to_weights)
if 'resnet18' in self.model_id.lower():
model = models.resnet18(weights = 'IMAGENET1K_V1')
elif 'resnet50' in self.model_id.lower():
model = models.resnet50(weights = 'IMAGENET1K_V1')
else:
raise ValueError("Unsupported architecture. Please specifiy 'resnet18' or 'resnet50'")
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, self.num_classes)
model.load_state_dict(torch.load(path_to_weights, map_location = self.device))
return model
class PromptTransformerRaterModel(BaseModel):
"""
A class representing a transformer-based model for rating prompts into PG, PG13, R, X, and XXX categories
This class encapsulates a transformer-based model for rating prompts or text inputs into predefined categories.
It provides methods for loading the model, preprocessing text inputs, and making predictions.
Attributes:
repo_id (str): The identifier of the Hugging Face repository containing the model.
model_id (str): The identifier of the specific model to be loaded.
device (torch.device): The device (CPU or GPU) on which the model will be loaded and inference will be performed.
Methods:
__init__: Initializes the transformer-based rating model.
load_model: Downloads and loads the pre-trained transformer model from the Hugging Face repository.
clean_text: Cleans input text data by removing extraneous characters and spaces.
inference: Performs inference on input text data using the transformer model and returns the predicted rating.
"""
def __init__(self, repo_id: str, model_id: str, model_directory: str|None = None,
device: torch.device = torch.device('cpu')):
self.repo_id = repo_id
self.model_id = model_id
if model_directory is None:
tempDir = tempfile.TemporaryDirectory()
self.model_directory = tempDir.name
else:
self.model_directory = model_directory
self.load_model()
self.model = AutoModelForSequenceClassification.from_pretrained(
self.model_directory
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_directory
)
self.device = device
self.model.to(device)
self.softmax = Softmax(dim=1)
def load_model(self) -> None:
"""
Downloads the files for the transformer model
- may end up neglecting this and creating custom
repos on HF for prompt models so we don't need to save
files locally
"""
for file in ['config.json', 'model.safetensors', 'tokenizer_config.json','special_tokens_map.json', 'vocab.txt', 'vocab.json', 'merges.txt', 'tokenizer.json',]:
try:
dl_file = huggingface_hub.hf_hub_download(
repo_id = self.repo_id,
filename = file,
subfolder = f'models/{self.model_id}'
)
shutil.copy(dl_file, os.path.join(self.model_directory,file))
except Exception as e:
# raise LookupError(f"file error {file} raised exception {e}")
continue
return None
@staticmethod
def clean_text(text: str) -> str:
"""
This method cleans prompt data, removing extraneous punctuation meant to denote blending, loras, or models without removing names or tags.
We also get rid of extraneous spaces or line breaks to reduce tokens and maintain as much semantic logic as possible
"""
text = str(text)
# Remove additional characters: ( ) : < > [ ]
cleaned_text = re.sub(r'[():<>[\]]', ' ', text)
cleaned_text = cleaned_text.replace('\n', ' ')
# Replace multiple spaces with a single space
cleaned_text = re.sub(r'\s+', ' ', cleaned_text)
cleaned_text = re.sub(r'\s*,\s*', ', ', cleaned_text)
return cleaned_text.strip()
def inference(self, *, image: Image = None, prompt: str = None) -> str:
"""
Does inference on prompt data using the transformer model
"""
if prompt is None:
raise ValueError("Prompt must be defined")
text = self.clean_text(prompt)
tokens = self.tokenizer(text, max_length = 512, truncation = True, padding = 'max_length', return_tensors = 'pt')
with torch.no_grad():
for key in tokens:
tokens[key] = tokens[key].to(self.device)
outputs = self.model(**tokens)
logits = outputs.logits
probs = self.softmax(logits)
_, pred = torch.max(probs,1)
pred = pred.item()
return self.model.config.id2label[pred]
class MovieRaterModel(BaseModel):
"""
A class representing a movie rating model that combines multiple sub-models.
This class combines multiple sub-models, including image-based and text-based rating models, to provide a comprehensive rating system for movies.
It allows for the integration of various rating models into a single interface and provides methods for making predictions based on input prompts and images.
Attributes:
repo_id (str): The identifier of the Hugging Face repository containing the sub-models.
models (List[str]): A list of identifiers for the sub-models to be loaded.
device (torch.device): The device (CPU or GPU) on which the sub-models will be loaded and inference will be performed.
mixtureDict (Dict[str|nn.Module]): A dictionary containing the loaded sub-models.
Methods:
__init__: Initializes the movie rating model and loads the sub-models.
load_model: Loads the sub-models specified in the models list and populates the mixtureDict.
inference_voting: Performs voting-based inference to determine the most common prediction among the sub-models.
inference: Makes predictions for movie ratings based on input prompts and images using the loaded sub-models.
"""
def __init__(self, repo_id: str, mixtureDict: dict = {},
models: List[str] = ['baseresNet18', 'baseresNet50', 'bestresNet50', 'promptMovieBert','promptMovieRoberta'],
device: torch.device = torch.device('cpu')):
self.repo_id = repo_id
self.models = models
self.device = device
self.mixtureDict = mixtureDict
self.mixtureDict = self.load_model()
def load_model(self) -> Dict[str,nn.Module]:
"""
Use established classes to load their models and populate the mixtureDict
"""
for model in self.models:
if 'resnet' in model.lower():
self.mixtureDict[model] = ImageRaterModel(self.repo_id, model, device = self.device)
elif 'prompt' in model.lower():
self.mixtureDict[model] = PromptTransformerRaterModel(self.repo_id, model, device = self.device)
return self.mixtureDict
@staticmethod
def inference_voting(mylist: List[int]) -> int:
"""
A function used to determine the most common pred among the N-odd models
in cases of tie, returns the most conservative answer
"""
counter = Counter(mylist)
most_common = counter.most_common()
most_common_element = sorted(Counter(mylist).most_common(), key = lambda x: (x[1], x[0]))[-1][0]
return most_common_element
@staticmethod
def inference_worker(model, *,image: Image = None, prompt: str = None) -> int:
"""
Worker function to perform inference using a single model
"""
if isinstance(model, ImageRaterModel):
return model.inference(image = image, prompt = prompt)
elif isinstance(model, PromptTransformerRaterModel):
return model.inference(image = image, prompt = prompt)
def inference(self, *,image: Image = None, prompt: str = None) -> str:
"""
Uses class specific inference for individual preds and then
calls inference_voting to return the most common pred
"""
if image is None or prompt is None:
raise ValueError("Image AND Prompt must be defined")
with concurrent.futures.ThreadPoolExecutor() as executor:
# Submit inference tasks for all models
futures = [executor.submit(self.inference_worker, model, image = image, prompt = prompt) for model in self.mixtureDict.values()]
# Get results as they become available
results = [future.result() for future in concurrent.futures.as_completed(futures)]
preds = results
label2id = {}
id2label = {}
for name, model in self.mixtureDict.items():
if 'prompt' in name.lower() and label2id == {}:
label2id = model.model.config.label2id
id2label = model.model.config.id2label
break
return id2label[self.inference_voting([label2id[i] for i in preds])]
|