czl commited on
Commit
8ffeacd
1 Parent(s): f288d63

test img2img

Browse files
Files changed (3) hide show
  1. app.py +79 -46
  2. requirements.txt +9 -2
  3. tools/synth.py +935 -0
app.py CHANGED
@@ -1,49 +1,72 @@
 
 
1
  import gradio as gr
2
  import numpy as np
3
- import random
4
- from diffusers import DiffusionPipeline
5
  import torch
6
 
 
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
8
 
9
  if torch.cuda.is_available():
10
  torch.cuda.max_memory_allocated(device=device)
11
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
12
- pipe.enable_xformers_memory_efficient_attention()
13
- pipe = pipe.to(device)
14
- else:
15
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
16
- pipe = pipe.to(device)
 
 
 
 
 
 
 
17
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
  MAX_IMAGE_SIZE = 1024
20
 
21
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  if randomize_seed:
24
  seed = random.randint(0, MAX_SEED)
25
-
26
  generator = torch.Generator().manual_seed(seed)
27
-
28
  image = pipe(
29
- prompt = prompt,
30
- negative_prompt = negative_prompt,
31
- guidance_scale = guidance_scale,
32
- num_inference_steps = num_inference_steps,
33
- width = width,
34
- height = height,
35
- generator = generator
36
- ).images[0]
37
-
 
38
  return image
39
 
 
40
  examples = [
41
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
42
  "An astronaut riding a green horse",
43
  "A delicious ceviche cheesecake slice",
44
  ]
45
 
46
- css="""
47
  #col-container {
48
  margin: 0 auto;
49
  max-width: 520px;
@@ -56,15 +79,17 @@ else:
56
  power_device = "CPU"
57
 
58
  with gr.Blocks(css=css) as demo:
59
-
60
  with gr.Column(elem_id="col-container"):
61
- gr.Markdown(f"""
 
62
  # Text-to-Image Gradio Template
63
  Currently running on {power_device}.
64
- """)
65
-
 
66
  with gr.Row():
67
-
68
  prompt = gr.Text(
69
  label="Prompt",
70
  show_label=False,
@@ -72,20 +97,21 @@ with gr.Blocks(css=css) as demo:
72
  placeholder="Enter your prompt",
73
  container=False,
74
  )
75
-
 
76
  run_button = gr.Button("Run", scale=0)
77
-
78
  result = gr.Image(label="Result", show_label=False)
79
 
80
  with gr.Accordion("Advanced Settings", open=False):
81
-
82
  negative_prompt = gr.Text(
83
  label="Negative prompt",
84
  max_lines=1,
85
  placeholder="Enter a negative prompt",
86
  visible=False,
87
  )
88
-
89
  seed = gr.Slider(
90
  label="Seed",
91
  minimum=0,
@@ -93,11 +119,11 @@ with gr.Blocks(css=css) as demo:
93
  step=1,
94
  value=0,
95
  )
96
-
97
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
98
-
99
  with gr.Row():
100
-
101
  width = gr.Slider(
102
  label="Width",
103
  minimum=256,
@@ -105,7 +131,7 @@ with gr.Blocks(css=css) as demo:
105
  step=32,
106
  value=512,
107
  )
108
-
109
  height = gr.Slider(
110
  label="Height",
111
  minimum=256,
@@ -113,9 +139,9 @@ with gr.Blocks(css=css) as demo:
113
  step=32,
114
  value=512,
115
  )
116
-
117
  with gr.Row():
118
-
119
  guidance_scale = gr.Slider(
120
  label="Guidance scale",
121
  minimum=0.0,
@@ -123,7 +149,7 @@ with gr.Blocks(css=css) as demo:
123
  step=0.1,
124
  value=0.0,
125
  )
126
-
127
  num_inference_steps = gr.Slider(
128
  label="Number of inference steps",
129
  minimum=1,
@@ -131,16 +157,23 @@ with gr.Blocks(css=css) as demo:
131
  step=1,
132
  value=2,
133
  )
134
-
135
- gr.Examples(
136
- examples = examples,
137
- inputs = [prompt]
138
- )
139
 
140
  run_button.click(
141
- fn = infer,
142
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
143
- outputs = [result]
 
 
 
 
 
 
 
 
 
 
144
  )
145
 
146
- demo.queue().launch()
 
1
+ import random
2
+
3
  import gradio as gr
4
  import numpy as np
 
 
5
  import torch
6
 
7
+ from tools import synth
8
+
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ model_path = "runwayml/stable-diffusion-v1-5"
11
 
12
  if torch.cuda.is_available():
13
  torch.cuda.max_memory_allocated(device=device)
14
+ pipe = synth.pipe_img(
15
+ model_path=model_path,
16
+ device=device,
17
+ use_torchcompile=False,
18
+ use_safetensors=True,
19
+ )
20
+ else:
21
+ pipe = synth.pipe_img(
22
+ model_path=model_path,
23
+ device=device,
24
+ use_torchcompile=False,
25
+ use_safetensors=True,
26
+ )
27
 
28
  MAX_SEED = np.iinfo(np.int32).max
29
  MAX_IMAGE_SIZE = 1024
30
 
31
+
32
+ def infer(
33
+ input_image,
34
+ prompt,
35
+ negative_prompt,
36
+ seed,
37
+ randomize_seed,
38
+ width,
39
+ height,
40
+ guidance_scale,
41
+ num_inference_steps,
42
+ ):
43
 
44
  if randomize_seed:
45
  seed = random.randint(0, MAX_SEED)
46
+
47
  generator = torch.Generator().manual_seed(seed)
48
+
49
  image = pipe(
50
+ prompt=prompt,
51
+ negative_prompt=negative_prompt,
52
+ guidance_scale=guidance_scale,
53
+ num_inference_steps=num_inference_steps,
54
+ width=width,
55
+ height=height,
56
+ generator=generator,
57
+ image=input_image,
58
+ ).images[0]
59
+
60
  return image
61
 
62
+
63
  examples = [
64
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
65
  "An astronaut riding a green horse",
66
  "A delicious ceviche cheesecake slice",
67
  ]
68
 
69
+ css = """
70
  #col-container {
