|
import spaces
|
|
import os
|
|
import gradio as gr
|
|
import json
|
|
import logging
|
|
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
|
import diffusers
|
|
diffusers.utils.logging.set_verbosity(40)
|
|
import warnings
|
|
warnings.filterwarnings(action="ignore", category=FutureWarning, module="diffusers")
|
|
warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers")
|
|
warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
|
|
from pathlib import Path
|
|
from env import (
|
|
hf_token,
|
|
hf_read_token,
|
|
CIVITAI_API_KEY,
|
|
HF_LORA_PRIVATE_REPOS1,
|
|
HF_LORA_PRIVATE_REPOS2,
|
|
HF_LORA_ESSENTIAL_PRIVATE_REPO,
|
|
HF_VAE_PRIVATE_REPO,
|
|
directory_models,
|
|
directory_loras,
|
|
directory_vaes,
|
|
download_model_list,
|
|
download_lora_list,
|
|
download_vae_list,
|
|
)
|
|
from modutils import (
|
|
to_list,
|
|
list_uniq,
|
|
list_sub,
|
|
get_lora_model_list,
|
|
download_private_repo,
|
|
safe_float,
|
|
escape_lora_basename,
|
|
to_lora_key,
|
|
to_lora_path,
|
|
get_local_model_list,
|
|
get_private_lora_model_lists,
|
|
get_valid_lora_name,
|
|
get_valid_lora_path,
|
|
get_valid_lora_wt,
|
|
get_lora_info,
|
|
normalize_prompt_list,
|
|
get_civitai_info,
|
|
search_lora_on_civitai,
|
|
)
|
|
|
|
|
|
def download_things(directory, url, hf_token="", civitai_api_key=""):
|
|
url = url.strip()
|
|
|
|
if "drive.google.com" in url:
|
|
original_dir = os.getcwd()
|
|
os.chdir(directory)
|
|
os.system(f"gdown --fuzzy {url}")
|
|
os.chdir(original_dir)
|
|
elif "huggingface.co" in url:
|
|
url = url.replace("?download=true", "")
|
|
|
|
if "/blob/" in url:
|
|
url = url.replace("/blob/", "/resolve/")
|
|
user_header = f'"Authorization: Bearer {hf_token}"'
|
|
if hf_token:
|
|
os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
|
|
else:
|
|
os.system (f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
|
|
elif "civitai.com" in url:
|
|
if "?" in url:
|
|
url = url.split("?")[0]
|
|
if civitai_api_key:
|
|
url = url + f"?token={civitai_api_key}"
|
|
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
|
|
else:
|
|
print("\033[91mYou need an API key to download Civitai models.\033[0m")
|
|
else:
|
|
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
|
|
|
|
|
|
def get_model_list(directory_path):
|
|
model_list = []
|
|
valid_extensions = {'.ckpt' , '.pt', '.pth', '.safetensors', '.bin'}
|
|
|
|
for filename in os.listdir(directory_path):
|
|
if os.path.splitext(filename)[1] in valid_extensions:
|
|
name_without_extension = os.path.splitext(filename)[0]
|
|
file_path = os.path.join(directory_path, filename)
|
|
|
|
model_list.append(file_path)
|
|
print('\033[34mFILE: ' + file_path + '\033[0m')
|
|
return model_list
|
|
|
|
|
|
|
|
download_model = ", ".join(download_model_list)
|
|
|
|
download_vae = ", ".join(download_vae_list)
|
|
|
|
download_lora = ", ".join(download_lora_list)
|
|
|
|
|
|
|
|
|
|
CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
|
|
hf_token = os.environ.get("HF_TOKEN")
|
|
|
|
|
|
for url in [url.strip() for url in download_model.split(',')]:
|
|
if not os.path.exists(f"./models/{url.split('/')[-1]}"):
|
|
download_things(directory_models, url, hf_token, CIVITAI_API_KEY)
|
|
for url in [url.strip() for url in download_vae.split(',')]:
|
|
if not os.path.exists(f"./vaes/{url.split('/')[-1]}"):
|
|
download_things(directory_vaes, url, hf_token, CIVITAI_API_KEY)
|
|
for url in [url.strip() for url in download_lora.split(',')]:
|
|
if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
|
|
download_things(directory_loras, url, hf_token, CIVITAI_API_KEY)
|
|
|
|
lora_model_list = get_lora_model_list()
|
|
vae_model_list = get_model_list(directory_vaes)
|
|
vae_model_list.insert(0, "None")
|
|
|
|
|
|
def get_t2i_model_info(repo_id: str):
|
|
from huggingface_hub import HfApi
|
|
api = HfApi()
|
|
try:
|
|
if " " in repo_id or not api.repo_exists(repo_id): return ""
|
|
model = api.model_info(repo_id=repo_id)
|
|
except Exception as e:
|
|
print(f"Error: Failed to get {repo_id}'s info. ")
|
|
print(e)
|
|
return ""
|
|
if model.private or model.gated: return ""
|
|
tags = model.tags
|
|
info = []
|
|
url = f"https://huggingface.co/{repo_id}/"
|
|
if not 'diffusers' in tags: return ""
|
|
if 'diffusers:StableDiffusionXLPipeline' in tags:
|
|
info.append("SDXL")
|
|
elif 'diffusers:StableDiffusionPipeline' in tags:
|
|
info.append("SD1.5")
|
|
if model.card_data and model.card_data.tags:
|
|
info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
|
|
info.append(f"DLs: {model.downloads}")
|
|
info.append(f"likes: {model.likes}")
|
|
info.append(model.last_modified.strftime("lastmod: %Y-%m-%d"))
|
|
md = f"Model Info: {', '.join(info)}, [Model Repo]({url})"
|
|
return gr.update(value=md)
|
|
|
|
|
|
private_lora_dict = {"": ["", "", "", "", ""]}
|
|
try:
|
|
with open('lora_dict.json', encoding='utf-8') as f:
|
|
d = json.load(f)
|
|
for k, v in d.items():
|
|
private_lora_dict[escape_lora_basename(k)] = v
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
private_lora_model_list = get_private_lora_model_lists()
|
|
loras_dict = {"None": ["", "", "", "", ""], "": ["", "", "", "", ""]} | private_lora_dict.copy()
|
|
loras_url_to_path_dict = {}
|
|
civitai_lora_last_results = {}
|
|
all_lora_list = []
|
|
|
|
|
|
def get_all_lora_list():
|
|
global all_lora_list
|
|
loras = get_lora_model_list()
|
|
all_lora_list = loras.copy()
|
|
return loras
|
|
|
|
|
|
def get_all_lora_tupled_list():
|
|
global loras_dict
|
|
models = get_all_lora_list()
|
|
if not models: return []
|
|
tupled_list = []
|
|
for model in models:
|
|
|
|
basename = Path(model).stem
|
|
key = to_lora_key(model)
|
|
items = None
|
|
if key in loras_dict.keys():
|
|
items = loras_dict.get(key, None)
|
|
else:
|
|
items = get_civitai_info(model)
|
|
if items != None:
|
|
loras_dict[key] = items
|
|
name = basename
|
|
value = model
|
|
if items and items[2] != "":
|
|
if items[1] == "Pony":
|
|
name = f"{basename} (for {items[1]}🐴, {items[2]})"
|
|
else:
|
|
name = f"{basename} (for {items[1]}, {items[2]})"
|
|
tupled_list.append((name, value))
|
|
return tupled_list
|
|
|
|
|
|
def update_lora_dict(path: str):
|
|
global loras_dict
|
|
key = to_lora_key(path)
|
|
if key in loras_dict.keys(): return
|
|
items = get_civitai_info(path)
|
|
if items == None: return
|
|
loras_dict[key] = items
|
|
|
|
|
|
def download_lora(dl_urls: str):
|
|
global loras_url_to_path_dict
|
|
dl_path = ""
|
|
before = get_local_model_list(directory_loras)
|
|
urls = []
|
|
for url in [url.strip() for url in dl_urls.split(',')]:
|
|
local_path = f"{directory_loras}/{url.split('/')[-1]}"
|
|
if not Path(local_path).exists():
|
|
download_things(directory_loras, url, hf_token, CIVITAI_API_KEY)
|
|
urls.append(url)
|
|
after = get_local_model_list(directory_loras)
|
|
new_files = list_sub(after, before)
|
|
for i, file in enumerate(new_files):
|
|
path = Path(file)
|
|
if path.exists():
|
|
new_path = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
|
|
path.resolve().rename(new_path.resolve())
|
|
loras_url_to_path_dict[urls[i]] = str(new_path)
|
|
update_lora_dict(str(new_path))
|
|
dl_path = str(new_path)
|
|
return dl_path
|
|
|
|
|
|
def copy_lora(path: str, new_path: str):
|
|
import shutil
|
|
if path == new_path: return new_path
|
|
cpath = Path(path)
|
|
npath = Path(new_path)
|
|
if cpath.exists():
|
|
try:
|
|
shutil.copy(str(cpath.resolve()), str(npath.resolve()))
|
|
except Exception:
|
|
return None
|
|
update_lora_dict(str(npath))
|
|
return new_path
|
|
else:
|
|
return None
|
|
|
|
|
|
def download_my_lora(dl_urls: str, lora: str):
|
|
path = download_lora(dl_urls)
|
|
if path: lora = path
|
|
choices = get_all_lora_tupled_list()
|
|
return gr.update(value=lora, choices=choices)
|
|
|
|
|
|
def apply_lora_prompt(lora_info: str):
|
|
if lora_info == "None": return ""
|
|
lora_tag = lora_info.replace("/",",")
|
|
lora_tags = lora_tag.split(",") if str(lora_info) != "None" else []
|
|
lora_prompts = normalize_prompt_list(lora_tags)
|
|
prompt = ", ".join(list_uniq(lora_prompts))
|
|
return prompt
|
|
|
|
|
|
def update_loras(prompt, lora, lora_wt):
|
|
import re
|
|
on, label, tag, md = get_lora_info(lora)
|
|
prompts = prompt.split(",") if prompt else []
|
|
output_prompts = []
|
|
for p in prompts:
|
|
p = str(p).strip()
|
|
if "<lora" in p:
|
|
result = re.findall(r'<lora:(.+?):(.+?)>', p)
|
|
if not result: continue
|
|
key = result[0][0]
|
|
wt = result[0][1]
|
|
path = to_lora_path(key)
|
|
if not key in loras_dict.keys() or not path: continue
|
|
if Path(path).exists(): output_prompts.append(f"<lora:{to_lora_key(path)}:{safe_float(wt):.2f}>")
|
|
elif p:
|
|
output_prompts.append(p)
|
|
lora_prompts = []
|
|
if on: lora_prompts.append(f"<lora:{to_lora_key(lora)}:{lora_wt:.2f}>")
|
|
output_prompt = ", ".join(list_uniq(output_prompts + lora_prompts))
|
|
choices = get_all_lora_tupled_list()
|
|
return gr.update(value=output_prompt), gr.update(value=lora, choices=choices), gr.update(value=lora_wt),\
|
|
gr.update(value=tag, label=label, visible=on), gr.update(visible=on), gr.update(value=md, visible=on)
|
|
|
|
|
|
def search_civitai_lora(query, base_model):
|
|
global civitai_lora_last_results
|
|
items = search_lora_on_civitai(query, base_model)
|
|
if not items: return gr.update(choices=[("", "")], value="", visible=False),\
|
|
gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
|
|
civitai_lora_last_results = {}
|
|
choices = []
|
|
for item in items:
|
|
base_model_name = "Pony🐴" if item['base_model'] == "Pony" else item['base_model']
|
|
name = f"{item['name']} (for {base_model_name} / By: {item['creator']} / Tags: {', '.join(item['tags'])})"
|
|
value = item['dl_url']
|
|
choices.append((name, value))
|
|
civitai_lora_last_results[value] = item
|
|
if not choices: return gr.update(choices=[("", "")], value="", visible=False),\
|
|
gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
|
|
result = civitai_lora_last_results.get(choices[0][1], "None")
|
|
md = result['md'] if result else ""
|
|
return gr.update(choices=choices, value=choices[0][1], visible=True), gr.update(value=md, visible=True),\
|
|
gr.update(visible=True), gr.update(visible=True)
|
|
|
|
|
|
def select_civitai_lora(search_result):
|
|
if not "http" in search_result: return gr.update(value=""), gr.update(value="None", visible=True)
|
|
result = civitai_lora_last_results.get(search_result, "None")
|
|
md = result['md'] if result else ""
|
|
return gr.update(value=search_result), gr.update(value=md, visible=True)
|
|
|
|
|
|
def search_civitai_lora_json(query, base_model):
|
|
results = {}
|
|
items = search_lora_on_civitai(query, base_model)
|
|
if not items: return gr.update(value=results)
|
|
for item in items:
|
|
results[item['dl_url']] = item
|
|
return gr.update(value=results)
|
|
|
|
|