File size: 4,300 Bytes
df766f8
0d08077
df766f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d08077
 
 
 
 
 
 
 
 
 
 
 
 
 
df766f8
 
 
 
0d08077
df766f8
0d08077
 
 
 
 
 
 
 
 
 
 
 
df766f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from transformers import AutoProcessor, AutoModelForCausalLM, BlipForConditionalGeneration

class ImageCaptionModel:
	def __init__(
		self,
		device,
		processor,
		model,
	) -> None:
		self.device = device
		self.processor = processor
		self.model = model
		self.model.to(self.device)
	
	def generate(
		self,
		image,
		num_captions=1,
		max_length=50,
		num_beam_groups=1,
		temperature=1.0,
		top_k=50,
		top_p=1.0,
		repetition_penalty=1.0,
		diversity_penalty=0.0,
	):
		"""
		Generates captions for the given image.
		
		-----
		Parameters:
		preprocessor: transformers.PreTrainedTokenizerFast
			The preprocessor to use for the model.
		model: transformers.PreTrainedModel	
			The model to use for generating captions.
		image: PIL.Image
			The image to generate captions for.
		num_captions: int
			The number of captions to generate.
		num_beam_groups: int
			The number of beam groups to use for beam search in order to maintain diversity. Must be between 1 and num_beams. 1 means no group_beam_search..
		temperature: float
			The temperature to use for sampling. The value used to module the next token probabilities that will be used by default in the generate method of the model. Must be strictly positive. Defaults to 1.0.
		top_k: int
			The number of highest probability vocabulary tokens to keep for top-k-filtering. A large value of top_k will keep more probabilities for each token leading to a better but slower generation. Defaults to 50.
		top_p: float
			The value that will be used by default in the generate method of the model for top_p. If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
		repetition_penalty: float
			The parameter for repetition penalty. 1.0 means no penalty. Defaults to 1.0.
		diversity_penalty: float
			The parameter for diversity penalty. 0.0 means no penalty. Defaults to 0.0.
		
		"""
		pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device)
			
		if diversity_penalty != 0.0:
			num_beam_groups = 2
			num_captions = num_captions if num_captions % 2 == 0 else num_captions + 1

		generated_ids = self.model.generate(
			pixel_values=pixel_values,
			max_length=max_length,
			num_beams=num_captions,
			num_beam_groups=num_beam_groups,
			num_return_sequences=num_captions,
			temperature=temperature,
			top_k=top_k,
			top_p=top_p,
			repetition_penalty=repetition_penalty,
			diversity_penalty=diversity_penalty,
		)

		generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)

		return generated_caption[:num_captions]

class GitBaseCocoModel(ImageCaptionModel):
	def __init__(self, device):
		"""
		A wrapper class for the Git-Base-COCO model. It is a pretrained model for image captioning.

		-----
		Parameters:
		device: torch.device
			The device to run the model on.
		checkpoint: str
			The checkpoint to load the model from.

		-----
		Returns:
		None
		"""
		checkpoint = "microsoft/git-base-coco"
		processor = AutoProcessor.from_pretrained(checkpoint)
		model = AutoModelForCausalLM.from_pretrained(checkpoint)
		super().__init__(device, processor, model)

	def generate(self, image, max_length=50, num_captions=1, **kwargs):
		"""
		Generates captions for the given image.

		-----
		Parameters:
		image: PIL.Image
			The image to generate captions for.
		max_len: int
			The maximum length of the caption.
		num_captions: int
			The number of captions to generate.
		"""
		captions = super().generate(image, max_length, num_captions, **kwargs)
		return captions
	

class BlipBaseModel(ImageCaptionModel):
	def __init__(self, device):
		self.checkpoint = "Salesforce/blip-image-captioning-base"
		processor = AutoProcessor.from_pretrained(self.checkpoint)
		model = BlipForConditionalGeneration.from_pretrained(self.checkpoint)
		super().__init__(device, processor, model)

	def generate(self, image, max_length=50, num_captions=1, **kwargs):
		"""
		Generates captions for the given image.

		-----
		Parameters:
		image: PIL.Image
			The image to generate captions for.
		max_len: int
			The maximum length of the caption.
		num_captions: int
			The number of captions to generate.
		"""
		captions = super().generate(image, max_length, num_captions, **kwargs)
		return captions