Joseph Catrambone
commited on
Commit
•
568dc2c
1
Parent(s):
f0855ac
First import. Add ControlNetSD21 Laion Face (full, pruned, and safetensors). Add README and samples. Add surrounding tools for example use.
Browse files- README.md +140 -3
- gradio_face2image.py +105 -0
- laion_face_common.py +180 -0
- laion_face_dataset.py +55 -0
- models/cldm_v21.yaml +85 -0
- models/controlnet_sd21_laion_face_v2_full.ckpt +3 -0
- models/controlnet_sd21_laion_face_v2_pruned.pth +3 -0
- models/controlnet_sd21_laion_face_v2_pruned.safetensors +3 -0
- samples_laion_face_dataset/happy_annotation.png +3 -0
- samples_laion_face_dataset/happy_result.png +3 -0
- samples_laion_face_dataset/happy_source.jpg +3 -0
- samples_laion_face_dataset/neutral_annotation.png +3 -0
- samples_laion_face_dataset/neutral_result.png +3 -0
- samples_laion_face_dataset/neutral_source.jpg +3 -0
- samples_laion_face_dataset/sad_annotation.png +3 -0
- samples_laion_face_dataset/sad_result.png +3 -0
- samples_laion_face_dataset/sad_source.jpg +3 -0
- samples_laion_face_dataset/screaming_annotation.png +3 -0
- samples_laion_face_dataset/screaming_result.png +3 -0
- samples_laion_face_dataset/screaming_source.jpg +3 -0
- samples_laion_face_dataset/sideways_annotation.png +3 -0
- samples_laion_face_dataset/sideways_result.png +3 -0
- samples_laion_face_dataset/sideways_source.jpg +3 -0
- samples_laion_face_dataset/surprised_annotation.png +3 -0
- samples_laion_face_dataset/surprised_result.png +3 -0
- samples_laion_face_dataset/surprised_source.jpg +3 -0
- tool_download_face_targets.py +86 -0
- tool_generate_face_poses.py +180 -0
- train_laion_face.py +46 -0
- train_laion_face_sd15.py +42 -0
README.md
CHANGED
@@ -1,3 +1,140 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ControlNet LAION Face Dataset
|
2 |
+
|
3 |
+
## Table of Contents:
|
4 |
+
- Overview: Samples, Contents, and Construction
|
5 |
+
- Usage: Downloading, Training, and Inference
|
6 |
+
- License
|
7 |
+
- Credits and Thanks
|
8 |
+
|
9 |
+
# Overview:
|
10 |
+
|
11 |
+
This dataset is designed to train a ControlNet with human facial expressions. It includes keypoints for pupils to allow gaze direction. Training has been tested on Stable Diffusion v2.1 base (512) and Stable Diffusion v1.5.
|
12 |
+
|
13 |
+
## Samples:
|
14 |
+
|
15 |
+
Cherry-picked from ControlNet + Stable Diffusion v2.1 Base
|
16 |
+
|
17 |
+
|Prompt|Input|Face Detection|Output|
|
18 |
+
|---:|:---:|:---:|:---:|
|
19 |
+
|Happy| <img src="samples_laion_face_dataset/happy_source.jpg" width="256" height="256"> | <img src="samples_laion_face_dataset/happy_annotation.png" width="256" height="256"> | <img src="samples_laion_face_dataset/happy_result.png" width="256" height="256"> |
|
20 |
+
|Neutral| <img src="samples_laion_face_dataset/neutral_source.jpg" width="256" height="256"> | <img src="samples_laion_face_dataset/neutral_annotation.png" with="256" height="256"> | <img src="samples_laion_face_dataset/neutral_result.png" width="256" height="256"> |
|
21 |
+
|Sad| <img src="samples_laion_face_dataset/sad_source.jpg" width="256" height="256"> | <img src="samples_laion_face_dataset/sad_annotation.png" width="256" height="256"> | <img src="samples_laion_face_dataset/sad_result.png" width="256" height="256"> |
|
22 |
+
|Screaming| <img src="samples_laion_face_dataset/screaming_source.jpg" width="256" height="256"> | <img src="samples_laion_face_dataset/screaming_annotation.png" width="256" height="256"> | <img src="samples_laion_face_dataset/screaming_result.png" width="256" height="256"> |
|
23 |
+
|Sideways| <img src="samples_laion_face_dataset/sideways_source.jpg" width="256" height="256"> | <img src="samples_laion_face_dataset/sideways_annotation.png" width="256" height="256"> | <img src="samples_laion_face_dataset/sideways_result.png" width="256" height="256"> |
|
24 |
+
|Surprised| <img src="samples_laion_face_dataset/surprised_source.jpg" width="256" height="256"> | <img src="samples_laion_face_dataset/surprised_annotation.png" width="256" height="256"> | <img src="samples_laion_face_dataset/surprised_result.png" width="256" height="256"> |
|
25 |
+
|
26 |
+
|
27 |
+
## Dataset Contents:
|
28 |
+
|
29 |
+
- train_laion_face.py - Entrypoint for ControlNet training.
|
30 |
+
- laion_face_dataset.py - Code for performing dataset iteration. Cropping and resizing happens here.
|
31 |
+
- tool_download_face_targets.py - A tool to read metadata.json and populate the target folder.
|
32 |
+
- tool_generate_face_poses.py - The original file used to generate the source images. Included for reproducibility, but not required for training.
|
33 |
+
- training/laion-face-processed/prompt.jsonl - Read by laion_face_dataset. Includes prompts for the images.
|
34 |
+
- training/laion-face-processed/metadata.json - Excerpts from LAION for the relevant data. Also used for downloading the target dataset.
|
35 |
+
- training/laion-face-processed/source/xxxxxxxxx.jpg - Images with detections performed. Generated from the target images.
|
36 |
+
- training/laion-face-processed/target/xxxxxxxxx.jpg - Selected images from LAION Face.
|
37 |
+
|
38 |
+
## Dataset Construction:
|
39 |
+
|
40 |
+
Source images were generated by pulling slice 00000 from LAION Face and passing them through MediaPipe's face detector with special configuration parameters.
|
41 |
+
|
42 |
+
The colors and line thicknesses used for MediaPipe are as follows:
|
43 |
+
|
44 |
+
```
|
45 |
+
f_thick = 2
|
46 |
+
f_rad = 1
|
47 |
+
right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad)
|
48 |
+
right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad)
|
49 |
+
right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad)
|
50 |
+
left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad)
|
51 |
+
left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad)
|
52 |
+
left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad)
|
53 |
+
mouth_draw = DrawingSpec(color=(10, 180, 10), thickness=f_thick, circle_radius=f_rad)
|
54 |
+
head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad)
|
55 |
+
|
56 |
+
iris_landmark_spec = {468: right_iris_draw, 473: left_iris_draw}
|
57 |
+
```
|
58 |
+
|
59 |
+
We have implemented a method named `draw_pupils` which modifies some functionality from MediaPipe. It exists as a stopgap until some pending changes are merged.
|
60 |
+
|
61 |
+
|
62 |
+
# Usage:
|
63 |
+
|
64 |
+
The containing ZIP file should be decompressed into the root of the ControlNet directory. The `train_laion_face.py`, `laion_face_dataset.py`, and other `.py` files should sit adjacent to `tutorial_train.py` and `tutorial_train_sd21.py`. We are assuming a checkout of the ControlNet repo at 0acb7e5, but there is no direct dependency on the repository.
|
65 |
+
|
66 |
+
## Downloading:
|
67 |
+
|
68 |
+
For copyright reasons, we cannot include the original target files. We have provided a script (tool_download_face_targets.py) which will read from training/laion-face-processed/metadata.json and populate the target folder. This file has no requirements, but will use tqdm if it is installed.
|
69 |
+
|
70 |
+
## Training:
|
71 |
+
|
72 |
+
When the targets folder is fully populated, training can be run on a machine with at least 24 gigabytes of VRAM. Our model was trained for 200 hours (four epochs) on an A6000.
|
73 |
+
|
74 |
+
```bash
|
75 |
+
python tool_add_control.py ./models/v1-5-pruned-emaonly.ckpt ./models/controlnet_sd15_laion_face.ckpt
|
76 |
+
python ./train_laion_face_sd15.py
|
77 |
+
```
|
78 |
+
|
79 |
+
## Inference:
|
80 |
+
|
81 |
+
We have provided `gradio_face2image.py`. Update the following two lines to point them to your trained model.
|
82 |
+
|
83 |
+
```
|
84 |
+
model = create_model('./models/cldm_v21.yaml').cpu() # If you fine-tuned on SD2.1 base, this does not need to change.
|
85 |
+
model.load_state_dict(load_state_dict('./models/control_sd21_openpose.pth', location='cuda'))
|
86 |
+
```
|
87 |
+
|
88 |
+
The model has some limitations: while it is empirically better at tracking gaze and mouth poses than previous attempts, it may still ignore controls. Adding details to the prompt like, "looking right" can abate bad behavior.
|
89 |
+
|
90 |
+
|
91 |
+
# License:
|
92 |
+
|
93 |
+
### Source Images: (/training/laion-face-processed/source/)
|
94 |
+
This work is marked with CC0 1.0. To view a copy of this license, visit http://creativecommons.org/publicdomain/zero/1.0
|
95 |
+
|
96 |
+
### Trained Models:
|
97 |
+
Our trained ControlNet checkpoints are released under CreativeML Open RAIL-M.
|
98 |
+
|
99 |
+
### Source Code:
|
100 |
+
lllyasviel/ControlNet is licensed under the Apache License 2.0
|
101 |
+
|
102 |
+
Our modifications are released under the same license.
|
103 |
+
|
104 |
+
|
105 |
+
# Credits and Thanks:
|
106 |
+
|
107 |
+
Greatest thanks to Zhang et al. for ControlNet, Rombach et al. (StabilityAI) for Stable Diffusion, and Schuhmann et al. for LAION.
|
108 |
+
|
109 |
+
Sample images for this document were obtained from Unsplash and are CC0.
|
110 |
+
|
111 |
+
```
|
112 |
+
@misc{zhang2023adding,
|
113 |
+
title={Adding Conditional Control to Text-to-Image Diffusion Models},
|
114 |
+
author={Lvmin Zhang and Maneesh Agrawala},
|
115 |
+
year={2023},
|
116 |
+
eprint={2302.05543},
|
117 |
+
archivePrefix={arXiv},
|
118 |
+
primaryClass={cs.CV}
|
119 |
+
}
|
120 |
+
|
121 |
+
@misc{rombach2021highresolution,
|
122 |
+
title={High-Resolution Image Synthesis with Latent Diffusion Models},
|
123 |
+
author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer},
|
124 |
+
year={2021},
|
125 |
+
eprint={2112.10752},
|
126 |
+
archivePrefix={arXiv},
|
127 |
+
primaryClass={cs.CV}
|
128 |
+
}
|
129 |
+
|
130 |
+
@misc{schuhmann2022laion5b,
|
131 |
+
title={LAION-5B: An open large-scale dataset for training next generation image-text models},
|
132 |
+
author={Christoph Schuhmann and Romain Beaumont and Richard Vencu and Cade Gordon and Ross Wightman and Mehdi Cherti and Theo Coombes and Aarush Katta and Clayton Mullis and Mitchell Wortsman and Patrick Schramowski and Srivatsa Kundurthy and Katherine Crowson and Ludwig Schmidt and Robert Kaczmarczyk and Jenia Jitsev},
|
133 |
+
year={2022},
|
134 |
+
eprint={2210.08402},
|
135 |
+
archivePrefix={arXiv},
|
136 |
+
primaryClass={cs.CV}
|
137 |
+
}
|
138 |
+
```
|
139 |
+
|
140 |
+
This project was made possible by Crucible AI.
|
gradio_face2image.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Mapping
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import numpy
|
6 |
+
import torch
|
7 |
+
import random
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from cldm.model import create_model, load_state_dict
|
11 |
+
from cldm.ddim_hacked import DDIMSampler
|
12 |
+
from laion_face_common import generate_annotation
|
13 |
+
from share import *
|
14 |
+
|
15 |
+
|
16 |
+
model = create_model('./models/cldm_v21.yaml').cpu()
|
17 |
+
model.load_state_dict(load_state_dict('./models/controlnet_face_condition_epoch_4_0percent.ckpt', location='cuda'))
|
18 |
+
model = model.cuda()
|
19 |
+
ddim_sampler = DDIMSampler(model) # ControlNet _only_ works with DDIM.
|
20 |
+
|
21 |
+
|
22 |
+
def process(input_image: Image.Image, prompt, a_prompt, n_prompt, max_faces, num_samples, ddim_steps, guess_mode, strength, scale, seed, eta):
|
23 |
+
with torch.no_grad():
|
24 |
+
empty = generate_annotation(input_image, max_faces)
|
25 |
+
visualization = Image.fromarray(empty) # Save to help debug.
|
26 |
+
|
27 |
+
empty = numpy.moveaxis(empty, 2, 0) # h, w, c -> c, h, w
|
28 |
+
control = torch.from_numpy(empty.copy()).float().cuda() / 255.0
|
29 |
+
control = torch.stack([control for _ in range(num_samples)], dim=0)
|
30 |
+
# control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
31 |
+
|
32 |
+
# Sanity check the dimensions.
|
33 |
+
B, C, H, W = control.shape
|
34 |
+
assert C == 3
|
35 |
+
assert B == num_samples
|
36 |
+
|
37 |
+
if seed != -1:
|
38 |
+
random.seed(seed)
|
39 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
40 |
+
numpy.random.seed(seed)
|
41 |
+
torch.manual_seed(seed)
|
42 |
+
torch.cuda.manual_seed(seed)
|
43 |
+
torch.backends.cudnn.deterministic = True
|
44 |
+
|
45 |
+
if config.save_memory:
|
46 |
+
model.low_vram_shift(is_diffusing=False)
|
47 |
+
|
48 |
+
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
|
49 |
+
un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
|
50 |
+
shape = (4, H // 8, W // 8)
|
51 |
+
|
52 |
+
if config.save_memory:
|
53 |
+
model.low_vram_shift(is_diffusing=True)
|
54 |
+
|
55 |
+
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
|
56 |
+
samples, intermediates = ddim_sampler.sample(
|
57 |
+
ddim_steps,
|
58 |
+
num_samples,
|
59 |
+
shape,
|
60 |
+
cond,
|
61 |
+
verbose=False,
|
62 |
+
eta=eta,
|
63 |
+
unconditional_guidance_scale=scale,
|
64 |
+
unconditional_conditioning=un_cond
|
65 |
+
)
|
66 |
+
|
67 |
+
if config.save_memory:
|
68 |
+
model.low_vram_shift(is_diffusing=False)
|
69 |
+
|
70 |
+
x_samples = model.decode_first_stage(samples)
|
71 |
+
# x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(numpy.uint8)
|
72 |
+
x_samples = numpy.moveaxis((x_samples * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(numpy.uint8), 1, -1) # b, c, h, w -> b, h, w, c
|
73 |
+
results = [visualization] + [x_samples[i] for i in range(num_samples)]
|
74 |
+
|
75 |
+
return results
|
76 |
+
|
77 |
+
|
78 |
+
block = gr.Blocks().queue()
|
79 |
+
with block:
|
80 |
+
with gr.Row():
|
81 |
+
gr.Markdown("## Control Stable Diffusion with a Facial Pose")
|
82 |
+
with gr.Row():
|
83 |
+
with gr.Column():
|
84 |
+
input_image = gr.Image(source='upload', type="numpy")
|
85 |
+
prompt = gr.Textbox(label="Prompt")
|
86 |
+
run_button = gr.Button(label="Run")
|
87 |
+
with gr.Accordion("Advanced options", open=False):
|
88 |
+
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
|
89 |
+
max_faces = gr.Slider(label="Max Faces", minimum=1, maximum=5, value=1, step=1)
|
90 |
+
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
91 |
+
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
92 |
+
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
93 |
+
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
94 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
95 |
+
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
96 |
+
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
97 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
98 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
99 |
+
with gr.Column():
|
100 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
101 |
+
ips = [input_image, prompt, a_prompt, n_prompt, max_faces, num_samples, ddim_steps, guess_mode, strength, scale, seed, eta]
|
102 |
+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
|
103 |
+
|
104 |
+
|
105 |
+
block.launch(server_name='0.0.0.0')
|
laion_face_common.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Mapping
|
2 |
+
|
3 |
+
import mediapipe as mp
|
4 |
+
import numpy
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
mp_drawing = mp.solutions.drawing_utils
|
9 |
+
mp_drawing_styles = mp.solutions.drawing_styles
|
10 |
+
mp_face_detection = mp.solutions.face_detection # Only for counting faces.
|
11 |
+
mp_face_mesh = mp.solutions.face_mesh
|
12 |
+
mp_face_connections = mp.solutions.face_mesh_connections.FACEMESH_TESSELATION
|
13 |
+
mp_hand_connections = mp.solutions.hands_connections.HAND_CONNECTIONS
|
14 |
+
mp_body_connections = mp.solutions.pose_connections.POSE_CONNECTIONS
|
15 |
+
|
16 |
+
DrawingSpec = mp.solutions.drawing_styles.DrawingSpec
|
17 |
+
PoseLandmark = mp.solutions.drawing_styles.PoseLandmark
|
18 |
+
|
19 |
+
f_thick = 2
|
20 |
+
f_rad = 1
|
21 |
+
right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad)
|
22 |
+
right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad)
|
23 |
+
right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad)
|
24 |
+
left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad)
|
25 |
+
left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad)
|
26 |
+
left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad)
|
27 |
+
mouth_draw = DrawingSpec(color=(10, 180, 10), thickness=f_thick, circle_radius=f_rad)
|
28 |
+
head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad)
|
29 |
+
|
30 |
+
# mp_face_mesh.FACEMESH_CONTOURS has all the items we care about.
|
31 |
+
face_connection_spec = {}
|
32 |
+
for edge in mp_face_mesh.FACEMESH_FACE_OVAL:
|
33 |
+
face_connection_spec[edge] = head_draw
|
34 |
+
for edge in mp_face_mesh.FACEMESH_LEFT_EYE:
|
35 |
+
face_connection_spec[edge] = left_eye_draw
|
36 |
+
for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW:
|
37 |
+
face_connection_spec[edge] = left_eyebrow_draw
|
38 |
+
# for edge in mp_face_mesh.FACEMESH_LEFT_IRIS:
|
39 |
+
# face_connection_spec[edge] = left_iris_draw
|
40 |
+
for edge in mp_face_mesh.FACEMESH_RIGHT_EYE:
|
41 |
+
face_connection_spec[edge] = right_eye_draw
|
42 |
+
for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW:
|
43 |
+
face_connection_spec[edge] = right_eyebrow_draw
|
44 |
+
# for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS:
|
45 |
+
# face_connection_spec[edge] = right_iris_draw
|
46 |
+
for edge in mp_face_mesh.FACEMESH_LIPS:
|
47 |
+
face_connection_spec[edge] = mouth_draw
|
48 |
+
iris_landmark_spec = {468: right_iris_draw, 473: left_iris_draw}
|
49 |
+
|
50 |
+
|
51 |
+
def draw_pupils(image, landmark_list, drawing_spec, halfwidth: int = 2):
|
52 |
+
"""We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all
|
53 |
+
landmarks. Until our PR is merged into mediapipe, we need this separate method."""
|
54 |
+
if len(image.shape) != 3:
|
55 |
+
raise ValueError("Input image must be H,W,C.")
|
56 |
+
image_rows, image_cols, image_channels = image.shape
|
57 |
+
if image_channels != 3: # BGR channels
|
58 |
+
raise ValueError('Input image must contain three channel bgr data.')
|
59 |
+
for idx, landmark in enumerate(landmark_list.landmark):
|
60 |
+
if (
|
61 |
+
(landmark.HasField('visibility') and landmark.visibility < 0.9) or
|
62 |
+
(landmark.HasField('presence') and landmark.presence < 0.5)
|
63 |
+
):
|
64 |
+
continue
|
65 |
+
if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0:
|
66 |
+
continue
|
67 |
+
image_x = int(image_cols*landmark.x)
|
68 |
+
image_y = int(image_rows*landmark.y)
|
69 |
+
draw_color = None
|
70 |
+
if isinstance(drawing_spec, Mapping):
|
71 |
+
if drawing_spec.get(idx) is None:
|
72 |
+
continue
|
73 |
+
else:
|
74 |
+
draw_color = drawing_spec[idx].color
|
75 |
+
elif isinstance(drawing_spec, DrawingSpec):
|
76 |
+
draw_color = drawing_spec.color
|
77 |
+
image[image_y-halfwidth:image_y+halfwidth, image_x-halfwidth:image_x+halfwidth, :] = draw_color
|
78 |
+
|
79 |
+
|
80 |
+
def reverse_channels(image):
|
81 |
+
"""Given a numpy array in RGB form, convert to BGR. Will also convert from BGR to RGB."""
|
82 |
+
# im[:,:,::-1] is a neat hack to convert BGR to RGB by reversing the indexing order.
|
83 |
+
# im[:,:,::[2,1,0]] would also work but makes a copy of the data.
|
84 |
+
return image[:, :, ::-1]
|
85 |
+
|
86 |
+
|
87 |
+
def generate_annotation(
|
88 |
+
input_image: Image.Image,
|
89 |
+
max_faces: int,
|
90 |
+
min_face_size_pixels: int = 0,
|
91 |
+
return_annotation_data: bool = False
|
92 |
+
):
|
93 |
+
"""
|
94 |
+
Find up to 'max_faces' inside the provided input image.
|
95 |
+
If min_face_size_pixels is provided and nonzero it will be used to filter faces that occupy less than this many
|
96 |
+
pixels in the image.
|
97 |
+
If return_annotation_data is TRUE (default: false) then in addition to returning the 'detected face' image, three
|
98 |
+
additional parameters will be returned: faces before filtering, faces after filtering, and an annotation image.
|
99 |
+
The faces_before_filtering return value is the number of faces detected in an image with no filtering.
|
100 |
+
faces_after_filtering is the number of faces remaining after filtering small faces.
|
101 |
+
|
102 |
+
:return:
|
103 |
+
If 'return_annotation_data==True', returns (numpy array, numpy array, int, int).
|
104 |
+
If 'return_annotation_data==False' (default), returns a numpy array.
|
105 |
+
"""
|
106 |
+
with mp_face_mesh.FaceMesh(
|
107 |
+
static_image_mode=True,
|
108 |
+
max_num_faces=max_faces,
|
109 |
+
refine_landmarks=True,
|
110 |
+
min_detection_confidence=0.5,
|
111 |
+
) as facemesh:
|
112 |
+
img_rgb = numpy.asarray(input_image)
|
113 |
+
results = facemesh.process(img_rgb).multi_face_landmarks
|
114 |
+
|
115 |
+
faces_found_before_filtering = len(results)
|
116 |
+
|
117 |
+
# Filter faces that are too small
|
118 |
+
filtered_landmarks = []
|
119 |
+
for lm in results:
|
120 |
+
landmarks = lm.landmark
|
121 |
+
face_rect = [
|
122 |
+
landmarks[0].x,
|
123 |
+
landmarks[0].y,
|
124 |
+
landmarks[0].x,
|
125 |
+
landmarks[0].y,
|
126 |
+
] # Left, up, right, down.
|
127 |
+
for i in range(len(landmarks)):
|
128 |
+
face_rect[0] = min(face_rect[0], landmarks[i].x)
|
129 |
+
face_rect[1] = min(face_rect[1], landmarks[i].y)
|
130 |
+
face_rect[2] = max(face_rect[2], landmarks[i].x)
|
131 |
+
face_rect[3] = max(face_rect[3], landmarks[i].y)
|
132 |
+
if min_face_size_pixels > 0:
|
133 |
+
face_width = abs(face_rect[2] - face_rect[0])
|
134 |
+
face_height = abs(face_rect[3] - face_rect[1])
|
135 |
+
face_width_pixels = face_width * input_image.size[0]
|
136 |
+
face_height_pixels = face_height * input_image.size[1]
|
137 |
+
face_size = min(face_width_pixels, face_height_pixels)
|
138 |
+
if face_size >= min_face_size_pixels:
|
139 |
+
filtered_landmarks.append(lm)
|
140 |
+
else:
|
141 |
+
filtered_landmarks.append(lm)
|
142 |
+
|
143 |
+
faces_remaining_after_filtering = len(filtered_landmarks)
|
144 |
+
|
145 |
+
# Annotations are drawn in BGR for some reason, but we don't need to flip a zero-filled image at the start.
|
146 |
+
empty = numpy.zeros_like(img_rgb)
|
147 |
+
|
148 |
+
# Draw detected faces:
|
149 |
+
for face_landmarks in filtered_landmarks:
|
150 |
+
mp_drawing.draw_landmarks(
|
151 |
+
empty,
|
152 |
+
face_landmarks,
|
153 |
+
connections=face_connection_spec.keys(),
|
154 |
+
landmark_drawing_spec=None,
|
155 |
+
connection_drawing_spec=face_connection_spec
|
156 |
+
)
|
157 |
+
draw_pupils(empty, face_landmarks, iris_landmark_spec, 2)
|
158 |
+
|
159 |
+
# Flip BGR back to RGB.
|
160 |
+
empty = reverse_channels(empty)
|
161 |
+
|
162 |
+
# We might have to generate a composite.
|
163 |
+
if return_annotation_data:
|
164 |
+
# Note that we're copying the input image AND flipping the channels so we can draw on top of it.
|
165 |
+
annotated = reverse_channels(numpy.asarray(input_image)).copy()
|
166 |
+
for face_landmarks in filtered_landmarks:
|
167 |
+
mp_drawing.draw_landmarks(
|
168 |
+
empty,
|
169 |
+
face_landmarks,
|
170 |
+
connections=face_connection_spec.keys(),
|
171 |
+
landmark_drawing_spec=None,
|
172 |
+
connection_drawing_spec=face_connection_spec
|
173 |
+
)
|
174 |
+
draw_pupils(empty, face_landmarks, iris_landmark_spec, 2)
|
175 |
+
annotated = reverse_channels(annotated)
|
176 |
+
|
177 |
+
if not return_annotation_data:
|
178 |
+
return empty
|
179 |
+
else:
|
180 |
+
return empty, annotated, faces_found_before_filtering, faces_remaining_after_filtering
|
laion_face_dataset.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import numpy
|
3 |
+
import os
|
4 |
+
from PIL import Image
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
|
8 |
+
class LaionDataset(Dataset):
|
9 |
+
def __init__(self):
|
10 |
+
self.data = []
|
11 |
+
with open('./training/laion-face-processed/prompt.jsonl', 'rt') as f:
|
12 |
+
for line in f:
|
13 |
+
self.data.append(json.loads(line))
|
14 |
+
|
15 |
+
def __len__(self):
|
16 |
+
return len(self.data)
|
17 |
+
|
18 |
+
def __getitem__(self, idx):
|
19 |
+
item = self.data[idx]
|
20 |
+
|
21 |
+
source_filename = os.path.split(item['source'])[-1]
|
22 |
+
target_filename = os.path.split(item['target'])[-1]
|
23 |
+
prompt = item['prompt']
|
24 |
+
|
25 |
+
# If prompt is "" or null, make it something simple.
|
26 |
+
if not prompt:
|
27 |
+
print(f"Image with index {idx} / {source_filename} has no text.")
|
28 |
+
prompt = "an image"
|
29 |
+
|
30 |
+
source_image = Image.open('./training/laion-face-processed/source/' + source_filename).convert("RGB")
|
31 |
+
target_image = Image.open('./training/laion-face-processed/target/' + target_filename).convert("RGB")
|
32 |
+
# Resize the image so that the minimum edge is bigger than 512x512, then crop center.
|
33 |
+
# This may cut off some parts of the face image, but in general they're smaller than 512x512 and we still want
|
34 |
+
# to cover the literal edge cases.
|
35 |
+
img_size = source_image.size
|
36 |
+
scale_factor = 512/min(img_size)
|
37 |
+
source_image = source_image.resize((1+int(img_size[0]*scale_factor), 1+int(img_size[1]*scale_factor)))
|
38 |
+
target_image = target_image.resize((1+int(img_size[0]*scale_factor), 1+int(img_size[1]*scale_factor)))
|
39 |
+
img_size = source_image.size
|
40 |
+
left_padding = (img_size[0] - 512)//2
|
41 |
+
top_padding = (img_size[1] - 512)//2
|
42 |
+
source_image = source_image.crop((left_padding, top_padding, left_padding+512, top_padding+512))
|
43 |
+
target_image = target_image.crop((left_padding, top_padding, left_padding+512, top_padding+512))
|
44 |
+
|
45 |
+
source = numpy.asarray(source_image)
|
46 |
+
target = numpy.asarray(target_image)
|
47 |
+
|
48 |
+
# Normalize source images to [0, 1].
|
49 |
+
source = source.astype(numpy.float32) / 255.0
|
50 |
+
|
51 |
+
# Normalize target images to [-1, 1].
|
52 |
+
target = (target.astype(numpy.float32) / 127.5) - 1.0
|
53 |
+
|
54 |
+
return dict(jpg=target, txt=prompt, hint=source)
|
55 |
+
|
models/cldm_v21.yaml
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: cldm.cldm.ControlLDM
|
3 |
+
params:
|
4 |
+
linear_start: 0.00085
|
5 |
+
linear_end: 0.0120
|
6 |
+
num_timesteps_cond: 1
|
7 |
+
log_every_t: 200
|
8 |
+
timesteps: 1000
|
9 |
+
first_stage_key: "jpg"
|
10 |
+
cond_stage_key: "txt"
|
11 |
+
control_key: "hint"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False
|
19 |
+
only_mid_control: False
|
20 |
+
|
21 |
+
control_stage_config:
|
22 |
+
target: cldm.cldm.ControlNet
|
23 |
+
params:
|
24 |
+
use_checkpoint: True
|
25 |
+
image_size: 32 # unused
|
26 |
+
in_channels: 4
|
27 |
+
hint_channels: 3
|
28 |
+
model_channels: 320
|
29 |
+
attention_resolutions: [ 4, 2, 1 ]
|
30 |
+
num_res_blocks: 2
|
31 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
32 |
+
num_head_channels: 64 # need to fix for flash-attn
|
33 |
+
use_spatial_transformer: True
|
34 |
+
use_linear_in_transformer: True
|
35 |
+
transformer_depth: 1
|
36 |
+
context_dim: 1024
|
37 |
+
legacy: False
|
38 |
+
|
39 |
+
unet_config:
|
40 |
+
target: cldm.cldm.ControlledUnetModel
|
41 |
+
params:
|
42 |
+
use_checkpoint: True
|
43 |
+
image_size: 32 # unused
|
44 |
+
in_channels: 4
|
45 |
+
out_channels: 4
|
46 |
+
model_channels: 320
|
47 |
+
attention_resolutions: [ 4, 2, 1 ]
|
48 |
+
num_res_blocks: 2
|
49 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
50 |
+
num_head_channels: 64 # need to fix for flash-attn
|
51 |
+
use_spatial_transformer: True
|
52 |
+
use_linear_in_transformer: True
|
53 |
+
transformer_depth: 1
|
54 |
+
context_dim: 1024
|
55 |
+
legacy: False
|
56 |
+
|
57 |
+
first_stage_config:
|
58 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
59 |
+
params:
|
60 |
+
embed_dim: 4
|
61 |
+
monitor: val/rec_loss
|
62 |
+
ddconfig:
|
63 |
+
#attn_type: "vanilla-xformers"
|
64 |
+
double_z: true
|
65 |
+
z_channels: 4
|
66 |
+
resolution: 256
|
67 |
+
in_channels: 3
|
68 |
+
out_ch: 3
|
69 |
+
ch: 128
|
70 |
+
ch_mult:
|
71 |
+
- 1
|
72 |
+
- 2
|
73 |
+
- 4
|
74 |
+
- 4
|
75 |
+
num_res_blocks: 2
|
76 |
+
attn_resolutions: []
|
77 |
+
dropout: 0.0
|
78 |
+
lossconfig:
|
79 |
+
target: torch.nn.Identity
|
80 |
+
|
81 |
+
cond_stage_config:
|
82 |
+
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
83 |
+
params:
|
84 |
+
freeze: True
|
85 |
+
layer: "penultimate"
|
models/controlnet_sd21_laion_face_v2_full.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ef9ecdb1479a0ddf391d8f5388cd8bffbe27c772bf4c2e466d60feeb7c4bcf5a
|
3 |
+
size 9586176211
|
models/controlnet_sd21_laion_face_v2_pruned.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:57178eca39962fbfbae1a1d6f838499e749fd626a6b67894f20bbf9af77464db
|
3 |
+
size 1457029033
|
models/controlnet_sd21_laion_face_v2_pruned.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:53633d3a2b27b5c00c94a3caf11110274e370ed581a7f3a171551ea9e66cae9b
|
3 |
+
size 1456951266
|
samples_laion_face_dataset/happy_annotation.png
ADDED
Git LFS Details
|
samples_laion_face_dataset/happy_result.png
ADDED
Git LFS Details
|
samples_laion_face_dataset/happy_source.jpg
ADDED
Git LFS Details
|
samples_laion_face_dataset/neutral_annotation.png
ADDED
Git LFS Details
|
samples_laion_face_dataset/neutral_result.png
ADDED
Git LFS Details
|
samples_laion_face_dataset/neutral_source.jpg
ADDED
Git LFS Details
|
samples_laion_face_dataset/sad_annotation.png
ADDED
Git LFS Details
|
samples_laion_face_dataset/sad_result.png
ADDED
Git LFS Details
|
samples_laion_face_dataset/sad_source.jpg
ADDED
Git LFS Details
|
samples_laion_face_dataset/screaming_annotation.png
ADDED
Git LFS Details
|
samples_laion_face_dataset/screaming_result.png
ADDED
Git LFS Details
|
samples_laion_face_dataset/screaming_source.jpg
ADDED
Git LFS Details
|
samples_laion_face_dataset/sideways_annotation.png
ADDED
Git LFS Details
|
samples_laion_face_dataset/sideways_result.png
ADDED
Git LFS Details
|
samples_laion_face_dataset/sideways_source.jpg
ADDED
Git LFS Details
|
samples_laion_face_dataset/surprised_annotation.png
ADDED
Git LFS Details
|
samples_laion_face_dataset/surprised_result.png
ADDED
Git LFS Details
|
samples_laion_face_dataset/surprised_source.jpg
ADDED
Git LFS Details
|
tool_download_face_targets.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
"""
|
3 |
+
tool_download_face_targets.py
|
4 |
+
|
5 |
+
Reads in the metadata from the LAION images and begins downloading all images.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import json
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
import time
|
12 |
+
import urllib
|
13 |
+
import urllib.request
|
14 |
+
try:
|
15 |
+
from tqdm import tqdm
|
16 |
+
except ImportError:
|
17 |
+
# Wrap this method into the identity.
|
18 |
+
print("TQDM not found. Progress will be quiet without 'verbose'.")
|
19 |
+
def tqdm(x):
|
20 |
+
return x
|
21 |
+
|
22 |
+
|
23 |
+
def main(logfile_path: str, verbose: bool = False, pause_between_fetches: float = 0.0):
|
24 |
+
"""Open the metadata.json file from the training directory and fetch all target images."""
|
25 |
+
# Toggle a function pointer so we don't have to check verbosity everywhere.
|
26 |
+
def out(x):
|
27 |
+
pass
|
28 |
+
if verbose:
|
29 |
+
out = print
|
30 |
+
|
31 |
+
log = open(logfile_path, 'at')
|
32 |
+
skipped_image_count = 0
|
33 |
+
errored_image_count = 0
|
34 |
+
successful_image_count = 0
|
35 |
+
if not os.path.exists("training"):
|
36 |
+
print("ERROR: training directory does not exist in the current directory.")
|
37 |
+
print("Has the archive been unzipped?")
|
38 |
+
print("Are you running from the project root?")
|
39 |
+
return 2 # BASH: No such directory.
|
40 |
+
if not os.path.exists("training/laion-face-processed/metadata.json"):
|
41 |
+
print("ERROR: metadata.json was not found in training/laion-face-processed.")
|
42 |
+
return 2
|
43 |
+
with open("training/laion-face-processed/metadata.json", 'rt') as md_in:
|
44 |
+
metadata = json.load(md_in)
|
45 |
+
# Create the directory for targets if it does not exist.
|
46 |
+
if not os.path.exists("training/laion-face-processed/target"):
|
47 |
+
os.mkdir("training/laion-face-processed/target")
|
48 |
+
for image_id, image_data in tqdm(metadata.items()):
|
49 |
+
filename = f"training/laion-face-processed/target/{image_id}.jpg"
|
50 |
+
if os.path.exists(filename):
|
51 |
+
out(f"Skipping {image_id}: file exists.")
|
52 |
+
skipped_image_count += 1
|
53 |
+
continue
|
54 |
+
if not download_file(image_data['url'], filename, verbose):
|
55 |
+
error_message = f"Problem downloading {image_id}"
|
56 |
+
out(error_message)
|
57 |
+
log.write(error_message + "\n")
|
58 |
+
log.flush() # Flush often in case we crash.
|
59 |
+
errored_image_count += 1
|
60 |
+
if pause_between_fetches > 0.0:
|
61 |
+
time.sleep(pause_between_fetches)
|
62 |
+
successful_image_count += 1
|
63 |
+
log.close()
|
64 |
+
print("Run success.")
|
65 |
+
print(f"{skipped_image_count} images skipped")
|
66 |
+
print(f"{errored_image_count} images failed to download")
|
67 |
+
print(f"{successful_image_count} images downloaded")
|
68 |
+
|
69 |
+
|
70 |
+
def download_file(url: str, output_path: str, verbose: bool = False) -> bool:
|
71 |
+
"""Download the file with the given URL and save it to the specified path. Return true on success."""
|
72 |
+
try:
|
73 |
+
r = urllib.request.urlopen(url)
|
74 |
+
if not r.status == 200:
|
75 |
+
return False
|
76 |
+
with open(output_path, 'wb') as fout:
|
77 |
+
fout.write(r.read())
|
78 |
+
return True
|
79 |
+
except Exception as e:
|
80 |
+
if verbose:
|
81 |
+
print(e)
|
82 |
+
return False
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
main("downloads.log", verbose="-v" in sys.argv)
|
tool_generate_face_poses.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from dataclasses import dataclass, field
|
5 |
+
from glob import glob
|
6 |
+
from typing import Mapping
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from laion_face_common import generate_annotation
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class RunProgress:
|
16 |
+
pending: list = field(default_factory=list)
|
17 |
+
success: list = field(default_factory=list)
|
18 |
+
skipped_size: list = field(default_factory=list)
|
19 |
+
skipped_nsfw: list = field(default_factory=list)
|
20 |
+
skipped_noface: list = field(default_factory=list)
|
21 |
+
skipped_smallface: list = field(default_factory=list)
|
22 |
+
|
23 |
+
|
24 |
+
def main(
|
25 |
+
status_filename: str,
|
26 |
+
prompt_filename: str,
|
27 |
+
input_glob: str,
|
28 |
+
output_directory: str,
|
29 |
+
annotated_output_directory: str = "",
|
30 |
+
min_image_size: int = 384,
|
31 |
+
max_image_size: int = 32766,
|
32 |
+
min_face_size_pixels: int = 64,
|
33 |
+
prompt_mapping: dict = None, # If present, maps a filename to a text prompt.
|
34 |
+
):
|
35 |
+
status = RunProgress()
|
36 |
+
|
37 |
+
if os.path.exists(status_filename):
|
38 |
+
print("Continuing from checkpoint.")
|
39 |
+
# Restore a saved state:
|
40 |
+
status_temp = json.load(open(status_filename, 'rt'))
|
41 |
+
for k in status.__dict__.keys():
|
42 |
+
status.__setattr__(k, status_temp[k])
|
43 |
+
# Output label file:
|
44 |
+
pout = open(prompt_filename, 'at')
|
45 |
+
else:
|
46 |
+
print("Starting run.")
|
47 |
+
status = RunProgress()
|
48 |
+
status.pending = list(glob(input_glob))
|
49 |
+
# Output label file:
|
50 |
+
pout = open(prompt_filename, 'wt')
|
51 |
+
with open(status_filename, 'wt') as fout:
|
52 |
+
json.dump(status.__dict__, fout)
|
53 |
+
|
54 |
+
print(f"{len(status.pending)} images remaining")
|
55 |
+
|
56 |
+
# If we don't have a preexisting set of labels (like for ImageNet/MSCOCO), just null-fill the mapping.
|
57 |
+
# We will try on a per-image basis to see if there's a metadata .json.
|
58 |
+
if prompt_mapping is None:
|
59 |
+
prompt_mapping = dict()
|
60 |
+
|
61 |
+
step = 0
|
62 |
+
with tqdm(total=len(status.pending)) as pbar:
|
63 |
+
while len(status.pending) > 0:
|
64 |
+
full_filename = status.pending.pop()
|
65 |
+
pbar.update(1)
|
66 |
+
step += 1
|
67 |
+
|
68 |
+
if step % 100 == 0:
|
69 |
+
# Checkpoint save:
|
70 |
+
with open(status_filename, 'wt') as fout:
|
71 |
+
json.dump(status.__dict__, fout)
|
72 |
+
|
73 |
+
_fpath, fname = os.path.split(full_filename)
|
74 |
+
|
75 |
+
# Make our output filenames.
|
76 |
+
# We used to do this here so we could check if a file existed before writing, then skip it, but since we
|
77 |
+
# have a 'status' that we cache and update, we no longer have to do this check.
|
78 |
+
annotation_filename = ""
|
79 |
+
if annotated_output_directory:
|
80 |
+
annotation_filename = os.path.join(annotated_output_directory, fname)
|
81 |
+
output_filename = os.path.join(output_directory, fname)
|
82 |
+
|
83 |
+
# The LAION dataset has accompanying .json files with each image.
|
84 |
+
partial_filename, extension = os.path.splitext(full_filename)
|
85 |
+
candidate_json_fullpath = partial_filename + ".json"
|
86 |
+
image_metadata = {}
|
87 |
+
if os.path.exists(candidate_json_fullpath):
|
88 |
+
try:
|
89 |
+
image_metadata = json.load(open(candidate_json_fullpath, 'rt'))
|
90 |
+
except Exception as e:
|
91 |
+
print(e)
|
92 |
+
if "NSFW" in image_metadata:
|
93 |
+
nsfw_marker = image_metadata.get("NSFW") # This can be "", None, or other weird things.
|
94 |
+
if nsfw_marker is not None and nsfw_marker.lower() != "unlikely":
|
95 |
+
# Skip NSFW images.
|
96 |
+
status.skipped_nsfw.append(full_filename)
|
97 |
+
continue
|
98 |
+
|
99 |
+
# Try to get a prompt/caption from the metadata or the prompt mapping.
|
100 |
+
image_prompt = image_metadata.get("caption", prompt_mapping.get(fname, ""))
|
101 |
+
|
102 |
+
# Load image:
|
103 |
+
img = Image.open(full_filename).convert("RGB")
|
104 |
+
img_width = img.size[0]
|
105 |
+
img_height = img.size[1]
|
106 |
+
img_size = min(img.size[0], img.size[1])
|
107 |
+
if img_size < min_image_size or max(img_width, img_height) > max_image_size:
|
108 |
+
status.skipped_size.append(full_filename)
|
109 |
+
continue
|
110 |
+
|
111 |
+
# We re-initialize the detector every time because it has a habit of triggering weird race conditions.
|
112 |
+
empty, annotated, faces_before_filtering, faces_after_filtering = generate_annotation(
|
113 |
+
img,
|
114 |
+
max_faces=5,
|
115 |
+
min_face_size_pixels=min_face_size_pixels,
|
116 |
+
return_annotation_data=True
|
117 |
+
)
|
118 |
+
if faces_before_filtering == 0:
|
119 |
+
# Skip images with no faces.
|
120 |
+
status.skipped_noface.append(full_filename)
|
121 |
+
continue
|
122 |
+
if faces_after_filtering == 0:
|
123 |
+
# Skip images with no faces large enough
|
124 |
+
status.skipped_smallface.append(full_filename)
|
125 |
+
continue
|
126 |
+
|
127 |
+
Image.fromarray(empty).save(output_filename)
|
128 |
+
if annotation_filename:
|
129 |
+
Image.fromarray(annotated).save(annotation_filename)
|
130 |
+
|
131 |
+
# See https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md for the training file format.
|
132 |
+
# prompt.json
|
133 |
+
# a JSONL file with {"source": "source/0.jpg", "target": "target/0.jpg", "prompt": "..."}.
|
134 |
+
# a source/xxxxx.jpg or source/xxxx.png file for each of the inputs.
|
135 |
+
# a target/xxxxx.jpg for each of the outputs.
|
136 |
+
pout.write(json.dumps({
|
137 |
+
"source": os.path.join(output_directory, fname),
|
138 |
+
"target": full_filename,
|
139 |
+
"prompt": image_prompt,
|
140 |
+
}) + "\n")
|
141 |
+
pout.flush()
|
142 |
+
status.success.append(full_filename)
|
143 |
+
|
144 |
+
# We do save every 100 iterations, but it's good to save on completion, too.
|
145 |
+
with open(status_filename, 'wt') as fout:
|
146 |
+
json.dump(status.__dict__, fout)
|
147 |
+
|
148 |
+
pout.close()
|
149 |
+
print("Done!")
|
150 |
+
print(f"{len(status.success)} images added to dataset.")
|
151 |
+
print(f"{len(status.skipped_size)} images rejected for size.")
|
152 |
+
print(f"{len(status.skipped_smallface)} images rejected for having faces too small.")
|
153 |
+
print(f"{len(status.skipped_noface)} images rejected for not having faces.")
|
154 |
+
print(f"{len(status.skipped_nsfw)} images rejected for NSFW.")
|
155 |
+
|
156 |
+
|
157 |
+
if __name__ == "__main__":
|
158 |
+
if len(sys.argv) >= 3 and "-h" not in sys.argv:
|
159 |
+
prompt_jsonl = sys.argv[1]
|
160 |
+
in_glob = sys.argv[2] # Should probably be in a directory called "target/*.jpg".
|
161 |
+
output_dir = sys.argv[3] # Should probably be a directory called "source".
|
162 |
+
annotation_dir = ""
|
163 |
+
if len(sys.argv) > 4:
|
164 |
+
annotation_dir = sys.argv[4]
|
165 |
+
main("generate_face_poses_checkpoint.json", prompt_jsonl, in_glob, output_dir, annotation_dir)
|
166 |
+
else:
|
167 |
+
print(f"""Usage:
|
168 |
+
python {sys.argv[0]} prompt.jsonl target/*.jpg source/ [annotated/]
|
169 |
+
source and target are slightly confusing in this context. We are writing the image names to prompt.jsonl, so
|
170 |
+
the naming system has to be consistent with what ControlNet expects. In ControlNet, the source is the input and
|
171 |
+
target is the output. We are generating source images from targets in this application, so the second argument
|
172 |
+
should be a folder full of images. The third argument should be 'source', where the images should be places.
|
173 |
+
Optionally, an 'annotated' directory can be provided. Augmented images will be placed here.
|
174 |
+
|
175 |
+
A checkpoint file named 'generate_face_poses_checkpoint.json' will be created in the place where the script is
|
176 |
+
run. If a run is cancelled, it can be resumed from this checkpoint.
|
177 |
+
|
178 |
+
If invoking the script from bash, do not forget to enclose globs with quotes. Example usage:
|
179 |
+
`python ./tool_generate_face_poses.py ./face_prompt.jsonl "/home/josephcatrambone/training_data/data-mscoco/images/train2017/*" /home/josephcatrambone/training_data/data-mscoco/images/source_2017/`
|
180 |
+
""")
|
train_laion_face.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from share import *
|
2 |
+
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from laion_face_dataset import LaionDataset
|
6 |
+
from cldm.logger import ImageLogger
|
7 |
+
from cldm.model import create_model, load_state_dict
|
8 |
+
|
9 |
+
|
10 |
+
# Configs
|
11 |
+
resume_path = './models/controlnet_sd21_laion_face.ckpt'
|
12 |
+
batch_size = 4
|
13 |
+
logger_freq = 2500
|
14 |
+
learning_rate = 1e-5
|
15 |
+
sd_locked = True
|
16 |
+
only_mid_control = False
|
17 |
+
|
18 |
+
|
19 |
+
# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
|
20 |
+
model = create_model('./models/cldm_v21.yaml').cpu()
|
21 |
+
model.load_state_dict(load_state_dict(resume_path, location='cpu'))
|
22 |
+
model.learning_rate = learning_rate
|
23 |
+
model.sd_locked = sd_locked
|
24 |
+
model.only_mid_control = only_mid_control
|
25 |
+
|
26 |
+
|
27 |
+
# Save every so often:
|
28 |
+
ckpt_callback = pl.callbacks.ModelCheckpoint(
|
29 |
+
dirpath="./checkpoints/",
|
30 |
+
filename="ckpt_controlnet_sd21_{epoch}_{step}_{loss}",
|
31 |
+
monitor='train/loss_simple_step',
|
32 |
+
save_top_k=5,
|
33 |
+
every_n_train_steps=5000,
|
34 |
+
save_last=True,
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
# Misc
|
39 |
+
dataset = LaionDataset()
|
40 |
+
dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
|
41 |
+
logger = ImageLogger(batch_frequency=logger_freq)
|
42 |
+
trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger, ckpt_callback])
|
43 |
+
|
44 |
+
|
45 |
+
# Train!
|
46 |
+
trainer.fit(model, dataloader)
|
train_laion_face_sd15.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from share import *
|
2 |
+
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from laion_face_dataset import LaionDataset
|
6 |
+
from cldm.logger import ImageLogger
|
7 |
+
from cldm.model import create_model, load_state_dict
|
8 |
+
|
9 |
+
|
10 |
+
# Configs
|
11 |
+
resume_path = './models/controlnet_sd15_laion_face.ckpt'
|
12 |
+
batch_size = 8
|
13 |
+
logger_freq = 2500
|
14 |
+
learning_rate = 1e-5
|
15 |
+
sd_locked = True
|
16 |
+
only_mid_control = False
|
17 |
+
|
18 |
+
# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
|
19 |
+
model = create_model('./models/cldm_v15.yaml').cpu()
|
20 |
+
model.load_state_dict(load_state_dict(resume_path, location='cpu'))
|
21 |
+
model.learning_rate = learning_rate
|
22 |
+
model.sd_locked = sd_locked
|
23 |
+
model.only_mid_control = only_mid_control
|
24 |
+
|
25 |
+
# Save every so often:
|
26 |
+
ckpt_callback = pl.callbacks.ModelCheckpoint(
|
27 |
+
dirpath="./checkpoints/",
|
28 |
+
filename="controlnet_sd15_laion_face_{epoch}_{step}_{loss}.ckpt",
|
29 |
+
monitor='train/loss_simple_step',
|
30 |
+
save_top_k=5,
|
31 |
+
every_n_train_steps=5000,
|
32 |
+
save_last=True,
|
33 |
+
)
|
34 |
+
|
35 |
+
# Misc
|
36 |
+
dataset = LaionDataset()
|
37 |
+
dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
|
38 |
+
logger = ImageLogger(batch_frequency=logger_freq)
|
39 |
+
trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger, ckpt_callback])
|
40 |
+
|
41 |
+
# Train!
|
42 |
+
trainer.fit(model, dataloader)
|