Spaces:
Runtime error
Runtime error
""" | |
A model worker that executes the model based on vLLM. | |
See documentations at docs/vllm_integration.md | |
""" | |
import argparse | |
import asyncio | |
import json | |
from typing import List | |
from fastapi import FastAPI, Request, BackgroundTasks | |
from fastapi.responses import StreamingResponse, JSONResponse | |
import torch | |
import uvicorn | |
from vllm import AsyncLLMEngine | |
from vllm.engine.arg_utils import AsyncEngineArgs | |
from vllm.sampling_params import SamplingParams | |
from vllm.utils import random_uuid | |
from fastchat.serve.model_worker import ( | |
BaseModelWorker, | |
logger, | |
worker_id, | |
global_counter, | |
model_semaphore, | |
) | |
from fastchat.utils import get_context_length | |
class VLLMWorker(BaseModelWorker): | |
def __init__( | |
self, | |
controller_addr: str, | |
worker_addr: str, | |
worker_id: str, | |
model_path: str, | |
model_names: List[str], | |
no_register: bool, | |
llm_engine: AsyncLLMEngine, | |
): | |
super().__init__( | |
controller_addr, worker_addr, worker_id, model_path, model_names | |
) | |
logger.info( | |
f"Loading the model {self.model_names} on worker {worker_id}, worker type: vLLM worker..." | |
) | |
self.tokenizer = llm_engine.engine.tokenizer | |
self.context_len = get_context_length(llm_engine.engine.model_config.hf_config) | |
if not no_register: | |
self.init_heart_beat() | |
async def generate_stream(self, params): | |
context = params.pop("prompt") | |
request_id = params.pop("request_id") | |
temperature = float(params.get("temperature", 1.0)) | |
top_p = float(params.get("top_p", 1.0)) | |
max_new_tokens = params.get("max_new_tokens", 256) | |
stop_str = params.get("stop", None) | |
echo = params.get("echo", True) | |
# Handle stop_str | |
if stop_str is None: | |
stop_str = [] | |
# TODO(Hao): handle stop token IDs | |
# stop_token_ids = params.get("stop_token_ids", None) or [] | |
# make sampling params in vllm | |
top_p = max(top_p, 1e-5) | |
if temperature <= 1e-5: | |
top_p = 1.0 | |
sampling_params = SamplingParams( | |
n=1, | |
temperature=temperature, | |
top_p=top_p, | |
use_beam_search=False, | |
stop=stop_str, | |
max_tokens=max_new_tokens, | |
) | |
results_generator = engine.generate(context, sampling_params, request_id) | |
async for request_output in results_generator: | |
prompt = request_output.prompt | |
if echo: | |
text_outputs = [ | |
prompt + output.text for output in request_output.outputs | |
] | |
else: | |
text_outputs = [output.text for output in request_output.outputs] | |
text_outputs = " ".join(text_outputs) | |
text_outputs = text_outputs.replace("</s>", "") | |
# Note: usage is not supported yet | |
ret = {"text": text_outputs, "error_code": 0, "usage": {}} | |
yield (json.dumps(ret) + "\0").encode() | |
async def generate(self, params): | |
async for x in self.generate_stream(params): | |
pass | |
return json.loads(x[:-1].decode()) | |
app = FastAPI() | |
def release_model_semaphore(): | |
model_semaphore.release() | |
def acquire_model_semaphore(): | |
global model_semaphore, global_counter | |
global_counter += 1 | |
if model_semaphore is None: | |
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) | |
return model_semaphore.acquire() | |
def create_background_tasks(request_id): | |
async def abort_request() -> None: | |
await engine.abort(request_id) | |
background_tasks = BackgroundTasks() | |
background_tasks.add_task(release_model_semaphore) | |
background_tasks.add_task(abort_request) | |
return background_tasks | |
async def api_generate_stream(request: Request): | |
params = await request.json() | |
await acquire_model_semaphore() | |
request_id = random_uuid() | |
params["request_id"] = request_id | |
generator = worker.generate_stream(params) | |
background_tasks = create_background_tasks(request_id) | |
return StreamingResponse(generator, background=background_tasks) | |
async def api_generate(request: Request): | |
params = await request.json() | |
await acquire_model_semaphore() | |
request_id = random_uuid() | |
params["request_id"] = request_id | |
output = await worker.generate(params) | |
release_model_semaphore() | |
engine.abort(request_id) | |
return JSONResponse(output) | |
async def api_get_status(request: Request): | |
return worker.get_status() | |
async def api_count_token(request: Request): | |
params = await request.json() | |
return worker.count_token(params) | |
async def api_get_conv(request: Request): | |
return worker.get_conv_template() | |
async def api_model_details(request: Request): | |
return {"context_length": worker.context_len} | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--host", type=str, default="localhost") | |
parser.add_argument("--port", type=int, default=21002) | |
parser.add_argument("--worker-address", type=str, default="http://localhost:21002") | |
parser.add_argument( | |
"--controller-address", type=str, default="http://localhost:21001" | |
) | |
parser.add_argument("--model-path", type=str, default="lmsys/vicuna-7b-v1.3") | |
parser.add_argument( | |
"--model-names", | |
type=lambda s: s.split(","), | |
help="Optional display comma separated names", | |
) | |
parser.add_argument("--limit-model-concurrency", type=int, default=1024) | |
parser.add_argument("--no-register", action="store_true") | |
parser.add_argument("--num-gpus", type=int, default=1) | |
parser = AsyncEngineArgs.add_cli_args(parser) | |
args = parser.parse_args() | |
if args.model_path: | |
args.model = args.model_path | |
if args.num_gpus > 1: | |
args.tensor_parallel_size = args.num_gpus | |
engine_args = AsyncEngineArgs.from_cli_args(args) | |
engine = AsyncLLMEngine.from_engine_args(engine_args) | |
worker = VLLMWorker( | |
args.controller_address, | |
args.worker_address, | |
worker_id, | |
args.model_path, | |
args.model_names, | |
args.no_register, | |
engine, | |
) | |
uvicorn.run(app, host=args.host, port=args.port, log_level="info") | |