adasdimchom commited on
Commit
427113a
1 Parent(s): c37d664

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +9 -29
handler.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import Blip2Processor, Blip2Model, Blip2ForConditionalGeneration
2
  from typing import Dict, List, Any
3
  from PIL import Image
4
  from transformers import pipeline
@@ -12,11 +12,7 @@ class EndpointHandler():
12
  """
13
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
  self.processor = Blip2Processor.from_pretrained(path)
15
- self.generate_model = Blip2ForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16)
16
- self.generate_model.to(self.device)
17
-
18
- #self.feature_model = Blip2Model.from_pretrained(path, torch_dtype=torch.float16)
19
- #self.feature_model.to(self.device)
20
 
21
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
22
  """
@@ -33,28 +29,12 @@ class EndpointHandler():
33
  prompt = inputs["prompt"]
34
  else:
35
  prompt = None
36
- #if "extract_feature" in inputs:
37
- # extract_feature = inputs["extract_feature"]
38
- #else:
39
- # extract_feature = False
40
-
41
- image = Image.open(requests.get(image_url, stream=True).raw)
42
- processed_image = self.processor(images=image, return_tensors="pt").to(self.device, torch.float16)
43
- generated_ids = self.generate_model.generate(**processed_image)
44
- generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
45
- result["image_caption"] = generated_text
46
-
47
- #if extract_feature:
48
- # caption_feature = self.feature_model(**processed_image)
49
- # result["caption_feature"] = caption_feature
50
-
51
  if prompt:
52
- prompt_image_processed = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device, torch.float16)
53
- generated_ids = self.generate_model.generate(**prompt_image_processed)
54
- generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
55
- result["image_prompt"] = generated_text
56
- #if extract_feature:
57
- # prompt_feature = self.feature_model(**prompt_image_processed)
58
- # result["prompt_feature"] = prompt_feature
59
-
60
  return result
 
1
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
2
  from typing import Dict, List, Any
3
  from PIL import Image
4
  from transformers import pipeline
 
12
  """
13
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
  self.processor = Blip2Processor.from_pretrained(path)
15
+ self.model = Blip2ForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16).to(self.device)
 
 
 
 
16
 
17
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
18
  """
 
29
  prompt = inputs["prompt"]
30
  else:
31
  prompt = None
32
+ image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  if prompt:
34
+ processed_image = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device, torch.float16)
35
+ else:
36
+ processed_image = self.processor(images=image, return_tensors="pt").to(self.device, torch.float16)
37
+ output = self.model.generate(**processed_image)
38
+ text_output = self.processor.decode(output[0], skip_special_tokens=True)
39
+ result["text_output"] = text_output
 
 
40
  return result