71
  margin: 0 auto;
72
  max-width: 520px;
 
79
  power_device = "CPU"
80
 
81
  with gr.Blocks(css=css) as demo:
82
+
83
  with gr.Column(elem_id="col-container"):
84
+ gr.Markdown(
85
+ f"""
86
  # Text-to-Image Gradio Template
87
  Currently running on {power_device}.
88
+ """
89
+ )
90
+
91
  with gr.Row():
92
+
93
  prompt = gr.Text(
94
  label="Prompt",
95
  show_label=False,
 
97
  placeholder="Enter your prompt",
98
  container=False,
99
  )
100
+ input_image = gr.Image(type="pil", label="Input Image")
101
+
102
  run_button = gr.Button("Run", scale=0)
103
+
104
  result = gr.Image(label="Result", show_label=False)
105
 
106
  with gr.Accordion("Advanced Settings", open=False):
107
+
108
  negative_prompt = gr.Text(
109
  label="Negative prompt",
110
  max_lines=1,
111
  placeholder="Enter a negative prompt",
112
  visible=False,
113
  )
114
+
115
  seed = gr.Slider(
116
  label="Seed",
117
  minimum=0,
 
119
  step=1,
120
  value=0,
121
  )
122
+
123
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
124
+
125
  with gr.Row():
126
+
127
  width = gr.Slider(
128
  label="Width",
129
  minimum=256,
 
131
  step=32,
132
  value=512,
133
  )
134
+
135
  height = gr.Slider(
136
  label="Height",
137
  minimum=256,
 
139
  step=32,
140
  value=512,
141
  )
142
+
143
  with gr.Row():
144
+
145
  guidance_scale = gr.Slider(
146
  label="Guidance scale",
147
  minimum=0.0,
 
149
  step=0.1,
150
  value=0.0,
151
  )
152
+
153
  num_inference_steps = gr.Slider(
154
  label="Number of inference steps",
155
  minimum=1,
 
157
  step=1,
158
  value=2,
159
  )
160
+
161
+ gr.Examples(examples=examples, inputs=[prompt])
 
 
 
162
 
163
  run_button.click(
164
+ fn=infer,
165
+ inputs=[
166
+ input_image,
167
+ prompt,
168
+ negative_prompt,
169
+ seed,
170
+ randomize_seed,
171
+ width,
172
+ height,
173
+ guidance_scale,
174
+ num_inference_steps,
175
+ ],
176
+ outputs=[result],
177
  )
178
 
179
+ demo.queue().launch()
requirements.txt CHANGED
@@ -1,6 +1,13 @@
1
  accelerate
2
  diffusers
3
  invisible_watermark
4
- torch
 
 
5
  transformers
6
- xformers
 
 
 
 
 
 
1
  accelerate
2
  diffusers
3
  invisible_watermark
4
+ torch==2.1.2
5
+ torchaudio==2.1.2
6
+ torchvision==0.16.2
7
  transformers
