Spaces:
Sleeping
Sleeping
File size: 4,664 Bytes
a1918ea 28324bf a1918ea 9cfd209 28324bf db6559c a1918ea 6fda9d8 a1918ea 6fda9d8 a1918ea 65efd0e a1918ea 4cd52ce a1918ea c685207 4cd52ce a1918ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import gradio as gr
import torch
from transformers import pipeline
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model_name = "eljanmahammadli/AzLlama-152M-Alpaca"
model = pipeline(
"text-generation", model=model_name, torch_dtype=torch.float16, device=device
)
logo_path = "AzLlama-logo.webp"
def get_prompt(question):
base_instruction = "Aşağıda tapşırığı təsvir edən təlimat və əlavə kontekst təmin edən giriş verilmiştir. Sorğunu uyğun şəkildə tamamlayan cavab yazın."
prompt = f"""{base_instruction}
### Təlimat:
{question}
### Cavab:
"""
return prompt
def get_answer(llm_output):
return llm_output.split("### Cavab:")[1].strip()
def answer_question(history, temperature, top_p, repetition_penalty, top_k, question):
model_params = {
"temperature": temperature,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"top_k": top_k,
"max_length": 512,
"do_sample": True,
}
prompt = get_prompt(question)
llm_output = model(prompt, **model_params)[0]
answer = get_answer(llm_output["generated_text"])
divider = "\n\n" if history else ""
print(answer)
new_history = history + divider + f"USER: {question}\nASSISTANT: {answer}\n"
return new_history, ""
def send_action(_=None):
send_button.click()
with gr.Blocks() as app:
gr.Markdown("# AzLlama-150M Chatbot\n\n")
with gr.Row():
with gr.Column(scale=0.2, min_width=200):
gr.Markdown("### Model Logo")
gr.Image(
value=logo_path,
)
gr.Markdown(
"### Model Info\n"
"This model is a 150M paramater LLaMA2 model trained from scratch on Azerbaijani text. It can be used to generate text based on the given prompt. \n\nAlso note that this is very small model which can be scaled up in terms of data and model size to get better results.\n\nPlease ask general knowledge questions to the model.\n\nAlso play with the model settings especially the temperature."
)
with gr.Column(scale=0.6):
gr.Markdown("### Chat with the Assistant")
history = gr.Textbox(
label="Chat History",
value="",
lines=20,
interactive=False,
)
question = gr.Textbox(
label="Your question",
placeholder="Type your question and press enter",
)
send_button = gr.Button("Send")
with gr.Column(scale=0.2, min_width=200):
gr.Markdown("### Model Settings")
temperature = gr.Slider(
minimum=0.1, maximum=1.0, value=0.9, label="Temperature"
)
gr.Markdown(
"Controls the randomness of predictions. Lower values make the model more deterministic."
)
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top P")
gr.Markdown(
"Nucleus sampling. Lower values focus on more likely predictions."
)
repetition_penalty = gr.Slider(
minimum=1.0, maximum=2.0, value=1.2, label="Repetition Penalty"
)
gr.Markdown(
"Penalizes repeated words. Higher values discourage repetition."
)
top_k = gr.Slider(minimum=0, maximum=100, value=50, label="Top K")
gr.Markdown("Keeps only the top k predictions. Set to 0 for no limit.")
question.submit(send_action)
send_button.click(
fn=answer_question,
inputs=[history, temperature, top_p, repetition_penalty, top_k, question],
outputs=[history, question],
)
# Examples
gr.Examples(
examples=[
["İş mühitində effektiv olmaq üçün nə etmək lazımdır?", 0.3],
["Körpənin qayğısına necə qalmaq lazımdır?", 0.45],
["Sağlam yaşamaq üçün nə etmək lazımdır?", 0.5],
["Bərpa olunan enerjidən istifadənin əhəmiyyəti", 0.5],
["Pizza hazırlamağın qaydası nədir?", 0.5],
["Çox pul qazanmaq üçün nə etmək lazımdır?", 0.2],
["İkinci dünya müharibəsi nə vaxt olmuşdur?", 0.2],
["Amerikanın paytaxtı hansı şəhərdir?", 0.3],
["Fransanın paytaxtı hansı şəhərdir?", 0.3],
],
inputs=[question, temperature, top_p, repetition_penalty, top_k, question],
outputs=[history, question],
fn=answer_question,
)
app.launch()
|