Ngit commited on
Commit
8d5025a
1 Parent(s): 6920a89

setup inference

Browse files
Files changed (2) hide show
  1. handler.py +45 -0
  2. requirements.txt +3 -0
handler.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Dict, List, Any
3
+ from transformers import AutoTokenizer, BitsAndBytesConfig
4
+ from peft import AutoPeftModelForCausalLM
5
+
6
+
7
+ def parse_output(text):
8
+ marker = "### Response:"
9
+ if marker in text:
10
+ pos = text.find(marker) + len(marker)
11
+ else:
12
+ pos = 0
13
+ return text[pos:].replace("<pad>", "").replace("</s>", "").strip()
14
+
15
+
16
+ class EndpointHandler:
17
+ def __init__(self, path="./", use_bnb=True):
18
+
19
+ bnb_config = BitsAndBytesConfig(
20
+ load_in_4bit=True,
21
+ bnb_4bit_use_double_quant=True,
22
+ bnb_4bit_quant_type="nf4",
23
+ bnb_4bit_compute_dtype=torch.bfloat16,
24
+ )
25
+ self.model = AutoPeftModelForCausalLM.from_pretrained(
26
+ path, load_in_8bit=False, quantization_config=bnb_config
27
+ )
28
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
29
+
30
+ def __call__(self, data: Any) -> List[List[Dict[str, str]]]:
31
+
32
+ inputs = data.get("inputs", data)
33
+ prompt = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction: \n{inputs}\n\n### Response: \n"
34
+ parameters = data.get("parameters", {})
35
+
36
+ inputs = self.tokenizer(
37
+ prompt, return_tensors="pt", return_token_type_ids=False
38
+ ).to(self.model.device)
39
+ outputs = self.model.generate(**inputs, **parameters)
40
+
41
+ return {
42
+ "generated_text": parse_output(
43
+ self.tokenizer.decode(outputs[0].tolist(), skip_special_tokens=True)
44
+ )
45
+ }
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ bitsandbytes
2
+ git+https://github.com/huggingface/accelerate.git
3
+ git+https://github.com/huggingface/peft.git