shahidul034 commited on
Commit
1cd3cbb
1 Parent(s): d9b8ec7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +79 -1
README.md CHANGED
@@ -7,4 +7,82 @@ widget:
7
  ---
8
 
9
  # KUET CHATBOT
10
- Contributed by Jesiara Khatun and Sadia Islam (CSE 2K19, KUET)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  ---
8
 
9
  # KUET CHATBOT
10
+ Contributed by Jesiara Khatun and Sadia Islam (CSE 2K19, KUET)
11
+ ```
12
+ !pip install peft -q
13
+ !pip install transformers[sentencepiece] -q
14
+ !pip install sentencepiece
15
+ !pip install accelerate
16
+ !pip install bitsandbytes
17
+ ```
18
+
19
+ ```
20
+ import torch
21
+ from peft import PeftModel
22
+ import transformers
23
+ import textwrap
24
+ from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
25
+ from transformers.generation.utils import GreedySearchDecoderOnlyOutput
26
+
27
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
+ DEVICE
29
+ tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
30
+
31
+ model = LlamaForCausalLM.from_pretrained(
32
+ "abhishek/llama-2-7b-hf-small-shards",
33
+ load_in_4bit=True,
34
+ device_map="auto",
35
+ )
36
+
37
+ model = PeftModel.from_pretrained(model, "shahidul034/kuet_chatbot", torch_dtype=torch.float16)
38
+ model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
39
+ model.config.bos_token_id = 1
40
+ model.config.eos_token_id = 2
41
+
42
+ model = model.eval()
43
+ model = torch.compile(model)
44
+ PROMPT_TEMPLATE = f"""
45
+ Below is an instruction that describes a task. Write a response that appropriately completes the request.
46
+
47
+ ### Instruction:
48
+ [INSTRUCTION]
49
+
50
+ ### Response:
51
+ """
52
+
53
+
54
+ def create_prompt(instruction: str) -> str:
55
+ return PROMPT_TEMPLATE.replace("[INSTRUCTION]", instruction)
56
+
57
+ # print(create_prompt("What is (are) Glaucoma ?"))
58
+
59
+ def generate_response(prompt: str, model: PeftModel) -> GreedySearchDecoderOnlyOutput:
60
+ encoding = tokenizer(prompt, return_tensors="pt")
61
+ input_ids = encoding["input_ids"].to(DEVICE)
62
+
63
+ generation_config = GenerationConfig(
64
+ temperature=0.1,
65
+ top_p=0.75,
66
+ repetition_penalty=1.1,
67
+ )
68
+ with torch.inference_mode():
69
+ return model.generate(
70
+ input_ids=input_ids,
71
+ generation_config=generation_config,
72
+ return_dict_in_generate=True,
73
+ output_scores=True,
74
+ max_new_tokens=256,
75
+ )
76
+ def format_response(response: GreedySearchDecoderOnlyOutput) -> str:
77
+ decoded_output = tokenizer.decode(response.sequences[0])
78
+ response = decoded_output.split("### Response:")[1].strip()
79
+ return "\n".join(textwrap.wrap(response))
80
+
81
+ def ask_alpaca(prompt: str, model: PeftModel = model) -> str:
82
+ prompt = create_prompt(prompt)
83
+ response = generate_response(prompt, model)
84
+ print(format_response(response))
85
+
86
+ ask_alpaca("where is kuet located?")
87
+
88
+ ```