amazonaws-la commited on
Commit
41548b6
1 Parent(s): a1777e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -34
app.py CHANGED
@@ -18,44 +18,12 @@ if not torch.cuda.is_available():
18
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
21
- MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
22
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
23
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
24
  ENABLE_REFINER = os.getenv("ENABLE_REFINER", "1") == "1"
25
 
26
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
27
- if torch.cuda.is_available():
28
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
29
- pipe = DiffusionPipeline.from_pretrained(
30
- "stabilityai/stable-diffusion-xl-base-1.0",
31
- vae=vae,
32
- torch_dtype=torch.float16,
33
- use_safetensors=True,
34
- variant="fp16",
35
- )
36
- if ENABLE_REFINER:
37
- refiner = DiffusionPipeline.from_pretrained(
38
- "stabilityai/stable-diffusion-xl-refiner-1.0",
39
- vae=vae,
40
- torch_dtype=torch.float16,
41
- use_safetensors=True,
42
- variant="fp16",
43
- )
44
-
45
- if ENABLE_CPU_OFFLOAD:
46
- pipe.enable_model_cpu_offload()
47
- if ENABLE_REFINER:
48
- refiner.enable_model_cpu_offload()
49
- else:
50
- pipe.to(device)
51
- if ENABLE_REFINER:
52
- refiner.to(device)
53
-
54
- if USE_TORCH_COMPILE:
55
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
56
- if ENABLE_REFINER:
57
- refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
58
-
59
 
60
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
61
  if randomize_seed:
@@ -80,7 +48,45 @@ def generate(
80
  num_inference_steps_base: int = 25,
81
  num_inference_steps_refiner: int = 25,
82
  apply_refiner: bool = False,
 
 
 
83
  ) -> PIL.Image.Image:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  generator = torch.Generator().manual_seed(seed)
85
 
86
  if not use_negative_prompt:
@@ -142,6 +148,9 @@ with gr.Blocks(css="style.css") as demo:
142
  visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
143
  )
144
  with gr.Group():
 
 
 
145
  with gr.Row():
146
  prompt = gr.Text(
147
  label="Prompt",
@@ -299,10 +308,13 @@ with gr.Blocks(css="style.css") as demo:
299
  num_inference_steps_base,
300
  num_inference_steps_refiner,
301
  apply_refiner,
 
 
 
302
  ],
303
  outputs=result,
304
  api_name="run",
305
  )
306
 
307
  if __name__ == "__main__":
308
- demo.queue(max_size=20).launch()
 
18
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
21
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1824"))
22
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
23
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
24
  ENABLE_REFINER = os.getenv("ENABLE_REFINER", "1") == "1"
25
 
26
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
29
  if randomize_seed:
 
48
  num_inference_steps_base: int = 25,
49
  num_inference_steps_refiner: int = 25,
50
  apply_refiner: bool = False,
51
+ model = 'stabilityai/stable-diffusion-xl-base-1.0',
52
+ vaecall = 'madebyollin/sdxl-vae-fp16-fix',
53
+ lora = 'pierroromeu/lora-trained-xl-folder',
54
  ) -> PIL.Image.Image:
55
+ if torch.cuda.is_available():
56
+ vae = AutoencoderKL.from_pretrained(vaecall, torch_dtype=torch.float16)
57
+ # Substitua o valor dinâmico antes de chamar from_pretrained
58
+ pipe_model_name = model
59
+ pipe = DiffusionPipeline.from_pretrained(
60
+ pipe_model_name,
61
+ vae=vae,
62
+ torch_dtype=torch.float16,
63
+ use_safetensors=True,
64
+ )
65
+ if ENABLE_REFINER:
66
+ refiner_model_name = model
67
+ refiner = DiffusionPipeline.from_pretrained(
68
+ refiner_model_name,
69
+ vae=vae,
70
+ torch_dtype=torch.float16,
71
+ use_safetensors=True,
72
+ )
73
+
74
+ pipe.load_lora_weights(lora)
75
+
76
+ if ENABLE_CPU_OFFLOAD:
77
+ pipe.enable_model_cpu_offload()
78
+ if ENABLE_REFINER:
79
+ refiner.enable_model_cpu_offload()
80
+ else:
81
+ pipe.to(device)
82
+ if ENABLE_REFINER:
83
+ refiner.to(device)
84
+
85
+ if USE_TORCH_COMPILE:
86
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
87
+ if ENABLE_REFINER:
88
+ refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
89
+
90
  generator = torch.Generator().manual_seed(seed)
91
 
92
  if not use_negative_prompt:
 
148
  visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
149
  )
150
  with gr.Group():
151
+ model = gr.Text(label='Modelo')
152
+ vaecall = gr.Text(label='VAE')
153
+ lora = gr.Text(label='LoRA')
154
  with gr.Row():
155
  prompt = gr.Text(
156
  label="Prompt",
 
308
  num_inference_steps_base,
309
  num_inference_steps_refiner,
310
  apply_refiner,
311
+ model,
312
+ vaecall,
313
+ lora,
314
  ],
315
  outputs=result,
316
  api_name="run",
317
  )
318
 
319
  if __name__ == "__main__":
320
+ demo.queue(max_size=20).launch()