JohnTan38's picture
Duplicate from chansung/Alpaca-LoRA-Serve
630f532
from strings import TITLE, ABSTRACT, BOTTOM_LINE
from strings import DEFAULT_EXAMPLES
from strings import SPECIAL_STRS
from styles import PARENT_BLOCK_CSS
import time
import gradio as gr
from model import load_model
from gen import get_output_batch, StreamModel
from utils import generate_prompt, post_processes_batch, post_process_stream, get_generation_config, common_post_process
model, tokenizer = load_model(
base="decapoda-research/llama-13b-hf",
finetuned="chansung/alpaca-lora-13b"
)
model = StreamModel(model, tokenizer)
def chat_stream(
context,
instruction,
state_chatbot,
):
# print(instruction)
# user input should be appropriately formatted (don't be confused by the function name)
instruction_display = common_post_process(instruction)
instruction_prompt = generate_prompt(instruction, state_chatbot, context)
bot_response = model(
instruction_prompt,
max_tokens=128,
temperature=1,
top_p=0.9
)
instruction_display = None if instruction_display == SPECIAL_STRS["continue"] else instruction_display
state_chatbot = state_chatbot + [(instruction_display, None)]
prev_index = 0
agg_tokens = ""
cutoff_idx = 0
for tokens in bot_response:
tokens = tokens.strip()
cur_token = tokens[prev_index:]
if "#" in cur_token and agg_tokens == "":
cutoff_idx = tokens.find("#")
agg_tokens = tokens[cutoff_idx:]
if agg_tokens != "":
if len(agg_tokens) < len("### Instruction:") :
agg_tokens = agg_tokens + cur_token
elif len(agg_tokens) >= len("### Instruction:"):
if tokens.find("### Instruction:") > -1:
processed_response, _ = post_process_stream(tokens[:tokens.find("### Instruction:")].strip())
state_chatbot[-1] = (
instruction_display,
processed_response
)
yield (state_chatbot, state_chatbot, context)
break
else:
agg_tokens = ""
cutoff_idx = 0
if agg_tokens == "":
processed_response, to_exit = post_process_stream(tokens)
state_chatbot[-1] = (instruction_display, processed_response)
yield (state_chatbot, state_chatbot, context)
if to_exit:
break
prev_index = len(tokens)
yield (
state_chatbot,
state_chatbot,
gr.Textbox.update(value=tokens) if instruction_display == SPECIAL_STRS["summarize"] else context
)
def chat_batch(
contexts,
instructions,
state_chatbots,
):
state_results = []
ctx_results = []
instruct_prompts = [
generate_prompt(instruct, histories, ctx)
for ctx, instruct, histories in zip(contexts, instructions, state_chatbots)
]
bot_responses = get_output_batch(
model, tokenizer, instruct_prompts, generation_config
)
bot_responses = post_processes_batch(bot_responses)
for ctx, instruction, bot_response, state_chatbot in zip(contexts, instructions, bot_responses, state_chatbots):
new_state_chatbot = state_chatbot + [('' if instruction == SPECIAL_STRS["continue"] else instruction, bot_response)]
ctx_results.append(gr.Textbox.update(value=bot_response) if instruction == SPECIAL_STRS["summarize"] else ctx)
state_results.append(new_state_chatbot)
return (state_results, state_results, ctx_results)
def reset_textbox():
return gr.Textbox.update(value='')
with gr.Blocks(css=PARENT_BLOCK_CSS) as demo:
state_chatbot = gr.State([])
with gr.Column(elem_id='col_container'):
gr.Markdown(f"## {TITLE}\n\n\n{ABSTRACT}")
with gr.Accordion("Context Setting", open=False):
context_txtbox = gr.Textbox(placeholder="Surrounding information to AI", label="Enter Context")
hidden_txtbox = gr.Textbox(placeholder="", label="Order", visible=False)
chatbot = gr.Chatbot(elem_id='chatbot', label="Alpaca-LoRA")
instruction_txtbox = gr.Textbox(placeholder="What do you want to say to AI?", label="Instruction")
send_prompt_btn = gr.Button(value="Send Prompt")
with gr.Accordion("Helper Buttons", open=False):
gr.Markdown(f"`Continue` lets AI to complete the previous incomplete answers. `Summarize` lets AI to summarize the conversations so far.")
continue_txtbox = gr.Textbox(value=SPECIAL_STRS["continue"], visible=False)
summrize_txtbox = gr.Textbox(value=SPECIAL_STRS["summarize"], visible=False)
continue_btn = gr.Button(value="Continue")
summarize_btn = gr.Button(value="Summarize")
gr.Markdown("#### Examples")
for idx, examples in enumerate(DEFAULT_EXAMPLES):
with gr.Accordion(examples["title"], open=False):
gr.Examples(
examples=examples["examples"],
inputs=[
hidden_txtbox, instruction_txtbox
],
label=None
)
gr.Markdown(f"{BOTTOM_LINE}")
send_prompt_btn.click(
chat_stream,
[context_txtbox, instruction_txtbox, state_chatbot],
[state_chatbot, chatbot, context_txtbox],
)
send_prompt_btn.click(
reset_textbox,
[],
[instruction_txtbox],
)
continue_btn.click(
chat_stream,
[context_txtbox, continue_txtbox, state_chatbot],
[state_chatbot, chatbot, context_txtbox],
)
continue_btn.click(
reset_textbox,
[],
[instruction_txtbox],
)
summarize_btn.click(
chat_stream,
[context_txtbox, summrize_txtbox, state_chatbot],
[state_chatbot, chatbot, context_txtbox],
)
summarize_btn.click(
reset_textbox,
[],
[instruction_txtbox],
)
demo.queue(
concurrency_count=1,
max_size=100,
).launch(
max_threads=5,
server_name="0.0.0.0",
)