adasdimchom
commited on
Commit
•
427113a
1
Parent(s):
c37d664
Upload handler.py
Browse files- handler.py +9 -29
handler.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from transformers import Blip2Processor,
|
2 |
from typing import Dict, List, Any
|
3 |
from PIL import Image
|
4 |
from transformers import pipeline
|
@@ -12,11 +12,7 @@ class EndpointHandler():
|
|
12 |
"""
|
13 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
self.processor = Blip2Processor.from_pretrained(path)
|
15 |
-
self.
|
16 |
-
self.generate_model.to(self.device)
|
17 |
-
|
18 |
-
#self.feature_model = Blip2Model.from_pretrained(path, torch_dtype=torch.float16)
|
19 |
-
#self.feature_model.to(self.device)
|
20 |
|
21 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
22 |
"""
|
@@ -33,28 +29,12 @@ class EndpointHandler():
|
|
33 |
prompt = inputs["prompt"]
|
34 |
else:
|
35 |
prompt = None
|
36 |
-
|
37 |
-
# extract_feature = inputs["extract_feature"]
|
38 |
-
#else:
|
39 |
-
# extract_feature = False
|
40 |
-
|
41 |
-
image = Image.open(requests.get(image_url, stream=True).raw)
|
42 |
-
processed_image = self.processor(images=image, return_tensors="pt").to(self.device, torch.float16)
|
43 |
-
generated_ids = self.generate_model.generate(**processed_image)
|
44 |
-
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
45 |
-
result["image_caption"] = generated_text
|
46 |
-
|
47 |
-
#if extract_feature:
|
48 |
-
# caption_feature = self.feature_model(**processed_image)
|
49 |
-
# result["caption_feature"] = caption_feature
|
50 |
-
|
51 |
if prompt:
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
# result["prompt_feature"] = prompt_feature
|
59 |
-
|
60 |
return result
|
|
|
1 |
+
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
2 |
from typing import Dict, List, Any
|
3 |
from PIL import Image
|
4 |
from transformers import pipeline
|
|
|
12 |
"""
|
13 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
self.processor = Blip2Processor.from_pretrained(path)
|
15 |
+
self.model = Blip2ForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16).to(self.device)
|
|
|
|
|
|
|
|
|
16 |
|
17 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
18 |
"""
|
|
|
29 |
prompt = inputs["prompt"]
|
30 |
else:
|
31 |
prompt = None
|
32 |
+
image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
if prompt:
|
34 |
+
processed_image = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device, torch.float16)
|
35 |
+
else:
|
36 |
+
processed_image = self.processor(images=image, return_tensors="pt").to(self.device, torch.float16)
|
37 |
+
output = self.model.generate(**processed_image)
|
38 |
+
text_output = self.processor.decode(output[0], skip_special_tokens=True)
|
39 |
+
result["text_output"] = text_output
|
|
|
|
|
40 |
return result
|