Spaces:
Running
on
A100
Running
on
A100
Update app.py
Browse files
app.py
CHANGED
@@ -133,8 +133,8 @@ models_rbm.generator.eval().requires_grad_(False)
|
|
133 |
def infer(ref_style_file, style_description, caption, use_low_vram, progress):
|
134 |
global models_rbm, models_b, device
|
135 |
|
136 |
-
|
137 |
-
|
138 |
try:
|
139 |
|
140 |
caption = f"{caption} in {style_description}"
|
@@ -234,6 +234,8 @@ def infer(ref_style_file, style_description, caption, use_low_vram, progress):
|
|
234 |
return sampled_image # Return the sampled_image PIL image
|
235 |
|
236 |
finally:
|
|
|
|
|
237 |
# Clear CUDA cache
|
238 |
torch.cuda.empty_cache()
|
239 |
gc.collect()
|
@@ -241,10 +243,9 @@ def infer(ref_style_file, style_description, caption, use_low_vram, progress):
|
|
241 |
def infer_compo(style_description, ref_style_file, caption, ref_sub_file, use_low_vram, progress):
|
242 |
global models_rbm, models_b, device
|
243 |
sam_model = LangSAM()
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
models_to(sam_model.sam, device=device)
|
248 |
try:
|
249 |
caption = f"{caption} in {style_description}"
|
250 |
sam_prompt = f"{caption}"
|
@@ -361,6 +362,10 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file, use_lo
|
|
361 |
return sampled_image # Return the sampled_image PIL image
|
362 |
|
363 |
finally:
|
|
|
|
|
|
|
|
|
364 |
# Clear CUDA cache
|
365 |
torch.cuda.empty_cache()
|
366 |
gc.collect()
|
|
|
133 |
def infer(ref_style_file, style_description, caption, use_low_vram, progress):
|
134 |
global models_rbm, models_b, device
|
135 |
|
136 |
+
models_to(models_rbm, device=device)
|
137 |
+
|
138 |
try:
|
139 |
|
140 |
caption = f"{caption} in {style_description}"
|
|
|
234 |
return sampled_image # Return the sampled_image PIL image
|
235 |
|
236 |
finally:
|
237 |
+
if use_low_vram:
|
238 |
+
models_to(models_rbm, device=device)
|
239 |
# Clear CUDA cache
|
240 |
torch.cuda.empty_cache()
|
241 |
gc.collect()
|
|
|
243 |
def infer_compo(style_description, ref_style_file, caption, ref_sub_file, use_low_vram, progress):
|
244 |
global models_rbm, models_b, device
|
245 |
sam_model = LangSAM()
|
246 |
+
models_to(models_rbm, device=device)
|
247 |
+
models_to(sam_model, device=device)
|
248 |
+
models_to(sam_model.sam, device=device)
|
|
|
249 |
try:
|
250 |
caption = f"{caption} in {style_description}"
|
251 |
sam_prompt = f"{caption}"
|
|
|
362 |
return sampled_image # Return the sampled_image PIL image
|
363 |
|
364 |
finally:
|
365 |
+
if use_low_vram:
|
366 |
+
models_to(models_rbm, device=device, excepts=["generator", "previewer"])
|
367 |
+
models_to(sam_model, device=device)
|
368 |
+
models_to(sam_model.sam, device=device)
|
369 |
# Clear CUDA cache
|
370 |
torch.cuda.empty_cache()
|
371 |
gc.collect()
|