File size: 4,664 Bytes
a1918ea
 
 
 
28324bf
 
a1918ea
9cfd209
 
28324bf
db6559c
a1918ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fda9d8
a1918ea
 
 
 
 
 
 
 
6fda9d8
a1918ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65efd0e
a1918ea
 
 
 
4cd52ce
 
 
 
a1918ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c685207
4cd52ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1918ea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
import torch
from transformers import pipeline

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model_name = "eljanmahammadli/AzLlama-152M-Alpaca"
model = pipeline(
    "text-generation", model=model_name, torch_dtype=torch.float16, device=device
)
logo_path = "AzLlama-logo.webp"


def get_prompt(question):
    base_instruction = "Aşağıda tapşırığı təsvir edən təlimat və əlavə kontekst təmin edən giriş verilmiştir. Sorğunu uyğun şəkildə tamamlayan cavab yazın."
    prompt = f"""{base_instruction}

### Təlimat:
{question}

### Cavab:
"""
    return prompt


def get_answer(llm_output):
    return llm_output.split("### Cavab:")[1].strip()


def answer_question(history, temperature, top_p, repetition_penalty, top_k, question):
    model_params = {
        "temperature": temperature,
        "top_p": top_p,
        "repetition_penalty": repetition_penalty,
        "top_k": top_k,
        "max_length": 512,
        "do_sample": True,
    }
    prompt = get_prompt(question)
    llm_output = model(prompt, **model_params)[0]
    answer = get_answer(llm_output["generated_text"])
    divider = "\n\n" if history else ""
    print(answer)
    new_history = history + divider + f"USER: {question}\nASSISTANT: {answer}\n"
    return new_history, ""


def send_action(_=None):
    send_button.click()


with gr.Blocks() as app:
    gr.Markdown("# AzLlama-150M Chatbot\n\n")

    with gr.Row():
        with gr.Column(scale=0.2, min_width=200):
            gr.Markdown("### Model Logo")
            gr.Image(
                value=logo_path,
            )
            gr.Markdown(
                "### Model Info\n"
                "This model is a 150M paramater LLaMA2 model trained from scratch on Azerbaijani text. It can be used to generate text based on the given prompt. \n\nAlso note that this is very small model which can be scaled up in terms of data and model size to get better results.\n\nPlease ask general knowledge questions to the model.\n\nAlso play with the model settings especially the temperature."
            )
        with gr.Column(scale=0.6):
            gr.Markdown("### Chat with the Assistant")
            history = gr.Textbox(
                label="Chat History",
                value="",
                lines=20,
                interactive=False,
            )
            question = gr.Textbox(
                label="Your question",
                placeholder="Type your question and press enter",
            )
            send_button = gr.Button("Send")
        with gr.Column(scale=0.2, min_width=200):
            gr.Markdown("### Model Settings")
            temperature = gr.Slider(
                minimum=0.1, maximum=1.0, value=0.9, label="Temperature"
            )
            gr.Markdown(
                "Controls the randomness of predictions. Lower values make the model more deterministic."
            )
            top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top P")
            gr.Markdown(
                "Nucleus sampling. Lower values focus on more likely predictions."
            )
            repetition_penalty = gr.Slider(
                minimum=1.0, maximum=2.0, value=1.2, label="Repetition Penalty"
            )
            gr.Markdown(
                "Penalizes repeated words. Higher values discourage repetition."
            )
            top_k = gr.Slider(minimum=0, maximum=100, value=50, label="Top K")
            gr.Markdown("Keeps only the top k predictions. Set to 0 for no limit.")

    question.submit(send_action)

    send_button.click(
        fn=answer_question,
        inputs=[history, temperature, top_p, repetition_penalty, top_k, question],
        outputs=[history, question],
    )

    # Examples
    gr.Examples(
        examples=[
            ["İş mühitində effektiv olmaq üçün nə etmək lazımdır?", 0.3],
            ["Körpənin qayğısına necə qalmaq lazımdır?", 0.45],
            ["Sağlam yaşamaq üçün nə etmək lazımdır?", 0.5],
            ["Bərpa olunan enerjidən istifadənin əhəmiyyəti", 0.5],
            ["Pizza hazırlamağın qaydası nədir?", 0.5],
            ["Çox pul qazanmaq üçün nə etmək lazımdır?", 0.2],
            ["İkinci dünya müharibəsi nə vaxt olmuşdur?", 0.2],
            ["Amerikanın paytaxtı hansı şəhərdir?", 0.3],
            ["Fransanın paytaxtı hansı şəhərdir?", 0.3],
        ],
        inputs=[question, temperature, top_p, repetition_penalty, top_k, question],
        outputs=[history, question],
        fn=answer_question,
    )


app.launch()