AndranikSargsyan commited on
Commit
f1cc496
β€’
1 Parent(s): 073105a

add support for diffusers checkpoint loading

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. .gitignore +3 -1
  2. LICENSE +21 -0
  3. {assets β†’ __assets__/demo}/config/ddpm/v1.yaml +0 -0
  4. {assets β†’ __assets__/demo}/config/ddpm/v2-upsample.yaml +0 -0
  5. {assets β†’ __assets__/demo}/config/encoders/clip.yaml +0 -0
  6. {assets β†’ __assets__/demo}/config/encoders/openclip.yaml +0 -0
  7. {assets β†’ __assets__/demo}/config/unet/inpainting/v1.yaml +0 -0
  8. {assets β†’ __assets__/demo}/config/unet/inpainting/v2.yaml +0 -0
  9. {assets β†’ __assets__/demo}/config/unet/upsample/v2.yaml +0 -0
  10. {assets β†’ __assets__/demo}/config/vae-upsample.yaml +0 -0
  11. {assets β†’ __assets__/demo}/config/vae.yaml +0 -0
  12. {assets β†’ __assets__/demo}/examples/images_1024/a19.jpg +0 -0
  13. {assets β†’ __assets__/demo}/examples/images_1024/a2.jpg +0 -0
  14. {assets β†’ __assets__/demo}/examples/images_1024/a4.jpg +0 -0
  15. {assets β†’ __assets__/demo}/examples/images_1024/a40.jpg +0 -0
  16. {assets β†’ __assets__/demo}/examples/images_1024/a46.jpg +0 -0
  17. {assets β†’ __assets__/demo}/examples/images_1024/a51.jpg +0 -0
  18. {assets β†’ __assets__/demo}/examples/images_1024/a54.jpg +0 -0
  19. {assets β†’ __assets__/demo}/examples/images_1024/a65.jpg +0 -0
  20. {assets β†’ __assets__/demo}/examples/images_2048/a19.jpg +0 -0
  21. {assets β†’ __assets__/demo}/examples/images_2048/a2.jpg +0 -0
  22. {assets β†’ __assets__/demo}/examples/images_2048/a4.jpg +0 -0
  23. {assets β†’ __assets__/demo}/examples/images_2048/a40.jpg +0 -0
  24. {assets β†’ __assets__/demo}/examples/images_2048/a46.jpg +0 -0
  25. {assets β†’ __assets__/demo}/examples/images_2048/a51.jpg +0 -0
  26. {assets β†’ __assets__/demo}/examples/images_2048/a54.jpg +0 -0
  27. {assets β†’ __assets__/demo}/examples/images_2048/a65.jpg +0 -0
  28. {assets β†’ __assets__/demo}/examples/sbs/a19.png +0 -0
  29. {assets β†’ __assets__/demo}/examples/sbs/a2.png +0 -0
  30. {assets β†’ __assets__/demo}/examples/sbs/a4.png +0 -0
  31. {assets β†’ __assets__/demo}/examples/sbs/a40.png +0 -0
  32. {assets β†’ __assets__/demo}/examples/sbs/a46.png +0 -0
  33. {assets β†’ __assets__/demo}/examples/sbs/a51.png +0 -0
  34. {assets β†’ __assets__/demo}/examples/sbs/a54.png +0 -0
  35. {assets β†’ __assets__/demo}/examples/sbs/a65.png +0 -0
  36. {assets β†’ __assets__/demo}/sr_info.png +0 -0
  37. app.py +59 -98
  38. assets/.gitignore +0 -1
  39. config/ddpm/v1.yaml +14 -0
  40. config/ddpm/v2-upsample.yaml +24 -0
  41. config/encoders/clip.yaml +1 -0
  42. config/encoders/openclip.yaml +4 -0
  43. config/unet/inpainting/v1.yaml +15 -0
  44. config/unet/inpainting/v2.yaml +16 -0
  45. config/unet/upsample/v2.yaml +19 -0
  46. config/vae-upsample.yaml +16 -0
  47. config/vae.yaml +17 -0
  48. lib/models/__init__.py +0 -1
  49. lib/models/common.py +0 -49
  50. lib/models/ds_inp.py +0 -51
