shahidul034
commited on
Commit
•
1cd3cbb
1
Parent(s):
d9b8ec7
Update README.md
Browse files
README.md
CHANGED
@@ -7,4 +7,82 @@ widget:
|
|
7 |
---
|
8 |
|
9 |
# KUET CHATBOT
|
10 |
-
Contributed by Jesiara Khatun and Sadia Islam (CSE 2K19, KUET)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
---
|
8 |
|
9 |
# KUET CHATBOT
|
10 |
+
Contributed by Jesiara Khatun and Sadia Islam (CSE 2K19, KUET)
|
11 |
+
```
|
12 |
+
!pip install peft -q
|
13 |
+
!pip install transformers[sentencepiece] -q
|
14 |
+
!pip install sentencepiece
|
15 |
+
!pip install accelerate
|
16 |
+
!pip install bitsandbytes
|
17 |
+
```
|
18 |
+
|
19 |
+
```
|
20 |
+
import torch
|
21 |
+
from peft import PeftModel
|
22 |
+
import transformers
|
23 |
+
import textwrap
|
24 |
+
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
|
25 |
+
from transformers.generation.utils import GreedySearchDecoderOnlyOutput
|
26 |
+
|
27 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
28 |
+
DEVICE
|
29 |
+
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
30 |
+
|
31 |
+
model = LlamaForCausalLM.from_pretrained(
|
32 |
+
"abhishek/llama-2-7b-hf-small-shards",
|
33 |
+
load_in_4bit=True,
|
34 |
+
device_map="auto",
|
35 |
+
)
|
36 |
+
|
37 |
+
model = PeftModel.from_pretrained(model, "shahidul034/kuet_chatbot", torch_dtype=torch.float16)
|
38 |
+
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
|
39 |
+
model.config.bos_token_id = 1
|
40 |
+
model.config.eos_token_id = 2
|
41 |
+
|
42 |
+
model = model.eval()
|
43 |
+
model = torch.compile(model)
|
44 |
+
PROMPT_TEMPLATE = f"""
|
45 |
+
Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
46 |
+
|
47 |
+
### Instruction:
|
48 |
+
[INSTRUCTION]
|
49 |
+
|
50 |
+
### Response:
|
51 |
+
"""
|
52 |
+
|
53 |
+
|
54 |
+
def create_prompt(instruction: str) -> str:
|
55 |
+
return PROMPT_TEMPLATE.replace("[INSTRUCTION]", instruction)
|
56 |
+
|
57 |
+
# print(create_prompt("What is (are) Glaucoma ?"))
|
58 |
+
|
59 |
+
def generate_response(prompt: str, model: PeftModel) -> GreedySearchDecoderOnlyOutput:
|
60 |
+
encoding = tokenizer(prompt, return_tensors="pt")
|
61 |
+
input_ids = encoding["input_ids"].to(DEVICE)
|
62 |
+
|
63 |
+
generation_config = GenerationConfig(
|
64 |
+
temperature=0.1,
|
65 |
+
top_p=0.75,
|
66 |
+
repetition_penalty=1.1,
|
67 |
+
)
|
68 |
+
with torch.inference_mode():
|
69 |
+
return model.generate(
|
70 |
+
input_ids=input_ids,
|
71 |
+
generation_config=generation_config,
|
72 |
+
return_dict_in_generate=True,
|
73 |
+
output_scores=True,
|
74 |
+
max_new_tokens=256,
|
75 |
+
)
|
76 |
+
def format_response(response: GreedySearchDecoderOnlyOutput) -> str:
|
77 |
+
decoded_output = tokenizer.decode(response.sequences[0])
|
78 |
+
response = decoded_output.split("### Response:")[1].strip()
|
79 |
+
return "\n".join(textwrap.wrap(response))
|
80 |
+
|
81 |
+
def ask_alpaca(prompt: str, model: PeftModel = model) -> str:
|
82 |
+
prompt = create_prompt(prompt)
|
83 |
+
response = generate_response(prompt, model)
|
84 |
+
print(format_response(response))
|
85 |
+
|
86 |
+
ask_alpaca("where is kuet located?")
|
87 |
+
|
88 |
+
```
|