Safetensors
wolfgangblack commited on
Commit
4a6aa21
1 Parent(s): ae230f8

Upload utils.py

Browse files
Files changed (1) hide show
  1. models/utils.py +327 -0
models/utils.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, re, shutil, tempfile, os
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
+ from torch.nn import Softmax
4
+ import huggingface_hub
5
+ from PIL import Image
6
+ from torchvision import transforms, models
7
+ from torch import nn
8
+ from collections import Counter
9
+ from typing import List, Dict
10
+ import concurrent.futures
11
+
12
+ class BaseModel:
13
+ def inference(self, *, image: Image = None, prompt: str = None):
14
+ pass
15
+
16
+ class ImageRaterModel(BaseModel):
17
+ """
18
+ A class representing an image rating model.
19
+
20
+ This class encapsulates a deep learning model for rating images into predefined categories.
21
+ It provides methods for loading the model, preprocessing images, and making predictions.
22
+
23
+ Attributes:
24
+ repo_id (str): The identifier of the Hugging Face repository containing the model.
25
+ model_id (str): The identifier of the specific model to be loaded.
26
+ image_transform (torchvision.transforms.Compose): A sequence of image transformations to be applied to input images.
27
+ num_classes (int): The number of rating classes/categories.
28
+ class_names (List[str]): A list of human-readable names corresponding to each rating class.
29
+ device (torch.device): The device (CPU or GPU) on which the model will be loaded and inference will be performed.
30
+
31
+ Methods:
32
+ __init__: Initializes the image rating model.
33
+ get_architecture: Returns the architecture name of the loaded model. Currently supports resnet18 and resnet50
34
+ preprocess_image_object: Preprocesses an input image for model inference.
35
+ inference: Performs inference on a single input image and returns the predicted rating class.
36
+ load_model: Loads the deep learning model from the Hugging Face repository.
37
+ """
38
+ def __init__(self, repo_id: str, model_id: str, image_transform: transforms =
39
+ transforms.Compose([transforms.Resize((256, 256)),
40
+ transforms.CenterCrop((224, 224)),
41
+ transforms.ToTensor(),
42
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]),
43
+ num_classes: int = 5, class_names: List[str] = ["PG", "PG13", "R", "X", "XXX"],
44
+ device: torch.device = torch.device('cpu'))-> nn.Module:
45
+
46
+ self.repo_id = repo_id
47
+ self.model_id = model_id
48
+ self.num_classes = num_classes
49
+ self.transform = image_transform
50
+ self.device = device
51
+ self.model = self.load_model()
52
+ self.model.to(device)
53
+ self.class_names = ["PG", "PG13", "R", "X", "XXX"]
54
+
55
+ def get_architecture(self) -> str:
56
+ """
57
+ returns the arictecture of the loaded model as string
58
+ """
59
+ if 'resnet18' in self.model_id.lower():
60
+ return 'resnet18'
61
+ elif 'resnet50' in self.model_id.lower():
62
+ return 'resnet50'
63
+ else:
64
+ raise ValueError("Unsupported architecture. Please specifiy 'resnet18' or 'resnet50'")
65
+
66
+ def preprocess_image_object(self, imageObject: Image) -> torch.Tensor:
67
+ """
68
+ Does the same preprocessing as the validation dataset for model training
69
+ NOTE: THIS IS FOR RESNET18_100EPOCHS_MAXV2
70
+ """
71
+ if imageObject.mode == 'RGBA':
72
+ imageObject = imageObject.convert("RGB")
73
+
74
+ image = self.transform(imageObject).unsqueeze(0)
75
+ return image
76
+
77
+ def inference(self, *, image: Image = None, prompt: str = None) -> str:
78
+ """
79
+ Similar to the batch_inference but for a single image object
80
+ """
81
+
82
+ if image is None:
83
+ raise ValueError("Image must be defined")
84
+
85
+
86
+ self.model.eval() # Set model to evaluation mode
87
+ image = self.preprocess_image_object(image)
88
+ image = image.to(self.device)
89
+
90
+ with torch.no_grad(): # No need to compute gradients during inference
91
+ output = self.model(image)
92
+ _, prediction = torch.max(output, 1)
93
+ predicted_class = self.class_names[prediction.item()]
94
+
95
+ return predicted_class
96
+
97
+ def load_model(self) -> nn.Module: ##Keep load model
98
+ """
99
+ Loads model specific architecture
100
+ """
101
+ dl_file = huggingface_hub.hf_hub_download(
102
+ repo_id = self.repo_id,
103
+ filename = 'best_model_params.pt',
104
+ subfolder = f'models/{self.model_id}'
105
+ )
106
+
107
+ tempDir = tempfile.TemporaryDirectory()
108
+ temp_dir_path = tempDir.name
109
+
110
+ path_to_weights = os.path.join(temp_dir_path, "best_model_params.pt")
111
+ shutil.copy(dl_file, path_to_weights)
112
+
113
+
114
+ if 'resnet18' in self.model_id.lower():
115
+ model = models.resnet18(weights = 'IMAGENET1K_V1')
116
+ elif 'resnet50' in self.model_id.lower():
117
+ model = models.resnet50(weights = 'IMAGENET1K_V1')
118
+ else:
119
+ raise ValueError("Unsupported architecture. Please specifiy 'resnet18' or 'resnet50'")
120
+
121
+ num_ftrs = model.fc.in_features
122
+ model.fc = nn.Linear(num_ftrs, self.num_classes)
123
+
124
+ model.load_state_dict(torch.load(path_to_weights, map_location = self.device))
125
+
126
+ return model
127
+
128
+ class PromptTransformerRaterModel(BaseModel):
129
+ """
130
+ A class representing a transformer-based model for rating prompts into PG, PG13, R, X, and XXX categories
131
+
132
+ This class encapsulates a transformer-based model for rating prompts or text inputs into predefined categories.
133
+ It provides methods for loading the model, preprocessing text inputs, and making predictions.
134
+
135
+ Attributes:
136
+ repo_id (str): The identifier of the Hugging Face repository containing the model.
137
+ model_id (str): The identifier of the specific model to be loaded.
138
+ device (torch.device): The device (CPU or GPU) on which the model will be loaded and inference will be performed.
139
+
140
+ Methods:
141
+ __init__: Initializes the transformer-based rating model.
142
+ load_model: Downloads and loads the pre-trained transformer model from the Hugging Face repository.
143
+ clean_text: Cleans input text data by removing extraneous characters and spaces.
144
+ inference: Performs inference on input text data using the transformer model and returns the predicted rating.
145
+ """
146
+ def __init__(self, repo_id: str, model_id: str, model_directory: str|None = None,
147
+ device: torch.device = torch.device('cpu')):
148
+
149
+ self.repo_id = repo_id
150
+ self.model_id = model_id
151
+ if model_directory is None:
152
+ tempDir = tempfile.TemporaryDirectory()
153
+ self.model_directory = tempDir.name
154
+ else:
155
+ self.model_directory = model_directory
156
+
157
+ self.load_model()
158
+
159
+ self.model = AutoModelForSequenceClassification.from_pretrained(
160
+ self.model_directory
161
+ )
162
+
163
+ self.tokenizer = AutoTokenizer.from_pretrained(
164
+ self.model_directory
165
+ )
166
+
167
+ self.device = device
168
+ self.model.to(device)
169
+ self.softmax = Softmax(dim=1)
170
+
171
+ def load_model(self) -> None:
172
+ """
173
+ Downloads the files for the transformer model
174
+ - may end up neglecting this and creating custom
175
+ repos on HF for prompt models so we don't need to save
176
+ files locally
177
+ """
178
+
179
+ for file in ['config.json', 'model.safetensors', 'tokenizer_config.json','special_tokens_map.json', 'vocab.txt', 'vocab.json', 'merges.txt', 'tokenizer.json',]:
180
+ try:
181
+ dl_file = huggingface_hub.hf_hub_download(
182
+ repo_id = self.repo_id,
183
+ filename = file,
184
+ subfolder = f'models/{self.model_id}'
185
+ )
186
+
187
+ shutil.copy(dl_file, os.path.join(self.model_directory,file))
188
+ except Exception as e:
189
+ # raise LookupError(f"file error {file} raised exception {e}")
190
+ continue
191
+
192
+ return None
193
+
194
+ @staticmethod
195
+ def clean_text(text: str) -> str:
196
+ """
197
+ This method cleans prompt data, removing extraneous punctuation meant to denote blending, loras, or models without removing names or tags.
198
+ We also get rid of extraneous spaces or line breaks to reduce tokens and maintain as much semantic logic as possible
199
+ """
200
+ text = str(text)
201
+ # Remove additional characters: ( ) : < > [ ]
202
+ cleaned_text = re.sub(r'[():<>[\]]', ' ', text)
203
+ cleaned_text = cleaned_text.replace('\n', ' ')
204
+ # Replace multiple spaces with a single space
205
+ cleaned_text = re.sub(r'\s+', ' ', cleaned_text)
206
+ cleaned_text = re.sub(r'\s*,\s*', ', ', cleaned_text)
207
+
208
+ return cleaned_text.strip()
209
+
210
+ def inference(self, *, image: Image = None, prompt: str = None) -> str:
211
+ """
212
+ Does inference on prompt data using the transformer model
213
+ """
214
+ if prompt is None:
215
+ raise ValueError("Prompt must be defined")
216
+
217
+ text = self.clean_text(prompt)
218
+ tokens = self.tokenizer(text, max_length = 512, truncation = True, padding = 'max_length', return_tensors = 'pt')
219
+
220
+ with torch.no_grad():
221
+
222
+ for key in tokens:
223
+ tokens[key] = tokens[key].to(self.device)
224
+
225
+ outputs = self.model(**tokens)
226
+ logits = outputs.logits
227
+ probs = self.softmax(logits)
228
+ _, pred = torch.max(probs,1)
229
+
230
+ pred = pred.item()
231
+
232
+ return self.model.config.id2label[pred]
233
+
234
+ class MovieRaterModel(BaseModel):
235
+ """
236
+ A class representing a movie rating model that combines multiple sub-models.
237
+
238
+ This class combines multiple sub-models, including image-based and text-based rating models, to provide a comprehensive rating system for movies.
239
+ It allows for the integration of various rating models into a single interface and provides methods for making predictions based on input prompts and images.
240
+
241
+ Attributes:
242
+ repo_id (str): The identifier of the Hugging Face repository containing the sub-models.
243
+ models (List[str]): A list of identifiers for the sub-models to be loaded.
244
+ device (torch.device): The device (CPU or GPU) on which the sub-models will be loaded and inference will be performed.
245
+ mixtureDict (Dict[str|nn.Module]): A dictionary containing the loaded sub-models.
246
+
247
+ Methods:
248
+ __init__: Initializes the movie rating model and loads the sub-models.
249
+ load_model: Loads the sub-models specified in the models list and populates the mixtureDict.
250
+ inference_voting: Performs voting-based inference to determine the most common prediction among the sub-models.
251
+ inference: Makes predictions for movie ratings based on input prompts and images using the loaded sub-models.
252
+ """
253
+ def __init__(self, repo_id: str, mixtureDict: dict = {},
254
+ models: List[str] = ['baseresNet18', 'baseresNet50', 'bestresNet50', 'promptMovieBert','promptMovieRoberta'],
255
+ device: torch.device = torch.device('cpu')):
256
+
257
+ self.repo_id = repo_id
258
+ self.models = models
259
+ self.device = device
260
+ self.mixtureDict = mixtureDict
261
+
262
+ self.mixtureDict = self.load_model()
263
+
264
+ def load_model(self) -> Dict[str,nn.Module]:
265
+ """
266
+ Use established classes to load their models and populate the mixtureDict
267
+ """
268
+
269
+ for model in self.models:
270
+ if 'resnet' in model.lower():
271
+ self.mixtureDict[model] = ImageRaterModel(self.repo_id, model, device = self.device)
272
+ elif 'prompt' in model.lower():
273
+ self.mixtureDict[model] = PromptTransformerRaterModel(self.repo_id, model, device = self.device)
274
+
275
+ return self.mixtureDict
276
+
277
+ @staticmethod
278
+ def inference_voting(mylist: List[int]) -> int:
279
+ """
280
+ A function used to determine the most common pred among the N-odd models
281
+ in cases of tie, returns the most conservative answer
282
+ """
283
+ counter = Counter(mylist)
284
+ most_common = counter.most_common()
285
+ most_common_element = sorted(Counter(mylist).most_common(), key = lambda x: (x[1], x[0]))[-1][0]
286
+
287
+ return most_common_element
288
+
289
+ @staticmethod
290
+ def inference_worker(model, *,image: Image = None, prompt: str = None) -> int:
291
+ """
292
+ Worker function to perform inference using a single model
293
+ """
294
+ if isinstance(model, ImageRaterModel):
295
+ return model.inference(image = image, prompt = prompt)
296
+ elif isinstance(model, PromptTransformerRaterModel):
297
+ return model.inference(image = image, prompt = prompt)
298
+
299
+ def inference(self, *,image: Image = None, prompt: str = None) -> str:
300
+ """
301
+ Uses class specific inference for individual preds and then
302
+ calls inference_voting to return the most common pred
303
+ """
304
+
305
+ if image is None or prompt is None:
306
+ raise ValueError("Image AND Prompt must be defined")
307
+
308
+ with concurrent.futures.ThreadPoolExecutor() as executor:
309
+ # Submit inference tasks for all models
310
+ futures = [executor.submit(self.inference_worker, model, image = image, prompt = prompt) for model in self.mixtureDict.values()]
311
+
312
+ # Get results as they become available
313
+ results = [future.result() for future in concurrent.futures.as_completed(futures)]
314
+
315
+ preds = results
316
+
317
+ label2id = {}
318
+ id2label = {}
319
+
320
+ for name, model in self.mixtureDict.items():
321
+ if 'prompt' in name.lower() and label2id == {}:
322
+ label2id = model.model.config.label2id
323
+ id2label = model.model.config.id2label
324
+ break
325
+
326
+ return id2label[self.inference_voting([label2id[i] for i in preds])]
327
+