.gitignore CHANGED
@@ -4,4 +4,6 @@
4
 
5
  outputs/
6
  gradio_tmp/
7
- __pycache__/
 
 
 
4
 
5
  outputs/
6
  gradio_tmp/
7
+ __pycache__/
8
+
9
+ checkpoints/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Picsart AI Research (PAIR)
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
{assets β†’ __assets__/demo}/config/ddpm/v1.yaml RENAMED
File without changes
{assets β†’ __assets__/demo}/config/ddpm/v2-upsample.yaml RENAMED
File without changes
{assets β†’ __assets__/demo}/config/encoders/clip.yaml RENAMED
File without changes
{assets β†’ __assets__/demo}/config/encoders/openclip.yaml RENAMED
File without changes
{assets β†’ __assets__/demo}/config/unet/inpainting/v1.yaml RENAMED
File without changes
{assets β†’ __assets__/demo}/config/unet/inpainting/v2.yaml RENAMED
File without changes
{assets β†’ __assets__/demo}/config/unet/upsample/v2.yaml RENAMED
File without changes
{assets β†’ __assets__/demo}/config/vae-upsample.yaml RENAMED
File without changes
{assets β†’ __assets__/demo}/config/vae.yaml RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/images_1024/a19.jpg RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/images_1024/a2.jpg RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/images_1024/a4.jpg RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/images_1024/a40.jpg RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/images_1024/a46.jpg RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/images_1024/a51.jpg RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/images_1024/a54.jpg RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/images_1024/a65.jpg RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/images_2048/a19.jpg RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/images_2048/a2.jpg RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/images_2048/a4.jpg RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/images_2048/a40.jpg RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/images_2048/a46.jpg RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/images_2048/a51.jpg RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/images_2048/a54.jpg RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/images_2048/a65.jpg RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/sbs/a19.png RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/sbs/a2.png RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/sbs/a4.png RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/sbs/a40.png RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/sbs/a46.png RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/sbs/a51.png RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/sbs/a54.png RENAMED
File without changes
{assets β†’ __assets__/demo}/examples/sbs/a65.png RENAMED
File without changes
{assets β†’ __assets__/demo}/sr_info.png RENAMED
File without changes
app.py CHANGED
@@ -1,40 +1,44 @@
1
  import os
 
 
2
  from collections import OrderedDict
3
 
4
  import gradio as gr
5
  import shutil
6
  import uuid
7
  import torch
8
- from pathlib import Path
9
- from lib.utils.iimage import IImage
10
  from PIL import Image
11
 
12
- from lib import models
13
- from lib.methods import rasg, sd, sr
14
- from lib.utils import poisson_blend, image_from_url_text
 
 
 
15
 
16
 
17
- TMP_DIR = 'gradio_tmp'
18
- if Path(TMP_DIR).exists():
19
- shutil.rmtree(TMP_DIR)
20
- Path(TMP_DIR).mkdir(exist_ok=True, parents=True)
21
 
22
- os.environ['GRADIO_TEMP_DIR'] = TMP_DIR
23
 
24
  on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"
25
 
26
  negative_prompt_str = "text, bad anatomy, bad proportions, blurry, cropped, deformed, disfigured, duplicate, error, extra limbs, gross proportions, jpeg artifacts, long neck, low quality, lowres, malformed, morbid, mutated, mutilated, out of frame, ugly, worst quality"
27
  positive_prompt_str = "Full HD, 4K, high quality, high resolution"
