MohamedRashad commited on
Commit
f2d6ac6
1 Parent(s): 5204b6c
Files changed (3) hide show
  1. app.py +123 -0
  2. live_preview_helpers.py +166 -0
  3. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import gradio as gr
3
+ from gradio_client import Client
4
+ import json
5
+ import torch
6
+ from diffusers import FluxPipeline, AutoencoderKL
7
+ from live_preview_helpers import flux_pipe_call_that_returns_an_iterable_of_images
8
+ import spaces
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device)
13
+ good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
14
+ # pipe.enable_sequential_cpu_offload()
15
+ pipe.vae.enable_slicing()
16
+ pipe.vae.enable_tiling()
17
+ # pipe.to(torch.float16)
18
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
19
+
20
+ llm_client = Client("Qwen/Qwen2.5-72B-Instruct")
21
+ # t2i_client = Client("black-forest-labs/FLUX.1-dev")
22
+ # t2i_client = Client("black-forest-labs/FLUX.1-schnell")
23
+
24
+ ds = load_dataset("MohamedRashad/FinePersonas-Lite", split="train")
25
+
26
+ prompt_template = """Generate a character with this persona description: {persona_description}
27
+ In a world with this description: {world_description}
28
+
29
+ Write the character in json format with the following fields:
30
+ - name: The name of the character
31
+ - background: The background of the character
32
+ - appearance: The appearance of the character
33
+ - personality: The personality of the character
34
+ - skills_and_abilities: The skills and abilities of the character
35
+ - goals: The goals of the character
36
+ - conflicts: The conflicts of the character
37
+ - backstory: The backstory of the character
38
+ - current_situation: The current situation of the character
39
+ - spoken_lines: The spoken lines of the character (list of strings)
40
+
41
+ Don't write anything else except the character description in json format and don't include '```'.
42
+ """
43
+
44
+ world_description_prompt = "Generate a unique and random world description (Don't Write anything else except the world description)."
45
+
46
+ def get_random_world_description():
47
+ result = llm_client.predict(
48
+ query=world_description_prompt,
49
+ history=[],
50
+ system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
51
+ api_name="/model_chat",
52
+ )
53
+ return result[1][0][-1]
54
+
55
+ def get_random_persona_description():
56
+ return ds.shuffle().select([100])[0]["persona"]
57
+
58
+ @spaces.GPU()
59
+ def generate_character(world_description, persona_description, progress=gr.Progress(track_tqdm=True)):
60
+ result = llm_client.predict(
61
+ query=prompt_template.format(
62
+ persona_description=persona_description, world_description=world_description
63
+ ),
64
+ history=[],
65
+ system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
66
+ api_name="/model_chat",
67
+ )
68
+ output = json.loads(result[1][0][-1])
69
+ print("Character generated")
70
+ print(json.dumps(output, indent=4))
71
+
72
+ for image in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
73
+ prompt=output["appearance"],
74
+ guidance_scale=3.5,
75
+ num_inference_steps=28,
76
+ width=1024,
77
+ height=1024,
78
+ generator=torch.Generator("cpu").manual_seed(0),
79
+ output_type="pil",
80
+ good_vae=good_vae,
81
+ ):
82
+ yield image, output
83
+ print("Character and image generated")
84
+
85
+
86
+ with gr.Blocks(title="Character Generator") as app:
87
+ with gr.Column():
88
+ gr.HTML("<center><h1>Character Generator</h1></center>")
89
+ gr.HTML(
90
+ "<center><h3>Generate a character with a persona description and a world description.</h3></center>"
91
+ )
92
+ with gr.Column():
93
+ with gr.Row():
94
+ world_description = gr.Textbox(lines=10, label="World Description", scale=4)
95
+ persona_description = gr.Textbox(lines=10, label="Persona Description", value=get_random_persona_description(), scale=1)
96
+ with gr.Row():
97
+ random_world_button = gr.Button(value="Get Random World Description", variant="secondary", scale=1)
98
+ submit_button = gr.Button(value="Generate Interesting Character!", variant="primary", scale=5)
99
+ random_persona_button = gr.Button(value="Get Random Persona Description", variant="secondary", scale=1)
100
+ with gr.Row():
101
+ character_image = gr.Image(label="Character Image")
102
+ character = gr.JSON(label="Character Description")
103
+
104
+ examples = gr.Examples(
105
+ [
106
+ "In a world where magic is real and dragons roam the skies, a group of adventurers set out to find the legendary sword of the dragon king.",
107
+ "Welcome to Aethoria, a vast and mysterious realm where the laws of physics bend to the will of ancient magic. This world is comprised of countless floating islands suspended in an endless sky, each one a unique ecosystem teeming with life and secrets. The islands of Aethoria range from lush, verdant jungles to barren, crystalline deserts. Some are no larger than a city block, while others span hundreds of miles. Connecting these disparate landmasses are shimmering bridges of pure energy, and those brave enough to venture off the beaten path can find hidden portals that instantly transport them across great distances. Aethoria's inhabitants are as diverse as its landscapes. Humans coexist with ethereal beings of light, rock-skinned giants, and shapeshifting creatures that defy classification. Ancient ruins dot the islands, hinting at long-lost civilizations and forgotten technologies that blur the line between science and sorcery. The world is powered by Aether, a mystical substance that flows through everything. Those who can harness its power become formidable mages, capable of manipulating reality itself. However, Aether is a finite resource, and its scarcity has led to conflicts between the various factions vying for control. In the skies between the islands, magnificent airships sail on currents of magic, facilitating trade and exploration. Pirates and sky raiders lurk in the cloudy depths, always on the lookout for unsuspecting prey. Deep beneath the floating lands lies the Undervoid, a dark and treacherous realm filled with nightmarish creatures and untold riches. Only the bravest adventurers dare to plumb its depths, and fewer still return to tell the tale. As an ever-present threat, the Chaos Storms rage at the edges of the known world, threatening to consume everything in their path. It falls to the heroes of Aethoria to uncover the secrets of their world and find a way to push back the encroaching darkness before it's too late. In Aethoria, every island holds a story, every creature has a secret, and every adventure could change the fate of this wondrous, imperiled world.",
108
+ "In a world from my imagination, there is a city called 'Orakis'. floating in the sky on pillars of pure light. The walls of the city are made of crystal glass, constantly reflecting the colors of dawn and dusk, giving it an eternal celestial glow. The buildings breathe and change their shapes according to the seasons—they grow in spring, strengthen in summer, and begin to fade in autumn until they become mist in winter.",
109
+ ],
110
+ world_description,
111
+ )
112
+
113
+ submit_button.click(
114
+ generate_character, [world_description, persona_description], outputs=[character_image, character]
115
+ )
116
+ random_world_button.click(
117
+ get_random_world_description, outputs=[world_description]
118
+ )
119
+ random_persona_button.click(
120
+ get_random_persona_description, outputs=[persona_description]
121
+ )
122
+
123
+ app.queue().launch(share=True)
live_preview_helpers.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from diffusers import FluxPipeline, AutoencoderTiny, FlowMatchEulerDiscreteScheduler
4
+ from typing import Any, Dict, List, Optional, Union
5
+
6
+ # Helper functions
7
+ def calculate_shift(
8
+ image_seq_len,
9
+ base_seq_len: int = 256,
10
+ max_seq_len: int = 4096,
11
+ base_shift: float = 0.5,
12
+ max_shift: float = 1.16,
13
+ ):
14
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
15
+ b = base_shift - m * base_seq_len
16
+ mu = image_seq_len * m + b
17
+ return mu
18
+
19
+ def retrieve_timesteps(
20
+ scheduler,
21
+ num_inference_steps: Optional[int] = None,
22
+ device: Optional[Union[str, torch.device]] = None,
23
+ timesteps: Optional[List[int]] = None,
24
+ sigmas: Optional[List[float]] = None,
25
+ **kwargs,
26
+ ):
27
+ if timesteps is not None and sigmas is not None:
28
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
29
+ if timesteps is not None:
30
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
31
+ timesteps = scheduler.timesteps
32
+ num_inference_steps = len(timesteps)
33
+ elif sigmas is not None:
34
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
35
+ timesteps = scheduler.timesteps
36
+ num_inference_steps = len(timesteps)
37
+ else:
38
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
39
+ timesteps = scheduler.timesteps
40
+ return timesteps, num_inference_steps
41
+
42
+ # FLUX pipeline function
43
+ @torch.inference_mode()
44
+ def flux_pipe_call_that_returns_an_iterable_of_images(
45
+ self,
46
+ prompt: Union[str, List[str]] = None,
47
+ prompt_2: Optional[Union[str, List[str]]] = None,
48
+ height: Optional[int] = None,
49
+ width: Optional[int] = None,
50
+ num_inference_steps: int = 28,
51
+ timesteps: List[int] = None,
52
+ guidance_scale: float = 3.5,
53
+ num_images_per_prompt: Optional[int] = 1,
54
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
55
+ latents: Optional[torch.FloatTensor] = None,
56
+ prompt_embeds: Optional[torch.FloatTensor] = None,
57
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
58
+ output_type: Optional[str] = "pil",
59
+ return_dict: bool = True,
60
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
61
+ max_sequence_length: int = 512,
62
+ good_vae: Optional[Any] = None,
63
+ ):
64
+ height = height or self.default_sample_size * self.vae_scale_factor
65
+ width = width or self.default_sample_size * self.vae_scale_factor
66
+
67
+ # 1. Check inputs
68
+ self.check_inputs(
69
+ prompt,
70
+ prompt_2,
71
+ height,
72
+ width,
73
+ prompt_embeds=prompt_embeds,
74
+ pooled_prompt_embeds=pooled_prompt_embeds,
75
+ max_sequence_length=max_sequence_length,
76
+ )
77
+
78
+ self._guidance_scale = guidance_scale
79
+ self._joint_attention_kwargs = joint_attention_kwargs
80
+ self._interrupt = False
81
+
82
+ # 2. Define call parameters
83
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
84
+ device = self._execution_device
85
+
86
+ # 3. Encode prompt
87
+ lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
88
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
89
+ prompt=prompt,
90
+ prompt_2=prompt_2,
91
+ prompt_embeds=prompt_embeds,
92
+ pooled_prompt_embeds=pooled_prompt_embeds,
93
+ device=device,
94
+ num_images_per_prompt=num_images_per_prompt,
95
+ max_sequence_length=max_sequence_length,
96
+ lora_scale=lora_scale,
97
+ )
98
+ # 4. Prepare latent variables
99
+ num_channels_latents = self.transformer.config.in_channels // 4
100
+ latents, latent_image_ids = self.prepare_latents(
101
+ batch_size * num_images_per_prompt,
102
+ num_channels_latents,
103
+ height,
104
+ width,
105
+ prompt_embeds.dtype,
106
+ device,
107
+ generator,
108
+ latents,
109
+ )
110
+ # 5. Prepare timesteps
111
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
112
+ image_seq_len = latents.shape[1]
113
+ mu = calculate_shift(
114
+ image_seq_len,
115
+ self.scheduler.config.base_image_seq_len,
116
+ self.scheduler.config.max_image_seq_len,
117
+ self.scheduler.config.base_shift,
118
+ self.scheduler.config.max_shift,
119
+ )
120
+ timesteps, num_inference_steps = retrieve_timesteps(
121
+ self.scheduler,
122
+ num_inference_steps,
123
+ device,
124
+ timesteps,
125
+ sigmas,
126
+ mu=mu,
127
+ )
128
+ self._num_timesteps = len(timesteps)
129
+
130
+ # Handle guidance
131
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
132
+
133
+ # 6. Denoising loop
134
+ for i, t in enumerate(timesteps):
135
+ if self.interrupt:
136
+ continue
137
+
138
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
139
+
140
+ noise_pred = self.transformer(
141
+ hidden_states=latents,
142
+ timestep=timestep / 1000,
143
+ guidance=guidance,
144
+ pooled_projections=pooled_prompt_embeds,
145
+ encoder_hidden_states=prompt_embeds,
146
+ txt_ids=text_ids,
147
+ img_ids=latent_image_ids,
148
+ joint_attention_kwargs=self.joint_attention_kwargs,
149
+ return_dict=False,
150
+ )[0]
151
+ # Yield intermediate result
152
+ latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
153
+ latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
154
+ image = self.vae.decode(latents_for_image, return_dict=False)[0]
155
+ yield self.image_processor.postprocess(image, output_type=output_type)[0]
156
+
157
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
158
+ torch.cuda.empty_cache()
159
+
160
+ # Final image using good_vae
161
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
162
+ latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
163
+ image = good_vae.decode(latents, return_dict=False)[0]
164
+ self.maybe_free_model_hooks()
165
+ torch.cuda.empty_cache()
166
+ yield self.image_processor.postprocess(image, output_type=output_type)[0]
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ spaces
2
+ torch
3
+ diffusers
4
+ gradio_client
5
+ datasets
6
+ numpy