hysts HF staff commited on
Commit
ef87020
1 Parent(s): 473b850

Use different base model and allow changing it

Browse files
Files changed (4) hide show
  1. README.md +2 -2
  2. app.py +39 -4
  3. model.py +28 -35
  4. requirements.txt +2 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: ControlNet
3
- emoji: 🌖
4
  colorFrom: pink
5
  colorTo: blue
6
  sdk: gradio
 
1
  ---
2
+ title: ControlNet with other models
3
+ emoji: 😻
4
  colorFrom: pink
5
  colorTo: blue
6
  sdk: gradio
app.py CHANGED
@@ -41,14 +41,23 @@ from gradio_scribble2image import create_demo as create_demo_scribble
41
  from gradio_scribble2image_interactive import \
42
  create_demo as create_demo_scribble_interactive
43
  from gradio_seg2image import create_demo as create_demo_seg
44
- from model import Model
 
45
 
46
  MAX_IMAGES = 1
47
- DESCRIPTION = '''# ControlNet
48
 
49
- This is an unofficial demo for [https://github.com/lllyasviel/ControlNet](https://github.com/lllyasviel/ControlNet).
 
50
  '''
51
- if (SPACE_ID := os.getenv('SPACE_ID')) is not None:
 
 
 
 
 
 
 
52
  DESCRIPTION += f'''<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.<br/>
53
  <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true">
54
  <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
@@ -59,6 +68,7 @@ model = Model()
59
 
60
  with gr.Blocks(css='style.css') as demo:
61
  gr.Markdown(DESCRIPTION)
 
62
  with gr.Tabs():
63
  with gr.TabItem('Canny'):
64
  create_demo_canny(model.process_canny, max_images=MAX_IMAGES)
@@ -83,4 +93,29 @@ with gr.Blocks(css='style.css') as demo:
83
  with gr.TabItem('Normal map'):
84
  create_demo_normal(model.process_normal, max_images=MAX_IMAGES)
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  demo.queue(api_open=False).launch()
 
41
  from gradio_scribble2image_interactive import \
42
  create_demo as create_demo_scribble_interactive
43
  from gradio_seg2image import create_demo as create_demo_seg
44
+ from model import (DEFAULT_BASE_MODEL_FILENAME, DEFAULT_BASE_MODEL_REPO,
45
+ DEFAULT_BASE_MODEL_URL, Model)
46
 
47
  MAX_IMAGES = 1
48
+ DESCRIPTION = '''# [ControlNet](https://github.com/lllyasviel/ControlNet)
49
 
