p1atdev commited on
Commit
dfc1ef8
1 Parent(s): b97a649

Upload 2 files

Browse files
image_processing_tagger.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copied from ViTImageProcessor (https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/vit/image_processing_vit.py)
2
+
3
+ """Image processor class for WD v14 Tagger."""
4
+
5
+ from typing import Optional, List, Dict, Union, Tuple
6
+
7
+ import numpy as np
8
+ import cv2
9
+ from PIL import Image
10
+
11
+ from transformers.image_processing_utils import (
12
+ BaseImageProcessor,
13
+ BatchFeature,
14
+ get_size_dict,
15
+ )
16
+ from transformers.image_transforms import (
17
+ rescale,
18
+ to_channel_dimension_format,
19
+ _rescale_for_pil_conversion,
20
+ to_pil_image,
21
+ )
22
+ from transformers.image_utils import (
23
+ IMAGENET_STANDARD_MEAN,
24
+ IMAGENET_STANDARD_STD,
25
+ ChannelDimension,
26
+ ImageInput,
27
+ PILImageResampling,
28
+ infer_channel_dimension_format,
29
+ is_scaled_image,
30
+ make_list_of_images,
31
+ to_numpy_array,
32
+ valid_images,
33
+ )
34
+ from transformers.utils import TensorType, logging
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+
39
+ # copied from transformers.image_transforms.resize
40
+ def resize_with_padding(
41
+ image: np.ndarray,
42
+ size: Tuple[int, int],
43
+ color: Tuple[int, int, int],
44
+ resample: PILImageResampling = None,
45
+ reducing_gap: Optional[int] = None,
46
+ data_format: Optional[ChannelDimension] = None,
47
+ return_numpy: bool = True,
48
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
49
+ ):
50
+ """
51
+ Resizes `image` to `(height, width)` specified by `size` using the PIL library.
52
+
53
+ Args:
54
+ image (`np.ndarray`):
55
+ The image to resize.
56
+ size (`Tuple[int, int]`):
57
+ The size to use for resizing the image.
58
+ color (`Tuple[int, int, int]`):
59
+ The color to use for padding the image.
60
+ resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
61
+ The filter to user for resampling.
62
+ reducing_gap (`int`, *optional*):
63
+ Apply optimization by resizing the image in two steps. The bigger `reducing_gap`, the closer the result to
64
+ the fair resampling. See corresponding Pillow documentation for more details.
65
+ data_format (`ChannelDimension`, *optional*):
66
+ The channel dimension format of the output image. If unset, will use the inferred format from the input.
67
+ return_numpy (`bool`, *optional*, defaults to `True`):
68
+ Whether or not to return the resized image as a numpy array. If False a `PIL.Image.Image` object is
69
+ returned.
70
+ input_data_format (`ChannelDimension`, *optional*):
71
+ The channel dimension format of the input image. If unset, will use the inferred format from the input.
72
+
73
+ Returns:
74
+ `np.ndarray`: The resized image.
75
+ """
76
+
77
+ resample = resample if resample is not None else PILImageResampling.BILINEAR
78
+
79
+ if not len(size) == 2:
80
+ raise ValueError("size must have 2 elements")
81
+
82
+ # For all transformations, we want to keep the same data format as the input image unless otherwise specified.
83
+ # The resized image from PIL will always have channels last, so find the input format first.
84
+ if input_data_format is None:
85
+ input_data_format = infer_channel_dimension_format(image)
86
+ data_format = input_data_format if data_format is None else data_format
87
+
88
+ # To maintain backwards compatibility with the resizing done in previous image feature extractors, we use
89
+ # the pillow library to resize the image and then convert back to numpy
90
+ do_rescale = False
91
+ if not isinstance(image, Image.Image):
92
+ do_rescale = _rescale_for_pil_conversion(image)
93
+ image = to_pil_image(
94
+ image, do_rescale=do_rescale, input_data_format=input_data_format
95
+ )
96
+ # PIL images are in the format (width, height)
97
+
98
+ assert isinstance(image, Image.Image)
99
+
100
+ height, width = size
101
+ original_width, original_height = image.size
102
+
103
+ # ratio
104
+ ratio = min(width / original_width, height / original_height)
105
+
106
+ # rescale and keep aspect ratio
107
+ new_width = int(original_width * ratio)
108
+ new_height = int(original_height * ratio)
109
+
110
+ resized_image = image.resize(
111
+ (new_width, new_height), resample=resample, reducing_gap=reducing_gap
112
+ )
113
+
114
+ # solid background
115
+ new_image = Image.new("RGBA", size, (color) + (255,))
116
+
117
+ # paste resized image at the center
118
+ offset = ((width - new_width) // 2, (height - new_height) // 2)
119
+ new_image.paste(
120
+ resized_image.convert("RGBA"), offset, resized_image.convert("RGBA")
121
+ )
122
+
123
+ new_image = new_image.convert("RGB")
124
+
125
+ if return_numpy:
126
+ new_image = np.array(new_image)
127
+ # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
128
+ # so we need to add it back if necessary.
129
+ new_image = (
130
+ np.expand_dims(new_image, axis=-1) if new_image.ndim == 2 else new_image
131
+ )
132
+ # The image is always in channels last format after converting from a PIL image
133
+ new_image = to_channel_dimension_format(
134
+ new_image, data_format, input_channel_dim=ChannelDimension.LAST
135
+ )
136
+ # If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to
137
+ # rescale it back to the original range.
138
+ new_image = rescale(new_image, 1 / 255) if do_rescale else new_image
139
+
140
+ return new_image
141
+
142
+
143
+ class WDv14TaggerImageProcessor(BaseImageProcessor):
144
+ r"""
145
+ Constructs a WD v14 Tagger image processor.
146
+
147
+ Args:
148
+ do_resize (`bool`, *optional*, defaults to `True`):
149
+ Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
150
+ size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
151
+ size (`dict`, *optional*, defaults to `{"height": 448, "width": 448}`):
152
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
153
+ method.
154
+ color (`List[int]`):
155
+ Color to use for padding the image after resizing. Can be overridden by the `size` parameter in the `preprocess`
156
+ method.
157
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
158
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
159
+ `preprocess` method.
160
+ do_rescale (`bool`, *optional*, defaults to `True`):
161
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
162
+ parameter in the `preprocess` method.
163
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
164
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
165
+ `preprocess` method.
166
+ do_normalize (`bool`, *optional*, defaults to `True`):
167
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
168
+ method.
169
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
170
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
171
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
172
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
173
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
174
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
175
+ """
176
+
177
+ model_input_names = ["pixel_values"]
178
+
179
+ def __init__(
180
+ self,
181
+ do_resize: bool = True,
182
+ size: Optional[Dict[str, int]] = None,
183
+ color: Optional[List[int]] = None,
184
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
185
+ do_rescale: bool = True,
186
+ rescale_factor: Union[int, float] = 1 / 255,
187
+ do_normalize: bool = True,
188
+ image_mean: Optional[Union[float, List[float]]] = None,
189
+ image_std: Optional[Union[float, List[float]]] = None,
190
+ **kwargs,
191
+ ) -> None:
192
+ super().__init__(**kwargs)
193
+ size = size if size is not None else {"height": 448, "width": 448}
194
+ size = get_size_dict(size)
195
+ color = color if color is not None else [255, 255, 255]
196
+ self.do_resize = do_resize
197
+ self.do_rescale = do_rescale
198
+ self.do_normalize = do_normalize
199
+ self.size = size
200
+ self.color = color
201
+ self.resample = resample
202
+ self.rescale_factor = rescale_factor
203
+ self.image_mean = (
204
+ image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
205
+ )
206
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
207
+
208
+ def resize(
209
+ self,
210
+ image: np.ndarray,
211
+ size: Dict[str, int],
212
+ color: List[int] = [255, 255, 255],
213
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
214
+ data_format: Optional[Union[str, ChannelDimension]] = None,
215
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
216
+ **kwargs,
217
+ ) -> np.ndarray:
218
+ """
219
+ Resize an image to `(size["height"], size["width"])`.
220
+
221
+ Args:
222
+ image (`np.ndarray`):
223
+ Image to resize.
224
+ size (`Dict[str, int]`):
225
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
226
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
227
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
228
+ data_format (`ChannelDimension` or `str`, *optional*):
229
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
230
+ image is used. Can be one of:
231
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
232
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
233
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
234
+ input_data_format (`ChannelDimension` or `str`, *optional*):
235
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
236
+ from the input image. Can be one of:
237
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
238
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
239
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
240
+
241
+ Returns:
242
+ `np.ndarray`: The resized image.
243
+ """
244
+ size = get_size_dict(size)
245
+ if "height" not in size or "width" not in size:
246
+ raise ValueError(
247
+ f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}"
248
+ )
249
+
250
+ output_size = (size["height"], size["width"])
251
+
252
+ color = tuple(color)
253
+
254
+ return resize_with_padding(
255
+ image,
256
+ size=output_size,
257
+ color=color,
258
+ resample=resample,
259
+ data_format=data_format,
260
+ input_data_format=input_data_format,
261
+ **kwargs,
262
+ )
263
+
264
+ def preprocess(
265
+ self,
266
+ images: ImageInput,
267
+ do_resize: Optional[bool] = None,
268
+ size: Optional[Dict[str, int]] = None,
269
+ color: Optional[List[int]] = None,
270
+ resample: PILImageResampling = None,
271
+ do_rescale: Optional[bool] = None,
272
+ rescale_factor: Optional[float] = None,
273
+ do_normalize: Optional[bool] = None,
274
+ image_mean: Optional[Union[float, List[float]]] = None,
275
+ image_std: Optional[Union[float, List[float]]] = None,
276
+ return_tensors: Optional[Union[str, TensorType]] = None,
277
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
278
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
279
+ **kwargs,
280
+ ):
281
+ """
282
+ Preprocess an image or batch of images.
283
+
284
+ Args:
285
+ images (`ImageInput`):
286
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
287
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
288
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
289
+ Whether to resize the image.
290
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
291
+ Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
292
+ resizing.
293
+ resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
294
+ `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
295
+ an effect if `do_resize` is set to `True`.
296
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
297
+ Whether to rescale the image values between [0 - 1].
298
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
299
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
300
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
301
+ Whether to normalize the image.
302
+ return_tensors (`str` or `TensorType`, *optional*):
303
+ The type of tensors to return. Can be one of:
304
+ - Unset: Return a list of `np.ndarray`.
305
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
306
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
307
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
308
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
309
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
310
+ The channel dimension format for the output image. Can be one of:
311
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
312
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
313
+ - Unset: Use the channel dimension format of the input image.
314
+ input_data_format (`ChannelDimension` or `str`, *optional*):
315
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
316
+ from the input image. Can be one of:
317
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
318
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
319
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
320
+ """
321
+ do_resize = do_resize if do_resize is not None else self.do_resize
322
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
323
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
324
+ resample = resample if resample is not None else self.resample
325
+ rescale_factor = (
326
+ rescale_factor if rescale_factor is not None else self.rescale_factor
327
+ )
328
+ image_mean = image_mean if image_mean is not None else self.image_mean
329
+ image_std = image_std if image_std is not None else self.image_std
330
+
331
+ size = size if size is not None else self.size
332
+ size_dict = get_size_dict(size)
333
+
334
+ color = color if color is not None else self.color
335
+
336
+ images = make_list_of_images(images)
337
+
338
+ if not valid_images(images):
339
+ raise ValueError(
340
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
341
+ "torch.Tensor, tf.Tensor or jax.ndarray."
342
+ )
343
+
344
+ if do_resize and size is None:
345
+ raise ValueError("Size must be specified if do_resize is True.")
346
+
347
+ if do_rescale and rescale_factor is None:
348
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
349
+
350
+ # All transformations expect numpy arrays.
351
+ images = [to_numpy_array(image) for image in images]
352
+
353
+ if is_scaled_image(images[0]) and do_rescale:
354
+ logger.warning_once(
355
+ "It looks like you are trying to rescale already rescaled images. If the input"
356
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
357
+ )
358
+
359
+ if input_data_format is None:
360
+ # We assume that all images have the same channel dimension format.
361
+ input_data_format = infer_channel_dimension_format(images[0])
362
+
363
+ if do_resize:
364
+ images = [
365
+ self.resize(
366
+ image=image,
367
+ size=size_dict,
368
+ color=color,
369
+ resample=resample,
370
+ input_data_format=input_data_format,
371
+ )
372
+ for image in images
373
+ ]
374
+
375
+ if do_rescale:
376
+ images = [
377
+ self.rescale(
378
+ image=image,
379
+ scale=rescale_factor,
380
+ input_data_format=input_data_format,
381
+ )
382
+ for image in images
383
+ ]
384
+
385
+ if do_normalize:
386
+ images = [
387
+ self.normalize(
388
+ image=image,
389
+ mean=image_mean,
390
+ std=image_std,
391
+ input_data_format=input_data_format,
392
+ )
393
+ for image in images
394
+ ]
395
+
396
+ images = [
397
+ to_channel_dimension_format(
398
+ image, data_format, input_channel_dim=input_data_format
399
+ )
400
+ for image in images
401
+ ]
402
+
403
+ data = {"pixel_values": images}
404
+ return BatchFeature(data=data, tensor_type=return_tensors)
preprocessor_config.json CHANGED
@@ -19,7 +19,7 @@
19
  0.5
20
  ],
21
  "resample": 2,
22
- "rescale_factor": 0.00392156862745098,
23
  "size": {
24
  "height": 448,
25
  "width": 448
 
19
  0.5
20
  ],
21
  "resample": 2,
22
+ "rescale_factor": 0.0,
23
  "size": {
24
  "height": 448,
25
  "width": 448