tori29umai commited on
Commit
d472855
1 Parent(s): be583dc
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import spaces
2
  import gradio as gr
3
  import torch
4
- from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, DDIMScheduler
5
  from PIL import Image
6
  import os
7
  import time
@@ -29,7 +29,7 @@ def load_model(lora_dir, cn_dir):
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
  dtype = torch.float16
31
  model = "cagliostrolab/animagine-xl-3.1"
32
- scheduler = DDIMScheduler.from_pretrained(model, subfolder="scheduler")
33
  controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype, use_safetensors=True)
34
  pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
35
  model,
@@ -65,9 +65,9 @@ def predict(input_image_path, prompt, negative_prompt, controlnet_scale):
65
  strength=1.0,
66
  prompt=prompt,
67
  negative_prompt = negative_prompt,
68
- controlnet_conditioning_scale=[float(controlnet_scale)],
69
  generator=generator,
70
- num_inference_steps=50,
71
  eta=1.0,
72
  ).images[0]
73
  print(f"Time taken: {time.time() - last_time}")
 
1
  import spaces
2
  import gradio as gr
3
  import torch
4
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, ControlNetModel, AutoencoderKL
5
  from PIL import Image
6
  import os
7
  import time
 
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
  dtype = torch.float16
31
  model = "cagliostrolab/animagine-xl-3.1"
32
+ scheduler = AutoencoderKL.from_pretrained(model, subfolder="scheduler")
33
  controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype, use_safetensors=True)
34
  pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
35
  model,
 
65
  strength=1.0,
66
  prompt=prompt,
67
  negative_prompt = negative_prompt,
68
+ controlnet_conditioning_scale=float(controlnet_scale),
69
  generator=generator,
70
+ num_inference_steps=30,
71
  eta=1.0,
72
  ).images[0]
73
  print(f"Time taken: {time.time() - last_time}")