chats-bug commited on
Commit
a95ba86
1 Parent(s): 1d4f82c

Added fine-tuning options

Browse files
Files changed (2) hide show
  1. app.py +28 -21
  2. model.py +76 -61
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import torch
3
  from PIL import Image
4
 
5
- from model import GitBaseCocoModel, BlipBaseModel
6
 
7
  MODELS = {
8
  "Git-Base-COCO": GitBaseCocoModel,
@@ -12,33 +12,38 @@ MODELS = {
12
  def generate_captions(
13
  image,
14
  num_captions,
 
15
  max_length,
16
  temperature,
17
  top_k,
18
  top_p,
19
  repetition_penalty,
20
  diversity_penalty,
21
- model_name,
22
  ):
23
  """
24
  Generates captions for the given image.
25
-
26
  -----
27
  Parameters:
28
  image: PIL.Image
29
  The image to generate captions for.
30
- max_len: int
31
- The maximum length of the caption.
32
  num_captions: int
33
  The number of captions to generate.
34
-
35
  -----
36
  Returns:
37
  list[str]
38
  """
 
 
 
 
 
 
 
39
 
40
  device = "cuda" if torch.cuda.is_available() else "cpu"
41
-
42
  model = MODELS[model_name](device)
43
 
44
  captions = model.generate(
@@ -56,32 +61,34 @@ def generate_captions(
56
  captions = "\n".join(captions)
57
  return captions
58
 
59
- title = "Git-Base-COCO Image Captioning"
60
- description = "A model for generating captions for images."
61
 
62
  interface = gr.Interface(
63
  fn=generate_captions,
64
  inputs=[
65
- gr.inputs.Image(type="pil", label="Image"),
66
- gr.inputs.Slider(minimum=1, maximum=10, step=1, default=1, label="Number of Captions to Generate"),
67
- gr.inputs.Slider(minimum=20, maximum=100, step=5, default=50, label="Maximum Caption Length"),
68
- gr.inputs.Slider(minimum=0.1, maximum=10.0, step=0.1, default=1.0, label="Temperature"),
69
- gr.inputs.Slider(minimum=1, maximum=100, step=1, default=50, label="Top K"),
70
- gr.inputs.Slider(minimum=-5.0, maximum=5.0, step=0.1, default=1.0, label="Top P"),
71
- gr.inputs.Slider(minimum=1.0, maximum=10.0, step=0.1, default=1.0, label="Repetition Penalty"),
72
- gr.inputs.Slider(minimum=0.0, maximum=10.0, step=0.1, default=0.0, label="Diversity Penalty"),
73
- gr.inputs.Dropdown(MODELS.keys(), label="Model"),
74
  ],
75
  outputs=[
76
- gr.outputs.Textbox(label="Caption"),
77
  ],
78
  title=title,
79
  description=description,
80
- )
 
81
 
82
 
83
  if __name__ == "__main__":
 
84
  interface.launch(
85
  enable_queue=True,
86
- debug=True
87
  )
 
2
  import torch
3
  from PIL import Image
4
 
5
+ from model import BlipBaseModel, GitBaseCocoModel
6
 
7
  MODELS = {
8
  "Git-Base-COCO": GitBaseCocoModel,
 
12
  def generate_captions(
13
  image,
14
  num_captions,
15
+ model_name,
16
  max_length,
17
  temperature,
18
  top_k,
19
  top_p,
20
  repetition_penalty,
21
  diversity_penalty,
 
22
  ):
23
  """
24
  Generates captions for the given image.
25
+
26
  -----
27
  Parameters:
28
  image: PIL.Image
29
  The image to generate captions for.
 
 
30
  num_captions: int
31
  The number of captions to generate.
32
+ ** Rest of the parameters are the same as in the model.generate method. **
33
  -----
34
  Returns:
35
  list[str]
36
  """
37
+ # Convert the numerical values to their corresponding types.
38
+ # Gradio Slider returns values as floats: except when the value is a whole number, in which case it returns an int.
39
+ # Only float values suffer from this issue.
40
+ temperature = float(temperature)
41
+ top_p = float(top_p)
42
+ repetition_penalty = float(repetition_penalty)
43
+ diversity_penalty = float(diversity_penalty)
44
 
45
  device = "cuda" if torch.cuda.is_available() else "cpu"
46
+
47
  model = MODELS[model_name](device)
48
 
49
  captions = model.generate(
 
61
  captions = "\n".join(captions)
62
  return captions
63
 
64
+ title = "AI tool for generating captions for images"
65
+ description = "This tool uses pretrained models to generate captions for images."
66
 
67
  interface = gr.Interface(
68
  fn=generate_captions,
69
  inputs=[
70
+ gr.components.Image(type="pil", label="Image"),
71
+ gr.components.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of Captions to Generate"),
72
+ gr.components.Dropdown(MODELS.keys(), label="Model", value=list(MODELS.keys())[1]), # Default to Blip Base
73
+ gr.components.Slider(minimum=20, maximum=100, step=5, value=50, label="Maximum Caption Length"),
74
+ gr.components.Slider(minimum=0.1, maximum=10.0, step=0.1, value=1.0, label="Temperature"),
75
+ gr.components.Slider(minimum=1, maximum=100, step=1, value=50, label="Top K"),
76
+ gr.components.Slider(minimum=0.1, maximum=5.0, step=0.1, value=1.0, label="Top P"),
77
+ gr.components.Slider(minimum=1.0, maximum=10.0, step=0.1, value=2.0, label="Repetition Penalty"),
78
+ gr.components.Slider(minimum=0.0, maximum=10.0, step=0.1, value=2.0, label="Diversity Penalty"),
79
  ],
80
  outputs=[
81
+ gr.components.Textbox(label="Caption"),
82
  ],
83
  title=title,
84
  description=description,
85
+ allow_flagging="never",
86
+ )
87
 
88
 
89
  if __name__ == "__main__":
90
+ # Launch the interface.
91
  interface.launch(
92
  enable_queue=True,
93
+ debug=True,
94
  )
model.py CHANGED
@@ -7,26 +7,41 @@ class ImageCaptionModel:
7
  processor,
8
  model,
9
  ) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  self.device = device
11
  self.processor = processor
12
  self.model = model
13
  self.model.to(self.device)
14
-
15
  def generate(
16
  self,
17
  image,
18
- num_captions=1,
19
- max_length=50,
20
- num_beam_groups=1,
21
- temperature=1.0,
22
- top_k=50,
23
- top_p=1.0,
24
- repetition_penalty=1.0,
25
- diversity_penalty=0.0,
26
  ):
27
  """
28
  Generates captions for the given image.
29
-
30
  -----
31
  Parameters:
32
  preprocessor: transformers.PreTrainedTokenizerFast
@@ -37,8 +52,6 @@ class ImageCaptionModel:
37
  The image to generate captions for.
38
  num_captions: int
39
  The number of captions to generate.
40
- num_beam_groups: int
41
- The number of beam groups to use for beam search in order to maintain diversity. Must be between 1 and num_beams. 1 means no group_beam_search..
42
  temperature: float
43
  The temperature to use for sampling. The value used to module the next token probabilities that will be used by default in the generate method of the model. Must be strictly positive. Defaults to 1.0.
44
  top_k: int
@@ -49,25 +62,45 @@ class ImageCaptionModel:
49
  The parameter for repetition penalty. 1.0 means no penalty. Defaults to 1.0.
50
  diversity_penalty: float
51
  The parameter for diversity penalty. 0.0 means no penalty. Defaults to 0.0.
52
-
53
  """
54
- pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device)
55
-
56
- generated_ids = self.model.generate(
57
- pixel_values=pixel_values,
58
- max_length=max_length,
59
- num_beams=num_captions*2,
60
- num_beam_groups=num_beam_groups,
61
- num_return_sequences=num_captions*2,
62
- temperature=temperature,
63
- top_k=top_k,
64
- top_p=top_p,
65
- repetition_penalty=repetition_penalty,
66
- diversity_penalty=diversity_penalty,
67
- )
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
70
- generated_caption = [generated_caption[i] for i in range(0, num_captions*2, 2)]
71
 
72
  return generated_caption
73
 
@@ -79,8 +112,8 @@ class GitBaseCocoModel(ImageCaptionModel):
79
 
80
  -----
81
  Parameters:
82
- device: torch.device
83
- The device to run the model on.
84
  checkpoint: str
85
  The checkpoint to load the model from.
86
 
@@ -93,42 +126,24 @@ class GitBaseCocoModel(ImageCaptionModel):
93
  model = AutoModelForCausalLM.from_pretrained(checkpoint)
94
  super().__init__(device, processor, model)
95
 
96
- def generate(self, image, max_length=50, num_captions=1, **kwargs):
97
- """
98
- Generates captions for the given image.
99
-
100
- -----
101
- Parameters:
102
- image: PIL.Image
103
- The image to generate captions for.
104
- max_len: int
105
- The maximum length of the caption.
106
- num_captions: int
107
- The number of captions to generate.
108
- """
109
- captions = super().generate(image, max_length, num_captions, **kwargs)
110
- return captions
111
-
112
 
113
  class BlipBaseModel(ImageCaptionModel):
114
  def __init__(self, device):
115
- self.checkpoint = "Salesforce/blip-image-captioning-base"
116
- processor = AutoProcessor.from_pretrained(self.checkpoint)
117
- model = BlipForConditionalGeneration.from_pretrained(self.checkpoint)
118
- super().__init__(device, processor, model)
119
-
120
- def generate(self, image, max_length=50, num_captions=1, **kwargs):
121
  """
122
- Generates captions for the given image.
123
 
124
  -----
125
  Parameters:
126
- image: PIL.Image
127
- The image to generate captions for.
128
- max_len: int
129
- The maximum length of the caption.
130
- num_captions: int
131
- The number of captions to generate.
 
 
132
  """
133
- captions = super().generate(image, max_length, num_captions, **kwargs)
134
- return captions
 
 
 
7
  processor,
8
  model,
9
  ) -> None:
10
+ """
11
+ Initializes the model for generating captions for images.
12
+
13
+ -----
14
+ Parameters:
15
+ device: str
16
+ The device to use for the model. Must be either "cpu" or "cuda".
17
+ processor: transformers.AutoProcessor
18
+ The preprocessor to use for the model.
19
+ model: transformers.AutoModelForCausalLM or transformers.BlipForConditionalGeneration
20
+ The model to use for generating captions.
21
+
22
+ -----
23
+ Returns:
24
+ None
25
+ """
26
  self.device = device
27
  self.processor = processor
28
  self.model = model
29
  self.model.to(self.device)
30
+
31
  def generate(
32
  self,
33
  image,
34
+ num_captions: int = 1,
35
+ max_length: int = 50,
36
+ temperature: float = 1.0,
37
+ top_k: int = 50,
38
+ top_p: float = 1.0,
39
+ repetition_penalty: float = 1.0,
40
+ diversity_penalty: float = 0.0,
 
41
  ):
42
  """
43
  Generates captions for the given image.
44
+
45
  -----
46
  Parameters:
47
  preprocessor: transformers.PreTrainedTokenizerFast
 
52
  The image to generate captions for.
53
  num_captions: int
54
  The number of captions to generate.
 
 
55
  temperature: float
56
  The temperature to use for sampling. The value used to module the next token probabilities that will be used by default in the generate method of the model. Must be strictly positive. Defaults to 1.0.
57
  top_k: int
 
62
  The parameter for repetition penalty. 1.0 means no penalty. Defaults to 1.0.
63
  diversity_penalty: float
64
  The parameter for diversity penalty. 0.0 means no penalty. Defaults to 0.0.
65
+
66
  """
67
+ # Type checking and making sure the values are valid.
68
+ assert type(num_captions) == int and num_captions > 0, "num_captions must be a positive integer."
69
+ assert type(max_length) == int and max_length > 0, "max_length must be a positive integer."
70
+ assert type(temperature) == float and temperature > 0.0, "temperature must be a positive float."
71
+ assert type(top_k) == int and top_k > 0, "top_k must be a positive integer."
72
+ assert type(top_p) == float and top_p > 0.0, "top_p must be a positive float."
73
+ assert type(repetition_penalty) == float and repetition_penalty >= 1.0, "repetition_penalty must be a positive float greater than or equal to 1."
74
+ assert type(diversity_penalty) == float and diversity_penalty >= 0.0, "diversity_penalty must be a non negative float."
 
 
 
 
 
 
75
 
76
+ pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device) # Convert the image to pixel values.
77
+
78
+ # Generate captions ids.
79
+ if num_captions == 1:
80
+ generated_ids = self.model.generate(
81
+ pixel_values=pixel_values,
82
+ max_length=max_length,
83
+ num_return_sequences=1,
84
+ temperature=temperature,
85
+ top_k=top_k,
86
+ top_p=top_p,
87
+ )
88
+ else:
89
+ generated_ids = self.model.generate(
90
+ pixel_values=pixel_values,
91
+ max_length=max_length,
92
+ num_beams=num_captions, # num_beams must be greater than or equal to num_captions and must be divisible by num_beam_groups.
93
+ num_beam_groups=num_captions, # num_beam_groups is set to equal to num_captions so that all the captions are diverse
94
+ num_return_sequences=num_captions, # generate multiple captions which are very similar to each other due to the grouping effect of beam search.
95
+ temperature=temperature,
96
+ top_k=top_k,
97
+ top_p=top_p,
98
+ repetition_penalty=repetition_penalty,
99
+ diversity_penalty=diversity_penalty,
100
+ )
101
+
102
+ # Decode the generated ids to get the captions.
103
  generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
 
104
 
105
  return generated_caption
106
 
 
112
 
113
  -----
114
  Parameters:
115
+ device: str
116
+ The device to run the model on, either "cpu" or "cuda".
117
  checkpoint: str
118
  The checkpoint to load the model from.
119
 
 
126
  model = AutoModelForCausalLM.from_pretrained(checkpoint)
127
  super().__init__(device, processor, model)
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  class BlipBaseModel(ImageCaptionModel):
131
  def __init__(self, device):
 
 
 
 
 
 
132
  """
133
+ A wrapper class for the Blip-Base model. It is a pretrained model for image captioning.
134
 
135
  -----
136
  Parameters:
137
+ device: str
138
+ The device to run the model on, either "cpu" or "cuda".
139
+ checkpoint: str
140
+ The checkpoint to load the model from.
141
+
142
+ -----
143
+ Returns:
144
+ None
145
  """
146
+ self.checkpoint = "Salesforce/blip-image-captioning-base"
147
+ processor = AutoProcessor.from_pretrained(self.checkpoint)
148
+ model = BlipForConditionalGeneration.from_pretrained(self.checkpoint)
149
+ super().__init__(device, processor, model)