File size: 2,008 Bytes
78011f3
 
 
bef7b02
78011f3
 
4dad4c8
78011f3
0c8a603
 
78011f3
b0c22d0
78011f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2deb7f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os 

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
device_map = 'cuda'

HF_TOKEN = os.environ.get("HF_TOKEN", None)

def load_model() -> AutoModelForCausalLM:
    return AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map)

def load_tokenizer() -> AutoTokenizer:
    return AutoTokenizer.from_pretrained(model_name)

def preprocess_messages(message: str, history: list, system_prompt: str) -> dict:
    messages = [{'role': 'system', 'content': system_prompt}, {'role': 'user', 'content': message}]
    prompt = load_tokenizer().apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    return prompt

def generate_text(prompt: str, max_new_tokens: int, temperature: float) -> str:
    model = load_model()
    terminators = [load_tokenizer().eos_token_id, load_tokenizer().convert_tokens_to_ids(['\n'])]
    temp = temperature + 0.1
    outputs = model.generate(
        prompt,
        max_new_tokens=max_new_tokens,
        eos_token_id=terminators[0],
        do_sample=True,
        temperature=temp,
        top_p=0.9
     )
    return load_tokenizer().decode(outputs[0], skip_special_tokens=True)

def chat_function(
    message: str,
    history: list,
    system_prompt: str,
    max_new_tokens: int,
    temperature: float
) -> str:
    prompt = preprocess_messages(message, history, system_prompt)
    return generate_text(prompt, max_new_tokens, temperature)

gr.ChatInterface(
    chat_function,
    chatbot=gr.Chatbot(height=400),
    textbox=gr.Textbox(placeholder="Enter message here", container=False, scale=7),
    title="LLAMA3 Chat",
    description="""Chat with llama3""",
    theme="soft",
    additional_inputs=[
        gr.Textbox("You shall answer to all the questions as very smart AI", label="System Prompt"),
        gr.Slider(512, 4096, label="Max New Tokens"),
        gr.Slider(0, 1, label="Temperature")
     ]
).launch(debug=True)