File size: 5,649 Bytes
d08b170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread

from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, ServiceContext
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

document = SimpleDirectoryReader("sft/dataset").load_data()

from llama_index.core import PromptTemplate

system_prompt = "You are a QA bot. Given the questions answer it correctly."

query_wrapper_prompt = PromptTemplate("<|user|>:{query_str}\n<|assistant|>:")

llm = HuggingFaceLLM(
    context_window=2048,
    max_new_tokens=256,
    generate_kwargs={"temperature":0.0, "do_sample":False},
    system_prompt=system_prompt,
    query_wrapper_prompt=query_wrapper_prompt,
    tokenizer_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    device_map="cuda",
    model_kwargs={"torch_dtype":torch.bfloat16},
)

embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")

service_context = ServiceContext.from_defaults(
    chunk_size=256,
    llm=llm,
    embed_model=embed_model
)

index = VectorStoreIndex.from_documents(document, service_context = service_context)

query_engine = index.as_query_engine()


# Defining a custom stopping criteria class for the model's text generation.
class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = [2]  # IDs of tokens where the generation should stop.
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:  # Checking if the last generated token is a stop token.
                return True
        return False

# Function to generate model predictions.
def predict(message, history):
    history_transformer_format = history + [[message, ""]]
    stop = StopOnTokens()

    # Formatting the input for the model.
    messages = "</s>".join(["</s>".join(["\n<|user|>:" + item[0], "\n<|assistant|>:" + item[1]])
                        for item in history_transformer_format])

    model_inputs = tokenizer([messages], return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=True,
        top_p=0.95,
        top_k=50,
        temperature=0.5,
        num_beams=1,
        stopping_criteria=StoppingCriteriaList([stop])
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()  # Starting the generation in a separate thread.
    partial_message = ""
    for new_token in streamer:
        partial_message += new_token
        if '</s>' in partial_message:  # Breaking the loop if the stop token is generated.
            break
        yield partial_message


def predict(input, history):
    response = query_engine.query(input)
    return str(response)

gr.ChatInterface(predict).launch(share=True)
# # Loading the tokenizer and model from Hugging Face's model hub.
# tokenizer = AutoTokenizer.from_pretrained("output/1T_FT_lr1e-5_ep5_top1_2024-03-04/checkpoint-575")
# model = AutoModelForCausalLM.from_pretrained("output/1T_FT_lr1e-5_ep5_top1_2024-03-04/checkpoint-575")

# # using CUDA for an optimal experience
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = model.to(device)

# # Defining a custom stopping criteria class for the model's text generation.
# class StopOnTokens(StoppingCriteria):
#     def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
#         stop_ids = [2]  # IDs of tokens where the generation should stop.
#         for stop_id in stop_ids:
#             if input_ids[0][-1] == stop_id:  # Checking if the last generated token is a stop token.
#                 return True
#         return False

# # Function to generate model predictions.
# def predict(message, history):
#     history_transformer_format = history + [[message, ""]]
#     stop = StopOnTokens()

#     # Formatting the input for the model.
#     messages = "</s>".join(["</s>".join(["\n<|user|>:" + item[0], "\n<|assistant|>:" + item[1]])
#                         for item in history_transformer_format])

#     model_inputs = tokenizer([messages], return_tensors="pt").to(device)
#     streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
#     generate_kwargs = dict(
#         model_inputs,
#         streamer=streamer,
#         max_new_tokens=1024,
#         do_sample=True,
#         top_p=0.95,
#         top_k=50,
#         temperature=0.5,
#         num_beams=1,
#         stopping_criteria=StoppingCriteriaList([stop])
#     )
#     t = Thread(target=model.generate, kwargs=generate_kwargs)
#     t.start()  # Starting the generation in a separate thread.
#     partial_message = ""
#     for new_token in streamer:
#         partial_message += new_token
#         if '</s>' in partial_message:  # Breaking the loop if the stop token is generated.
#             break
#         yield partial_message


# # Setting up the Gradio chat interface.
# gr.ChatInterface(predict,
#                  title="Tinyllama_chatBot",
#                  description="Ask Tiny llama any questions",
#                  examples=['How to cook a fish?', 'Who is the president of US now?']
#                  ).launch(share=True)  # Launching the web interface.