Spaces:
Running
Running
File size: 4,343 Bytes
01e655b f83b1b7 02e90e4 374f426 01e655b 02e90e4 374f426 01e655b bf13828 01e655b 02e90e4 bf13828 02e90e4 da8d589 374f426 01e655b c59d697 bf13828 01e655b 374f426 02e90e4 01e655b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import os
import logging
# logging.basicConfig(
# level=os.getenv("LOG_LEVEL", "INFO"),
# format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
# )
from modules.devices import devices
from modules.utils import env
from modules.webui import webui_config
from modules.webui.app import webui_init, create_interface
from modules import generate_audio
from modules import config
if __name__ == "__main__":
import argparse
import dotenv
dotenv.load_dotenv(
dotenv_path=os.getenv("ENV_FILE", ".env.webui"),
)
parser = argparse.ArgumentParser(description="Gradio App")
parser.add_argument("--server_name", type=str, help="server name")
parser.add_argument("--server_port", type=int, help="server port")
parser.add_argument(
"--share", action="store_true", help="share the gradio interface"
)
parser.add_argument("--debug", action="store_true", help="enable debug mode")
parser.add_argument("--auth", type=str, help="username:password for authentication")
parser.add_argument(
"--half",
action="store_true",
help="Enable half precision for model inference",
)
parser.add_argument(
"--off_tqdm",
action="store_true",
help="Disable tqdm progress bar",
)
parser.add_argument(
"--tts_max_len",
type=int,
help="Max length of text for TTS",
)
parser.add_argument(
"--ssml_max_len",
type=int,
help="Max length of text for SSML",
)
parser.add_argument(
"--max_batch_size",
type=int,
help="Max batch size for TTS",
)
parser.add_argument(
"--lru_size",
type=int,
default=64,
help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
)
parser.add_argument(
"--device_id",
type=str,
help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
default=None,
)
parser.add_argument(
"--use_cpu",
nargs="+",
help="use CPU as torch device for specified modules",
default=[],
type=str.lower,
)
parser.add_argument("--compile", action="store_true", help="Enable model compile")
# webui_Experimental
parser.add_argument(
"--webui_experimental",
action="store_true",
help="Enable webui_experimental features",
)
parser.add_argument(
"--language",
type=str,
default="zh-CN",
help="Set the default language for the webui",
)
args = parser.parse_args()
def get_and_update_env(*args):
val = env.get_env_or_arg(*args)
key = args[1]
config.runtime_env_vars[key] = val
return val
server_name = get_and_update_env(args, "server_name", "0.0.0.0", str)
server_port = get_and_update_env(args, "server_port", 7860, int)
share = get_and_update_env(args, "share", False, bool)
debug = get_and_update_env(args, "debug", False, bool)
auth = get_and_update_env(args, "auth", None, str)
half = get_and_update_env(args, "half", False, bool)
off_tqdm = get_and_update_env(args, "off_tqdm", False, bool)
lru_size = get_and_update_env(args, "lru_size", 64, int)
device_id = get_and_update_env(args, "device_id", None, str)
use_cpu = get_and_update_env(args, "use_cpu", [], list)
compile = get_and_update_env(args, "compile", False, bool)
language = get_and_update_env(args, "language", False, bool)
webui_config.experimental = get_and_update_env(
args, "webui_experimental", False, bool
)
webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int)
webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int)
webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int)
config.runtime_env_vars.language = "zh-CN"
webui_init()
demo = create_interface()
if auth:
auth = tuple(auth.split(":"))
generate_audio.setup_lru_cache()
devices.reset_device()
devices.first_time_calculation()
demo.queue().launch(
server_name=server_name,
server_port=server_port,
share=share,
debug=debug,
auth=auth,
show_api=False,
)
|