Spaces:
Runtime error
Runtime error
from strings import TITLE, ABSTRACT, BOTTOM_LINE | |
from strings import DEFAULT_EXAMPLES | |
from strings import SPECIAL_STRS | |
from styles import PARENT_BLOCK_CSS | |
import time | |
import gradio as gr | |
from model import load_model | |
from gen import get_output_batch, StreamModel | |
from utils import generate_prompt, post_processes_batch, post_process_stream, get_generation_config, common_post_process | |
model, tokenizer = load_model( | |
base="decapoda-research/llama-13b-hf", | |
finetuned="chansung/alpaca-lora-13b" | |
) | |
model = StreamModel(model, tokenizer) | |
def chat_stream( | |
context, | |
instruction, | |
state_chatbot, | |
): | |
# print(instruction) | |
# user input should be appropriately formatted (don't be confused by the function name) | |
instruction_display = common_post_process(instruction) | |
instruction_prompt = generate_prompt(instruction, state_chatbot, context) | |
bot_response = model( | |
instruction_prompt, | |
max_tokens=128, | |
temperature=1, | |
top_p=0.9 | |
) | |
instruction_display = None if instruction_display == SPECIAL_STRS["continue"] else instruction_display | |
state_chatbot = state_chatbot + [(instruction_display, None)] | |
prev_index = 0 | |
agg_tokens = "" | |
cutoff_idx = 0 | |
for tokens in bot_response: | |
tokens = tokens.strip() | |
cur_token = tokens[prev_index:] | |
if "#" in cur_token and agg_tokens == "": | |
cutoff_idx = tokens.find("#") | |
agg_tokens = tokens[cutoff_idx:] | |
if agg_tokens != "": | |
if len(agg_tokens) < len("### Instruction:") : | |
agg_tokens = agg_tokens + cur_token | |
elif len(agg_tokens) >= len("### Instruction:"): | |
if tokens.find("### Instruction:") > -1: | |
processed_response, _ = post_process_stream(tokens[:tokens.find("### Instruction:")].strip()) | |
state_chatbot[-1] = ( | |
instruction_display, | |
processed_response | |
) | |
yield (state_chatbot, state_chatbot, context) | |
break | |
else: | |
agg_tokens = "" | |
cutoff_idx = 0 | |
if agg_tokens == "": | |
processed_response, to_exit = post_process_stream(tokens) | |
state_chatbot[-1] = (instruction_display, processed_response) | |
yield (state_chatbot, state_chatbot, context) | |
if to_exit: | |
break | |
prev_index = len(tokens) | |
yield ( | |
state_chatbot, | |
state_chatbot, | |
gr.Textbox.update(value=tokens) if instruction_display == SPECIAL_STRS["summarize"] else context | |
) | |
def chat_batch( | |
contexts, | |
instructions, | |
state_chatbots, | |
): | |
state_results = [] | |
ctx_results = [] | |
instruct_prompts = [ | |
generate_prompt(instruct, histories, ctx) | |
for ctx, instruct, histories in zip(contexts, instructions, state_chatbots) | |
] | |
bot_responses = get_output_batch( | |
model, tokenizer, instruct_prompts, generation_config | |
) | |
bot_responses = post_processes_batch(bot_responses) | |
for ctx, instruction, bot_response, state_chatbot in zip(contexts, instructions, bot_responses, state_chatbots): | |
new_state_chatbot = state_chatbot + [('' if instruction == SPECIAL_STRS["continue"] else instruction, bot_response)] | |
ctx_results.append(gr.Textbox.update(value=bot_response) if instruction == SPECIAL_STRS["summarize"] else ctx) | |
state_results.append(new_state_chatbot) | |
return (state_results, state_results, ctx_results) | |
def reset_textbox(): | |
return gr.Textbox.update(value='') | |
with gr.Blocks(css=PARENT_BLOCK_CSS) as demo: | |
state_chatbot = gr.State([]) | |
with gr.Column(elem_id='col_container'): | |
gr.Markdown(f"## {TITLE}\n\n\n{ABSTRACT}") | |
with gr.Accordion("Context Setting", open=False): | |
context_txtbox = gr.Textbox(placeholder="Surrounding information to AI", label="Enter Context") | |
hidden_txtbox = gr.Textbox(placeholder="", label="Order", visible=False) | |
chatbot = gr.Chatbot(elem_id='chatbot', label="Alpaca-LoRA") | |
instruction_txtbox = gr.Textbox(placeholder="What do you want to say to AI?", label="Instruction") | |
send_prompt_btn = gr.Button(value="Send Prompt") | |
with gr.Accordion("Helper Buttons", open=False): | |
gr.Markdown(f"`Continue` lets AI to complete the previous incomplete answers. `Summarize` lets AI to summarize the conversations so far.") | |
continue_txtbox = gr.Textbox(value=SPECIAL_STRS["continue"], visible=False) | |
summrize_txtbox = gr.Textbox(value=SPECIAL_STRS["summarize"], visible=False) | |
continue_btn = gr.Button(value="Continue") | |
summarize_btn = gr.Button(value="Summarize") | |
gr.Markdown("#### Examples") | |
for idx, examples in enumerate(DEFAULT_EXAMPLES): | |
with gr.Accordion(examples["title"], open=False): | |
gr.Examples( | |
examples=examples["examples"], | |
inputs=[ | |
hidden_txtbox, instruction_txtbox | |
], | |
label=None | |
) | |
gr.Markdown(f"{BOTTOM_LINE}") | |
send_prompt_btn.click( | |
chat_stream, | |
[context_txtbox, instruction_txtbox, state_chatbot], | |
[state_chatbot, chatbot, context_txtbox], | |
) | |
send_prompt_btn.click( | |
reset_textbox, | |
[], | |
[instruction_txtbox], | |
) | |
continue_btn.click( | |
chat_stream, | |
[context_txtbox, continue_txtbox, state_chatbot], | |
[state_chatbot, chatbot, context_txtbox], | |
) | |
continue_btn.click( | |
reset_textbox, | |
[], | |
[instruction_txtbox], | |
) | |
summarize_btn.click( | |
chat_stream, | |
[context_txtbox, summrize_txtbox, state_chatbot], | |
[state_chatbot, chatbot, context_txtbox], | |
) | |
summarize_btn.click( | |
reset_textbox, | |
[], | |
[instruction_txtbox], | |
) | |
demo.queue( | |
concurrency_count=1, | |
max_size=100, | |
).launch( | |
max_threads=5, | |
server_name="0.0.0.0", | |
) |