File size: 2,493 Bytes
b8c24aa
3a82207
7dc3087
c8fdb3b
3a82207
4e81072
7dc3087
08c1bd3
4e81072
7dc3087
 
 
 
 
 
fca3d9e
7dc3087
4e81072
7dc3087
 
64d8a64
 
 
 
 
 
7dc3087
fccbbf3
 
08c1bd3
4e81072
3a82207
7dc3087
3a82207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7dc3087
3a82207
7dc3087
 
3a82207
 
 
7dc3087
 
 
3a82207
7dc3087
 
 
3a82207
7dc3087
3a82207
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
import os
from threading import Thread
import spaces
import time

token = os.environ["HF_TOKEN"]

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

model = AutoModelForCausalLM.from_pretrained("google/gemma-1.1-7b-it", 
                                             quantization_config=quantization_config,
                                             token=token)
tok = AutoTokenizer.from_pretrained("google/gemma-1.1-7b-it", token=token)

if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
    device = torch.device('cpu')
    print("Using CPU")
    
# model = model.to(device)
    # Dispatch Errors

@spaces.GPU
def chat(message, history):
    start_time = time.time()
    chat = []
    for item in history:
        chat.append({"role": "user", "content": item[0]})
        if item[1] is not None:
            chat.append({"role": "assistant", "content": item[1]})
    chat.append({"role": "user", "content": message})
    messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
    model_inputs = tok([messages], return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(
        tok, timeout=10., skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=True,
        top_p=0.95,
        top_k=1000,
        temperature=0.75,
        num_beams=1,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    partial_text = ""
    first_token_time = None
    for new_text in streamer:
        if not first_token_time:
            first_token_time = time.time() - start_time
        partial_text += new_text
        yield partial_text

    total_time = time.time() - start_time
    tokens = len(tok.tokenize(partial_text))
    tokens_per_second = tokens / total_time if total_time > 0 else 0

    # Append the timing information to the final output
    timing_info = f"\nTime taken to first token: {first_token_time:.2f} seconds\nTokens per second: {tokens_per_second:.2f}"
    yield partial_text + timing_info

demo = gr.ChatInterface(fn=chat, examples=[["Write me a poem about Machine Learning."]], title="Chat With LLMS")
demo.launch()