|
|
|
|
|
|
|
import gradio as gr |
|
|
|
import torch |
|
from utils import * |
|
from presets import * |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
base_model = "project-baize/baize-v2-7b" |
|
|
|
tokenizer,model,device = load_tokenizer_and_model(base_model) |
|
|
|
|
|
|
|
|
|
def predict(text, |
|
chatbotGr, |
|
history, |
|
top_p, |
|
temperature, |
|
max_length_tokens, |
|
max_context_length_tokens,): |
|
if text=="": |
|
yield chatbotGr,history,"Testo vuoto." |
|
return |
|
try: |
|
model |
|
except: |
|
yield [[text,"Nessun modello trovato"]],[],"Nessun modello trovato" |
|
return |
|
|
|
inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens) |
|
if inputs is None: |
|
yield chatbotGr,history,"Input troppo lungo." |
|
return |
|
else: |
|
prompt,inputs=inputs |
|
begin_length = len(prompt) |
|
|
|
input_ids = inputs["input_ids"][:,-max_context_length_tokens:].to(device) |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p): |
|
if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False: |
|
if "[|Human|]" in x: |
|
x = x[:x.index("[|Human|]")].strip() |
|
if "[|AI|]" in x: |
|
x = x[:x.index("[|AI|]")].strip() |
|
x = x.strip() |
|
a, b= [[y[0],convert_to_markdown(y[1])] for y in history]+[[text, convert_to_markdown(x)]],history + [[text,x]] |
|
yield a, b, "Generating..." |
|
if shared_state.interrupted: |
|
shared_state.recover() |
|
try: |
|
yield a, b, "Stop: Success" |
|
return |
|
except: |
|
pass |
|
del input_ids |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
try: |
|
yield a,b,"Generate: Success" |
|
except: |
|
pass |
|
|
|
|
|
def reset_chat(): |
|
|
|
|
|
reset_textbox() |
|
|
|
|
|
|
|
|
|
def translate(): |
|
return "Kommt noch!" |
|
|
|
|
|
def coding(): |
|
return "Kommt noch!" |
|
|
|
|
|
|
|
|
|
with open("custom.css", "r", encoding="utf-8") as f: |
|
customCSS = f.read() |
|
|
|
with gr.Blocks(theme=small_and_beautiful_theme) as demo: |
|
history = gr.State([]) |
|
user_question = gr.State("") |
|
gr.Markdown("Scegli cosa vuoi provare") |
|
with gr.Tabs(): |
|
with gr.TabItem("Chat"): |
|
with gr.Row(): |
|
gr.HTML(title) |
|
status_display = gr.Markdown("Erfolg", elem_id="status_display") |
|
gr.Markdown(description_top) |
|
with gr.Row(scale=1).style(equal_height=True): |
|
with gr.Column(scale=5): |
|
with gr.Row(scale=1): |
|
chatbotGr = gr.Chatbot(elem_id="LI_chatbot").style(height="100%") |
|
with gr.Row(scale=1): |
|
with gr.Column(scale=12): |
|
user_input = gr.Textbox( |
|
show_label=False, placeholder="Gib deinen Text / Frage ein." |
|
).style(container=False) |
|
with gr.Column(min_width=100, scale=1): |
|
submitBtn = gr.Button("Absenden") |
|
with gr.Column(min_width=100, scale=1): |
|
cancelBtn = gr.Button("Stoppen") |
|
with gr.Row(scale=1): |
|
emptyBtn = gr.Button( |
|
"🧹 Neuer Chat", |
|
) |
|
with gr.Column(): |
|
with gr.Column(min_width=50, scale=1): |
|
with gr.Tab(label="Parameter zum Model"): |
|
gr.Markdown("# Parameters") |
|
top_p = gr.Slider( |
|
minimum=-0, |
|
maximum=1.0, |
|
value=0.95, |
|
step=0.05, |
|
interactive=True, |
|
label="Top-p", |
|
) |
|
temperature = gr.Slider( |
|
minimum=0.1, |
|
maximum=2.0, |
|
value=1, |
|
step=0.1, |
|
interactive=True, |
|
label="Temperature", |
|
) |
|
max_length_tokens = gr.Slider( |
|
minimum=0, |
|
maximum=512, |
|
value=512, |
|
step=8, |
|
interactive=True, |
|
label="Max Generation Tokens", |
|
) |
|
max_context_length_tokens = gr.Slider( |
|
minimum=0, |
|
maximum=4096, |
|
value=2048, |
|
step=128, |
|
interactive=True, |
|
label="Max History Tokens", |
|
) |
|
gr.Markdown(description) |
|
|
|
with gr.TabItem("Traduzioni"): |
|
with gr.Row(): |
|
gr.Textbox( |
|
show_label=False, placeholder="In costruzione ..." |
|
).style(container=False) |
|
with gr.TabItem("Generazione di codice"): |
|
with gr.Row(): |
|
gr.Textbox( |
|
show_label=False, placeholder="In costruzione ..." |
|
).style(container=False) |
|
|
|
predict_args = dict( |
|
fn=predict, |
|
inputs=[ |
|
user_question, |
|
chatbotGr, |
|
history, |
|
top_p, |
|
temperature, |
|
max_length_tokens, |
|
max_context_length_tokens, |
|
], |
|
outputs=[chatbotGr, history, status_display], |
|
show_progress=True, |
|
) |
|
|
|
|
|
reset_args = dict( |
|
|
|
fn=reset_textbox, inputs=[], outputs=[user_input, status_display] |
|
) |
|
|
|
|
|
transfer_input_args = dict( |
|
fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input, submitBtn], show_progress=True |
|
) |
|
|
|
|
|
predict_event1 = user_input.submit(**transfer_input_args).then(**predict_args) |
|
predict_event2 = submitBtn.click(**transfer_input_args).then(**predict_args) |
|
|
|
|
|
emptyBtn.click( |
|
reset_state, |
|
outputs=[chatbotGr, history, status_display], |
|
show_progress=True, |
|
) |
|
emptyBtn.click(**reset_args) |
|
|
|
demo.title = "Chat" |
|
|
|
demo.queue(concurrency_count=1).launch(debug=True) |