Upload 2 files
Browse files
mod.py
CHANGED
@@ -71,15 +71,17 @@ def get_repo_safetensors(repo_id: str):
|
|
71 |
# Initialize the base model
|
72 |
base_model = models[0]
|
73 |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
|
74 |
-
|
75 |
|
76 |
def change_base_model(repo_id: str, progress=gr.Progress(track_tqdm=True)):
|
77 |
global pipe
|
|
|
78 |
try:
|
79 |
-
if not is_repo_name(repo_id) or not is_repo_exists(repo_id): return
|
80 |
progress(0, f"Loading model: {repo_id}")
|
81 |
clear_cache()
|
82 |
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
|
|
|
83 |
progress(1, f"Model loaded: {repo_id}")
|
84 |
except Exception as e:
|
85 |
print(e)
|
@@ -135,6 +137,7 @@ def fuse_loras(pipe, lorajson: list[dict]):
|
|
135 |
#pipe.unload_lora_weights()
|
136 |
|
137 |
|
|
|
138 |
fuse_loras.zerogpu = True
|
139 |
|
140 |
|
|
|
71 |
# Initialize the base model
|
72 |
base_model = models[0]
|
73 |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
|
74 |
+
last_model = models[0]
|
75 |
|
76 |
def change_base_model(repo_id: str, progress=gr.Progress(track_tqdm=True)):
|
77 |
global pipe
|
78 |
+
global last_model
|
79 |
try:
|
80 |
+
if repo_id == last_model or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return
|
81 |
progress(0, f"Loading model: {repo_id}")
|
82 |
clear_cache()
|
83 |
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
|
84 |
+
last_model = repo_id
|
85 |
progress(1, f"Model loaded: {repo_id}")
|
86 |
except Exception as e:
|
87 |
print(e)
|
|
|
137 |
#pipe.unload_lora_weights()
|
138 |
|
139 |
|
140 |
+
change_base_model.zerogpu = True
|
141 |
fuse_loras.zerogpu = True
|
142 |
|
143 |
|