File size: 3,868 Bytes
59e3834
 
 
 
5d42eda
3697a24
 
 
5d42eda
0da0787
59e3834
3697a24
59e3834
3697a24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
739f5f0
02cddd1
3697a24
 
 
 
59e3834
3697a24
59e3834
3697a24
 
63643d0
3697a24
59e3834
 
bde65d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59e3834
ed744f1
 
 
bde65d6
ed744f1
 
 
 
 
bde65d6
 
5457d9a
59e3834
ed744f1
59e3834
 
 
 
 
 
 
 
 
3f6c140
59e3834
 
6bbd1a8
59e3834
1312c32
59e3834
6bbd1a8
 
 
 
 
 
 
 
59e3834
 
 
 
6bbd1a8
59e3834
6bbd1a8
 
91e30ca
6bbd1a8
 
3f6c140
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from threading import Thread
from typing import Iterator
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
import transformers
from torch import cuda, bfloat16
from peft import PeftModel, PeftConfig

token = os.environ.get("HF_API_TOKEN")

base_model_id = 'meta-llama/Llama-2-7b-chat-hf'

device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'

bnb_config = transformers.BitsAndBytesConfig(
    llm_int8_enable_fp32_cpu_offload = True
)

model_config = transformers.AutoConfig.from_pretrained(
    base_model_id,
    use_auth_token=token
)

model = transformers.AutoModelForCausalLM.from_pretrained(
    base_model_id,
    trust_remote_code=True,
    config=model_config,
    quantization_config=bnb_config,
    # device_map='auto',
    use_auth_token=token
)

config = PeftConfig.from_pretrained("Ashishkr/llama-2-medical-consultation")
model = PeftModel.from_pretrained(model, "Ashishkr/llama-2-medical-consultation").to(device)

model.eval()

tokenizer = transformers.AutoTokenizer.from_pretrained(
    base_model_id,
    use_auth_token=token
)


# def get_prompt(message: str, chat_history: list[tuple[str, str]],
#                system_prompt: str) -> str:
#     texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
#     # The first user input is _not_ stripped
#     do_strip = False
#     for user_input, response in chat_history:
#         user_input = user_input.strip() if do_strip else user_input
#         do_strip = True
#         texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
#     message = message.strip() if do_strip else message
#     texts.append(f'{message} [/INST]')
#     return ''.join(texts)

def get_prompt(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> str:
    texts = [f'{system_prompt}\n']

    if chat_history:
        for user_input, response in chat_history[:-1]:
            texts.append(f'{user_input} {response}\n')

        # Getting the user input and response from the last tuple in the chat history
        last_user_input, last_response = chat_history[-1]
        texts.append(f' input: {last_user_input} {last_response} {message} Response: ')
    else:
        texts.append(f' input: {message} Response: ')

    return ''.join(texts)



def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
    prompt = get_prompt(message, chat_history, system_prompt)
    input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
    return input_ids.shape[-1]


def run(message: str,
        chat_history: list[tuple[str, str]],
        system_prompt: str,
        max_new_tokens: int = 1024,
        temperature: float = 0.8,
        top_p: float = 0.95,
        top_k: int = 50) -> Iterator[str]:
    prompt = get_prompt(message, chat_history, system_prompt)
    inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to(device)

    streamer = TextIteratorStreamer(tokenizer,
                                    timeout=10.,
                                    skip_prompt=True,
                                    skip_special_tokens=True)
    generate_kwargs = dict(
        inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        if "instruction:" in text:
            # Append only the part of text before "instruction:" and stop streaming
            outputs.append(text.split("instruction:")[0])
            break
        else:
            outputs.append(text)
        
        yield ''.join(outputs)