28
 
 
29
  example_inputs = [
30
- ['assets/examples/images_1024/a40.jpg', 'assets/examples/images_2048/a40.jpg', 'medieval castle'],
31
- ['assets/examples/images_1024/a4.jpg', 'assets/examples/images_2048/a4.jpg', 'parrot'],
32
- ['assets/examples/images_1024/a65.jpg', 'assets/examples/images_2048/a65.jpg', 'hoodie'],
33
- ['assets/examples/images_1024/a54.jpg', 'assets/examples/images_2048/a54.jpg', 'salad'],
34
- ['assets/examples/images_1024/a51.jpg', 'assets/examples/images_2048/a51.jpg', 'space helmet'],
35
- ['assets/examples/images_1024/a46.jpg', 'assets/examples/images_2048/a46.jpg', 'stack of books'],
36
- ['assets/examples/images_1024/a19.jpg', 'assets/examples/images_2048/a19.jpg', 'antique greek vase'],
37
- ['assets/examples/images_1024/a2.jpg', 'assets/examples/images_2048/a2.jpg', 'sunglasses'],
38
  ]
39
 
40
  thumbnails = [
@@ -60,27 +64,35 @@ example_previews = [
60
  ]
61
 
62
  # Load models
 
63
  inpainting_models = OrderedDict([
64
- ("Dreamshaper Inpainting V8", models.ds_inp.load_model()),
65
- ("Stable-Inpainting 2.0", models.sd2_inp.load_model()),
66
- ("Stable-Inpainting 1.5", models.sd15_inp.load_model())
67
  ])
68
  sr_model = models.sd2_sr.load_model(device='cuda:1')
69
  sam_predictor = models.sam.load_model(device='cuda:0')
70
 
71
- inp_model = inpainting_models[list(inpainting_models.keys())[0]]
72
- def set_model_from_name(inp_model_name):
 
 
 
 
73
  global inp_model
74
- print (f"Activating Inpaintng Model: {inp_model_name}")
75
- inp_model = inpainting_models[inp_model_name]
 
 
 
 
76
 
77
 
78
  def save_user_session(hr_image, hr_mask, lr_results, prompt, session_id=None):
79
  if session_id == '':
80
  session_id = str(uuid.uuid4())
81
 
82
- tmp_dir = Path(TMP_DIR)
83
- session_dir = tmp_dir / session_id
84
  session_dir.mkdir(exist_ok=True, parents=True)
85
 
86
  hr_image.save(session_dir / 'hr_image.png')
@@ -103,8 +115,7 @@ def recover_user_session(session_id):
103
  if session_id == '':
104
  return None, None, [], ''
105
 
106
- tmp_dir = Path(TMP_DIR)
107
- session_dir = tmp_dir / session_id
108
  lr_results_dir = session_dir / 'lr_results'
109
 
110
  hr_image = Image.open(session_dir / 'hr_image.png')
@@ -121,64 +132,22 @@ def recover_user_session(session_id):
121
  return hr_image, hr_mask, gallery, prompt
122
 
123
 
124
- def rasg_run(
125
- use_painta, prompt, imageMask, hr_image, seed, eta,
126
- negative_prompt, positive_prompt, ddim_steps,
127
- guidance_scale=7.5,
128
- batch_size=1, session_id=''
129
  ):
130
  torch.cuda.empty_cache()
 
131
 
132
- seed = int(seed)
133
- batch_size = max(1, min(int(batch_size), 4))
134
-
135
- image = IImage(hr_image).resize(512)
136
- mask = IImage(imageMask['mask']).rgb().resize(512)
137
-
138
- method = ['rasg']
139
  if use_painta: method.append('painta')
 
140
  method = '-'.join(method)
141
 
142
- inpainted_images = []
143
- blended_images = []
144
- for i in range(batch_size):
145
- seed = seed + i * 1000
146
-
147
- inpainted_image = rasg.run(
148
- ddim=inp_model,
149
- method=method,
150
- prompt=prompt,
151
- image=image,
152
- mask=mask,
153
- seed=seed,
154
- eta=eta,
155
- negative_prompt=negative_prompt,
156
- positive_prompt=positive_prompt,
157
- num_steps=ddim_steps,
158
- guidance_scale=guidance_scale
159
- ).crop(image.size)
160
-
161
- blended_image = poisson_blend(
162
- orig_img=image.data[0],
163
- fake_img=inpainted_image.data[0],
164
- mask=mask.data[0],
165
- dilation=12
166
- )
167
- blended_images.append(blended_image)
168
- inpainted_images.append(inpainted_image.pil())
169
-
170
- session_id = save_user_session(
171
- hr_image, imageMask['mask'], inpainted_images, prompt, session_id=session_id)
172
-
173
- return blended_images, session_id
174
-
175
-
176
- def sd_run(use_painta, prompt, imageMask, hr_image, seed, eta,
177
- negative_prompt, positive_prompt, ddim_steps,
178
- guidance_scale=7.5,
179
- batch_size=1, session_id=''
180
- ):
181
- torch.cuda.empty_cache()
182
 
183
  seed = int(seed)
184
  batch_size = max(1, min(int(batch_size), 4))
@@ -195,7 +164,7 @@ def sd_run(use_painta, prompt, imageMask, hr_image, seed, eta,
195
  for i in range(batch_size):
196
  seed = seed + i * 1000
197
 
198
- inpainted_image = sd.run(
199
  ddim=inp_model,
200
  method=method,
201
  prompt=prompt,
@@ -226,13 +195,12 @@ def sd_run(use_painta, prompt, imageMask, hr_image, seed, eta,
226
 
227
  def upscale_run(
228
  ddim_steps, seed, use_sam_mask, session_id, img_index,
229
- negative_prompt='',
230
- positive_prompt=', high resolution professional photo'
231
  ):
232
  hr_image, hr_mask, gallery, prompt = recover_user_session(session_id)
233
 
234
  if len(gallery) == 0:
235
- return Image.open('./assets/sr_info.png')
236
 
237
  torch.cuda.empty_cache()
238
 
@@ -249,7 +217,7 @@ def upscale_run(
249
  inpainted_image,
250
  hr_image,
251
  hr_mask,
252
- prompt=prompt + positive_prompt,
253
  noise_level=20,
254
  blend_trick=True,
255
  blend_output=True,
@@ -261,14 +229,7 @@ def upscale_run(
261
  return output_image
262
 
263
 
264
- def switch_run(use_rasg, model_name, *args):
265
- set_model_from_name(model_name)
266
- if use_rasg:
267
- return rasg_run(*args)
268
- return sd_run(*args)
269
-
270
-
271
- with gr.Blocks(css='style.css') as demo:
272
  gr.HTML(
273
  """
274
  <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
@@ -300,7 +261,7 @@ with gr.Blocks(css='style.css') as demo:
300
  <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
301
  </p>""")
302
 
303
- with open('script.js', 'r') as f:
304
  js_str = f.read()
305
 
306
  demo.load(_js=js_str)
@@ -380,10 +341,10 @@ with gr.Blocks(css='style.css') as demo:
380
  html_info = gr.HTML(elem_id=f'html_info', elem_classes="infotext")
381
 
382
  inpaint_btn.click(
383
- fn=switch_run,
384
  inputs=[
385
- use_rasg,
386
  model_picker,
 
387
  use_painta,
388
  prompt,
389
  imageMask,
@@ -415,4 +376,4 @@ with gr.Blocks(css='style.css') as demo:
415
  )
416
 
417
  demo.queue(max_size=20)
418
- demo.launch(share=True, allowed_paths=[TMP_DIR])
 
1
  import os
2
+ import sys
3
+ from pathlib import Path
4
  from collections import OrderedDict
5
 
6
  import gradio as gr
7
  import shutil
8
  import uuid
9
  import torch
 
 
10
  from PIL import Image
11
 
12
+ demo_path = Path(__file__).resolve().parent
13
+ root_path = demo_path
14
+ sys.path.append(str(root_path))
15
+ from src import models
16
+ from src.methods import rasg, sd, sr
17
+ from src.utils import IImage, poisson_blend, image_from_url_text
18
 
19
 
20
+ TMP_DIR = root_path / 'gradio_tmp'
21
+ if TMP_DIR.exists():
22
+ shutil.rmtree(str(TMP_DIR))
23
+ TMP_DIR.mkdir(exist_ok=True, parents=True)
24
 
25
+ os.environ['GRADIO_TEMP_DIR'] = str(TMP_DIR)
26
 
27
  on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"
28
 
29
  negative_prompt_str = "text, bad anatomy, bad proportions, blurry, cropped, deformed, disfigured, duplicate, error, extra limbs, gross proportions, jpeg artifacts, long neck, low quality, lowres, malformed, morbid, mutated, mutilated, out of frame, ugly, worst quality"
30
  positive_prompt_str = "Full HD, 4K, high quality, high resolution"
31
 
32
+ examples_path = root_path / '__assets__/demo/examples'
33
  example_inputs = [
34
+ [f'{examples_path}/images_1024/a40.jpg', f'{examples_path}/images_2048/a40.jpg', 'medieval castle'],
35
+ [f'{examples_path}/images_1024/a4.jpg', f'{examples_path}/images_2048/a4.jpg', 'parrot'],
36
+ [f'{examples_path}/images_1024/a65.jpg', f'{examples_path}/images_2048/a65.jpg', 'hoodie'],
37
+ [f'{examples_path}/images_1024/a54.jpg', f'{examples_path}/images_2048/a54.jpg', 'salad'],
38
+ [f'{examples_path}/images_1024/a51.jpg', f'{examples_path}/images_2048/a51.jpg', 'space helmet'],
39
+ [f'{examples_path}/images_1024/a46.jpg', f'{examples_path}/images_2048/a46.jpg', 'stack of books'],
40
+ [f'{examples_path}/images_1024/a19.jpg', f'{examples_path}/images_2048/a19.jpg', 'antique greek vase'],
41
+ [f'{examples_path}/images_1024/a2.jpg', f'{examples_path}/images_2048/a2.jpg', 'sunglasses'],
42
  ]
43
 
44
  thumbnails = [
 
64
  ]
65
 
66
  # Load models
67
+ models.pre_download_inpainting_models()
68
  inpainting_models = OrderedDict([
69
+ ("Dreamshaper Inpainting V8", 'ds8_inp'),
70
+ ("Stable-Inpainting 2.0", 'sd2_inp'),
71
+ ("Stable-Inpainting 1.5", 'sd15_inp')
72
  ])
73
  sr_model = models.sd2_sr.load_model(device='cuda:1')
74
  sam_predictor = models.sam.load_model(device='cuda:0')
75
 
76
+ inp_model_name = list(inpainting_models.keys())[0]
77
+ inp_model = models.load_inpainting_model(
78
+ inpainting_models[inp_model_name], device='cuda:0', cache=False)
79
+
80
+
81
+ def set_model_from_name(new_inp_model_name):
82
  global inp_model
83
+ global inp_model_name
84
+ if new_inp_model_name != inp_model_name:
85
+ print (f"Activating Inpaintng Model: {new_inp_model_name}")
86
+ inp_model = models.load_inpainting_model(
87
+ inpainting_models[new_inp_model_name], device='cuda:0', cache=False)
88
+ inp_model_name = new_inp_model_name
89
 
90
 
91
  def save_user_session(hr_image, hr_mask, lr_results, prompt, session_id=None):
92
  if session_id == '':
93
  session_id = str(uuid.uuid4())
94
 
95
+ session_dir = TMP_DIR / session_id
 
96
  session_dir.mkdir(exist_ok=True, parents=True)
97
 
98
  hr_image.save(session_dir / 'hr_image.png')
 
115
  if session_id == '':
116
  return None, None, [], ''
117
 
118
+ session_dir = TMP_DIR / session_id
 
119
  lr_results_dir = session_dir / 'lr_results'
120
 
121
  hr_image = Image.open(session_dir / 'hr_image.png')
 
132
  return hr_image, hr_mask, gallery, prompt
133
 
134
 
135
+ def inpainting_run(model_name, use_rasg, use_painta, prompt, imageMask,
136
+ hr_image, seed, eta, negative_prompt, positive_prompt, ddim_steps,
137
+ guidance_scale=7.5, batch_size=1, session_id=''
 
 
138
  ):
139
  torch.cuda.empty_cache()
140
+ set_model_from_name(model_name)
141
 
142
+ method = ['default']
 
 
 
 
 
 
143
  if use_painta: method.append('painta')
144
+ if use_rasg: method.append('rasg')
145
  method = '-'.join(method)
146
 
147
+ if use_rasg:
148
+ inpainting_f = rasg.run
149
+ else:
150
+ inpainting_f = sd.run
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  seed = int(seed)
153
  batch_size = max(1, min(int(batch_size), 4))
 
164
  for i in range(batch_size):
165
  seed = seed + i * 1000
166
 
167
+ inpainted_image = inpainting_f(
168
  ddim=inp_model,
169
  method=method,
170
  prompt=prompt,
 
195
 
196
  def upscale_run(
197
  ddim_steps, seed, use_sam_mask, session_id, img_index,
198
+ negative_prompt='', positive_prompt='high resolution professional photo'
 
199
  ):
200
  hr_image, hr_mask, gallery, prompt = recover_user_session(session_id)
201
 
202
  if len(gallery) == 0:
203
+ return Image.open(root_path / '__assets__/sr_info.png')
204
 
205
  torch.cuda.empty_cache()
206
 
 
217
  inpainted_image,
218
  hr_image,
219
  hr_mask,
220
+ prompt=f'{prompt}, {positive_prompt}',
221
  noise_level=20,
222
  blend_trick=True,
223
  blend_output=True,
 
229
  return output_image
230
 
231
 
232
+ with gr.Blocks(css=demo_path / 'style.css') as demo:
 
 
 
 
 
 
 
233
  gr.HTML(
234
  """
235
  <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
 
261
  <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
262
  </p>""")
263
 
264
+ with open(demo_path / 'script.js', 'r') as f:
265
  js_str = f.read()
266
 
267
  demo.load(_js=js_str)
 
341
  html_info = gr.HTML(elem_id=f'html_info', elem_classes="infotext")
342
 
343
  inpaint_btn.click(
344
+ fn=inpainting_run,
345
  inputs=[
 
346
  model_picker,
347
+ use_rasg,
348
  use_painta,
349
  prompt,
350
  imageMask,
 
376
  )
377
 
378
  demo.queue(max_size=20)
379
+ demo.launch(share=True, allowed_paths=[str(TMP_DIR)])
assets/.gitignore DELETED
@@ -1 +0,0 @@
1
- models/
 
 
config/ddpm/v1.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ linear_start: 0.00085
2
+ linear_end: 0.0120
3
+ num_timesteps_cond: 1
4
+ log_every_t: 200
5
+ timesteps: 1000
6
+ first_stage_key: "jpg"
7
+ cond_stage_key: "txt"
8
+ image_size: 64
9
+ channels: 4
10
+ cond_stage_trainable: false
11
+ conditioning_key: crossattn
12
+ monitor: val/loss_simple_ema
13
+ scale_factor: 0.18215
14
+ use_ema: False # we set this to false because this is an inference only config
config/ddpm/v2-upsample.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ parameterization: "v"
2
+ low_scale_key: "lr"
3
+ linear_start: 0.0001
4
+ linear_end: 0.02
5
+ num_timesteps_cond: 1
6
+ log_every_t: 200
7
+ timesteps: 1000
8
+ first_stage_key: "jpg"
9
+ cond_stage_key: "txt"
10
+ image_size: 128
11
+ channels: 4
12
+ cond_stage_trainable: false
13
+ conditioning_key: "hybrid-adm"
14
+ monitor: val/loss_simple_ema
15
+ scale_factor: 0.08333
16
+ use_ema: False
17
+
18
+ low_scale_config:
19
+ target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation
20
+ params:
21
+ noise_schedule_config: # image space
22
+ linear_start: 0.0001
23
+ linear_end: 0.02
24
+ max_noise_level: 350
config/encoders/clip.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ __class__: smplfusion.models.encoders.clip_embedder.FrozenCLIPEmbedder
config/encoders/openclip.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __class__: smplfusion.models.encoders.open_clip_embedder.FrozenOpenCLIPEmbedder
2
+ __init__:
3
+ freeze: True
4
+ layer: "penultimate"
config/unet/inpainting/v1.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __class__: smplfusion.models.unet.UNetModel
2
+ __init__:
3
+ image_size: 32 # unused
4
+ in_channels: 9 # 4 data + 4 downscaled image + 1 mask
5
+ out_channels: 4
6
+ model_channels: 320
7
+ attention_resolutions: [ 4, 2, 1 ]
8
+ num_res_blocks: 2
9
+ channel_mult: [ 1, 2, 4, 4 ]
10
+ num_heads: 8
11
+ use_spatial_transformer: True
12
+ transformer_depth: 1
13
+ context_dim: 768
14
+ use_checkpoint: False
15
+ legacy: False
config/unet/inpainting/v2.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __class__: smplfusion.models.unet.UNetModel
2
+ __init__:
3
+ use_checkpoint: False
4
+ image_size: 32 # unused
5
+ in_channels: 9
6
+ out_channels: 4
7
+ model_channels: 320
8
+ attention_resolutions: [ 4, 2, 1 ]
9
+ num_res_blocks: 2
10
+ channel_mult: [ 1, 2, 4, 4 ]
11
+ num_head_channels: 64 # need to fix for flash-attn
12
+ use_spatial_transformer: True
13
+ use_linear_in_transformer: True
14
+ transformer_depth: 1
15
+ context_dim: 1024
16
+ legacy: False
config/unet/upsample/v2.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __class__: smplfusion.models.unet.UNetModel
2
+ __init__:
3
+ use_checkpoint: False
4
+ num_classes: 1000 # timesteps for noise conditioning (here constant, just need one)
5
+ image_size: 128
6
+ in_channels: 7
7
+ out_channels: 4
8
+ model_channels: 256
9
+ attention_resolutions: [ 2,4,8]
10
+ num_res_blocks: 2
11
+ channel_mult: [ 1, 2, 2, 4]
12
+ disable_self_attentions: [True, True, True, False]
13
+ disable_middle_self_attn: False
14
+ num_heads: 8
15
+ use_spatial_transformer: True
16
+ transformer_depth: 1
17
+ context_dim: 1024
18
+ legacy: False
19
+ use_linear_in_transformer: True
config/vae-upsample.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __class__: smplfusion.models.vae.AutoencoderKL
2
+ __init__:
3
+ embed_dim: 4
4
+ ddconfig:
5
+ double_z: True
6
+ z_channels: 4
7
+ resolution: 256
8
+ in_channels: 3
9
+ out_ch: 3
10
+ ch: 128
11
+ ch_mult: [ 1,2,4 ]
12
+ num_res_blocks: 2
13
+ attn_resolutions: [ ]
14
+ dropout: 0.0
15
+ lossconfig:
16
+ target: torch.nn.Identity
config/vae.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __class__: smplfusion.models.vae.AutoencoderKL
2
+ __init__:
3
+ embed_dim: 4
4
+ monitor: val/rec_loss
5
+ ddconfig:
6
+ double_z: true
7
+ z_channels: 4
8
+ resolution: 256
9
+ in_channels: 3
10
+ out_ch: 3
11
+ ch: 128
12
+ ch_mult: [1,2,4,4]
13
+ num_res_blocks: 2
14
+ attn_resolutions: []
15
+ dropout: 0.0
16
+ lossconfig:
17
+ target: torch.nn.Identity
lib/models/__init__.py DELETED
@@ -1 +0,0 @@
1
- from . import sd2_inp, ds_inp, sd15_inp, sd2_sr, sam
 
 
lib/models/common.py DELETED
@@ -1,49 +0,0 @@
1
- import importlib
2
- import requests
3
- from pathlib import Path
4
- from os.path import dirname
5
-
6
- from omegaconf import OmegaConf
7
- from tqdm import tqdm
8
-
9
-
10
- PROJECT_DIR = dirname(dirname(dirname(__file__)))
11
- CONFIG_FOLDER = f'{PROJECT_DIR}/assets/config'
12
- MODEL_FOLDER = f'{PROJECT_DIR}/assets/models'
13
-
14
-
15
- def download_file(url, save_path, chunk_size=1024):
16
- try:
17
- save_path = Path(save_path)
18
- if save_path.exists():
19
- print(f'{save_path.name} exists')
20
- return
21
- save_path.parent.mkdir(exist_ok=True, parents=True)
22
- resp = requests.get(url, stream=True)
23
- total = int(resp.headers.get('content-length', 0))
24
- with open(save_path, 'wb') as file, tqdm(
25
- desc=save_path.name,
26
- total=total,
27
- unit='iB',
28
- unit_scale=True,
29
- unit_divisor=1024,
30
- ) as bar:
31
- for data in resp.iter_content(chunk_size=chunk_size):
32
- size = file.write(data)
33
- bar.update(size)
34
- print(f'{save_path.name} download finished')
35
- except Exception as e:
36
- raise Exception(f"Download failed: {e}")
37
-
38
-
39
- def get_obj_from_str(string):
40
- module, cls = string.rsplit(".", 1)
41
- try:
42
- return getattr(importlib.import_module(module, package=None), cls)
43
- except:
44
- return getattr(importlib.import_module('lib.' + module, package=None), cls)
45
-
46
-
47
- def load_obj(path):
48
- objyaml = OmegaConf.load(path)
49
- return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {}))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/models/ds_inp.py DELETED
@@ -1,51 +0,0 @@
1
- import importlib
2
- from omegaconf import OmegaConf
3
- import torch
4
- import safetensors
5
- import safetensors.torch
6
-
7
- from lib.smplfusion import DDIM, share, scheduler
8
- from .common import *
9
-
10
-
11
- MODEL_PATH = f'{MODEL_FOLDER}/dreamshaper/dreamshaper_8Inpainting.safetensors'
12
- DOWNLOAD_URL = 'https://civitai.com/api/download/models/131004'
13
-
14
- # pre-download
15
- download_file(DOWNLOAD_URL, MODEL_PATH)
16
-
17
-
18
- def load_model(dtype=torch.float16):
19
- print ("Loading model: Dreamshaper Inpainting V8")
20
-
21
- download_file(DOWNLOAD_URL, MODEL_PATH)
22
-
23
- state_dict = safetensors.torch.load_file(MODEL_PATH)
24
-
25
- config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml')
26
- unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v1.yaml').eval().cuda()
27
- vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda()
28
- encoder = load_obj(f'{CONFIG_FOLDER}/encoders/clip.yaml').eval().cuda()
29
-
30
- extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
31
- unet_state = extract(state_dict, 'model.diffusion_model')
32
- encoder_state = extract(state_dict, 'cond_stage_model')
33
- vae_state = extract(state_dict, 'first_stage_model')
34
-
35
- unet.load_state_dict(unet_state)
36
- encoder.load_state_dict(encoder_state)
37
- vae.load_state_dict(vae_state)
38
-
39
- if dtype == torch.float16:
40
- unet.convert_to_fp16()
41
- vae.to(dtype)
42
- encoder.to(dtype)
43
-
44
- unet = unet.requires_grad_(False)
45
- encoder = encoder.requires_grad_(False)
46
- vae = vae.requires_grad_(False)
47
-
48
- ddim = DDIM(config, vae, encoder, unet)
49
- share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end)
50
-
51
- return ddim