8
+ xformers==0.0.23.post1
9
+ DeepCache
10
+ pandas
11
+ numpy
12
+ torchmetrics[image]
13
+ gradio
tools/synth.py ADDED
@@ -0,0 +1,935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helper scripts for generating synthetic images using diffusion model.
3
+
4
+ Functions:
5
+ - get_top_misclassified
6
+ - get_class_list
7
+ - generateClassPairs
8
+ - outputDirectory
9
+ - pipe_img
10
+ - createPrompts
11
+ - interpolatePrompts
12
+ - slerp
13
+ - get_middle_elements
14
+ - remove_middle
15
+ - genClassImg
16
+ - getMetadata
17
+ - groupbyInterpolation
18
+ - ungroupInterpolation
19
+ - groupAllbyInterpolation
20
+ - getPairIndices
21
+ - generateImagesFromDataset
22
+ - generateTrace
23
+ """
24
+
25
+ import json
26
+ import os
27
+
28
+ import numpy as np
29
+ import pandas as pd
30
+ import torch
31
+ from DeepCache import DeepCacheSDHelper
32
+ from diffusers import (
33
+ LMSDiscreteScheduler,
34
+ StableDiffusionImg2ImgPipeline,
35
+ )
36
+ from torch import nn
37
+ from torchmetrics.functional.image import structural_similarity_index_measure as ssim
38
+ from torchvision import transforms
39
+
40
+
41
+ def get_top_misclassified(val_classifier_json):
42
+ """
43
+ Retrieves the top misclassified classes from a validation classifier JSON file.
44
+
45
+ Args:
46
+ val_classifier_json (str): The path to the validation classifier JSON file.
47
+
48
+ Returns:
49
+ dict: A dictionary containing the top misclassified classes, where the keys are the class names
50
+ and the values are the number of misclassifications.
51
+ """
52
+ with open(val_classifier_json) as f:
53
+ val_output = json.load(f)
54
+ val_metrics_df = pd.DataFrame.from_dict(
55
+ val_output["val_metrics_details"], orient="index"
56
+ )
57
+ class_dict = dict()
58
+ for k, v in val_metrics_df["top_n_classes"].items():
59
+ class_dict[k] = v
60
+ return class_dict
61
+
62
+
63
+ def get_class_list(val_classifier_json):
64
+ """
65
+ Retrieves the list of classes from the given validation classifier JSON file.
66
+
67
+ Args:
68
+ val_classifier_json (str): The path to the validation classifier JSON file.
69
+
70
+ Returns:
71
+ list: A sorted list of class names extracted from the JSON file.
72
+ """
73
+ with open(val_classifier_json, "r") as f:
74
+ data = json.load(f)
75
+ return sorted(list(data["val_metrics_details"].keys()))
76
+
77
+
78
+ def generateClassPairs(val_classifier_json):
79
+ """
80
+ Generate pairs of misclassified classes from the given validation classifier JSON.
81
+
82
+ Args:
83
+ val_classifier_json (str): The path to the validation classifier JSON file.
84
+
85
+ Returns:
86
+ list: A sorted list of pairs of misclassified classes.
87
+ """
88
+ pairs = set()
89
+ misclassified_classes = get_top_misclassified(val_classifier_json)
90
+ for key, value in misclassified_classes.items():
91
+ for v in value:
92
+ pairs.add(tuple(sorted([key, v])))
93
+ return sorted(list(pairs))
94
+
95
+
96
+ def outputDirectory(class_pairs, synth_path, metadata_path):
97
+ """
98
+ Creates the output directory structure for the synthesized data.
99
+
100
+ Args:
101
+ class_pairs (list): A list of class pairs.
102
+ synth_path (str): The path to the directory where the synthesized data will be stored.
103
+ metadata_path (str): The path to the directory where the metadata will be stored.
104
+
105
+ Returns:
106
+ None
107
+ """
108
+ for id in class_pairs:
109
+ class_folder = f"{synth_path}/{id}"
110
+ if not (os.path.exists(class_folder)):
111
+ os.makedirs(class_folder)
112
+ if not (os.path.exists(metadata_path)):
113
+ os.makedirs(metadata_path)
114
+ print("Info: Output directory ready.")
115
+
116
+
117
+ def pipe_img(
118
+ model_path,
119
+ device="cuda",
120
+ apply_optimization=True,
121
+ use_torchcompile=False,
122
+ ci_cb=(5, 1),
123
+ use_safetensors=None,
124
+ cpu_offload=False,
125
+ scheduler=None,
126
+ ):
127
+ """
128
+ Creates and returns an image-to-image pipeline for stable diffusion.
129
+
130
+ Args:
131
+ model_path (str): The path to the pretrained model.
132
+ device (str, optional): The device to use for computation. Defaults to "cuda".
133
+ apply_optimization (bool, optional): Whether to apply optimization techniques. Defaults to True.
134
+ use_torchcompile (bool, optional): Whether to use torchcompile for model compilation. Defaults to False.
135
+ ci_cb (tuple, optional): A tuple containing the cache interval and cache branch ID. Defaults to (5, 1).
136
+ use_safetensors (bool, optional): Whether to use safetensors. Defaults to None.
137
+ cpu_offload (bool, optional): Whether to enable CPU offloading. Defaults to False.
138
+ scheduler (LMSDiscreteScheduler, optional): The scheduler for the pipeline. Defaults to None.
139
+
140
+ Returns:
141
+ StableDiffusionImg2ImgPipeline: The image-to-image pipeline for stable diffusion.
142
+ """
143
+ ###############################
144
+ # Reference:
145
+ # Akimov, R. (2024) Images Interpolation with Stable Diffusion - Hugging Face Open-Source AI Cookbook. Available at: https://huggingface.co/learn/cookbook/en/stable_diffusion_interpolation (Accessed: 4 June 2024).
146
+ ###############################
147
+ if scheduler is None:
148
+ scheduler = LMSDiscreteScheduler(
149
+ beta_start=0.00085,
150
+ beta_end=0.012,
151
+ beta_schedule="scaled_linear",
152
+ num_train_timesteps=1000,
153
+ steps_offset=1,
154
+ )
155
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
156
+ model_path,
157
+ scheduler=scheduler,
158
+ torch_dtype=torch.float32,
159
+ use_safetensors=use_safetensors,
160
+ safety_checker=None,
161
+ ).to(device)
162
+ if cpu_offload:
163
+ pipe.enable_model_cpu_offload()
164
+ if apply_optimization:
165
+ # tomesd.apply_patch(pipe, ratio=0.5)
166
+ helper = DeepCacheSDHelper(pipe=pipe)
167
+ cache_interval, cache_branch_id = ci_cb
168
+ helper.set_params(
169
+ cache_interval=cache_interval, cache_branch_id=cache_branch_id
170
+ ) # lower is faster but lower quality
171
+ helper.enable()
172
+ pipe.enable_xformers_memory_efficient_attention()
173
+ if use_torchcompile:
174
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
175
+ return pipe
176
+
177
+
178
+ def createPrompts(
179
+ class_name_pairs,
180
+ prompt_structure=None,
181
+ use_default_negative_prompt=False,
182
+ negative_prompt=None,
183
+ ):
184
+ """
185
+ Create prompts for image generation.
186
+
187
+ Args:
188
+ class_name_pairs (list): A list of two class names.
189
+ prompt_structure (str, optional): The structure of the prompt. Defaults to "a photo of a <class_name>".
190
+ use_default_negative_prompt (bool, optional): Whether to use the default negative prompt. Defaults to False.
191
+ negative_prompt (str, optional): The negative prompt to steer the generation away from certain features.
192
+
193
+ Returns:
194
+ tuple: A tuple containing two lists - prompts and negative_prompts.
195
+ prompts (list): Text prompts that describe the desired output image.
196
+ negative_prompts (list): Negative prompts that can be used to steer the generation away from certain features.
197
+ """
198
+ if prompt_structure is None:
199
+ prompt_structure = "a photo of a <class_name>"
200
+ elif "<class_name>" not in prompt_structure:
201
+ raise ValueError(
202
+ "The prompt structure must contain the <class_name> placeholder."
203
+ )
204
+ if use_default_negative_prompt:
205
+ default_negative_prompt = (
206
+ "blurry image, disfigured, deformed, distorted, cartoon, drawings"
207
+ )
208
+ negative_prompt = default_negative_prompt
209
+
210
+ class1 = class_name_pairs[0]
211
+ class2 = class_name_pairs[1]
212
+ prompt1 = prompt_structure.replace("<class_name>", class1)
213
+ prompt2 = prompt_structure.replace("<class_name>", class2)
214
+ prompts = [prompt1, prompt2]
215
+ if negative_prompt is None:
216
+ print("Info: Negative prompt not provided, returning as None.")
217
+ return prompts, None
218
+ else:
219
+ # Negative prompts that can be used to steer the generation away from certain features.
220
+ negative_prompts = [negative_prompt] * len(prompts)
221
+ return prompts, negative_prompts
222
+
223
+
224
+ def interpolatePrompts(
225
+ prompts,
226
+ pipeline,
227
+ num_interpolation_steps,
228
+ sample_mid_interpolation,
229
+ remove_n_middle=0,
230
+ device="cuda",
231
+ ):
232
+ """
233
+ Interpolates prompts by generating intermediate embeddings between pairs of prompts.
234
+
235
+ Args:
236
+ prompts (List[str]): A list of prompts to be interpolated.
237
+ pipeline: The pipeline object containing the tokenizer and text encoder.
238
+ num_interpolation_steps (int): The number of interpolation steps between each pair of prompts.
239
+ sample_mid_interpolation (int): The number of intermediate embeddings to sample from the middle of the interpolated prompts.
240
+ remove_n_middle (int, optional): The number of middle embeddings to remove from the interpolated prompts. Defaults to 0.
241
+ device (str, optional): The device to run the interpolation on. Defaults to "cuda".
242
+
243
+ Returns:
244
+ interpolated_prompt_embeds (torch.Tensor): The interpolated prompt embeddings.
245
+ prompt_metadata (dict): Metadata about the interpolation process, including similarity scores and nearest class information.
246
+
247
+ e.g. if num_interpolation_steps = 10, sample_mid_interpolation = 6, remove_n_middle = 2
248
+ Interpolated: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
249
+ Sampled: [2, 3, 4, 5, 6, 7]
250
+ Removed: x x
251
+ Returns: [2, 3, 6, 7]
252
+ """
253
+
254
+ ###############################
255
+ # Reference:
256
+ # Akimov, R. (2024) Images Interpolation with Stable Diffusion - Hugging Face Open-Source AI Cookbook. Available at: https://huggingface.co/learn/cookbook/en/stable_diffusion_interpolation (Accessed: 4 June 2024).
257
+ ###############################
258
+
259
+ def slerp(v0, v1, num, t0=0, t1=1):
260
+ """
261
+ Performs spherical linear interpolation between two vectors.
262
+
263
+ Args:
264
+ v0 (torch.Tensor): The starting vector.
265
+ v1 (torch.Tensor): The ending vector.
266
+ num (int): The number of interpolation points.
267
+ t0 (float, optional): The starting time. Defaults to 0.
268
+ t1 (float, optional): The ending time. Defaults to 1.
269
+
270
+ Returns:
271
+ torch.Tensor: The interpolated vectors.
272
+
273
+ """
274
+ ###############################
275
+ # Reference:
276
+ # Karpathy, A. (2022) hacky stablediffusion code for generating videos, Gist. Available at: https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355 (Accessed: 4 June 2024).
277
+ ###############################
278
+ v0 = v0.detach().cpu().numpy()
279
+ v1 = v1.detach().cpu().numpy()
280
+
281
+ def interpolation(t, v0, v1, DOT_THRESHOLD=0.9995):
282
+ """helper function to spherically interpolate two arrays v1 v2"""
283
+ dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
284
+ if np.abs(dot) > DOT_THRESHOLD:
285
+ v2 = (1 - t) * v0 + t * v1
286
+ else:
287
+ theta_0 = np.arccos(dot)
288
+ sin_theta_0 = np.sin(theta_0)
289
+ theta_t = theta_0 * t
290
+ sin_theta_t = np.sin(theta_t)
291
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
292
+ s1 = sin_theta_t / sin_theta_0
293
+ v2 = s0 * v0 + s1 * v1
294
+ return v2
295
+
296
+ t = np.linspace(t0, t1, num)
297
+
298
+ v3 = torch.tensor(np.array([interpolation(t[i], v0, v1) for i in range(num)]))
299
+
300
+ return v3
301
+
302
+ def get_middle_elements(lst, n):
303
+ """
304
+ Returns a tuple containing a sublist of the middle elements of the given list `lst` and a range of indices of those elements.
305
+
306
+ Args:
307
+ lst (list): The list from which to extract the middle elements.
308
+ n (int): The number of middle elements to extract.
309
+
310
+ Returns:
311
+ tuple: A tuple containing the sublist of middle elements and a range of indices.
312
+
313
+ Raises:
314
+ None
315
+
316
+ Examples:
317
+ lst = [1, 2, 3, 4, 5]
318
+ get_middle_elements(lst, 3)
319
+ ([2, 3, 4], range(2, 5))
320
+ """
321
+ if n % 2 == 0: # Even number of elements
322
+ middle_index = len(lst) // 2 - 1
323
+ start = middle_index - n // 2 + 1
324
+ end = middle_index + n // 2 + 1
325
+ return lst[start:end], range(start, end)
326
+ else: # Odd number of elements
327
+ middle_index = len(lst) // 2
328
+ start = middle_index - n // 2
329
+ end = middle_index + n // 2 + 1
330
+ return lst[start:end], range(start, end)
331
+
332
+ def remove_middle(data, n):
333
+ """
334
+ Remove the middle n elements from a list.
335
+
336
+ Args:
337
+ data (list): The input list.
338
+ n (int): The number of elements to remove from the middle of the list.
339
+
340
+ Returns:
341
+ list: The modified list with the middle n elements removed.
342
+
343
+ Raises:
344
+ ValueError: If n is negative or greater than the length of the list.
345
+
346
+ """
347
+ if n < 0 or n > len(data):
348
+ raise ValueError(
349
+ "Invalid value for n. It should be non-negative and less than half the list length"
350
+ )
351
+
352
+ # Find the middle index
353
+ middle = len(data) // 2
354
+
355
+ # Create slices to exclude the middle n elements
356
+ if n == 1:
357
+ return data[:middle] + data[middle + 1 :]
358
+ elif n % 2 == 0:
359
+ return data[: middle - n // 2] + data[middle + n // 2 :]
360
+ else:
361
+ return data[: middle - n // 2] + data[middle + n // 2 + 1 :]
362
+
363
+ batch_size = len(prompts)
364
+
365
+ # Tokenizing and encoding prompts into embeddings.
366
+ prompts_tokens = pipeline.tokenizer(
367
+ prompts,
368
+ padding="max_length",
369
+ max_length=pipeline.tokenizer.model_max_length,
370
+ truncation=True,
371
+ return_tensors="pt",
372
+ )
373
+ prompts_embeds = pipeline.text_encoder(prompts_tokens.input_ids.to(device))[0]
374
+
375
+ # Interpolating between embeddings pairs for the given number of interpolation steps.
376
+ interpolated_prompt_embeds = []
377
+
378
+ for i in range(batch_size - 1):
379
+ interpolated_prompt_embeds.append(
380
+ slerp(prompts_embeds[i], prompts_embeds[i + 1], num_interpolation_steps)
381
+ )
382
+
383
+ full_interpolated_prompt_embeds = interpolated_prompt_embeds[:]
384
+ interpolated_prompt_embeds[0], sample_range = get_middle_elements(
385
+ interpolated_prompt_embeds[0], sample_mid_interpolation
386
+ )
387
+
388
+ if remove_n_middle > 0:
389
+ interpolated_prompt_embeds[0] = remove_middle(
390
+ interpolated_prompt_embeds[0], remove_n_middle
391
+ )
392
+
393
+ prompt_metadata = dict()
394
+ similarity = nn.CosineSimilarity(dim=-1, eps=1e-6)
395
+ for i in range(num_interpolation_steps):
396
+ class1_sim = (
397
+ similarity(
398
+ full_interpolated_prompt_embeds[0][0],
399
+ full_interpolated_prompt_embeds[0][i],
400
+ )
401
+ .mean()
402
+ .item()
403
+ )
404
+ class2_sim = (
405
+ similarity(
406
+ full_interpolated_prompt_embeds[0][num_interpolation_steps - 1],
407
+ full_interpolated_prompt_embeds[0][i],
408
+ )
409
+ .mean()
410
+ .item()
411
+ )
412
+ relative_distance = class1_sim / (class1_sim + class2_sim)
413
+
414
+ prompt_metadata[i] = {
415
+ "selected": i in sample_range,
416
+ "similarity": {
417
+ "class1": class1_sim,
418
+ "class2": class2_sim,
419
+ "class1_relative_distance": relative_distance,
420
+ "class2_relative_distance": 1 - relative_distance,
421
+ },
422
+ "nearest_class": int(relative_distance < 0.5),
423
+ }
424
+
425
+ interpolated_prompt_embeds = torch.cat(interpolated_prompt_embeds, dim=0).to(device)
426
+ return interpolated_prompt_embeds, prompt_metadata
427
+
428
+
429
+ def genClassImg(
430
+ pipeline,
431
+ pos_embed,
432
+ neg_embed,
433
+ input_image,
434
+ generator,
435
+ latents,
436
+ num_imgs=1,
437
+ height=512,
438
+ width=512,
439
+ num_inference_steps=25,
440
+ guidance_scale=7.5,
441
+ ):
442
+ """
443
+ Generate class image using the given inputs.
444
+
445
+ Args:
446
+ pipeline: The pipeline object used for image generation.
447
+ pos_embed: The positive embedding for the class.
448
+ neg_embed: The negative embedding for the class (optional).
449
+ input_image: The input image for guidance (optional).
450
+ generator: The generator model used for image generation.
451
+ latents: The latent vectors used for image generation.
452
+ num_imgs: The number of images to generate (default is 1).
453
+ height: The height of the generated images (default is 512).
454
+ width: The width of the generated images (default is 512).
455
+ num_inference_steps: The number of inference steps for image generation (default is 25).
456
+ guidance_scale: The scale factor for guidance (default is 7.5).
457
+
458
+ Returns:
459
+ The generated class image.
460
+ """
461
+
462
+ if neg_embed is not None:
463
+ npe = neg_embed[None, ...]
464
+ else:
465
+ npe = None
466
+
467
+ return pipeline(
468
+ height=height,
469
+ width=width,
470
+ num_images_per_prompt=num_imgs,
471
+ prompt_embeds=pos_embed[None, ...],
472
+ negative_prompt_embeds=npe,
473
+ num_inference_steps=num_inference_steps,
474
+ guidance_scale=guidance_scale,
475
+ generator=generator,
476
+ latents=latents,
477
+ image=input_image,
478
+ ).images[0]
479
+
480
+
481
+ def getMetadata(
482
+ class_pairs,
483
+ path,
484
+ seed,
485
+ guidance_scale,
486
+ num_inference_steps,
487
+ num_interpolation_steps,
488
+ sample_mid_interpolation,
489
+ height,
490
+ width,
491
+ prompts,
492
+ negative_prompts,
493
+ pipeline,
494
+ prompt_metadata,
495
+ negative_prompt_metadata,
496
+ ssim_metadata=None,
497
+ save_json=True,
498
+ save_path=".",
499
+ ):
500
+ """
501
+ Generate metadata for the given parameters.
502
+
503
+ Args:
504
+ class_pairs (list): List of class pairs.
505
+ path (str): Path to the data.
506
+ seed (int): Seed value for randomization.
507
+ guidance_scale (float): Scale factor for guidance.
508
+ num_inference_steps (int): Number of inference steps.
509
+ num_interpolation_steps (int): Number of interpolation steps.
510
+ sample_mid_interpolation (bool): Flag to sample mid-interpolation.
511
+ height (int): Height of the image.
512
+ width (int): Width of the image.
513
+ prompts (list): List of prompts.
514
+ negative_prompts (list): List of negative prompts.
515
+ pipeline (object): Pipeline object.
516
+ prompt_metadata (dict): Metadata for prompts.
517
+ negative_prompt_metadata (dict): Metadata for negative prompts.
518
+ ssim_metadata (dict, optional): SSIM scores metadata. Defaults to None.
519
+ save_json (bool, optional): Flag to save metadata as JSON. Defaults to True.
520
+ save_path (str, optional): Path to save the JSON file. Defaults to ".".
521
+
522
+ Returns:
523
+ dict: Generated metadata.
524
+ """
525
+
526
+ metadata = dict()
527
+
528
+ metadata["class_pairs"] = class_pairs
529
+ metadata["path"] = path
530
+ metadata["seed"] = seed
531
+ metadata["params"] = {
532
+ "CFG": guidance_scale,
533
+ "inferenceSteps": num_inference_steps,
534
+ "interpolationSteps": num_interpolation_steps,
535
+ "sampleMidInterpolation": sample_mid_interpolation,
536
+ "height": height,
537
+ "width": width,
538
+ }
539
+ for i in range(len(prompts)):
540
+ metadata[f"prompt_text_{i}"] = prompts[i]
541
+ if negative_prompts is not None:
542
+ metadata[f"negative_prompt_text_{i}"] = negative_prompts[i]
543
+ metadata["pipe_config"] = dict(pipeline.config)
544
+ metadata["prompt_embed_similarity"] = prompt_metadata
545
+ metadata["negative_prompt_embed_similarity"] = negative_prompt_metadata
546
+ if ssim_metadata is not None:
547
+ print("Info: SSIM scores are available.")
548
+ metadata["ssim_scores"] = ssim_metadata
549
+ if save_json:
550
+ with open(
551
+ os.path.join(save_path, f"{'_'.join(i for i in class_pairs)}_{seed}.json"),
552
+ "w",
553
+ ) as f:
554
+ json.dump(metadata, f, indent=4)
555
+ return metadata
556
+
557
+
558
+ def groupbyInterpolation(dir_to_classfolder):
559
+ """
560
+ Group files in a directory by interpolation step.
561
+
562
+ Args:
563
+ dir_to_classfolder (str): The path to the directory containing the files.
564
+
565
+ Returns:
566
+ None
567
+ """
568
+ files = [
569
+ (f.split(sep="_")[1].split(sep=".")[0], os.path.join(dir_to_classfolder, f))
570
+ for f in os.listdir(dir_to_classfolder)
571
+ ]
572
+ # create a subfolder for each step of the interpolation
573
+ for interpolation_step, file_path in files:
574
+ new_dir = os.path.join(dir_to_classfolder, interpolation_step)
575
+ if not os.path.exists(new_dir):
576
+ os.makedirs(new_dir)
577
+ os.rename(file_path, os.path.join(new_dir, os.path.basename(file_path)))
578
+
579
+
580
+ def ungroupInterpolation(dir_to_classfolder):
581
+ """
582
+ Moves all files from subdirectories within `dir_to_classfolder` to `dir_to_classfolder` itself,
583
+ and then removes the subdirectories.
584
+
585
+ Args:
586
+ dir_to_classfolder (str): The path to the directory containing the subdirectories.
587
+
588
+ Returns:
589
+ None
590
+ """
591
+ for interpolation_step in os.listdir(dir_to_classfolder):
592
+ if os.path.isdir(os.path.join(dir_to_classfolder, interpolation_step)):
593
+ for f in os.listdir(os.path.join(dir_to_classfolder, interpolation_step)):
594
+ os.rename(
595
+ os.path.join(dir_to_classfolder, interpolation_step, f),
596
+ os.path.join(dir_to_classfolder, f),
597
+ )
598
+ os.rmdir(os.path.join(dir_to_classfolder, interpolation_step))
599
+
600
+
601
+ def groupAllbyInterpolation(
602
+ data_path,
603
+ group=True,
604
+ fn_group=groupbyInterpolation,
605
+ fn_ungroup=ungroupInterpolation,
606
+ ):
607
+ """
608
+ Group or ungroup all data classes by interpolation.
609
+
610
+ Args:
611
+ data_path (str): The path to the data.
612
+ group (bool, optional): Whether to group the data. Defaults to True.
613
+ fn_group (function, optional): The function to use for grouping. Defaults to groupbyInterpolation.
614
+ fn_ungroup (function, optional): The function to use for ungrouping. Defaults to ungroupInterpolation.
615
+ """
616
+ data_classes = sorted(os.listdir(data_path))
617
+ if group:
618
+ fn = fn_group
619
+ else:
620
+ fn = fn_ungroup
621
+ for c in data_classes:
622
+ c_path = os.path.join(data_path, c)
623
+ if os.path.isdir(c_path):
624
+ fn(c_path)
625
+ print(f"Processed {c}")
626
+
627
+
628
+ def getPairIndices(subset_len, total_pair_count=1, seed=None):
629
+ """
630
+ Generate pairs of indices for a given subset length.
631
+
632
+ Args:
633
+ subset_len (int): The length of the subset.
634
+ total_pair_count (int, optional): The total number of pairs to generate. Defaults to 1.
635
+ seed (int, optional): The seed value for the random number generator. Defaults to None.
636
+
637
+ Returns:
638
+ list: A list of pairs of indices.
639
+
640
+ """
641
+ rng = np.random.default_rng(seed)
642
+ group_size = (subset_len + total_pair_count - 1) // total_pair_count
643
+ numbers = list(range(subset_len))
644
+ numbers_selection = list(range(subset_len))
645
+ rng.shuffle(numbers)
646
+ for i in range(group_size - subset_len % group_size):
647
+ numbers.append(numbers_selection[i])
648
+ numbers = np.array(numbers)
649
+ groups = numbers[: group_size * total_pair_count].reshape(-1, group_size)
650
+ return groups.tolist()
651
+
652
+
653
+ def generateImagesFromDataset(
654
+ img_subsets,
655
+ class_iterables,
656
+ pipeline,
657
+ interpolated_prompt_embeds,
658
+ interpolated_negative_prompts_embeds,
659
+ num_inference_steps,
660
+ guidance_scale,
661
+ height=512,
662
+ width=512,
663
+ seed=None,
664
+ save_path=".",
665
+ class_pairs=("0", "1"),
666
+ save_image=True,
667
+ image_type="jpg",
668
+ interpolate_range="full",
669
+ device="cuda",
670
+ return_images=False,
671
+ ):
672
+ """
673
+ Generates images from a dataset using the given parameters.
674
+
675
+ Args:
676
+ img_subsets (dict): A dictionary containing image subsets for each class.
677
+ class_iterables (dict): A dictionary containing iterable objects for each class.
678
+ pipeline (object): The pipeline object used for image generation.
679
+ interpolated_prompt_embeds (list): A list of interpolated prompt embeddings.
680
+ interpolated_negative_prompts_embeds (list): A list of interpolated negative prompt embeddings.
681
+ num_inference_steps (int): The number of inference steps for image generation.
682
+ guidance_scale (float): The scale factor for guidance loss during image generation.
683
+ height (int, optional): The height of the generated images. Defaults to 512.
684
+ width (int, optional): The width of the generated images. Defaults to 512.
685
+ seed (int, optional): The seed value for random number generation. Defaults to None.
686
+ save_path (str, optional): The path to save the generated images. Defaults to ".".
687
+ class_pairs (tuple, optional): A tuple containing pairs of class identifiers. Defaults to ("0", "1").
688
+ save_image (bool, optional): Whether to save the generated images. Defaults to True.
689
+ image_type (str, optional): The file format of the saved images. Defaults to "jpg".
690
+ interpolate_range (str, optional): The range of interpolation for prompt embeddings.
691
+ Possible values are "full", "nearest", or "furthest". Defaults to "full".
692
+ device (str, optional): The device to use for image generation. Defaults to "cuda".
693
+ return_images (bool, optional): Whether to return the generated images. Defaults to False.
694
+
695
+ Returns:
696
+ dict or tuple: If return_images is True, returns a dictionary containing the generated images for each class and a dictionary containing the SSIM scores for each class and interpolation step.
697
+ If return_images is False, returns a dictionary containing the SSIM scores for each class and interpolation step.
698
+ """
699
+ if interpolate_range == "nearest":
700
+ nearest_half = True
701
+ furthest_half = False
702
+ elif interpolate_range == "furthest":
703
+ nearest_half = False
704
+ furthest_half = True
705
+ else:
706
+ nearest_half = False
707
+ furthest_half = False
708
+
709
+ if seed is None:
710
+ seed = torch.Generator().seed()
711
+ generator = torch.manual_seed(seed)
712
+ rng = np.random.default_rng(seed)
713
+ # Generating initial U-Net latent vectors from a random normal distribution.
714
+ latents = torch.randn(
715
+ (1, pipeline.unet.config.in_channels, height // 8, width // 8),
716
+ generator=generator,
717
+ ).to(device)
718
+
719
+ embed_len = len(interpolated_prompt_embeds)
720
+ embed_pairs = zip(interpolated_prompt_embeds, interpolated_negative_prompts_embeds)
721
+ embed_pairs_list = list(embed_pairs)
722
+ if return_images:
723
+ class_images = dict()
724
+ class_ssim = dict()
725
+
726
+ if nearest_half or furthest_half:
727
+ if nearest_half:
728
+ steps_range = (range(0, embed_len // 2), range(embed_len // 2, embed_len))
729
+ mutiplier = 2
730
+ elif furthest_half:
731
+ # uses opposite class of images of the text interpolation
732
+ steps_range = (range(embed_len // 2, embed_len), range(0, embed_len // 2))
733
+ mutiplier = 2
734
+ else:
735
+ steps_range = (range(embed_len), range(embed_len))
736
+ mutiplier = 1
737
+
738
+ for class_iter, class_id in enumerate(class_pairs):
739
+ if return_images:
740
+ class_images[class_id] = list()
741
+ class_ssim[class_id] = {
742
+ i: {"ssim_sum": 0, "ssim_count": 0, "ssim_avg": 0} for i in range(embed_len)
743
+ }
744
+ subset_len = len(img_subsets[class_id])
745
+ # to efficiently randomize the steps to interpolate for each image in the class, group_map is used
746
+ # group_map: index is the image id, element is the group id
747
+ # steps_range[class_iter] determines the range of steps to interpolate for the class,
748
+ # so the first half of the steps are for the first class and so on. range(0,7) and range(8,15) for 16 steps
749
+ # then the rest is to multiply the steps to cover the whole subset + remainder
750
+ group_map = (
751
+ list(steps_range[class_iter]) * mutiplier * (subset_len // embed_len + 1)
752
+ )
753
+ rng.shuffle(
754
+ group_map
755
+ ) # shuffle the steps to interpolate for each image, position in the group_map is mapped to the image id
756
+
757
+ iter_indices = class_iterables[class_id].pop()
758
+ # generate images for each image in the class, randomly selecting an interpolated step
759
+ for image_id in iter_indices:
760
+ img, trg = img_subsets[class_id][image_id]
761
+ input_image = img.unsqueeze(0)
762
+ interpolate_step = group_map[image_id]
763
+ prompt_embeds, negative_prompt_embeds = embed_pairs_list[interpolate_step]
764
+ generated_image = genClassImg(
765
+ pipeline,
766
+ prompt_embeds,
767
+ negative_prompt_embeds,
768
+ input_image,
769
+ generator,
770
+ latents,
771
+ num_imgs=1,
772
+ height=height,
773
+ width=width,
774
+ num_inference_steps=num_inference_steps,
775
+ guidance_scale=guidance_scale,
776
+ )
777
+ pred_image = transforms.ToTensor()(generated_image).unsqueeze(0)
778
+ ssim_score = ssim(pred_image, input_image).item()
779
+ class_ssim[class_id][interpolate_step]["ssim_sum"] += ssim_score
780
+ class_ssim[class_id][interpolate_step]["ssim_count"] += 1
781
+ if return_images:
782
+ class_images[class_id].append(generated_image)
783
+ if save_image:
784
+ if image_type == "jpg":
785
+ generated_image.save(
786
+ f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}",
787
+ format="JPEG",
788
+ quality=95,
789
+ )
790
+ elif image_type == "png":
791
+ generated_image.save(
792
+ f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}",
793
+ format="PNG",
794
+ )
795
+ else:
796
+ generated_image.save(
797
+ f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}"
798
+ )
799
+
800
+ # calculate ssim avg for the class
801
+ for i_step in range(embed_len):
802
+ if class_ssim[class_id][i_step]["ssim_count"] > 0:
803
+ class_ssim[class_id][i_step]["ssim_avg"] = (
804
+ class_ssim[class_id][i_step]["ssim_sum"]
805
+ / class_ssim[class_id][i_step]["ssim_count"]
806
+ )
807
+
808
+ if return_images:
809
+ return class_images, class_ssim
810
+ else:
811
+ return class_ssim
812
+
813
+
814
+ def generateTrace(
815
+ prompts,
816
+ img_subsets,
817
+ class_iterables,
818
+ interpolated_prompt_embeds,
819
+ interpolated_negative_prompts_embeds,
820
+ subset_indices,
821
+ seed=None,
822
+ save_path=".",
823
+ class_pairs=("0", "1"),
824
+ image_type="jpg",
825
+ interpolate_range="full",
826
+ save_prompt_embeds=False,
827
+ ):
828
+ """
829
+ Generate a trace dictionary containing information about the generated images.
830
+
831
+ Args:
832
+ prompts (list): List of prompt texts.
833
+ img_subsets (dict): Dictionary containing image subsets for each class.
834
+ class_iterables (dict): Dictionary containing iterable objects for each class.
835
+ interpolated_prompt_embeds (torch.Tensor): Tensor containing interpolated prompt embeddings.
836
+ interpolated_negative_prompts_embeds (torch.Tensor): Tensor containing interpolated negative prompt embeddings.
837
+ subset_indices (dict): Dictionary containing indices of subsets for each class.
838
+ seed (int, optional): Seed value for random number generation. Defaults to None.
839
+ save_path (str, optional): Path to save the generated images. Defaults to ".".
840
+ class_pairs (tuple, optional): Tuple containing class pairs. Defaults to ("0", "1").
841
+ image_type (str, optional): Type of the generated images. Defaults to "jpg".
842
+ interpolate_range (str, optional): Range of interpolation. Defaults to "full".
843
+ save_prompt_embeds (bool, optional): Flag to save prompt embeddings. Defaults to False.
844
+
845
+ Returns:
846
+ dict: Trace dictionary containing information about the generated images.
847
+ """
848
+ trace_dict = {
849
+ "class_pairs": list(),
850
+ "class_id": list(),
851
+ "image_id": list(),
852
+ "interpolation_step": list(),
853
+ "embed_len": list(),
854
+ "pos_prompt_text": list(),
855
+ "neg_prompt_text": list(),
856
+ "input_file_path": list(),
857
+ "output_file_path": list(),
858
+ "input_prompts_embed": list(),
859
+ }
860
+
861
+ if interpolate_range == "nearest":
862
+ nearest_half = True
863
+ furthest_half = False
864
+ elif interpolate_range == "furthest":
865
+ nearest_half = False
866
+ furthest_half = True
867
+ else:
868
+ nearest_half = False
869
+ furthest_half = False
870
+
871
+ if seed is None:
872
+ seed = torch.Generator().seed()
873
+ rng = np.random.default_rng(seed)
874
+
875
+ embed_len = len(interpolated_prompt_embeds)
876
+ embed_pairs = zip(
877
+ interpolated_prompt_embeds.cpu().numpy(),
878
+ interpolated_negative_prompts_embeds.cpu().numpy(),
879
+ )
880
+ embed_pairs_list = list(embed_pairs)
881
+
882
+ if nearest_half or furthest_half:
883
+ if nearest_half:
884
+ steps_range = (range(0, embed_len // 2), range(embed_len // 2, embed_len))
885
+ mutiplier = 2
886
+ elif furthest_half:
887
+ # uses opposite class of images of the text interpolation
888
+ steps_range = (range(embed_len // 2, embed_len), range(0, embed_len // 2))
889
+ mutiplier = 2
890
+ else:
891
+ steps_range = (range(embed_len), range(embed_len))
892
+ mutiplier = 1
893
+
894
+ for class_iter, class_id in enumerate(class_pairs):
895
+
896
+ subset_len = len(img_subsets[class_id])
897
+ # to efficiently randomize the steps to interpolate for each image in the class, group_map is used
898
+ # group_map: index is the image id, element is the group id
899
+ # steps_range[class_iter] determines the range of steps to interpolate for the class,
900
+ # so the first half of the steps are for the first class and so on. range(0,7) and range(8,15) for 16 steps
901
+ # then the rest is to multiply the steps to cover the whole subset + remainder
902
+ group_map = (
903
+ list(steps_range[class_iter]) * mutiplier * (subset_len // embed_len + 1)
904
+ )
905
+ rng.shuffle(
906
+ group_map
907
+ ) # shuffle the steps to interpolate for each image, position in the group_map is mapped to the image id
908
+
909
+ iter_indices = class_iterables[class_id].pop()
910
+ # generate images for each image in the class, randomly selecting an interpolated step
911
+ for image_id in iter_indices:
912
+ class_ds = img_subsets[class_id]
913
+ interpolate_step = group_map[image_id]
914
+ sample_count = subset_indices[class_id][0] + image_id
915
+ input_file = os.path.normpath(class_ds.dataset.samples[sample_count][0])
916
+ pos_prompt = prompts[0]
917
+ neg_prompt = prompts[1]
918
+ output_file = f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}"
919
+ if save_prompt_embeds:
920
+ input_prompts_embed = embed_pairs_list[interpolate_step]
921
+ else:
922
+ input_prompts_embed = None
923
+
924
+ trace_dict["class_pairs"].append(class_pairs)
925
+ trace_dict["class_id"].append(class_id)
926
+ trace_dict["image_id"].append(image_id)
927
+ trace_dict["interpolation_step"].append(interpolate_step)
928
+ trace_dict["embed_len"].append(embed_len)
929
+ trace_dict["pos_prompt_text"].append(pos_prompt)
930
+ trace_dict["neg_prompt_text"].append(neg_prompt)
931
+ trace_dict["input_file_path"].append(input_file)
932
+ trace_dict["output_file_path"].append(output_file)
933
+ trace_dict["input_prompts_embed"].append(input_prompts_embed)
934
+
935
+ return trace_dict