Spaces:
Running
on
Zero
Running
on
Zero
MohamedRashad
commited on
Commit
•
f2d6ac6
1
Parent(s):
5204b6c
Add App
Browse files- app.py +123 -0
- live_preview_helpers.py +166 -0
- requirements.txt +6 -0
app.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
import gradio as gr
|
3 |
+
from gradio_client import Client
|
4 |
+
import json
|
5 |
+
import torch
|
6 |
+
from diffusers import FluxPipeline, AutoencoderKL
|
7 |
+
from live_preview_helpers import flux_pipe_call_that_returns_an_iterable_of_images
|
8 |
+
import spaces
|
9 |
+
|
10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
|
12 |
+
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device)
|
13 |
+
good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
|
14 |
+
# pipe.enable_sequential_cpu_offload()
|
15 |
+
pipe.vae.enable_slicing()
|
16 |
+
pipe.vae.enable_tiling()
|
17 |
+
# pipe.to(torch.float16)
|
18 |
+
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
|
19 |
+
|
20 |
+
llm_client = Client("Qwen/Qwen2.5-72B-Instruct")
|
21 |
+
# t2i_client = Client("black-forest-labs/FLUX.1-dev")
|
22 |
+
# t2i_client = Client("black-forest-labs/FLUX.1-schnell")
|
23 |
+
|
24 |
+
ds = load_dataset("MohamedRashad/FinePersonas-Lite", split="train")
|
25 |
+
|
26 |
+
prompt_template = """Generate a character with this persona description: {persona_description}
|
27 |
+
In a world with this description: {world_description}
|
28 |
+
|
29 |
+
Write the character in json format with the following fields:
|
30 |
+
- name: The name of the character
|
31 |
+
- background: The background of the character
|
32 |
+
- appearance: The appearance of the character
|
33 |
+
- personality: The personality of the character
|
34 |
+
- skills_and_abilities: The skills and abilities of the character
|
35 |
+
- goals: The goals of the character
|
36 |
+
- conflicts: The conflicts of the character
|
37 |
+
- backstory: The backstory of the character
|
38 |
+
- current_situation: The current situation of the character
|
39 |
+
- spoken_lines: The spoken lines of the character (list of strings)
|
40 |
+
|
41 |
+
Don't write anything else except the character description in json format and don't include '```'.
|
42 |
+
"""
|
43 |
+
|
44 |
+
world_description_prompt = "Generate a unique and random world description (Don't Write anything else except the world description)."
|
45 |
+
|
46 |
+
def get_random_world_description():
|
47 |
+
result = llm_client.predict(
|
48 |
+
query=world_description_prompt,
|
49 |
+
history=[],
|
50 |
+
system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
|
51 |
+
api_name="/model_chat",
|
52 |
+
)
|
53 |
+
return result[1][0][-1]
|
54 |
+
|
55 |
+
def get_random_persona_description():
|
56 |
+
return ds.shuffle().select([100])[0]["persona"]
|
57 |
+
|
58 |
+
@spaces.GPU()
|
59 |
+
def generate_character(world_description, persona_description, progress=gr.Progress(track_tqdm=True)):
|
60 |
+
result = llm_client.predict(
|
61 |
+
query=prompt_template.format(
|
62 |
+
persona_description=persona_description, world_description=world_description
|
63 |
+
),
|
64 |
+
history=[],
|
65 |
+
system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
|
66 |
+
api_name="/model_chat",
|
67 |
+
)
|
68 |
+
output = json.loads(result[1][0][-1])
|
69 |
+
print("Character generated")
|
70 |
+
print(json.dumps(output, indent=4))
|
71 |
+
|
72 |
+
for image in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
|
73 |
+
prompt=output["appearance"],
|
74 |
+
guidance_scale=3.5,
|
75 |
+
num_inference_steps=28,
|
76 |
+
width=1024,
|
77 |
+
height=1024,
|
78 |
+
generator=torch.Generator("cpu").manual_seed(0),
|
79 |
+
output_type="pil",
|
80 |
+
good_vae=good_vae,
|
81 |
+
):
|
82 |
+
yield image, output
|
83 |
+
print("Character and image generated")
|
84 |
+
|
85 |
+
|
86 |
+
with gr.Blocks(title="Character Generator") as app:
|
87 |
+
with gr.Column():
|
88 |
+
gr.HTML("<center><h1>Character Generator</h1></center>")
|
89 |
+
gr.HTML(
|
90 |
+
"<center><h3>Generate a character with a persona description and a world description.</h3></center>"
|
91 |
+
)
|
92 |
+
with gr.Column():
|
93 |
+
with gr.Row():
|
94 |
+
world_description = gr.Textbox(lines=10, label="World Description", scale=4)
|
95 |
+
persona_description = gr.Textbox(lines=10, label="Persona Description", value=get_random_persona_description(), scale=1)
|
96 |
+
with gr.Row():
|
97 |
+
random_world_button = gr.Button(value="Get Random World Description", variant="secondary", scale=1)
|
98 |
+
submit_button = gr.Button(value="Generate Interesting Character!", variant="primary", scale=5)
|
99 |
+
random_persona_button = gr.Button(value="Get Random Persona Description", variant="secondary", scale=1)
|
100 |
+
with gr.Row():
|
101 |
+
character_image = gr.Image(label="Character Image")
|
102 |
+
character = gr.JSON(label="Character Description")
|
103 |
+
|
104 |
+
examples = gr.Examples(
|
105 |
+
[
|
106 |
+
"In a world where magic is real and dragons roam the skies, a group of adventurers set out to find the legendary sword of the dragon king.",
|
107 |
+
"Welcome to Aethoria, a vast and mysterious realm where the laws of physics bend to the will of ancient magic. This world is comprised of countless floating islands suspended in an endless sky, each one a unique ecosystem teeming with life and secrets. The islands of Aethoria range from lush, verdant jungles to barren, crystalline deserts. Some are no larger than a city block, while others span hundreds of miles. Connecting these disparate landmasses are shimmering bridges of pure energy, and those brave enough to venture off the beaten path can find hidden portals that instantly transport them across great distances. Aethoria's inhabitants are as diverse as its landscapes. Humans coexist with ethereal beings of light, rock-skinned giants, and shapeshifting creatures that defy classification. Ancient ruins dot the islands, hinting at long-lost civilizations and forgotten technologies that blur the line between science and sorcery. The world is powered by Aether, a mystical substance that flows through everything. Those who can harness its power become formidable mages, capable of manipulating reality itself. However, Aether is a finite resource, and its scarcity has led to conflicts between the various factions vying for control. In the skies between the islands, magnificent airships sail on currents of magic, facilitating trade and exploration. Pirates and sky raiders lurk in the cloudy depths, always on the lookout for unsuspecting prey. Deep beneath the floating lands lies the Undervoid, a dark and treacherous realm filled with nightmarish creatures and untold riches. Only the bravest adventurers dare to plumb its depths, and fewer still return to tell the tale. As an ever-present threat, the Chaos Storms rage at the edges of the known world, threatening to consume everything in their path. It falls to the heroes of Aethoria to uncover the secrets of their world and find a way to push back the encroaching darkness before it's too late. In Aethoria, every island holds a story, every creature has a secret, and every adventure could change the fate of this wondrous, imperiled world.",
|
108 |
+
"In a world from my imagination, there is a city called 'Orakis'. floating in the sky on pillars of pure light. The walls of the city are made of crystal glass, constantly reflecting the colors of dawn and dusk, giving it an eternal celestial glow. The buildings breathe and change their shapes according to the seasons—they grow in spring, strengthen in summer, and begin to fade in autumn until they become mist in winter.",
|
109 |
+
],
|
110 |
+
world_description,
|
111 |
+
)
|
112 |
+
|
113 |
+
submit_button.click(
|
114 |
+
generate_character, [world_description, persona_description], outputs=[character_image, character]
|
115 |
+
)
|
116 |
+
random_world_button.click(
|
117 |
+
get_random_world_description, outputs=[world_description]
|
118 |
+
)
|
119 |
+
random_persona_button.click(
|
120 |
+
get_random_persona_description, outputs=[persona_description]
|
121 |
+
)
|
122 |
+
|
123 |
+
app.queue().launch(share=True)
|
live_preview_helpers.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from diffusers import FluxPipeline, AutoencoderTiny, FlowMatchEulerDiscreteScheduler
|
4 |
+
from typing import Any, Dict, List, Optional, Union
|
5 |
+
|
6 |
+
# Helper functions
|
7 |
+
def calculate_shift(
|
8 |
+
image_seq_len,
|
9 |
+
base_seq_len: int = 256,
|
10 |
+
max_seq_len: int = 4096,
|
11 |
+
base_shift: float = 0.5,
|
12 |
+
max_shift: float = 1.16,
|
13 |
+
):
|
14 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
15 |
+
b = base_shift - m * base_seq_len
|
16 |
+
mu = image_seq_len * m + b
|
17 |
+
return mu
|
18 |
+
|
19 |
+
def retrieve_timesteps(
|
20 |
+
scheduler,
|
21 |
+
num_inference_steps: Optional[int] = None,
|
22 |
+
device: Optional[Union[str, torch.device]] = None,
|
23 |
+
timesteps: Optional[List[int]] = None,
|
24 |
+
sigmas: Optional[List[float]] = None,
|
25 |
+
**kwargs,
|
26 |
+
):
|
27 |
+
if timesteps is not None and sigmas is not None:
|
28 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
29 |
+
if timesteps is not None:
|
30 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
31 |
+
timesteps = scheduler.timesteps
|
32 |
+
num_inference_steps = len(timesteps)
|
33 |
+
elif sigmas is not None:
|
34 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
35 |
+
timesteps = scheduler.timesteps
|
36 |
+
num_inference_steps = len(timesteps)
|
37 |
+
else:
|
38 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
39 |
+
timesteps = scheduler.timesteps
|
40 |
+
return timesteps, num_inference_steps
|
41 |
+
|
42 |
+
# FLUX pipeline function
|
43 |
+
@torch.inference_mode()
|
44 |
+
def flux_pipe_call_that_returns_an_iterable_of_images(
|
45 |
+
self,
|
46 |
+
prompt: Union[str, List[str]] = None,
|
47 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
48 |
+
height: Optional[int] = None,
|
49 |
+
width: Optional[int] = None,
|
50 |
+
num_inference_steps: int = 28,
|
51 |
+
timesteps: List[int] = None,
|
52 |
+
guidance_scale: float = 3.5,
|
53 |
+
num_images_per_prompt: Optional[int] = 1,
|
54 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
55 |
+
latents: Optional[torch.FloatTensor] = None,
|
56 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
57 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
58 |
+
output_type: Optional[str] = "pil",
|
59 |
+
return_dict: bool = True,
|
60 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
61 |
+
max_sequence_length: int = 512,
|
62 |
+
good_vae: Optional[Any] = None,
|
63 |
+
):
|
64 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
65 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
66 |
+
|
67 |
+
# 1. Check inputs
|
68 |
+
self.check_inputs(
|
69 |
+
prompt,
|
70 |
+
prompt_2,
|
71 |
+
height,
|
72 |
+
width,
|
73 |
+
prompt_embeds=prompt_embeds,
|
74 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
75 |
+
max_sequence_length=max_sequence_length,
|
76 |
+
)
|
77 |
+
|
78 |
+
self._guidance_scale = guidance_scale
|
79 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
80 |
+
self._interrupt = False
|
81 |
+
|
82 |
+
# 2. Define call parameters
|
83 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
84 |
+
device = self._execution_device
|
85 |
+
|
86 |
+
# 3. Encode prompt
|
87 |
+
lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
|
88 |
+
prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
|
89 |
+
prompt=prompt,
|
90 |
+
prompt_2=prompt_2,
|
91 |
+
prompt_embeds=prompt_embeds,
|
92 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
93 |
+
device=device,
|
94 |
+
num_images_per_prompt=num_images_per_prompt,
|
95 |
+
max_sequence_length=max_sequence_length,
|
96 |
+
lora_scale=lora_scale,
|
97 |
+
)
|
98 |
+
# 4. Prepare latent variables
|
99 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
100 |
+
latents, latent_image_ids = self.prepare_latents(
|
101 |
+
batch_size * num_images_per_prompt,
|
102 |
+
num_channels_latents,
|
103 |
+
height,
|
104 |
+
width,
|
105 |
+
prompt_embeds.dtype,
|
106 |
+
device,
|
107 |
+
generator,
|
108 |
+
latents,
|
109 |
+
)
|
110 |
+
# 5. Prepare timesteps
|
111 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
112 |
+
image_seq_len = latents.shape[1]
|
113 |
+
mu = calculate_shift(
|
114 |
+
image_seq_len,
|
115 |
+
self.scheduler.config.base_image_seq_len,
|
116 |
+
self.scheduler.config.max_image_seq_len,
|
117 |
+
self.scheduler.config.base_shift,
|
118 |
+
self.scheduler.config.max_shift,
|
119 |
+
)
|
120 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
121 |
+
self.scheduler,
|
122 |
+
num_inference_steps,
|
123 |
+
device,
|
124 |
+
timesteps,
|
125 |
+
sigmas,
|
126 |
+
mu=mu,
|
127 |
+
)
|
128 |
+
self._num_timesteps = len(timesteps)
|
129 |
+
|
130 |
+
# Handle guidance
|
131 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
|
132 |
+
|
133 |
+
# 6. Denoising loop
|
134 |
+
for i, t in enumerate(timesteps):
|
135 |
+
if self.interrupt:
|
136 |
+
continue
|
137 |
+
|
138 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
139 |
+
|
140 |
+
noise_pred = self.transformer(
|
141 |
+
hidden_states=latents,
|
142 |
+
timestep=timestep / 1000,
|
143 |
+
guidance=guidance,
|
144 |
+
pooled_projections=pooled_prompt_embeds,
|
145 |
+
encoder_hidden_states=prompt_embeds,
|
146 |
+
txt_ids=text_ids,
|
147 |
+
img_ids=latent_image_ids,
|
148 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
149 |
+
return_dict=False,
|
150 |
+
)[0]
|
151 |
+
# Yield intermediate result
|
152 |
+
latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
153 |
+
latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
154 |
+
image = self.vae.decode(latents_for_image, return_dict=False)[0]
|
155 |
+
yield self.image_processor.postprocess(image, output_type=output_type)[0]
|
156 |
+
|
157 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
158 |
+
torch.cuda.empty_cache()
|
159 |
+
|
160 |
+
# Final image using good_vae
|
161 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
162 |
+
latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
|
163 |
+
image = good_vae.decode(latents, return_dict=False)[0]
|
164 |
+
self.maybe_free_model_hooks()
|
165 |
+
torch.cuda.empty_cache()
|
166 |
+
yield self.image_processor.postprocess(image, output_type=output_type)[0]
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
spaces
|
2 |
+
torch
|
3 |
+
diffusers
|
4 |
+
gradio_client
|
5 |
+
datasets
|
6 |
+
numpy
|