Spaces:
vilarin
/
Running on Zero

File size: 10,695 Bytes
0cffd40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
from diffusers import UNet2DConditionModel, AutoencoderKL, DDIMScheduler, AutoencoderTiny
from transformers import AutoTokenizer, CLIPTextModel, CLIPTextModelWithProjection
from accelerate import Accelerator
from huggingface_hub import hf_hub_download
import spaces
import gradio as gr    
import numpy as np
import torch
import time
import PIL

base = "stabilityai/stable-diffusion-xl-base-1.0"
repo_id = "tianweiy/DMD2"
subfolder = "model/sdxl/sdxl_cond999_8node_lr5e-7_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch_checkpoint_model_019000"
filename = "pytorch_model.bin"


class ModelWrapper:
    def __init__(self, model_id, checkpoint_path, precision, image_resolution, latent_resolution, num_train_timesteps, conditioning_timestep, num_step, revision, accelerator):
        super().__init__()
        torch.set_grad_enabled(False)
        
        self.DTYPE = getattr(torch, precision)
        self.device = accelerator.device

        self.tokenizer_one = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, use_fast=False)
        self.tokenizer_two = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, use_fast=False)

        self.text_encoder = SDXLTextEncoder(model_id, revision, accelerator, dtype=self.DTYPE)

        self.vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").float().to(self.device)
        self.vae_dtype = torch.float32

        self.tiny_vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=self.DTYPE).to(self.device)
        self.tiny_vae_dtype = self.DTYPE

        self.model = self.create_generator(model_id, checkpoint_path).to(dtype=self.DTYPE).to(self.device)

        self.accelerator = accelerator
        self.image_resolution = image_resolution
        self.latent_resolution = latent_resolution
        self.num_train_timesteps = num_train_timesteps
        self.vae_downsample_ratio = image_resolution // latent_resolution
        self.conditioning_timestep = conditioning_timestep

        self.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
        self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
        self.num_step = num_step

    def create_generator(self, model_id, checkpoint_path):
        generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
        state_dict = torch.load(checkpoint_path, map_location="cpu")
        generator.load_state_dict(state_dict, strict=True)
        generator.requires_grad_(False)
        return generator

    def build_condition_input(self, height, width):
        original_size = (height, width)
        target_size = (height, width)
        crop_top_left = (0, 0)

        add_time_ids = list(original_size + crop_top_left + target_size)
        add_time_ids = torch.tensor([add_time_ids], device=self.device, dtype=self.DTYPE)
        return add_time_ids

    def _encode_prompt(self, prompt):
        text_input_ids_one = self.tokenizer_one([prompt], padding="max_length", max_length=self.tokenizer_one.model_max_length, truncation=True, return_tensors="pt").input_ids
        text_input_ids_two = self.tokenizer_two([prompt], padding="max_length", max_length=self.tokenizer_two.model_max_length, truncation=True, return_tensors="pt").input_ids

        prompt_dict = {
            'text_input_ids_one': text_input_ids_one.unsqueeze(0).to(self.device),
            'text_input_ids_two': text_input_ids_two.unsqueeze(0).to(self.device)
        }
        return prompt_dict

    @staticmethod
    def _get_time():
        torch.cuda.synchronize()
        return time.time()

    def sample(self, noise, unet_added_conditions, prompt_embed, fast_vae_decode):
        alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)

        if self.num_step == 1:
            all_timesteps = [self.conditioning_timestep]
            step_interval = 0 
        elif self.num_step == 4:
            all_timesteps = [999, 749, 499, 249]
            step_interval = 250 
        else:
            raise NotImplementedError()
        
        DTYPE = prompt_embed.dtype
        
        for constant in all_timesteps:
            current_timesteps = torch.ones(len(prompt_embed), device=self.device, dtype=torch.long) * constant
            eval_images = self.model(noise, current_timesteps, prompt_embed, added_cond_kwargs=unet_added_conditions).sample

            eval_images = get_x0_from_noise(noise, eval_images, alphas_cumprod, current_timesteps).to(self.DTYPE)

            next_timestep = current_timesteps - step_interval 
            noise = self.scheduler.add_noise(eval_images, torch.randn_like(eval_images), next_timestep).to(DTYPE)  

        if fast_vae_decode:
            eval_images = self.tiny_vae.decode(eval_images.to(self.tiny_vae_dtype) / self.tiny_vae.config.scaling_factor, return_dict=False)[0]
        else:
            eval_images = self.vae.decode(eval_images.to(self.vae_dtype) / self.vae.config.scaling_factor, return_dict=False)[0]
        eval_images = ((eval_images + 1.0) * 127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1)
        return eval_images 
        
    @spaces.GPU(enable_queue=True)
    @torch.no_grad()
    def inference(self, prompt, seed, height, width, num_images, fast_vae_decode):
        print("Running model inference...")

        if seed == -1:
            seed = np.random.randint(0, 1000000)

        generator = torch.manual_seed(seed)

        add_time_ids = self.build_condition_input(height, width).repeat(num_images, 1)

        noise = torch.randn(num_images, 4, height // self.vae_downsample_ratio, width // self.vae_downsample_ratio, generator=generator).to(device=self.device, dtype=self.DTYPE) 

        prompt_inputs = self._encode_prompt(prompt)
        
        start_time = self._get_time()

        prompt_embeds, pooled_prompt_embeds = self.text_encoder(prompt_inputs)

        batch_prompt_embeds, batch_pooled_prompt_embeds = (
            prompt_embeds.repeat(num_images, 1, 1),
            pooled_prompt_embeds.repeat(num_images, 1, 1)
        )

        unet_added_conditions = {
            "time_ids": add_time_ids,
            "text_embeds": batch_pooled_prompt_embeds.squeeze(1)
        }

        eval_images = self.sample(noise=noise, unet_added_conditions=unet_added_conditions, prompt_embed=batch_prompt_embeds, fast_vae_decode=fast_vae_decode)

        end_time = self._get_time()

        output_image_list = [] 
        for image in eval_images:
            output_image_list.append(PIL.Image.fromarray(image.cpu().numpy()))

        return output_image_list, f"Run successfully in {(end_time-start_time):.2f} seconds"

def get_x0_from_noise(sample, model_output, alphas_cumprod, timestep):
    alpha_prod_t = alphas_cumprod[timestep].reshape(-1, 1, 1, 1)
    beta_prod_t = 1 - alpha_prod_t

    pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
    return pred_original_sample

class SDXLTextEncoder(torch.nn.Module):
    def __init__(self, model_id, revision, accelerator, dtype=torch.float32):
        super().__init__()

        self.text_encoder_one = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", revision=revision).to(accelerator.device).to(dtype=dtype)
        self.text_encoder_two = CLIPTextModelWithProjection.from_pretrained(model_id, subfolder="text_encoder_2", revision=revision).to(accelerator.device).to(dtype=dtype)

        self.accelerator = accelerator

    def forward(self, batch):
        text_input_ids_one = batch['text_input_ids_one'].to(self.accelerator.device).squeeze(1)
        text_input_ids_two = batch['text_input_ids_two'].to(self.accelerator.device).squeeze(1)
        prompt_embeds_list = []

        for text_input_ids, text_encoder in zip([text_input_ids_one, text_input_ids_two], [self.text_encoder_one, self.text_encoder_two]):
            prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True)

            pooled_prompt_embeds = prompt_embeds[0]

            prompt_embeds = prompt_embeds.hidden_states[-2]
            bs_embed, seq_len, _ = prompt_embeds.shape
            prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
            prompt_embeds_list.append(prompt_embeds)

        prompt_embeds = torch.cat(prompt_embeds_list, dim=-1)
        pooled_prompt_embeds = pooled_prompt_embeds.view(len(text_input_ids_one), -1)
        
        return prompt_embeds, pooled_prompt_embeds

