Spaces:
Runtime error
Runtime error
File size: 4,300 Bytes
df766f8 0d08077 df766f8 0d08077 df766f8 0d08077 df766f8 0d08077 df766f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from transformers import AutoProcessor, AutoModelForCausalLM, BlipForConditionalGeneration
class ImageCaptionModel:
def __init__(
self,
device,
processor,
model,
) -> None:
self.device = device
self.processor = processor
self.model = model
self.model.to(self.device)
def generate(
self,
image,
num_captions=1,
max_length=50,
num_beam_groups=1,
temperature=1.0,
top_k=50,
top_p=1.0,
repetition_penalty=1.0,
diversity_penalty=0.0,
):
"""
Generates captions for the given image.
-----
Parameters:
preprocessor: transformers.PreTrainedTokenizerFast
The preprocessor to use for the model.
model: transformers.PreTrainedModel
The model to use for generating captions.
image: PIL.Image
The image to generate captions for.
num_captions: int
The number of captions to generate.
num_beam_groups: int
The number of beam groups to use for beam search in order to maintain diversity. Must be between 1 and num_beams. 1 means no group_beam_search..
temperature: float
The temperature to use for sampling. The value used to module the next token probabilities that will be used by default in the generate method of the model. Must be strictly positive. Defaults to 1.0.
top_k: int
The number of highest probability vocabulary tokens to keep for top-k-filtering. A large value of top_k will keep more probabilities for each token leading to a better but slower generation. Defaults to 50.
top_p: float
The value that will be used by default in the generate method of the model for top_p. If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
repetition_penalty: float
The parameter for repetition penalty. 1.0 means no penalty. Defaults to 1.0.
diversity_penalty: float
The parameter for diversity penalty. 0.0 means no penalty. Defaults to 0.0.
"""
pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device)
if diversity_penalty != 0.0:
num_beam_groups = 2
num_captions = num_captions if num_captions % 2 == 0 else num_captions + 1
generated_ids = self.model.generate(
pixel_values=pixel_values,
max_length=max_length,
num_beams=num_captions,
num_beam_groups=num_beam_groups,
num_return_sequences=num_captions,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
diversity_penalty=diversity_penalty,
)
generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
return generated_caption[:num_captions]
class GitBaseCocoModel(ImageCaptionModel):
def __init__(self, device):
"""
A wrapper class for the Git-Base-COCO model. It is a pretrained model for image captioning.
-----
Parameters:
device: torch.device
The device to run the model on.
checkpoint: str
The checkpoint to load the model from.
-----
Returns:
None
"""
checkpoint = "microsoft/git-base-coco"
processor = AutoProcessor.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint)
super().__init__(device, processor, model)
def generate(self, image, max_length=50, num_captions=1, **kwargs):
"""
Generates captions for the given image.
-----
Parameters:
image: PIL.Image
The image to generate captions for.
max_len: int
The maximum length of the caption.
num_captions: int
The number of captions to generate.
"""
captions = super().generate(image, max_length, num_captions, **kwargs)
return captions
class BlipBaseModel(ImageCaptionModel):
def __init__(self, device):
self.checkpoint = "Salesforce/blip-image-captioning-base"
processor = AutoProcessor.from_pretrained(self.checkpoint)
model = BlipForConditionalGeneration.from_pretrained(self.checkpoint)
super().__init__(device, processor, model)
def generate(self, image, max_length=50, num_captions=1, **kwargs):
"""
Generates captions for the given image.
-----
Parameters:
image: PIL.Image
The image to generate captions for.
max_len: int
The maximum length of the caption.
num_captions: int
The number of captions to generate.
"""
captions = super().generate(image, max_length, num_captions, **kwargs)
return captions
|