Spaces:
Paused
Paused
# Copy from https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/server/__main__.py | |
# pyright: reportAssignmentType=false | |
from __future__ import annotations | |
import os | |
import sys | |
import argparse | |
import uvicorn | |
from llama_cpp.server.app import create_app | |
from llama_cpp.server.settings import ( | |
Settings, | |
ServerSettings, | |
ModelSettings, | |
ConfigFileSettings, | |
) | |
from llama_cpp.server.cli import add_args_from_model, parse_model_from_args | |
def main(): | |
description = "🦙 Llama.cpp python server. Host your own LLMs!🚀" | |
parser = argparse.ArgumentParser(description=description) | |
add_args_from_model(parser, Settings) | |
parser.add_argument( | |
"--config_file", | |
type=str, | |
help="Path to a config file to load.", | |
) | |
server_settings: ServerSettings | None = None | |
model_settings: list[ModelSettings] = [] | |
args = parser.parse_args() | |
try: | |
# Load server settings from config_file if provided | |
config_file = os.environ.get("CONFIG_FILE", args.config_file) | |
if config_file: | |
if not os.path.exists(config_file): | |
raise ValueError(f"Config file {config_file} not found!") | |
with open(config_file, "rb") as f: | |
config_file_settings = ConfigFileSettings.model_validate_json(f.read()) | |
server_settings = ServerSettings.model_validate(config_file_settings) | |
model_settings = config_file_settings.models | |
else: | |
server_settings = parse_model_from_args(ServerSettings, args) | |
model_settings = [parse_model_from_args(ModelSettings, args)] | |
except Exception as e: | |
print(e, file=sys.stderr) | |
parser.print_help() | |
sys.exit(1) | |
assert server_settings is not None | |
assert model_settings is not None | |
app = create_app( | |
server_settings=server_settings, | |
model_settings=model_settings, | |
) | |
# Remove X-Request-Id plugin: https://github.com/abetlen/llama-cpp-python/issues/1337 | |
app.user_middleware.pop() | |
uvicorn.run( | |
app, | |
host=os.getenv("HOST", server_settings.host), | |
port=int(os.getenv("PORT", server_settings.port)), | |
ssl_keyfile=server_settings.ssl_keyfile, | |
ssl_certfile=server_settings.ssl_certfile, | |
) | |
if __name__ == "__main__": | |
main() |