50
+ This Space is a modified version of [this Space](https://huggingface.co/spaces/hysts/ControlNet).
51
+ The original Space uses [Stable Diffusion v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) as the base model, but [Anything v4.0](https://huggingface.co/andite/anything-v4.0) is used in this Space.
52
  '''
53
+
54
+ SPACE_ID = os.getenv('SPACE_ID')
55
+ ALLOW_CHANGING_BASE_MODEL = SPACE_ID != 'hysts/ControlNet-with-other-models'
56
+
57
+ if not ALLOW_CHANGING_BASE_MODEL:
58
+ DESCRIPTION += 'In this Space, the base model is not allowed to be changed so as not to slow down the demo, but it can be changed if you duplicate the Space.'
59
+
60
+ if SPACE_ID is not None:
61
  DESCRIPTION += f'''<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.<br/>
62
  <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true">
63
  <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
 
68
 
69
  with gr.Blocks(css='style.css') as demo:
70
  gr.Markdown(DESCRIPTION)
71
+
72
  with gr.Tabs():
73
  with gr.TabItem('Canny'):
74
  create_demo_canny(model.process_canny, max_images=MAX_IMAGES)
 
93
  with gr.TabItem('Normal map'):
94
  create_demo_normal(model.process_normal, max_images=MAX_IMAGES)
95
 
96
+ with gr.Accordion(label='Base model', open=False):
97
+ current_base_model = gr.Text(label='Current base model',
98
+ value=DEFAULT_BASE_MODEL_URL)
99
+ with gr.Row():
100
+ base_model_repo = gr.Text(label='Base model repo',
101
+ max_lines=1,
102
+ placeholder=DEFAULT_BASE_MODEL_REPO,
103
+ interactive=ALLOW_CHANGING_BASE_MODEL)
104
+ base_model_filename = gr.Text(
105
+ label='Base model file',
106
+ max_lines=1,
107
+ placeholder=DEFAULT_BASE_MODEL_FILENAME,
108
+ interactive=ALLOW_CHANGING_BASE_MODEL)
109
+ change_base_model_button = gr.Button('Change base model')
110
+ gr.Markdown(
111
+ '''- You can use other base models by specifying the repository name and filename.
112
+ The base model must be compatible with Stable Diffusion v1.5.''')
113
+
114
+ change_base_model_button.click(fn=model.set_base_model,
115
+ inputs=[
116
+ base_model_repo,
117
+ base_model_filename,
118
+ ],
119
+ outputs=current_base_model)
120
+
121
  demo.queue(api_open=False).launch()
model.py CHANGED
@@ -12,6 +12,7 @@ import cv2
12
  import einops
13
  import numpy as np
14
  import torch
 
15
  from pytorch_lightning import seed_everything
16
 
17
  sys.path.append('ControlNet')
@@ -28,19 +29,7 @@ from cldm.model import create_model, load_state_dict
28
  from ldm.models.diffusion.ddim import DDIMSampler
29
  from share import *
30
 
31
- ORIGINAL_MODEL_NAMES = {
32
- 'canny': 'control_sd15_canny.pth',
33
- 'hough': 'control_sd15_mlsd.pth',
34
- 'hed': 'control_sd15_hed.pth',
35
- 'scribble': 'control_sd15_scribble.pth',
36
- 'pose': 'control_sd15_openpose.pth',
37
- 'seg': 'control_sd15_seg.pth',
38
- 'depth': 'control_sd15_depth.pth',
39
- 'normal': 'control_sd15_normal.pth',
40
- }
41
- ORIGINAL_WEIGHT_ROOT = 'https://huggingface.co/lllyasviel/ControlNet/resolve/main/models/'
42
-
43
- LIGHTWEIGHT_MODEL_NAMES = {
44
  'canny': 'control_canny-fp16.safetensors',
45
  'hough': 'control_mlsd-fp16.safetensors',
46
  'hed': 'control_hed-fp16.safetensors',
@@ -50,34 +39,42 @@ LIGHTWEIGHT_MODEL_NAMES = {
50
  'depth': 'control_depth-fp16.safetensors',
51
  'normal': 'control_normal-fp16.safetensors',
52
  }
53
- LIGHTWEIGHT_WEIGHT_ROOT = 'https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/'
 
 
 
 
54
 
55
 
56
  class Model:
57
  def __init__(self,
58
  model_config_path: str = 'ControlNet/models/cldm_v15.yaml',
59
- model_dir: str = 'models',
60
- use_lightweight: bool = True):
61
  self.device = torch.device(
62
  'cuda:0' if torch.cuda.is_available() else 'cpu')
63
  self.model = create_model(model_config_path).to(self.device)
64
  self.ddim_sampler = DDIMSampler(self.model)
65
  self.task_name = ''
66
 
 
67
  self.model_dir = pathlib.Path(model_dir)
 
68
 
69
- self.use_lightweight = use_lightweight
70
- if use_lightweight:
71
- self.model_names = LIGHTWEIGHT_MODEL_NAMES
72
- self.weight_root = LIGHTWEIGHT_WEIGHT_ROOT
73
- base_model_url = 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors'
74
- self.load_base_model(base_model_url)
75
- else:
76
- self.model_names = ORIGINAL_MODEL_NAMES
77
- self.weight_root = ORIGINAL_WEIGHT_ROOT
78
  self.download_models()
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  def download_base_model(self, model_url: str) -> pathlib.Path:
 
81
  model_name = model_url.split('/')[-1]
82
  out_path = self.model_dir / model_name
83
  if not out_path.exists():
@@ -94,27 +91,23 @@ class Model:
94
  if task_name == self.task_name:
95
  return
96
  weight_path = self.get_weight_path(task_name)
97
- if not self.use_lightweight:
98
- self.model.load_state_dict(
99
- load_state_dict(weight_path, location=self.device))
100
- else:
101
- self.model.control_model.load_state_dict(
102
- load_state_dict(weight_path, location=self.device.type))
103
  self.task_name = task_name
104
 
105
  def get_weight_path(self, task_name: str) -> str:
106
  if 'scribble' in task_name:
107
  task_name = 'scribble'
108
- return f'{self.model_dir}/{self.model_names[task_name]}'
109
 
110
  def download_models(self) -> None:
111
  self.model_dir.mkdir(exist_ok=True, parents=True)
112
- for name in self.model_names.values():
113
  out_path = self.model_dir / name
114
  if out_path.exists():
115
  continue
116
- subprocess.run(
117
- shlex.split(f'wget {self.weight_root}{name} -O {out_path}'))
118
 
119
  @torch.inference_mode()
120
  def process_canny(self, input_image, prompt, a_prompt, n_prompt,
 
12
  import einops
13
  import numpy as np
14
  import torch
15
+ from huggingface_hub import hf_hub_url
16
  from pytorch_lightning import seed_everything
17
 
18
  sys.path.append('ControlNet')
 
29
  from ldm.models.diffusion.ddim import DDIMSampler
30
  from share import *
31
 
32
+ MODEL_NAMES = {
 
 
 
 
 
 
 
 
 
 
 
 
33
  'canny': 'control_canny-fp16.safetensors',
34
  'hough': 'control_mlsd-fp16.safetensors',
35
  'hed': 'control_hed-fp16.safetensors',
 
39
  'depth': 'control_depth-fp16.safetensors',
40
  'normal': 'control_normal-fp16.safetensors',
41
  }
42
+ MODEL_REPO = 'webui/ControlNet-modules-safetensors'
43
+
44
+ DEFAULT_BASE_MODEL_REPO = 'andite/anything-v4.0'
45
+ DEFAULT_BASE_MODEL_FILENAME = 'anything-v4.0-pruned.safetensors'
46
+ DEFAULT_BASE_MODEL_URL = 'https://huggingface.co/andite/anything-v4.0/resolve/main/anything-v4.0-pruned.safetensors'
47
 
48
 
49
  class Model:
50
  def __init__(self,
51
  model_config_path: str = 'ControlNet/models/cldm_v15.yaml',
52
+ model_dir: str = 'models'):
 
53
  self.device = torch.device(
54
  'cuda:0' if torch.cuda.is_available() else 'cpu')
55
  self.model = create_model(model_config_path).to(self.device)
56
  self.ddim_sampler = DDIMSampler(self.model)
57
  self.task_name = ''
58
 
59
+ self.base_model_url = ''
60
  self.model_dir = pathlib.Path(model_dir)
61
+ self.model_dir.mkdir(exist_ok=True, parents=True)
62
 
 
 
 
 
 
 
 
 
 
63
  self.download_models()
64
+ self.set_base_model(DEFAULT_BASE_MODEL_REPO,
65
+ DEFAULT_BASE_MODEL_FILENAME)
66
+
67
+ def set_base_model(self, model_id: str, filename: str) -> str:
68
+ if not model_id or not filename:
69
+ return self.base_model_url
70
+ base_model_url = hf_hub_url(model_id, filename)
71
+ if base_model_url != self.base_model_url:
72
+ self.load_base_model(base_model_url)
73
+ self.base_model_url = base_model_url
74
+ return self.base_model_url
75
 
76
  def download_base_model(self, model_url: str) -> pathlib.Path:
77
+ self.model_dir.mkdir(exist_ok=True, parents=True)
78
  model_name = model_url.split('/')[-1]
79
  out_path = self.model_dir / model_name
80
  if not out_path.exists():
 
91
  if task_name == self.task_name:
92
  return
93
  weight_path = self.get_weight_path(task_name)
94
+ self.model.control_model.load_state_dict(
95
+ load_state_dict(weight_path, location=self.device.type))
 
 
 
 
96
  self.task_name = task_name
97
 
98
  def get_weight_path(self, task_name: str) -> str:
99
  if 'scribble' in task_name:
100
  task_name = 'scribble'
101
+ return f'{self.model_dir}/{MODEL_NAMES[task_name]}'
102
 
103
  def download_models(self) -> None:
104
  self.model_dir.mkdir(exist_ok=True, parents=True)
105
+ for name in MODEL_NAMES.values():
106
  out_path = self.model_dir / name
107
  if out_path.exists():
108
  continue
109
+ model_url = hf_hub_url(MODEL_REPO, name)
110
+ subprocess.run(shlex.split(f'wget {model_url} -O {out_path}'))
111
 
112
  @torch.inference_mode()
113
  def process_canny(self, input_image, prompt, a_prompt, n_prompt,
requirements.txt CHANGED
@@ -2,6 +2,7 @@ addict==2.4.0
2
  albumentations==1.3.0
3
  einops==0.6.0
4
  gradio==3.18.0
 
5
  imageio==2.25.0
6
  imageio-ffmpeg==0.4.8
7
  kornia==0.6.9
@@ -16,4 +17,5 @@ timm==0.6.12
16
  torch==1.13.1
17
  torchvision==0.14.1
18
  transformers==4.26.1
 
19
  yapf==0.32.0
 
2
  albumentations==1.3.0
3
  einops==0.6.0
4
  gradio==3.18.0
5
+ huggingface-hub==0.12.0
6
  imageio==2.25.0
7
  imageio-ffmpeg==0.4.8
8
  kornia==0.6.9
 
17
  torch==1.13.1
18
  torchvision==0.14.1
19
  transformers==4.26.1
20
+ xformers==0.0.16
21
  yapf==0.32.0