SunderAli17 commited on
Commit
9b06330
1 Parent(s): 5ac1546

Create cli.py

Browse files
Files changed (1) hide show
  1. flux/cli.py +259 -0
flux/cli.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ from dataclasses import dataclass
5
+ from glob import iglob
6
+
7
+ import torch
8
+ from einops import rearrange
9
+ from fire import Fire
10
+ from PIL import ExifTags, Image
11
+ from transformers import pipeline
12
+
13
+ from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
14
+ from flux.util import (
15
+ configs,
16
+ load_ae,
17
+ load_clip,
18
+ load_flow_model,
19
+ load_t5,
20
+ )
21
+
22
+ NSFW_THRESHOLD = 0.85
23
+
24
+
25
+ @dataclass
26
+ class SamplingOptions:
27
+ prompt: str
28
+ width: int
29
+ height: int
30
+ num_steps: int
31
+ guidance: float
32
+ seed: int
33
+
34
+
35
+ def parse_prompt(options: SamplingOptions) -> SamplingOptions:
36
+ user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
37
+ usage = (
38
+ "Usage: Either write your prompt directly, leave this field empty "
39
+ "to repeat the prompt or write a command starting with a slash:\n"
40
+ "- '/w <width>' will set the width of the generated image\n"
41
+ "- '/h <height>' will set the height of the generated image\n"
42
+ "- '/s <seed>' sets the next seed\n"
43
+ "- '/g <guidance>' sets the guidance (flux-dev only)\n"
44
+ "- '/n <steps>' sets the number of steps\n"
45
+ "- '/q' to quit"
46
+ )
47
+
48
+ while (prompt := input(user_question)).startswith("/"):
49
+ if prompt.startswith("/w"):
50
+ if prompt.count(" ") != 1:
51
+ print(f"Got invalid command '{prompt}'\n{usage}")
52
+ continue
53
+ _, width = prompt.split()
54
+ options.width = 16 * (int(width) // 16)
55
+ print(
56
+ f"Setting resolution to {options.width} x {options.height} "
57
+ f"({options.height * options.width / 1e6:.2f}MP)"
58
+ )
59
+ elif prompt.startswith("/h"):
60
+ if prompt.count(" ") != 1:
61
+ print(f"Got invalid command '{prompt}'\n{usage}")
62
+ continue
63
+ _, height = prompt.split()
64
+ options.height = 16 * (int(height) // 16)
65
+ print(
66
+ f"Setting resolution to {options.width} x {options.height} "
67
+ f"({options.height * options.width / 1e6:.2f}MP)"
68
+ )
69
+ elif prompt.startswith("/g"):
70
+ if prompt.count(" ") != 1:
71
+ print(f"Got invalid command '{prompt}'\n{usage}")
72
+ continue
73
+ _, guidance = prompt.split()
74
+ options.guidance = float(guidance)
75
+ print(f"Setting guidance to {options.guidance}")
76
+ elif prompt.startswith("/s"):
77
+ if prompt.count(" ") != 1:
78
+ print(f"Got invalid command '{prompt}'\n{usage}")
79
+ continue
80
+ _, seed = prompt.split()
81
+ options.seed = int(seed)
82
+ print(f"Setting seed to {options.seed}")
83
+ elif prompt.startswith("/n"):
84
+ if prompt.count(" ") != 1:
85
+ print(f"Got invalid command '{prompt}'\n{usage}")
86
+ continue
87
+ _, steps = prompt.split()
88
+ options.num_steps = int(steps)
89
+ print(f"Setting seed to {options.num_steps}")
90
+ elif prompt.startswith("/q"):
91
+ print("Quitting")
92
+ return None
93
+ else:
94
+ if not prompt.startswith("/h"):
95
+ print(f"Got invalid command '{prompt}'\n{usage}")
96
+ print(usage)
97
+ if prompt != "":
98
+ options.prompt = prompt
99
+ return options
100
+
101
+
102
+ @torch.inference_mode()
103
+ def main(
104
+ name: str = "flux-schnell",
105
+ width: int = 1360,
106
+ height: int = 768,
107
+ seed: int = None,
108
+ prompt: str = (
109
+ "a photo of a forest with mist swirling around the tree trunks. The word "
110
+ '"FLUX" is painted over it in big, red brush strokes with visible texture'
111
+ ),
112
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
113
+ num_steps: int = None,
114
+ loop: bool = False,
115
+ guidance: float = 3.5,
116
+ offload: bool = False,
117
+ output_dir: str = "output",
118
+ add_sampling_metadata: bool = True,
119
+ ):
120
+ """
121
+ Sample the flux model. Either interactively (set `--loop`) or run for a
122
+ single image.
123
+ Args:
124
+ name: Name of the model to load
125
+ height: height of the sample in pixels (should be a multiple of 16)
126
+ width: width of the sample in pixels (should be a multiple of 16)
127
+ seed: Set a seed for sampling
128
+ output_name: where to save the output image, `{idx}` will be replaced
129
+ by the index of the sample
130
+ prompt: Prompt used for sampling
131
+ device: Pytorch device
132
+ num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
133
+ loop: start an interactive session and sample multiple times
134
+ guidance: guidance value used for guidance distillation
135
+ add_sampling_metadata: Add the prompt to the image Exif metadata
136
+ """
137
+ nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
138
+
139
+ if name not in configs:
140
+ available = ", ".join(configs.keys())
141
+ raise ValueError(f"Got unknown model name: {name}, chose from {available}")
142
+
143
+ torch_device = torch.device(device)
144
+ if num_steps is None:
145
+ num_steps = 4 if name == "flux-schnell" else 50
146
+
147
+ # allow for packing and conversion to latent space
148
+ height = 16 * (height // 16)
149
+ width = 16 * (width // 16)
150
+
151
+ output_name = os.path.join(output_dir, "img_{idx}.jpg")
152
+ if not os.path.exists(output_dir):
153
+ os.makedirs(output_dir)
154
+ idx = 0
155
+ else:
156
+ fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]\.jpg$", fn)]
157
+ if len(fns) > 0:
158
+ idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
159
+ else:
160
+ idx = 0
161
+
162
+ # init all components
163
+ t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
164
+ clip = load_clip(torch_device)
165
+ model = load_flow_model(name, device="cpu" if offload else torch_device)
166
+ ae = load_ae(name, device="cpu" if offload else torch_device)
167
+
168
+ rng = torch.Generator(device="cpu")
169
+ opts = SamplingOptions(
170
+ prompt=prompt,
171
+ width=width,
172
+ height=height,
173
+ num_steps=num_steps,
174
+ guidance=guidance,
175
+ seed=seed,
176
+ )
177
+
178
+ if loop:
179
+ opts = parse_prompt(opts)
180
+
181
+ while opts is not None:
182
+ if opts.seed is None:
183
+ opts.seed = rng.seed()
184
+ print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
185
+ t0 = time.perf_counter()
186
+
187
+ # prepare input
188
+ x = get_noise(
189
+ 1,
190
+ opts.height,
191
+ opts.width,
192
+ device=torch_device,
193
+ dtype=torch.bfloat16,
194
+ seed=opts.seed,
195
+ )
196
+ opts.seed = None
197
+ if offload:
198
+ ae = ae.cpu()
199
+ torch.cuda.empty_cache()
200
+ t5, clip = t5.to(torch_device), clip.to(torch_device)
201
+ inp = prepare(t5, clip, x, prompt=opts.prompt)
202
+ timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
203
+
204
+ # offload TEs to CPU, load model to gpu
205
+ if offload:
206
+ t5, clip = t5.cpu(), clip.cpu()
207
+ torch.cuda.empty_cache()
208
+ model = model.to(torch_device)
209
+
210
+ # denoise initial noise
211
+ x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
212
+
213
+ # offload model, load autoencoder to gpu
214
+ if offload:
215
+ model.cpu()
216
+ torch.cuda.empty_cache()
217
+ ae.decoder.to(x.device)
218
+
219
+ # decode latents to pixel space
220
+ x = unpack(x.float(), opts.height, opts.width)
221
+ with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
222
+ x = ae.decode(x)
223
+ t1 = time.perf_counter()
224
+
225
+ fn = output_name.format(idx=idx)
226
+ print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
227
+ # bring into PIL format and save
228
+ x = x.clamp(-1, 1)
229
+ # x = embed_watermark(x.float())
230
+ x = rearrange(x[0], "c h w -> h w c")
231
+
232
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
233
+ nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]
234
+
235
+ if nsfw_score < NSFW_THRESHOLD:
236
+ exif_data = Image.Exif()
237
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
238
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
239
+ exif_data[ExifTags.Base.Model] = name
240
+ if add_sampling_metadata:
241
+ exif_data[ExifTags.Base.ImageDescription] = prompt
242
+ img.save(fn, exif=exif_data, quality=95, subsampling=0)
243
+ idx += 1
244
+ else:
245
+ print("Your generated image may contain NSFW content.")
246
+
247
+ if loop:
248
+ print("-" * 80)
249
+ opts = parse_prompt(opts)
250
+ else:
251
+ opts = None
252
+
253
+
254
+ def app():
255
+ Fire(main)
256
+
257
+
258
+ if __name__ == "__main__":
259
+ app()