Safetensors
File size: 13,845 Bytes
4a6aa21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
import torch, re, shutil, tempfile, os
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from torch.nn import Softmax
import huggingface_hub  
from PIL import Image
from torchvision import transforms, models
from torch import nn
from collections import Counter
from typing import List, Dict
import concurrent.futures

class BaseModel:
    def inference(self, *, image: Image = None, prompt: str = None):
        pass

class ImageRaterModel(BaseModel):
    """
    A class representing an image rating model.

    This class encapsulates a deep learning model for rating images into predefined categories. 
    It provides methods for loading the model, preprocessing images, and making predictions.

    Attributes:
        repo_id (str): The identifier of the Hugging Face repository containing the model.
        model_id (str): The identifier of the specific model to be loaded.
        image_transform (torchvision.transforms.Compose): A sequence of image transformations to be applied to input images.
        num_classes (int): The number of rating classes/categories.
        class_names (List[str]): A list of human-readable names corresponding to each rating class.
        device (torch.device): The device (CPU or GPU) on which the model will be loaded and inference will be performed.

    Methods:
        __init__: Initializes the image rating model.
        get_architecture: Returns the architecture name of the loaded model. Currently supports resnet18 and resnet50
        preprocess_image_object: Preprocesses an input image for model inference.
        inference: Performs inference on a single input image and returns the predicted rating class.
        load_model: Loads the deep learning model from the Hugging Face repository.
    """
    def __init__(self, repo_id: str, model_id: str, image_transform: transforms = 
                    transforms.Compose([transforms.Resize((256, 256)),  
                    transforms.CenterCrop((224, 224)), 
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]),
                num_classes: int = 5, class_names: List[str] = ["PG", "PG13", "R", "X", "XXX"],
                device: torch.device = torch.device('cpu'))-> nn.Module:
        
        self.repo_id = repo_id
        self.model_id = model_id
        self.num_classes = num_classes
        self.transform = image_transform
        self.device = device
        self.model = self.load_model()
        self.model.to(device)
        self.class_names = ["PG", "PG13", "R", "X", "XXX"]
        
    def get_architecture(self) -> str:
        """
        returns the arictecture of the loaded model as string
        """
        if 'resnet18' in self.model_id.lower():
            return 'resnet18'
        elif 'resnet50' in self.model_id.lower():
            return 'resnet50'
        else:
            raise ValueError("Unsupported architecture. Please specifiy 'resnet18' or 'resnet50'")

    def preprocess_image_object(self, imageObject: Image) -> torch.Tensor:
        """
        Does the same preprocessing as the validation dataset for model training
        NOTE: THIS IS FOR RESNET18_100EPOCHS_MAXV2
        """
        if imageObject.mode == 'RGBA':
            imageObject = imageObject.convert("RGB")

        image = self.transform(imageObject).unsqueeze(0)
        return image

    def inference(self, *, image: Image = None, prompt: str = None) -> str:
        """
        Similar to the batch_inference but for a single image object
        """

        if image is None:
            raise ValueError("Image must be defined")
        

        self.model.eval()  # Set model to evaluation mode
        image = self.preprocess_image_object(image)
        image = image.to(self.device)

        with torch.no_grad():  # No need to compute gradients during inference
            output = self.model(image)
            _, prediction = torch.max(output, 1)
            predicted_class = self.class_names[prediction.item()]

        return predicted_class
    
    def load_model(self) -> nn.Module: ##Keep load model
        """
        Loads model specific architecture
        """
        dl_file = huggingface_hub.hf_hub_download(
            repo_id = self.repo_id,
            filename = 'best_model_params.pt',
            subfolder = f'models/{self.model_id}'
        )

        tempDir = tempfile.TemporaryDirectory()
        temp_dir_path = tempDir.name

        path_to_weights = os.path.join(temp_dir_path, "best_model_params.pt")
        shutil.copy(dl_file, path_to_weights)
        

        if 'resnet18' in self.model_id.lower():
            model = models.resnet18(weights = 'IMAGENET1K_V1')
        elif 'resnet50' in self.model_id.lower():
            model = models.resnet50(weights = 'IMAGENET1K_V1')
        else:
            raise ValueError("Unsupported architecture. Please specifiy 'resnet18' or 'resnet50'")

        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, self.num_classes)

        model.load_state_dict(torch.load(path_to_weights, map_location = self.device))

        return model 

