zhiweili commited on
Commit
7892d1d
1 Parent(s): 885c0aa

change adapter

Browse files
Files changed (1) hide show
  1. app_haircolor.py +6 -10
app_haircolor.py CHANGED
@@ -38,20 +38,15 @@ pidinet_detector = pidinet_detector.to(DEVICE)
38
 
39
  canndy_detector = CannyDetector()
40
 
41
- midas_detector = MidasDetector.from_pretrained(
42
- "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
43
- )
44
- midas_detector = midas_detector.to(DEVICE)
45
-
46
  adapters = MultiAdapter(
47
  [
48
  T2IAdapter.from_pretrained(
49
- "TencentARC/t2i-adapter-lineart-sdxl-1.0",
50
  torch_dtype=torch.float16,
51
  varient="fp16",
52
  ),
53
  T2IAdapter.from_pretrained(
54
- "TencentARC/t2i-adapter-canny-sdxl-1.0",
55
  torch_dtype=torch.float16,
56
  varient="fp16",
57
  )
@@ -85,12 +80,13 @@ def image_to_image(
85
  run_task_time = 0
86
  time_cost_str = ''
87
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
88
- lineart_image = lineart_detector(input_image, int(generate_size*0.375), generate_size)
89
- run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
90
  canny_image = canndy_detector(input_image, int(generate_size*0.375), generate_size)
91
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
 
92
 
93
- cond_image = [lineart_image, canny_image]
94
  cond_scale = [lineart_scale, canny_scale]
95
 
96
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
 
38
 
39
  canndy_detector = CannyDetector()
40
 
 
 
 
 
 
41
  adapters = MultiAdapter(
42
  [
43
  T2IAdapter.from_pretrained(
44
+ "TencentARC/t2iadapter_sketch_sd15v2",
45
  torch_dtype=torch.float16,
46
  varient="fp16",
47
  ),
48
  T2IAdapter.from_pretrained(
49
+ "TencentARC/t2iadapter_canny_sd15v2",
50
  torch_dtype=torch.float16,
51
  varient="fp16",
52
  )
 
80
  run_task_time = 0
81
  time_cost_str = ''
82
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
83
+ # lineart_image = lineart_detector(input_image, int(generate_size*0.375), generate_size)
84
+ # run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
85
  canny_image = canndy_detector(input_image, int(generate_size*0.375), generate_size)
86
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
87
+ sketch_image = pidinet_detector(input_image, int(generate_size*0.5), generate_size)
88
 
89
+ cond_image = [sketch_image, canny_image]
90
  cond_scale = [lineart_scale, canny_scale]
91
 
92
  generator = torch.Generator(device=DEVICE).manual_seed(seed)