File size: 7,995 Bytes
d5779bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import torch
import comfy.utils
from .Pytorch_Retinaface.pytorch_retinaface import Pytorch_RetinaFace
from comfy.model_management import get_torch_device

class AutoCropFaces:
    def __init__(self):
        pass
    
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "image": ("IMAGE",),
                "number_of_faces": ("INT", {
                    "default": 5, 
                    "min": 1,
                    "max": 100,
                    "step": 1,
                }),
                "scale_factor": ("FLOAT", {
                    "default": 1.5,
                    "min": 0.5,
                    "max": 10,
                    "step": 0.5,
                    "display": "slider"
                }),
                "shift_factor": ("FLOAT", {
                    "default": 0.45,
                    "min": 0,
                    "max": 1,
                    "step": 0.01,
                    "display": "slider"
                }),
                "start_index": ("INT", {
                    "default": 0,
                    "step": 1,
                    "display": "number"
                }),
                "max_faces_per_image": ("INT", {
                    "default": 50,
                    "min": 1,
                    "max": 1000,
                    "step": 1,
                }),
                # "aspect_ratio": ("FLOAT", {
                #     "default": 1, 
                #     "min": 0.2,
                #     "max": 5,
                #     "step": 0.1,
                # }),
                "aspect_ratio": (["9:16", "2:3", "3:4", "4:5", "1:1", "5:4", "4:3", "3:2", "16:9"], {
                    "default": "1:1",
                }),
            },
        }

    RETURN_TYPES = ("IMAGE", "CROP_DATA")
    RETURN_NAMES = ("face",)

    FUNCTION = "auto_crop_faces"

    CATEGORY = "Faces"

    def aspect_ratio_string_to_float(self, str_aspect_ratio="1:1"):
        a, b = map(float, str_aspect_ratio.split(':'))
        return a / b

    def auto_crop_faces_in_image (self, image, max_number_of_faces, scale_factor, shift_factor, aspect_ratio, method='lanczos'): 
        image_255 = image * 255
        rf = Pytorch_RetinaFace(top_k=50, keep_top_k=max_number_of_faces, device=get_torch_device())
        dets = rf.detect_faces(image_255)
        cropped_faces, bbox_info = rf.center_and_crop_rescale(image, dets, scale_factor=scale_factor, shift_factor=shift_factor, aspect_ratio=aspect_ratio)

        # Add a batch dimension to each cropped face
        cropped_faces_with_batch = [face.unsqueeze(0) for face in cropped_faces]
        return cropped_faces_with_batch, bbox_info

    def auto_crop_faces(self, image, number_of_faces, start_index, max_faces_per_image, scale_factor, shift_factor, aspect_ratio, method='lanczos'):
        """ 

        "image" - Input can be one image or a batch of images with shape (batch, width, height, channel count)

        "number_of_faces" - This is passed into PyTorch_RetinaFace which allows you to define a maximum number of faces to look for.

        "start_index" - The starting index of which face you select out of the set of detected faces.

        "scale_factor" - How much crop factor or padding do you want around each detected face.

        "shift_factor" - Pan up or down relative to the face, 0.5 should be right in the center.

        "aspect_ratio" - When we crop, you can have it crop down at a particular aspect ratio.

        "method" - Scaling pixel sampling interpolation method.

        """
        
        # Turn aspect ratio to float value
        aspect_ratio = self.aspect_ratio_string_to_float(aspect_ratio)

        selected_faces, detected_cropped_faces = [], []
        selected_crop_data, detected_crop_data = [], []
        original_images = []

        # Loop through the input batches. Even if there is only one input image, it's still considered a batch.
        for i in range(image.shape[0]):

            original_images.append(image[i].unsqueeze(0)) # Temporarily the image, but insure it still has the batch dimension.
            # Detect the faces in the image, this will return multiple images and crop data for it.
            cropped_images, infos = self.auto_crop_faces_in_image(
                image[i],
                max_faces_per_image,
                scale_factor,
                shift_factor,
                aspect_ratio,
                method)

            detected_cropped_faces.extend(cropped_images)
            detected_crop_data.extend(infos)

        # If we haven't detected anything, just return the original images, and default crop data.
        if not detected_cropped_faces or len(detected_cropped_faces) == 0:
            selected_crop_data = [(0, 0, img.shape[3], img.shape[2]) for img in original_images]
            return (image, selected_crop_data)

         # Circular index calculation
        start_index = start_index % len(detected_cropped_faces)

        if number_of_faces >= len(detected_cropped_faces):
            selected_faces = detected_cropped_faces[start_index:] + detected_cropped_faces[:start_index]
            selected_crop_data = detected_crop_data[start_index:] + detected_crop_data[:start_index]
        else:
            end_index = (start_index + number_of_faces) % len(detected_cropped_faces)
            if start_index < end_index:
                selected_faces = detected_cropped_faces[start_index:end_index]
                selected_crop_data = detected_crop_data[start_index:end_index]
            else:
                selected_faces = detected_cropped_faces[start_index:] + detected_cropped_faces[:end_index]
                selected_crop_data = detected_crop_data[start_index:] + detected_crop_data[:end_index]

        # If we haven't selected anything, then return original images.
        if len(selected_faces) == 0: 
            selected_crop_data = [(0, 0, img.shape[3], img.shape[2]) for img in original_images]
            return (image, selected_crop_data)

        # If there is only one detected face in batch of images, just return that one.
        elif len(selected_faces) <= 1:
            out = selected_faces[0]
            return (out, selected_crop_data)

        # Determine the index of the face with the maximum width
        max_width_index = max(range(len(selected_faces)), key=lambda i: selected_faces[i].shape[1])

        # Determine the maximum width
        max_width = selected_faces[max_width_index].shape[1]
        max_height = selected_faces[max_width_index].shape[2]
        shape = (max_height, max_width)

        out = None
        # All images need to have the same width/height to fit into the tensor such that we can output as image batches.
        for face_image in selected_faces:
            if shape != face_image.shape[1:3]: # Determine whether cropped face image size matches largest cropped face image. 
                face_image = comfy.utils.common_upscale( # This method expects (batch, channel, height, width)
                    face_image.movedim(-1, 1), # Move channel dimension to width dimension
                    max_height, # Height
                    max_width, # Width
                    method, # Pixel sampling method.
                    "" # Only "center" is implemented right now, and we don't want to use that.
                ).movedim(1, -1)
            # Append the fitted image into the tensor.
            if out is None:
                out = face_image
            else:
                out = torch.cat((out, face_image), dim=0)

        return (out, selected_crop_data)

NODE_CLASS_MAPPINGS = {
    "AutoCropFaces": AutoCropFaces
}

# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
    "AutoCropFaces": "Auto Crop Faces"
}