sarang-shrivastava commited on
Commit
68637ef
1 Parent(s): 6b0e4e6

Add handler and requirements.txt

Browse files
Files changed (2) hide show
  1. handler.py +103 -0
  2. requirements.txt +5 -0
handler.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ # import transformers
3
+ # from transformers import AutoTokenizer
4
+ # import torch
5
+ from datetime import datetime
6
+
7
+
8
+
9
+
10
+ import requests
11
+ from PIL import Image
12
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
13
+
14
+
15
+ class EndpointHandler():
16
+
17
+ def __init__(self, path=""):
18
+
19
+ self.processor = Blip2Processor.from_pretrained(path)
20
+ self.model = Blip2ForConditionalGeneration.from_pretrained(path, device_map="auto")
21
+
22
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ # self.model.eval()
24
+ # self.model.to(device=device, dtype=self.torch_dtype)
25
+
26
+ # self.generate_kwargs = {
27
+ # 'max_new_tokens': 512,
28
+ # 'temperature': 0.0001,
29
+ # 'top_p': 1.0,
30
+ # 'top_k': 0,
31
+ # 'use_cache': True,
32
+ # 'do_sample': True,
33
+ # 'eos_token_id': self.tokenizer.eos_token_id,
34
+ # 'pad_token_id': self.tokenizer.pad_token_id,
35
+ # "repetition_penalty": 1.1
36
+ # }
37
+
38
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
39
+ """
40
+ data args:
41
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
42
+ kwargs
43
+ Return:
44
+ A :obj:`list` | `dict`: will be serialized and returned
45
+ """
46
+
47
+ # streamer = TextIteratorStreamer(
48
+ # self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
49
+ # )
50
+
51
+ ## Model Parameters
52
+ # self.generate_kwargs['max_new_tokens'] = data['max_new_tokens'] if 'max_new_tokens' in data else self.generate_kwargs['max_new_tokens']
53
+ # self.generate_kwargs['temperature'] = data['temperature'] if 'temperature' in data else self.generate_kwargs['temperature']
54
+ # self.generate_kwargs['top_p'] = data['top_p'] if 'top_p' in data else self.generate_kwargs['top_p']
55
+ # self.generate_kwargs['top_k'] = data['top_k'] if 'top_k' in data else self.generate_kwargs['top_k']
56
+ # self.generate_kwargs['do_sample'] = data['do_sample'] if 'do_sample' in data else self.generate_kwargs['do_sample']
57
+ # self.generate_kwargs['repetition_penalty'] = data['repetition_penalty'] if 'repetition_penalty' in data else self.generate_kwargs['repetition_penalty']
58
+
59
+
60
+ ## Prepare the inputs
61
+ # inputs = data.pop("inputs",data)
62
+ # input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids
63
+ # input_ids = input_ids.to(self.model.device)
64
+
65
+
66
+ # pip install accelerate
67
+
68
+ img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
69
+
70
+ now = datetime.now()
71
+
72
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
73
+
74
+ question = "how many dogs are in the picture?"
75
+ inputs = self.processor(raw_image, question, return_tensors="pt").to("cuda")
76
+
77
+ out = self.model.generate(**inputs)
78
+ output_text = self.processor.decode(out[0], skip_special_tokens=True)
79
+
80
+ current = datetime.now()
81
+
82
+ # encoded_inp = self.tokenizer(inputs, return_tensors='pt', padding=True)
83
+ # for key, value in encoded_inp.items():
84
+ # encoded_inp[key] = value.to('cuda:0')
85
+
86
+ ## Invoke the model
87
+ # with torch.no_grad():
88
+ # gen_tokens = self.model.generate(
89
+ # input_ids=encoded_inp['input_ids'],
90
+ # attention_mask=encoded_inp['attention_mask'],
91
+ # **generate_kwargs,
92
+ # )
93
+
94
+ # ## Decode using tokenizer
95
+ # decoded_gen = self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
96
+
97
+ # with torch.no_grad():
98
+ # output_ids = self.model.generate(input_ids, **self.generate_kwargs)
99
+ # # Slice the output_ids tensor to get only new tokens
100
+ # new_tokens = output_ids[0, len(input_ids[0]) :]
101
+ # output_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
102
+
103
+ return [{"gen_text":output_text, "time_elapsed": str(current-now)}]
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Pillow
2
+ requests
3
+ accelerate
4
+ torch
5
+ transformers