from transformers import AutoProcessor, AutoModelForCausalLM, BlipForConditionalGeneration class ImageCaptionModel: def __init__( self, device, processor, model, ) -> None: self.device = device self.processor = processor self.model = model self.model.to(self.device) def generate( self, image, num_captions=1, max_length=50, num_beam_groups=1, temperature=1.0, top_k=50, top_p=1.0, repetition_penalty=1.0, diversity_penalty=0.0, ): """ Generates captions for the given image. ----- Parameters: preprocessor: transformers.PreTrainedTokenizerFast The preprocessor to use for the model. model: transformers.PreTrainedModel The model to use for generating captions. image: PIL.Image The image to generate captions for. num_captions: int The number of captions to generate. num_beam_groups: int The number of beam groups to use for beam search in order to maintain diversity. Must be between 1 and num_beams. 1 means no group_beam_search.. temperature: float The temperature to use for sampling. The value used to module the next token probabilities that will be used by default in the generate method of the model. Must be strictly positive. Defaults to 1.0. top_k: int The number of highest probability vocabulary tokens to keep for top-k-filtering. A large value of top_k will keep more probabilities for each token leading to a better but slower generation. Defaults to 50. top_p: float The value that will be used by default in the generate method of the model for top_p. If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. repetition_penalty: float The parameter for repetition penalty. 1.0 means no penalty. Defaults to 1.0. diversity_penalty: float The parameter for diversity penalty. 0.0 means no penalty. Defaults to 0.0. """ pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device) if diversity_penalty != 0.0: num_beam_groups = 2 num_captions = num_captions if num_captions % 2 == 0 else num_captions + 1 generated_ids = self.model.generate( pixel_values=pixel_values, max_length=max_length, num_beams=num_captions, num_beam_groups=num_beam_groups, num_return_sequences=num_captions, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, diversity_penalty=diversity_penalty, ) generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True) return generated_caption[:num_captions] class GitBaseCocoModel(ImageCaptionModel): def __init__(self, device): """ A wrapper class for the Git-Base-COCO model. It is a pretrained model for image captioning. ----- Parameters: device: torch.device The device to run the model on. checkpoint: str The checkpoint to load the model from. ----- Returns: None """ checkpoint = "microsoft/git-base-coco" processor = AutoProcessor.from_pretrained(checkpoint) model = AutoModelForCausalLM.from_pretrained(checkpoint) super().__init__(device, processor, model) def generate(self, image, max_length=50, num_captions=1, **kwargs): """ Generates captions for the given image. ----- Parameters: image: PIL.Image The image to generate captions for. max_len: int The maximum length of the caption. num_captions: int The number of captions to generate. """ captions = super().generate(image, max_length, num_captions, **kwargs) return captions class BlipBaseModel(ImageCaptionModel): def __init__(self, device): self.checkpoint = "Salesforce/blip-image-captioning-base" processor = AutoProcessor.from_pretrained(self.checkpoint) model = BlipForConditionalGeneration.from_pretrained(self.checkpoint) super().__init__(device, processor, model) def generate(self, image, max_length=50, num_captions=1, **kwargs): """ Generates captions for the given image. ----- Parameters: image: PIL.Image The image to generate captions for. max_len: int The maximum length of the caption. num_captions: int The number of captions to generate. """ captions = super().generate(image, max_length, num_captions, **kwargs) return captions