# Copyright 2024 the LlamaFactory team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import random import subprocess import sys from enum import Enum, unique from . import launcher from .api.app import run_api from .chat.chat_model import run_chat from .eval.evaluator import run_eval from .extras.env import VERSION, print_env from .extras.logging import get_logger from .extras.misc import get_device_count from .train.tuner import export_model, run_exp from .webui.interface import run_web_demo, run_web_ui USAGE = ( "-" * 70 + "\n" + "| Usage: |\n" + "| llamafactory-cli api -h: launch an OpenAI-style API server |\n" + "| llamafactory-cli chat -h: launch a chat interface in CLI |\n" + "| llamafactory-cli eval -h: evaluate models |\n" + "| llamafactory-cli export -h: merge LoRA adapters and export model |\n" + "| llamafactory-cli train -h: train models |\n" + "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n" + "| llamafactory-cli webui: launch LlamaBoard |\n" + "| llamafactory-cli version: show version info |\n" + "-" * 70 ) WELCOME = ( "-" * 58 + "\n" + "| Welcome to LLaMA Factory, version {}".format(VERSION) + " " * (21 - len(VERSION)) + "|\n|" + " " * 56 + "|\n" + "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n" + "-" * 58 ) logger = get_logger(__name__) @unique class Command(str, Enum): API = "api" CHAT = "chat" ENV = "env" EVAL = "eval" EXPORT = "export" TRAIN = "train" WEBDEMO = "webchat" WEBUI = "webui" VER = "version" HELP = "help" def main(): command = sys.argv.pop(1) if command == Command.API: run_api() elif command == Command.CHAT: run_chat() elif command == Command.ENV: print_env() elif command == Command.EVAL: run_eval() elif command == Command.EXPORT: export_model() elif command == Command.TRAIN: force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"] if force_torchrun or get_device_count() > 1: master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999))) logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port)) subprocess.run( ( "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} " "--master_addr {master_addr} --master_port {master_port} {file_name} {args}" ).format( nnodes=os.environ.get("NNODES", "1"), node_rank=os.environ.get("RANK", "0"), nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())), master_addr=master_addr, master_port=master_port, file_name=launcher.__file__, args=" ".join(sys.argv[1:]), ), shell=True, ) else: run_exp() elif command == Command.WEBDEMO: run_web_demo() elif command == Command.WEBUI: run_web_ui() elif command == Command.VER: print(WELCOME) elif command == Command.HELP: print(USAGE) else: raise NotImplementedError("Unknown command: {}".format(command))