Spaces:
Runtime error
Runtime error
import json | |
from traceback import format_exc | |
import flask_sock | |
import hivemind | |
import torch | |
import config | |
from app import sock, models | |
from utils import safe_decode | |
logger = hivemind.get_logger(__file__) | |
def ws_api_generate(ws): | |
try: | |
request = json.loads(ws.receive(timeout=config.STEP_TIMEOUT)) | |
assert request["type"] == "open_inference_session" | |
model_name = request.get("model") | |
if model_name is None: | |
model_name = config.DEFAULT_MODEL_NAME | |
logger.info(f"ws.generate.open(), model={repr(model_name)}, max_length={repr(request['max_length'])}") | |
model, tokenizer = models[model_name] | |
with model.inference_session(max_length=request["max_length"]) as session: | |
ws.send(json.dumps({"ok": True})) | |
while True: | |
request = json.loads(ws.receive(timeout=config.STEP_TIMEOUT)) | |
assert request["type"] == "generate" | |
inputs = request.get("inputs") | |
logger.info(f"ws.generate.step(), inputs={repr(inputs)}") | |
if inputs is not None: | |
inputs = tokenizer(inputs, return_tensors="pt")["input_ids"].to(config.DEVICE) | |
n_input_tokens = inputs.shape[1] | |
else: | |
n_input_tokens = 0 | |
stop_sequence = request.get("stop_sequence") | |
extra_stop_sequences = request.get("extra_stop_sequences") | |
if extra_stop_sequences is not None: | |
cont_token = tokenizer(stop_sequence, return_tensors="pt")["input_ids"].to(config.DEVICE) | |
assert cont_token.shape == (1, 1), \ | |
"extra_stop_sequences require stop_sequence length to be exactly 1 token" | |
all_outputs = '' | |
delta_q = [] | |
stop = False | |
while not stop: | |
outputs = model.generate( | |
inputs=inputs, | |
do_sample=request.get("do_sample", False), | |
temperature=request.get("temperature", 1.0), | |
top_k=request.get("top_k"), | |
top_p=request.get("top_p"), | |
max_length=request.get("max_length"), | |
max_new_tokens=request.get("max_new_tokens"), | |
session=session, | |
) | |
delta = outputs[0, n_input_tokens:].tolist() | |
outputs = safe_decode(tokenizer, torch.Tensor(delta_q + delta)) | |
inputs = None # Inputs are passed only for the 1st token of the bot's response | |
n_input_tokens = 0 | |
combined = all_outputs + outputs | |
stop = stop_sequence is None or combined.endswith(stop_sequence) | |
if extra_stop_sequences is not None: | |
for seq in extra_stop_sequences: | |
if combined.endswith(seq): | |
stop = True | |
session.last_token_id = cont_token | |
if not stop and outputs[-10:].find(u'\ufffd') > -1: | |
# If there's a replacement character, keep getting more tokens | |
# until we can decode properly | |
delta_q = delta_q + delta | |
logger.info(f"ws.generate.append_retry(), all_outputs={repr(combined)}") | |
else: | |
all_outputs = combined | |
delta_q = [] | |
logger.info(f"ws.generate.step(), all_outputs={repr(all_outputs)}, stop={stop}") | |
ws.send(json.dumps({"ok": True, "outputs": outputs, "stop": stop})) | |
except flask_sock.ConnectionClosed: | |
pass | |
except Exception: | |
logger.warning("ws.generate failed:", exc_info=True) | |
ws.send(json.dumps({"ok": False, "traceback": format_exc()})) | |
finally: | |
logger.info(f"ws.generate.close()") | |