class PromptTransformerRaterModel(BaseModel):
    """
    A class representing a transformer-based model for rating prompts into PG, PG13, R, X, and XXX categories

    This class encapsulates a transformer-based model for rating prompts or text inputs into predefined categories. 
    It provides methods for loading the model, preprocessing text inputs, and making predictions.

    Attributes:
        repo_id (str): The identifier of the Hugging Face repository containing the model.
        model_id (str): The identifier of the specific model to be loaded.
        device (torch.device): The device (CPU or GPU) on which the model will be loaded and inference will be performed.

    Methods:
        __init__: Initializes the transformer-based rating model.
        load_model: Downloads and loads the pre-trained transformer model from the Hugging Face repository.
        clean_text: Cleans input text data by removing extraneous characters and spaces.
        inference: Performs inference on input text data using the transformer model and returns the predicted rating.
    """
    def __init__(self, repo_id: str, model_id: str, model_directory: str|None = None, 
                device: torch.device = torch.device('cpu')):
        
        self.repo_id = repo_id
        self.model_id = model_id
        if model_directory is None:
            tempDir = tempfile.TemporaryDirectory()
            self.model_directory = tempDir.name
        else:
            self.model_directory = model_directory

        self.load_model()

        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.model_directory
            )
        
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_directory
            )

        self.device = device
        self.model.to(device)
        self.softmax = Softmax(dim=1)

    def load_model(self) -> None:
        """
        Downloads the files for the transformer model
            - may end up neglecting this and creating custom
            repos on HF for prompt models so we don't need to save
            files locally 
        """

        for file in ['config.json', 'model.safetensors', 'tokenizer_config.json','special_tokens_map.json', 'vocab.txt', 'vocab.json', 'merges.txt', 'tokenizer.json',]:
            try:
                dl_file = huggingface_hub.hf_hub_download(
                    repo_id = self.repo_id,
                    filename = file,
                    subfolder = f'models/{self.model_id}'
                )

                shutil.copy(dl_file, os.path.join(self.model_directory,file))
            except Exception as e:
                # raise LookupError(f"file error {file} raised exception {e}")
                continue

        return None

    @staticmethod
    def clean_text(text: str) -> str:
        """
        This method cleans prompt data, removing extraneous punctuation meant to denote blending, loras, or models without removing names or tags. 
        We also get rid of extraneous spaces or line breaks to reduce tokens and maintain as much semantic logic as possible
        """
        text = str(text)
        # Remove additional characters: ( ) : < > [ ]
        cleaned_text = re.sub(r'[():<>[\]]', ' ', text)
        cleaned_text = cleaned_text.replace('\n', ' ')
        # Replace multiple spaces with a single space
        cleaned_text = re.sub(r'\s+', ' ', cleaned_text)
        cleaned_text = re.sub(r'\s*,\s*', ', ', cleaned_text)

        return cleaned_text.strip()
    
    def inference(self, *, image: Image = None, prompt: str = None) -> str:
        """
        Does inference on prompt data using the transformer model
        """
        if prompt is None:
            raise ValueError("Prompt must be defined")
        
        text = self.clean_text(prompt)
        tokens = self.tokenizer(text, max_length = 512, truncation = True, padding = 'max_length', return_tensors = 'pt')

        with torch.no_grad():

            for key in tokens:
                tokens[key]  = tokens[key].to(self.device)

            outputs = self.model(**tokens)
            logits = outputs.logits
            probs = self.softmax(logits)
            _, pred = torch.max(probs,1)
        
        pred = pred.item()
        
        return self.model.config.id2label[pred]

class MovieRaterModel(BaseModel):
    """
    A class representing a movie rating model that combines multiple sub-models.

    This class combines multiple sub-models, including image-based and text-based rating models, to provide a comprehensive rating system for movies. 
    It allows for the integration of various rating models into a single interface and provides methods for making predictions based on input prompts and images.

    Attributes:
        repo_id (str): The identifier of the Hugging Face repository containing the sub-models.
        models (List[str]): A list of identifiers for the sub-models to be loaded.
        device (torch.device): The device (CPU or GPU) on which the sub-models will be loaded and inference will be performed.
        mixtureDict (Dict[str|nn.Module]): A dictionary containing the loaded sub-models.

    Methods:
        __init__: Initializes the movie rating model and loads the sub-models.
        load_model: Loads the sub-models specified in the models list and populates the mixtureDict.
        inference_voting: Performs voting-based inference to determine the most common prediction among the sub-models.
        inference: Makes predictions for movie ratings based on input prompts and images using the loaded sub-models.
    """
    def __init__(self, repo_id: str, mixtureDict: dict = {}, 
                models: List[str] = ['baseresNet18', 'baseresNet50', 'bestresNet50', 'promptMovieBert','promptMovieRoberta'],
                device: torch.device = torch.device('cpu')):
        
        self.repo_id = repo_id
        self.models = models
        self.device = device
        self.mixtureDict = mixtureDict

        self.mixtureDict = self.load_model()

    def load_model(self) -> Dict[str,nn.Module]:
        """
        Use established classes to load their models and populate the mixtureDict
        """
        
        for model in self.models:
            if 'resnet' in model.lower():
                self.mixtureDict[model] = ImageRaterModel(self.repo_id, model, device = self.device)
            elif 'prompt' in model.lower():
                self.mixtureDict[model] = PromptTransformerRaterModel(self.repo_id, model, device = self.device)

        return self.mixtureDict

    @staticmethod
    def inference_voting(mylist: List[int]) -> int:
        """
        A function used to determine the most common pred among the N-odd models
        in cases of tie, returns the most conservative answer
        """
        counter = Counter(mylist)
        most_common = counter.most_common()
        most_common_element = sorted(Counter(mylist).most_common(), key = lambda x: (x[1], x[0]))[-1][0]

        return most_common_element

    @staticmethod
    def inference_worker(model, *,image: Image = None, prompt: str = None) -> int:
        """
        Worker function to perform inference using a single model
        """
        if isinstance(model, ImageRaterModel):
            return model.inference(image = image, prompt = prompt)
        elif isinstance(model, PromptTransformerRaterModel):
            return model.inference(image = image, prompt = prompt)

    def inference(self, *,image: Image = None, prompt: str = None) -> str:
        """
        Uses class specific inference for individual preds and then
        calls inference_voting to return the most common pred 
        """

        if image is None or prompt is None:
            raise ValueError("Image AND Prompt must be defined")
        
        with concurrent.futures.ThreadPoolExecutor() as executor:
            # Submit inference tasks for all models
            futures = [executor.submit(self.inference_worker, model, image = image, prompt = prompt) for model in self.mixtureDict.values()]
            
            # Get results as they become available
            results = [future.result() for future in concurrent.futures.as_completed(futures)]

        preds = results
        
        label2id = {}
        id2label = {}

        for name, model in self.mixtureDict.items():
            if 'prompt' in name.lower() and label2id == {}:
                label2id = model.model.config.label2id
                id2label = model.model.config.id2label 
                break

        return id2label[self.inference_voting([label2id[i] for i in preds])]