amazonaws-la commited on
Commit
034a9f5
1 Parent(s): 8b0b6ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -25
app.py CHANGED
@@ -50,41 +50,22 @@ def generate(
50
  apply_refiner: bool = False,
51
  model = 'stabilityai/stable-diffusion-xl-base-1.0',
52
  vaecall = 'madebyollin/sdxl-vae-fp16-fix',
53
- lora = 'pierroromeu/lora-trained-xl-folder',
54
  ) -> PIL.Image.Image:
55
  if torch.cuda.is_available():
56
- vae = AutoencoderKL.from_pretrained(vaecall, torch_dtype=torch.float16)
57
- # Substitua o valor dinâmico antes de chamar from_pretrained
58
- pipe_model_name = model
59
- pipe = DiffusionPipeline.from_pretrained(
60
- pipe_model_name,
61
- vae=vae,
62
- torch_dtype=torch.float16,
63
- use_safetensors=True,
64
- )
65
- if ENABLE_REFINER:
66
- refiner_model_name = model
67
- refiner = DiffusionPipeline.from_pretrained(
68
- refiner_model_name,
69
- vae=vae,
70
- torch_dtype=torch.float16,
71
- use_safetensors=True,
72
- )
73
 
74
  if ENABLE_CPU_OFFLOAD:
75
  pipe.enable_model_cpu_offload()
76
- if ENABLE_REFINER:
77
- refiner.enable_model_cpu_offload()
78
  else:
79
  pipe.to(device)
80
- if ENABLE_REFINER:
81
- refiner.to(device)
82
 
83
  if USE_TORCH_COMPILE:
84
  pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
85
- if ENABLE_REFINER:
86
- refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
87
-
88
  generator = torch.Generator().manual_seed(seed)
89
 
90
  if not use_negative_prompt:
 
50
  apply_refiner: bool = False,
51
  model = 'stabilityai/stable-diffusion-xl-base-1.0',
52
  vaecall = 'madebyollin/sdxl-vae-fp16-fix',
53
+ lora = 'amazonaws-la/juliette',
54
  ) -> PIL.Image.Image:
55
  if torch.cuda.is_available():
56
+ pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16)
57
+
58
+ pipe.load_lora_weights(lora)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  if ENABLE_CPU_OFFLOAD:
61
  pipe.enable_model_cpu_offload()
62
+
 
63
  else:
64
  pipe.to(device)
 
 
65
 
66
  if USE_TORCH_COMPILE:
67
  pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
68
+
 
 
69
  generator = torch.Generator().manual_seed(seed)
70
 
71
  if not use_negative_prompt: