Spaces:
Running
on
Zero
Running
on
Zero
import asyncio | |
import os | |
import json | |
import shutil | |
import inspect | |
import aiohttp | |
from server import PromptServer | |
from tqdm import tqdm | |
config = None | |
def is_logging_enabled(): | |
config = get_extension_config() | |
if "logging" not in config: | |
return False | |
return config["logging"] | |
def log(message, type=None, always=False, name=None): | |
if not always and not is_logging_enabled(): | |
return | |
if type is not None: | |
message = f"[{type}] {message}" | |
if name is None: | |
name = get_extension_config()["name"] | |
print(f"(pysssss:{name}) {message}") | |
def get_ext_dir(subpath=None, mkdir=False): | |
dir = os.path.dirname(__file__) | |
if subpath is not None: | |
dir = os.path.join(dir, subpath) | |
dir = os.path.abspath(dir) | |
if mkdir and not os.path.exists(dir): | |
os.makedirs(dir) | |
return dir | |
def get_comfy_dir(subpath=None, mkdir=False): | |
dir = os.path.dirname(inspect.getfile(PromptServer)) | |
if subpath is not None: | |
dir = os.path.join(dir, subpath) | |
dir = os.path.abspath(dir) | |
if mkdir and not os.path.exists(dir): | |
os.makedirs(dir) | |
return dir | |
def get_web_ext_dir(): | |
config = get_extension_config() | |
name = config["name"] | |
dir = get_comfy_dir("web/extensions/pysssss") | |
if not os.path.exists(dir): | |
os.makedirs(dir) | |
dir = os.path.join(dir, name) | |
return dir | |
def get_extension_config(reload=False): | |
global config | |
if reload == False and config is not None: | |
return config | |
config_path = get_ext_dir("pysssss.json") | |
default_config_path = get_ext_dir("pysssss.default.json") | |
if not os.path.exists(config_path): | |
if os.path.exists(default_config_path): | |
shutil.copy(default_config_path, config_path) | |
if not os.path.exists(config_path): | |
log(f"Failed to create config at {config_path}", type="ERROR", always=True, name="???") | |
print(f"Extension path: {get_ext_dir()}") | |
return {"name": "Unknown", "version": -1} | |
else: | |
log("Missing pysssss.default.json, this extension may not work correctly. Please reinstall the extension.", | |
type="ERROR", always=True, name="???") | |
print(f"Extension path: {get_ext_dir()}") | |
return {"name": "Unknown", "version": -1} | |
with open(config_path, "r") as f: | |
config = json.loads(f.read()) | |
return config | |
def link_js(src, dst): | |
src = os.path.abspath(src) | |
dst = os.path.abspath(dst) | |
if os.name == "nt": | |
try: | |
import _winapi | |
_winapi.CreateJunction(src, dst) | |
return True | |
except: | |
pass | |
try: | |
os.symlink(src, dst) | |
return True | |
except: | |
import logging | |
logging.exception('') | |
return False | |
def is_junction(path): | |
if os.name != "nt": | |
return False | |
try: | |
return bool(os.readlink(path)) | |
except OSError: | |
return False | |
def install_js(): | |
src_dir = get_ext_dir("web/js") | |
if not os.path.exists(src_dir): | |
log("No JS") | |
return | |
should_install = should_install_js() | |
if should_install: | |
log("it looks like you're running an old version of ComfyUI that requires manual setup of web files, it is recommended you update your installation.", "warning", True) | |
dst_dir = get_web_ext_dir() | |
linked = os.path.islink(dst_dir) or is_junction(dst_dir) | |
if linked or os.path.exists(dst_dir): | |
if linked: | |
if should_install: | |
log("JS already linked") | |
else: | |
os.unlink(dst_dir) | |
log("JS unlinked, PromptServer will serve extension") | |
elif not should_install: | |
shutil.rmtree(dst_dir) | |
log("JS deleted, PromptServer will serve extension") | |
return | |
if not should_install: | |
log("JS skipped, PromptServer will serve extension") | |
return | |
if link_js(src_dir, dst_dir): | |
log("JS linked") | |
return | |
log("Copying JS files") | |
shutil.copytree(src_dir, dst_dir, dirs_exist_ok=True) | |
def should_install_js(): | |
return not hasattr(PromptServer.instance, "supports") or "custom_nodes_from_web" not in PromptServer.instance.supports | |
def init(check_imports=None): | |
log("Init") | |
if check_imports is not None: | |
import importlib.util | |
for imp in check_imports: | |
spec = importlib.util.find_spec(imp) | |
if spec is None: | |
log(f"{imp} is required, please check requirements are installed.", | |
type="ERROR", always=True) | |
return False | |
install_js() | |
return True | |
def get_async_loop(): | |
loop = None | |
try: | |
loop = asyncio.get_event_loop() | |
except: | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
return loop | |
def get_http_session(): | |
loop = get_async_loop() | |
return aiohttp.ClientSession(loop=loop) | |
async def download(url, stream, update_callback=None, session=None): | |
close_session = False | |
if session is None: | |
close_session = True | |
session = get_http_session() | |
try: | |
async with session.get(url) as response: | |
size = int(response.headers.get('content-length', 0)) or None | |
with tqdm( | |
unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1], total=size, | |
) as progressbar: | |
perc = 0 | |
async for chunk in response.content.iter_chunked(2048): | |
stream.write(chunk) | |
progressbar.update(len(chunk)) | |
if update_callback is not None and progressbar.total is not None and progressbar.total != 0: | |
last = perc | |
perc = round(progressbar.n / progressbar.total, 2) | |
if perc != last: | |
last = perc | |
await update_callback(perc) | |
finally: | |
if close_session and session is not None: | |
await session.close() | |
async def download_to_file(url, destination, update_callback=None, is_ext_subpath=True, session=None): | |
if is_ext_subpath: | |
destination = get_ext_dir(destination) | |
with open(destination, mode='wb') as f: | |
download(url, f, update_callback, session) | |
def wait_for_async(async_fn, loop=None): | |
res = [] | |
async def run_async(): | |
r = await async_fn() | |
res.append(r) | |
if loop is None: | |
try: | |
loop = asyncio.get_event_loop() | |
except: | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
loop.run_until_complete(run_async()) | |
return res[0] | |
def update_node_status(client_id, node, text, progress=None): | |
if client_id is None: | |
client_id = PromptServer.instance.client_id | |
if client_id is None: | |
return | |
PromptServer.instance.send_sync("pysssss/update_status", { | |
"node": node, | |
"progress": progress, | |
"text": text | |
}, client_id) | |
async def update_node_status_async(client_id, node, text, progress=None): | |
if client_id is None: | |
client_id = PromptServer.instance.client_id | |
if client_id is None: | |
return | |
await PromptServer.instance.send("pysssss/update_status", { | |
"node": node, | |
"progress": progress, | |
"text": text | |
}, client_id) | |
def get_config_value(key, default=None, throw=False): | |
split = key.split(".") | |
obj = get_extension_config() | |
for s in split: | |
if s in obj: | |
obj = obj[s] | |
else: | |
if throw: | |
raise KeyError("Configuration key missing: " + key) | |
else: | |
return default | |
return obj | |
def is_inside_dir(root_dir, check_path): | |
root_dir = os.path.abspath(root_dir) | |
if not os.path.isabs(check_path): | |
check_path = os.path.abspath(os.path.join(root_dir, check_path)) | |
return os.path.commonpath([check_path, root_dir]) == root_dir | |
def get_child_dir(root_dir, child_path, throw_if_outside=True): | |
child_path = os.path.abspath(os.path.join(root_dir, child_path)) | |
if is_inside_dir(root_dir, child_path): | |
return child_path | |
if throw_if_outside: | |
raise NotADirectoryError( | |
"Saving outside the target folder is not allowed.") | |
return None | |