File size: 3,282 Bytes
7bd6e67
 
 
 
 
 
 
4cbd18d
79a49f5
 
7bd6e67
 
 
 
 
 
4cbd18d
119fcb8
79a49f5
 
 
 
 
 
 
205ebd7
4cbd18d
 
 
 
79a49f5
4cbd18d
205ebd7
 
4cbd18d
 
205ebd7
 
 
7bd6e67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bc6adc
 
 
 
 
7bd6e67
 
 
 
 
3bc6adc
7bd6e67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df914b7
7bd6e67
06dfd60
 
7bd6e67
 
df914b7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch


from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, AwqConfig

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
ACCESS_TOKEN = os.getenv("HF_TOKEN", "")

model_id = "hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4"

quantization_config = AwqConfig(
    bits=4,
    fuse_max_seq_len=512, # Note: Update this as per your use-case
    do_fuse=True,
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
    low_cpu_mem_usage=True,
    quantization_config=quantization_config,
    token=ACCESS_TOKEN)
tokenizer = AutoTokenizer.from_pretrained(
    model_id,
    trust_remote_code=True,
    token=ACCESS_TOKEN)
tokenizer.use_default_system_prompt = False


@spaces.GPU
def generate(
    message: str,
    system_prompt: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.01,
    top_p: float = 0.01,
) -> Iterator[str]:
    conversation = []
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]
    
    streamer = TextIteratorStreamer(tokenizer, timeout=300.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        eos_token_id=terminators,
        do_sample=True,
        top_p=top_p,
        temperature=temperature,
        num_beams=1,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


chat_interface = gr.Interface(
    fn=generate,
    inputs=[
        gr.Textbox(lines=2, placeholder="Prompt", label="Prompt"),
    ],
    outputs="text",
    additional_inputs=[
        gr.Textbox(label="System prompt", lines=6),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.01,
            value=0.01,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.01,
            value=0.01,
        ),          
    ],
    title="Model testing",
    description="Provide system settings and a prompt to interact with the model.",
)

chat_interface.queue(max_size=20).launch()