def create_demo():
    TITLE = "# DMD2-SDXL Demo"
    model_id = "stabilityai/stable-diffusion-xl-base-1.0"
    checkpoint_path = hf_hub_download(repo_id=repo_id, subfolder=subfolder,filename=filename)
    precision = "float16"
    image_resolution = 1024
    latent_resolution = 128
    num_train_timesteps = 1000
    conditioning_timestep = 999
    num_step = 4
    revision = None

    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True 

    accelerator = Accelerator()

    model = ModelWrapper(model_id, checkpoint_path, precision, image_resolution, latent_resolution, num_train_timesteps, conditioning_timestep, num_step, revision, accelerator)

    with gr.Blocks() as demo:
        gr.Markdown(TITLE)
        with gr.Row():
            with gr.Column():
                prompt = gr.Text(value="An oil painting of two rabbits in the style of American Gothic, wearing the same clothes as in the original.", label="Prompt")
                run_button = gr.Button("Run")
                with gr.Accordion(label="Advanced options", open=True):
                    seed = gr.Slider(label="Seed", minimum=-1, maximum=1000000, step=1, value=0)
                    num_images = gr.Slider(label="Number of generated images", minimum=1, maximum=16, step=1, value=16)
                    fast_vae_decode = gr.Checkbox(label="Use Tiny VAE for faster decoding", value=True)
                    height = gr.Slider(label="Image Height", minimum=512, maximum=1536, step=64, value=1024)
                    width = gr.Slider(label="Image Width", minimum=512, maximum=1536, step=64, value=1024)
            with gr.Column():
                result = gr.Gallery(label="Generated Images", show_label=False, elem_id="gallery", height=1024)
                error_message = gr.Text(label="Job Status")

        inputs = [prompt, seed, height, width, num_images, fast_vae_decode]
        run_button.click(fn=model.inference, inputs=inputs, outputs=[result, error_message], concurrency_limit=1)
    return demo

if __name__ == "__main__":
    demo = create_demo()
    demo.queue(api_open=False)
    demo.launch(show_error=True, share=True)