from transformers import AutoProcessor, AutoModelForCausalLM class GitBaseCocoModel: def __init__(self, device, checkpoint="microsoft/git-base-coco"): """ A wrapper class for the Git-Base-COCO model. It is a pretrained model for image captioning. ----- Parameters: device: torch.device The device to run the model on. checkpoint: str The checkpoint to load the model from. ----- Returns: None """ self.checkpoint = checkpoint self.device = device self.processor = AutoProcessor.from_pretrained(self.checkpoint) self.model = AutoModelForCausalLM.from_pretrained(self.checkpoint).to(self.device) def generate(self, image, max_len=50, num_captions=1): """ Generates captions for the given image. ----- Parameters: image: PIL.Image The image to generate captions for. max_len: int The maximum length of the caption. num_captions: int The number of captions to generate. """ pixel_values = self.processor( images=image, return_tensors="pt" ).pixel_values.to(self.device) generated_ids = self.model.generate( pixel_values=pixel_values, max_length=max_len, num_beams=num_captions, num_return_sequences=num_captions, ) return self.processor.batch_decode(generated_ids, skip_special_tokens=True)