chattts / webui.py
zhzluke96
update
da8d589
raw
history blame
3.92 kB
import os
from modules.devices import devices
from modules.utils import env
from modules.webui import webui_config
from modules.webui.app import webui_init, create_interface
from modules import generate_audio
from modules import config
if __name__ == "__main__":
import argparse
import dotenv
dotenv.load_dotenv(
dotenv_path=os.getenv("ENV_FILE", ".env.webui"),
)
parser = argparse.ArgumentParser(description="Gradio App")
parser.add_argument("--server_name", type=str, help="server name")
parser.add_argument("--server_port", type=int, help="server port")
parser.add_argument(
"--share", action="store_true", help="share the gradio interface"
)
parser.add_argument("--debug", action="store_true", help="enable debug mode")
parser.add_argument("--auth", type=str, help="username:password for authentication")
parser.add_argument(
"--half",
action="store_true",
help="Enable half precision for model inference",
)
parser.add_argument(
"--off_tqdm",
action="store_true",
help="Disable tqdm progress bar",
)
parser.add_argument(
"--tts_max_len",
type=int,
help="Max length of text for TTS",
)
parser.add_argument(
"--ssml_max_len",
type=int,
help="Max length of text for SSML",
)
parser.add_argument(
"--max_batch_size",
type=int,
help="Max batch size for TTS",
)
parser.add_argument(
"--lru_size",
type=int,
default=64,
help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
)
parser.add_argument(
"--device_id",
type=str,
help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
default=None,
)
parser.add_argument(
"--use_cpu",
nargs="+",
help="use CPU as torch device for specified modules",
default=[],
type=str.lower,
)
parser.add_argument("--compile", action="store_true", help="Enable model compile")
# webui_Experimental
parser.add_argument(
"--webui_experimental",
action="store_true",
help="Enable webui_experimental features",
)
args = parser.parse_args()
def get_and_update_env(*args):
val = env.get_env_or_arg(*args)
key = args[1]
config.runtime_env_vars[key] = val
return val
server_name = get_and_update_env(args, "server_name", "0.0.0.0", str)
server_port = get_and_update_env(args, "server_port", 7860, int)
share = get_and_update_env(args, "share", False, bool)
debug = get_and_update_env(args, "debug", False, bool)
auth = get_and_update_env(args, "auth", None, str)
half = get_and_update_env(args, "half", False, bool)
off_tqdm = get_and_update_env(args, "off_tqdm", False, bool)
lru_size = get_and_update_env(args, "lru_size", 64, int)
device_id = get_and_update_env(args, "device_id", None, str)
use_cpu = get_and_update_env(args, "use_cpu", [], list)
compile = get_and_update_env(args, "compile", False, bool)
webui_config.experimental = get_and_update_env(
args, "webui_experimental", False, bool
)
webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int)
webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int)
webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int)
demo = create_interface()
if auth:
auth = tuple(auth.split(":"))
generate_audio.setup_lru_cache()
devices.reset_device()
devices.first_time_calculation()
webui_init()
demo.queue().launch(
server_name=server_name,
server_port=server_port,
share=share,
debug=debug,
auth=auth,
show_api=False,
)