zhiweili commited on
Commit
01b766d
1 Parent(s): 21ef723

change base model

Browse files
Files changed (1) hide show
  1. app_haircolor.py +30 -9
app_haircolor.py CHANGED
@@ -11,13 +11,15 @@ from segment_utils import(
11
  from diffusers import (
12
  DiffusionPipeline,
13
  T2IAdapter,
 
14
  )
15
 
16
  from controlnet_aux import (
17
  LineartDetector,
 
18
  )
19
 
20
- BASE_MODEL = "stabilityai/sdxl-turbo"
21
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
  DEFAULT_EDIT_PROMPT = "a woman, blue hair, high detailed"
@@ -28,18 +30,30 @@ DEFAULT_CATEGORY = "hair"
28
  lineart_detector = LineartDetector.from_pretrained("lllyasviel/Annotators")
29
  lineart_detector = lineart_detector.to(DEVICE)
30
 
31
- adapter = T2IAdapter.from_pretrained(
32
- "TencentARC/t2i-adapter-lineart-sdxl-1.0",
33
- torch_dtype=torch.float16,
34
- varient="fp16",
 
 
 
 
 
 
 
 
 
 
 
35
  )
 
36
 
37
 
38
  basepipeline = DiffusionPipeline.from_pretrained(
39
  BASE_MODEL,
40
  torch_dtype=torch.float16,
41
  use_safetensors=True,
42
- adapter=adapter,
43
  custom_pipeline="./pipelines/pipeline_sdxl_adapter_img2img.py",
44
  )
45
 
@@ -55,13 +69,20 @@ def image_to_image(
55
  num_steps: int,
56
  guidance_scale: float,
57
  generate_size: int,
58
- adapter_weight: float = 1.0,
 
59
  ):
60
  run_task_time = 0
61
  time_cost_str = ''
62
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
63
  lineart_image = lineart_detector(input_image, int(generate_size*0.375), generate_size)
64
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
 
 
 
 
 
 
65
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
66
  generated_image = basepipeline(
67
  generator=generator,
@@ -72,8 +93,8 @@ def image_to_image(
72
  width=generate_size,
73
  guidance_scale=guidance_scale,
74
  num_inference_steps=num_steps,
75
- adapter_image=lineart_image,
76
- adapter_conditioning_scale=adapter_weight,
77
  ).images[0]
78
 
79
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
 
11
  from diffusers import (
12
  DiffusionPipeline,
13
  T2IAdapter,
14
+ MultiAdapter,
15
  )
16
 
17
  from controlnet_aux import (
18
  LineartDetector,
19
+ CannyDetector,
20
  )
21
 
22
+ BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
23
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
 
25
  DEFAULT_EDIT_PROMPT = "a woman, blue hair, high detailed"
 
30
  lineart_detector = LineartDetector.from_pretrained("lllyasviel/Annotators")
31
  lineart_detector = lineart_detector.to(DEVICE)
32
 
33
+ canndy_detector = CannyDetector()
34
+
35
+ adapters = MultiAdapter(
36
+ [
37
+ T2IAdapter.from_pretrained(
38
+ "TencentARC/t2i-adapter-lineart-sdxl-1.0",
39
+ torch_dtype=torch.float16,
40
+ varient="fp16",
41
+ ),
42
+ T2IAdapter.from_pretrained(
43
+ "TencentARC/t2i-adapter-canny-sdxl-1.0",
44
+ torch_dtype=torch.float16,
45
+ varient="fp16",
46
+ ),
47
+ ]
48
  )
49
+ adapters = adapters.to(torch.float16)
50
 
51
 
52
  basepipeline = DiffusionPipeline.from_pretrained(
53
  BASE_MODEL,
54
  torch_dtype=torch.float16,
55
  use_safetensors=True,
56
+ adapter=adapters,
57
  custom_pipeline="./pipelines/pipeline_sdxl_adapter_img2img.py",
58
  )
59
 
 
69
  num_steps: int,
70
  guidance_scale: float,
71
  generate_size: int,
72
+ lineart_scale: float = 1.0,
73
+ canny_scale: float = 0.5,
74
  ):
75
  run_task_time = 0
76
  time_cost_str = ''
77
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
78
  lineart_image = lineart_detector(input_image, int(generate_size*0.375), generate_size)
79
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
80
+ canny_image = canndy_detector(input_image, int(generate_size*0.375), generate_size)
81
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
82
+
83
+ cond_image = [lineart_image, canny_image]
84
+ cond_scale = [lineart_scale, canny_scale]
85
+
86
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
87
  generated_image = basepipeline(
88
  generator=generator,
 
93
  width=generate_size,
94
  guidance_scale=guidance_scale,
95
  num_inference_steps=num_steps,
96
+ adapter_image=cond_image,
97
+ adapter_conditioning_scale=cond_scale,
98
  ).images[0]
99
 
100
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)