File size: 2,388 Bytes
d9b8ec7 1cd3cbb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
# KUET CHATBOT
Contributed by Jesiara Khatun and Sadia Islam (CSE 2K19, KUET)
```
!pip install peft -q
!pip install transformers[sentencepiece] -q
!pip install sentencepiece
!pip install accelerate
!pip install bitsandbytes
```
```
import torch
from peft import PeftModel
import transformers
import textwrap
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
from transformers.generation.utils import GreedySearchDecoderOnlyOutput
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
model = LlamaForCausalLM.from_pretrained(
"abhishek/llama-2-7b-hf-small-shards",
load_in_4bit=True,
device_map="auto",
)
model = PeftModel.from_pretrained(model, "shahidul034/kuet_chatbot", torch_dtype=torch.float16)
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
model.config.bos_token_id = 1
model.config.eos_token_id = 2
model = model.eval()
model = torch.compile(model)
PROMPT_TEMPLATE = f"""
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
[INSTRUCTION]
### Response:
"""
def create_prompt(instruction: str) -> str:
return PROMPT_TEMPLATE.replace("[INSTRUCTION]", instruction)
# print(create_prompt("What is (are) Glaucoma ?"))
def generate_response(prompt: str, model: PeftModel) -> GreedySearchDecoderOnlyOutput:
encoding = tokenizer(prompt, return_tensors="pt")
input_ids = encoding["input_ids"].to(DEVICE)
generation_config = GenerationConfig(
temperature=0.1,
top_p=0.75,
repetition_penalty=1.1,
)
with torch.inference_mode():
return model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=256,
)
def format_response(response: GreedySearchDecoderOnlyOutput) -> str:
decoded_output = tokenizer.decode(response.sequences[0])
response = decoded_output.split("### Response:")[1].strip()
return "\n".join(textwrap.wrap(response))
def ask_alpaca(prompt: str, model: PeftModel = model) -> str:
prompt = create_prompt(prompt)
response = generate_response(prompt, model)
print(format_response(response))
ask_alpaca("where is kuet located?")
``` |