Safetensors
wolfgangblack's picture
Upload utils.py
4a6aa21 verified
raw
history blame
No virus
13.8 kB
import torch, re, shutil, tempfile, os
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from torch.nn import Softmax
import huggingface_hub
from PIL import Image
from torchvision import transforms, models
from torch import nn
from collections import Counter
from typing import List, Dict
import concurrent.futures
class BaseModel:
def inference(self, *, image: Image = None, prompt: str = None):
pass
class ImageRaterModel(BaseModel):
"""
A class representing an image rating model.
This class encapsulates a deep learning model for rating images into predefined categories.
It provides methods for loading the model, preprocessing images, and making predictions.
Attributes:
repo_id (str): The identifier of the Hugging Face repository containing the model.
model_id (str): The identifier of the specific model to be loaded.
image_transform (torchvision.transforms.Compose): A sequence of image transformations to be applied to input images.
num_classes (int): The number of rating classes/categories.
class_names (List[str]): A list of human-readable names corresponding to each rating class.
device (torch.device): The device (CPU or GPU) on which the model will be loaded and inference will be performed.
Methods:
__init__: Initializes the image rating model.
get_architecture: Returns the architecture name of the loaded model. Currently supports resnet18 and resnet50
preprocess_image_object: Preprocesses an input image for model inference.
inference: Performs inference on a single input image and returns the predicted rating class.
load_model: Loads the deep learning model from the Hugging Face repository.
"""
def __init__(self, repo_id: str, model_id: str, image_transform: transforms =
transforms.Compose([transforms.Resize((256, 256)),
transforms.CenterCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]),
num_classes: int = 5, class_names: List[str] = ["PG", "PG13", "R", "X", "XXX"],
device: torch.device = torch.device('cpu'))-> nn.Module:
self.repo_id = repo_id
self.model_id = model_id
self.num_classes = num_classes
self.transform = image_transform
self.device = device
self.model = self.load_model()
self.model.to(device)
self.class_names = ["PG", "PG13", "R", "X", "XXX"]
def get_architecture(self) -> str:
"""
returns the arictecture of the loaded model as string
"""
if 'resnet18' in self.model_id.lower():
return 'resnet18'
elif 'resnet50' in self.model_id.lower():
return 'resnet50'
else:
raise ValueError("Unsupported architecture. Please specifiy 'resnet18' or 'resnet50'")
def preprocess_image_object(self, imageObject: Image) -> torch.Tensor:
"""
Does the same preprocessing as the validation dataset for model training
NOTE: THIS IS FOR RESNET18_100EPOCHS_MAXV2
"""
if imageObject.mode == 'RGBA':
imageObject = imageObject.convert("RGB")
image = self.transform(imageObject).unsqueeze(0)
return image
def inference(self, *, image: Image = None, prompt: str = None) -> str:
"""
Similar to the batch_inference but for a single image object
"""
if image is None:
raise ValueError("Image must be defined")
self.model.eval() # Set model to evaluation mode
image = self.preprocess_image_object(image)
image = image.to(self.device)
with torch.no_grad(): # No need to compute gradients during inference
output = self.model(image)
_, prediction = torch.max(output, 1)
predicted_class = self.class_names[prediction.item()]
return predicted_class
def load_model(self) -> nn.Module: ##Keep load model
"""
Loads model specific architecture
"""
dl_file = huggingface_hub.hf_hub_download(
repo_id = self.repo_id,
filename = 'best_model_params.pt',
subfolder = f'models/{self.model_id}'
)
tempDir = tempfile.TemporaryDirectory()
temp_dir_path = tempDir.name
path_to_weights = os.path.join(temp_dir_path, "best_model_params.pt")
shutil.copy(dl_file, path_to_weights)
if 'resnet18' in self.model_id.lower():
model = models.resnet18(weights = 'IMAGENET1K_V1')
elif 'resnet50' in self.model_id.lower():
model = models.resnet50(weights = 'IMAGENET1K_V1')
else:
raise ValueError("Unsupported architecture. Please specifiy 'resnet18' or 'resnet50'")
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, self.num_classes)
model.load_state_dict(torch.load(path_to_weights, map_location = self.device))
return model
class PromptTransformerRaterModel(BaseModel):
"""
A class representing a transformer-based model for rating prompts into PG, PG13, R, X, and XXX categories
This class encapsulates a transformer-based model for rating prompts or text inputs into predefined categories.
It provides methods for loading the model, preprocessing text inputs, and making predictions.
Attributes:
repo_id (str): The identifier of the Hugging Face repository containing the model.
model_id (str): The identifier of the specific model to be loaded.
device (torch.device): The device (CPU or GPU) on which the model will be loaded and inference will be performed.
Methods:
__init__: Initializes the transformer-based rating model.
load_model: Downloads and loads the pre-trained transformer model from the Hugging Face repository.
clean_text: Cleans input text data by removing extraneous characters and spaces.
inference: Performs inference on input text data using the transformer model and returns the predicted rating.
"""
def __init__(self, repo_id: str, model_id: str, model_directory: str|None = None,
device: torch.device = torch.device('cpu')):
self.repo_id = repo_id
self.model_id = model_id
if model_directory is None:
tempDir = tempfile.TemporaryDirectory()
self.model_directory = tempDir.name
else:
self.model_directory = model_directory
self.load_model()
self.model = AutoModelForSequenceClassification.from_pretrained(
self.model_directory
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_directory
)
self.device = device
self.model.to(device)
self.softmax = Softmax(dim=1)
def load_model(self) -> None:
"""
Downloads the files for the transformer model
- may end up neglecting this and creating custom
repos on HF for prompt models so we don't need to save
files locally
"""
for file in ['config.json', 'model.safetensors', 'tokenizer_config.json','special_tokens_map.json', 'vocab.txt', 'vocab.json', 'merges.txt', 'tokenizer.json',]:
try:
dl_file = huggingface_hub.hf_hub_download(
repo_id = self.repo_id,
filename = file,
subfolder = f'models/{self.model_id}'
)
shutil.copy(dl_file, os.path.join(self.model_directory,file))
except Exception as e:
# raise LookupError(f"file error {file} raised exception {e}")
continue
return None
@staticmethod
def clean_text(text: str) -> str:
"""
This method cleans prompt data, removing extraneous punctuation meant to denote blending, loras, or models without removing names or tags.
We also get rid of extraneous spaces or line breaks to reduce tokens and maintain as much semantic logic as possible
"""
text = str(text)
# Remove additional characters: ( ) : < > [ ]
cleaned_text = re.sub(r'[():<>[\]]', ' ', text)
cleaned_text = cleaned_text.replace('\n', ' ')
# Replace multiple spaces with a single space
cleaned_text = re.sub(r'\s+', ' ', cleaned_text)
cleaned_text = re.sub(r'\s*,\s*', ', ', cleaned_text)
return cleaned_text.strip()
def inference(self, *, image: Image = None, prompt: str = None) -> str:
"""
Does inference on prompt data using the transformer model
"""
if prompt is None:
raise ValueError("Prompt must be defined")
text = self.clean_text(prompt)
tokens = self.tokenizer(text, max_length = 512, truncation = True, padding = 'max_length', return_tensors = 'pt')
with torch.no_grad():
for key in tokens:
tokens[key] = tokens[key].to(self.device)
outputs = self.model(**tokens)
logits = outputs.logits
probs = self.softmax(logits)
_, pred = torch.max(probs,1)
pred = pred.item()
return self.model.config.id2label[pred]
class MovieRaterModel(BaseModel):
"""
A class representing a movie rating model that combines multiple sub-models.
This class combines multiple sub-models, including image-based and text-based rating models, to provide a comprehensive rating system for movies.
It allows for the integration of various rating models into a single interface and provides methods for making predictions based on input prompts and images.
Attributes:
repo_id (str): The identifier of the Hugging Face repository containing the sub-models.
models (List[str]): A list of identifiers for the sub-models to be loaded.
device (torch.device): The device (CPU or GPU) on which the sub-models will be loaded and inference will be performed.
mixtureDict (Dict[str|nn.Module]): A dictionary containing the loaded sub-models.
Methods:
__init__: Initializes the movie rating model and loads the sub-models.
load_model: Loads the sub-models specified in the models list and populates the mixtureDict.
inference_voting: Performs voting-based inference to determine the most common prediction among the sub-models.
inference: Makes predictions for movie ratings based on input prompts and images using the loaded sub-models.
"""
def __init__(self, repo_id: str, mixtureDict: dict = {},
models: List[str] = ['baseresNet18', 'baseresNet50', 'bestresNet50', 'promptMovieBert','promptMovieRoberta'],
device: torch.device = torch.device('cpu')):
self.repo_id = repo_id
self.models = models
self.device = device
self.mixtureDict = mixtureDict
self.mixtureDict = self.load_model()
def load_model(self) -> Dict[str,nn.Module]:
"""
Use established classes to load their models and populate the mixtureDict
"""
for model in self.models:
if 'resnet' in model.lower():
self.mixtureDict[model] = ImageRaterModel(self.repo_id, model, device = self.device)
elif 'prompt' in model.lower():
self.mixtureDict[model] = PromptTransformerRaterModel(self.repo_id, model, device = self.device)
return self.mixtureDict
@staticmethod
def inference_voting(mylist: List[int]) -> int:
"""
A function used to determine the most common pred among the N-odd models
in cases of tie, returns the most conservative answer
"""
counter = Counter(mylist)
most_common = counter.most_common()
most_common_element = sorted(Counter(mylist).most_common(), key = lambda x: (x[1], x[0]))[-1][0]
return most_common_element
@staticmethod
def inference_worker(model, *,image: Image = None, prompt: str = None) -> int:
"""
Worker function to perform inference using a single model
"""
if isinstance(model, ImageRaterModel):
return model.inference(image = image, prompt = prompt)
elif isinstance(model, PromptTransformerRaterModel):
return model.inference(image = image, prompt = prompt)
def inference(self, *,image: Image = None, prompt: str = None) -> str:
"""
Uses class specific inference for individual preds and then
calls inference_voting to return the most common pred
"""
if image is None or prompt is None:
raise ValueError("Image AND Prompt must be defined")
with concurrent.futures.ThreadPoolExecutor() as executor:
# Submit inference tasks for all models
futures = [executor.submit(self.inference_worker, model, image = image, prompt = prompt) for model in self.mixtureDict.values()]
# Get results as they become available
results = [future.result() for future in concurrent.futures.as_completed(futures)]
preds = results
label2id = {}
id2label = {}
for name, model in self.mixtureDict.items():
if 'prompt' in name.lower() and label2id == {}:
label2id = model.model.config.label2id
id2label = model.model.config.id2label
break
return id2label[self.inference_voting([label2id[i] for i in preds])]