Spaces:
Runtime error
Runtime error
Use different base model and allow changing it
Browse files
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
-
title: ControlNet
|
3 |
-
emoji:
|
4 |
colorFrom: pink
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
+
title: ControlNet with other models
|
3 |
+
emoji: 😻
|
4 |
colorFrom: pink
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
app.py
CHANGED
@@ -41,14 +41,23 @@ from gradio_scribble2image import create_demo as create_demo_scribble
|
|
41 |
from gradio_scribble2image_interactive import \
|
42 |
create_demo as create_demo_scribble_interactive
|
43 |
from gradio_seg2image import create_demo as create_demo_seg
|
44 |
-
from model import
|
|
|
45 |
|
46 |
MAX_IMAGES = 1
|
47 |
-
DESCRIPTION = '''# ControlNet
|
48 |
|
49 |
-
This is
|
|
|
50 |
'''
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
DESCRIPTION += f'''<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.<br/>
|
53 |
<a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true">
|
54 |
<img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
|
@@ -59,6 +68,7 @@ model = Model()
|
|
59 |
|
60 |
with gr.Blocks(css='style.css') as demo:
|
61 |
gr.Markdown(DESCRIPTION)
|
|
|
62 |
with gr.Tabs():
|
63 |
with gr.TabItem('Canny'):
|
64 |
create_demo_canny(model.process_canny, max_images=MAX_IMAGES)
|
@@ -83,4 +93,29 @@ with gr.Blocks(css='style.css') as demo:
|
|
83 |
with gr.TabItem('Normal map'):
|
84 |
create_demo_normal(model.process_normal, max_images=MAX_IMAGES)
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
demo.queue(api_open=False).launch()
|
|
|
41 |
from gradio_scribble2image_interactive import \
|
42 |
create_demo as create_demo_scribble_interactive
|
43 |
from gradio_seg2image import create_demo as create_demo_seg
|
44 |
+
from model import (DEFAULT_BASE_MODEL_FILENAME, DEFAULT_BASE_MODEL_REPO,
|
45 |
+
DEFAULT_BASE_MODEL_URL, Model)
|
46 |
|
47 |
MAX_IMAGES = 1
|
48 |
+
DESCRIPTION = '''# [ControlNet](https://github.com/lllyasviel/ControlNet)
|
49 |
|
50 |
+
This Space is a modified version of [this Space](https://huggingface.co/spaces/hysts/ControlNet).
|
51 |
+
The original Space uses [Stable Diffusion v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) as the base model, but [Anything v4.0](https://huggingface.co/andite/anything-v4.0) is used in this Space.
|
52 |
'''
|
53 |
+
|
54 |
+
SPACE_ID = os.getenv('SPACE_ID')
|
55 |
+
ALLOW_CHANGING_BASE_MODEL = SPACE_ID != 'hysts/ControlNet-with-other-models'
|
56 |
+
|
57 |
+
if not ALLOW_CHANGING_BASE_MODEL:
|
58 |
+
DESCRIPTION += 'In this Space, the base model is not allowed to be changed so as not to slow down the demo, but it can be changed if you duplicate the Space.'
|
59 |
+
|
60 |
+
if SPACE_ID is not None:
|
61 |
DESCRIPTION += f'''<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.<br/>
|
62 |
<a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true">
|
63 |
<img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
|
|
|
68 |
|
69 |
with gr.Blocks(css='style.css') as demo:
|
70 |
gr.Markdown(DESCRIPTION)
|
71 |
+
|
72 |
with gr.Tabs():
|
73 |
with gr.TabItem('Canny'):
|
74 |
create_demo_canny(model.process_canny, max_images=MAX_IMAGES)
|
|
|
93 |
with gr.TabItem('Normal map'):
|
94 |
create_demo_normal(model.process_normal, max_images=MAX_IMAGES)
|
95 |
|
96 |
+
with gr.Accordion(label='Base model', open=False):
|
97 |
+
current_base_model = gr.Text(label='Current base model',
|
98 |
+
value=DEFAULT_BASE_MODEL_URL)
|
99 |
+
with gr.Row():
|
100 |
+
base_model_repo = gr.Text(label='Base model repo',
|
101 |
+
max_lines=1,
|
102 |
+
placeholder=DEFAULT_BASE_MODEL_REPO,
|
103 |
+
interactive=ALLOW_CHANGING_BASE_MODEL)
|
104 |
+
base_model_filename = gr.Text(
|
105 |
+
label='Base model file',
|
106 |
+
max_lines=1,
|
107 |
+
placeholder=DEFAULT_BASE_MODEL_FILENAME,
|
108 |
+
interactive=ALLOW_CHANGING_BASE_MODEL)
|
109 |
+
change_base_model_button = gr.Button('Change base model')
|
110 |
+
gr.Markdown(
|
111 |
+
'''- You can use other base models by specifying the repository name and filename.
|
112 |
+
The base model must be compatible with Stable Diffusion v1.5.''')
|
113 |
+
|
114 |
+
change_base_model_button.click(fn=model.set_base_model,
|
115 |
+
inputs=[
|
116 |
+
base_model_repo,
|
117 |
+
base_model_filename,
|
118 |
+
],
|
119 |
+
outputs=current_base_model)
|
120 |
+
|
121 |
demo.queue(api_open=False).launch()
|
model.py
CHANGED
@@ -12,6 +12,7 @@ import cv2
|
|
12 |
import einops
|
13 |
import numpy as np
|
14 |
import torch
|
|
|
15 |
from pytorch_lightning import seed_everything
|
16 |
|
17 |
sys.path.append('ControlNet')
|
@@ -28,19 +29,7 @@ from cldm.model import create_model, load_state_dict
|
|
28 |
from ldm.models.diffusion.ddim import DDIMSampler
|
29 |
from share import *
|
30 |
|
31 |
-
|
32 |
-
'canny': 'control_sd15_canny.pth',
|
33 |
-
'hough': 'control_sd15_mlsd.pth',
|
34 |
-
'hed': 'control_sd15_hed.pth',
|
35 |
-
'scribble': 'control_sd15_scribble.pth',
|
36 |
-
'pose': 'control_sd15_openpose.pth',
|
37 |
-
'seg': 'control_sd15_seg.pth',
|
38 |
-
'depth': 'control_sd15_depth.pth',
|
39 |
-
'normal': 'control_sd15_normal.pth',
|
40 |
-
}
|
41 |
-
ORIGINAL_WEIGHT_ROOT = 'https://huggingface.co/lllyasviel/ControlNet/resolve/main/models/'
|
42 |
-
|
43 |
-
LIGHTWEIGHT_MODEL_NAMES = {
|
44 |
'canny': 'control_canny-fp16.safetensors',
|
45 |
'hough': 'control_mlsd-fp16.safetensors',
|
46 |
'hed': 'control_hed-fp16.safetensors',
|
@@ -50,34 +39,42 @@ LIGHTWEIGHT_MODEL_NAMES = {
|
|
50 |
'depth': 'control_depth-fp16.safetensors',
|
51 |
'normal': 'control_normal-fp16.safetensors',
|
52 |
}
|
53 |
-
|
|
|
|
|
|
|
|
|
54 |
|
55 |
|
56 |
class Model:
|
57 |
def __init__(self,
|
58 |
model_config_path: str = 'ControlNet/models/cldm_v15.yaml',
|
59 |
-
model_dir: str = 'models'
|
60 |
-
use_lightweight: bool = True):
|
61 |
self.device = torch.device(
|
62 |
'cuda:0' if torch.cuda.is_available() else 'cpu')
|
63 |
self.model = create_model(model_config_path).to(self.device)
|
64 |
self.ddim_sampler = DDIMSampler(self.model)
|
65 |
self.task_name = ''
|
66 |
|
|
|
67 |
self.model_dir = pathlib.Path(model_dir)
|
|
|
68 |
|
69 |
-
self.use_lightweight = use_lightweight
|
70 |
-
if use_lightweight:
|
71 |
-
self.model_names = LIGHTWEIGHT_MODEL_NAMES
|
72 |
-
self.weight_root = LIGHTWEIGHT_WEIGHT_ROOT
|
73 |
-
base_model_url = 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors'
|
74 |
-
self.load_base_model(base_model_url)
|
75 |
-
else:
|
76 |
-
self.model_names = ORIGINAL_MODEL_NAMES
|
77 |
-
self.weight_root = ORIGINAL_WEIGHT_ROOT
|
78 |
self.download_models()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
def download_base_model(self, model_url: str) -> pathlib.Path:
|
|
|
81 |
model_name = model_url.split('/')[-1]
|
82 |
out_path = self.model_dir / model_name
|
83 |
if not out_path.exists():
|
@@ -94,27 +91,23 @@ class Model:
|
|
94 |
if task_name == self.task_name:
|
95 |
return
|
96 |
weight_path = self.get_weight_path(task_name)
|
97 |
-
|
98 |
-
self.
|
99 |
-
load_state_dict(weight_path, location=self.device))
|
100 |
-
else:
|
101 |
-
self.model.control_model.load_state_dict(
|
102 |
-
load_state_dict(weight_path, location=self.device.type))
|
103 |
self.task_name = task_name
|
104 |
|
105 |
def get_weight_path(self, task_name: str) -> str:
|
106 |
if 'scribble' in task_name:
|
107 |
task_name = 'scribble'
|
108 |
-
return f'{self.model_dir}/{
|
109 |
|
110 |
def download_models(self) -> None:
|
111 |
self.model_dir.mkdir(exist_ok=True, parents=True)
|
112 |
-
for name in
|
113 |
out_path = self.model_dir / name
|
114 |
if out_path.exists():
|
115 |
continue
|
116 |
-
|
117 |
-
|
118 |
|
119 |
@torch.inference_mode()
|
120 |
def process_canny(self, input_image, prompt, a_prompt, n_prompt,
|
|
|
12 |
import einops
|
13 |
import numpy as np
|
14 |
import torch
|
15 |
+
from huggingface_hub import hf_hub_url
|
16 |
from pytorch_lightning import seed_everything
|
17 |
|
18 |
sys.path.append('ControlNet')
|
|
|
29 |
from ldm.models.diffusion.ddim import DDIMSampler
|
30 |
from share import *
|
31 |
|
32 |
+
MODEL_NAMES = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
'canny': 'control_canny-fp16.safetensors',
|
34 |
'hough': 'control_mlsd-fp16.safetensors',
|
35 |
'hed': 'control_hed-fp16.safetensors',
|
|
|
39 |
'depth': 'control_depth-fp16.safetensors',
|
40 |
'normal': 'control_normal-fp16.safetensors',
|
41 |
}
|
42 |
+
MODEL_REPO = 'webui/ControlNet-modules-safetensors'
|
43 |
+
|
44 |
+
DEFAULT_BASE_MODEL_REPO = 'andite/anything-v4.0'
|
45 |
+
DEFAULT_BASE_MODEL_FILENAME = 'anything-v4.0-pruned.safetensors'
|
46 |
+
DEFAULT_BASE_MODEL_URL = 'https://huggingface.co/andite/anything-v4.0/resolve/main/anything-v4.0-pruned.safetensors'
|
47 |
|
48 |
|
49 |
class Model:
|
50 |
def __init__(self,
|
51 |
model_config_path: str = 'ControlNet/models/cldm_v15.yaml',
|
52 |
+
model_dir: str = 'models'):
|
|
|
53 |
self.device = torch.device(
|
54 |
'cuda:0' if torch.cuda.is_available() else 'cpu')
|
55 |
self.model = create_model(model_config_path).to(self.device)
|
56 |
self.ddim_sampler = DDIMSampler(self.model)
|
57 |
self.task_name = ''
|
58 |
|
59 |
+
self.base_model_url = ''
|
60 |
self.model_dir = pathlib.Path(model_dir)
|
61 |
+
self.model_dir.mkdir(exist_ok=True, parents=True)
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
self.download_models()
|
64 |
+
self.set_base_model(DEFAULT_BASE_MODEL_REPO,
|
65 |
+
DEFAULT_BASE_MODEL_FILENAME)
|
66 |
+
|
67 |
+
def set_base_model(self, model_id: str, filename: str) -> str:
|
68 |
+
if not model_id or not filename:
|
69 |
+
return self.base_model_url
|
70 |
+
base_model_url = hf_hub_url(model_id, filename)
|
71 |
+
if base_model_url != self.base_model_url:
|
72 |
+
self.load_base_model(base_model_url)
|
73 |
+
self.base_model_url = base_model_url
|
74 |
+
return self.base_model_url
|
75 |
|
76 |
def download_base_model(self, model_url: str) -> pathlib.Path:
|
77 |
+
self.model_dir.mkdir(exist_ok=True, parents=True)
|
78 |
model_name = model_url.split('/')[-1]
|
79 |
out_path = self.model_dir / model_name
|
80 |
if not out_path.exists():
|
|
|
91 |
if task_name == self.task_name:
|
92 |
return
|
93 |
weight_path = self.get_weight_path(task_name)
|
94 |
+
self.model.control_model.load_state_dict(
|
95 |
+
load_state_dict(weight_path, location=self.device.type))
|
|
|
|
|
|
|
|
|
96 |
self.task_name = task_name
|
97 |
|
98 |
def get_weight_path(self, task_name: str) -> str:
|
99 |
if 'scribble' in task_name:
|
100 |
task_name = 'scribble'
|
101 |
+
return f'{self.model_dir}/{MODEL_NAMES[task_name]}'
|
102 |
|
103 |
def download_models(self) -> None:
|
104 |
self.model_dir.mkdir(exist_ok=True, parents=True)
|
105 |
+
for name in MODEL_NAMES.values():
|
106 |
out_path = self.model_dir / name
|
107 |
if out_path.exists():
|
108 |
continue
|
109 |
+
model_url = hf_hub_url(MODEL_REPO, name)
|
110 |
+
subprocess.run(shlex.split(f'wget {model_url} -O {out_path}'))
|
111 |
|
112 |
@torch.inference_mode()
|
113 |
def process_canny(self, input_image, prompt, a_prompt, n_prompt,
|
requirements.txt
CHANGED
@@ -2,6 +2,7 @@ addict==2.4.0
|
|
2 |
albumentations==1.3.0
|
3 |
einops==0.6.0
|
4 |
gradio==3.18.0
|
|
|
5 |
imageio==2.25.0
|
6 |
imageio-ffmpeg==0.4.8
|
7 |
kornia==0.6.9
|
@@ -16,4 +17,5 @@ timm==0.6.12
|
|
16 |
torch==1.13.1
|
17 |
torchvision==0.14.1
|
18 |
transformers==4.26.1
|
|
|
19 |
yapf==0.32.0
|
|
|
2 |
albumentations==1.3.0
|
3 |
einops==0.6.0
|
4 |
gradio==3.18.0
|
5 |
+
huggingface-hub==0.12.0
|
6 |
imageio==2.25.0
|
7 |
imageio-ffmpeg==0.4.8
|
8 |
kornia==0.6.9
|
|
|
17 |
torch==1.13.1
|
18 |
torchvision==0.14.1
|
19 |
transformers==4.26.1
|
20 |
+
xformers==0.0.16
|
21 |
yapf==0.32.0
|