File size: 2,759 Bytes
419b4c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8dd3cf1
419b4c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ef9dc8
27e36d4
 
 
 
419b4c1
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import numpy as np
from transformers import Blip2Processor, Blip2ForConditionalGeneration, BlipForQuestionAnswering, BitsAndBytesConfig
from transformers import AutoProcessor, AutoModelForCausalLM
from typing import Dict, List, Any
from PIL import Image
from transformers import pipeline
import requests
import torch
from io import BytesIO
import base64

class EndpointHandler():
    def __init__(self, path=""):
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        print("device:",self.device)
        self.model_base = "Salesforce/blip2-opt-2.7b"
        self.model_name = "sooh-j/VQA-for-VIP"
        self.processor = AutoProcessor.from_pretrained(self.model_name)
        self.model = Blip2ForConditionalGeneration.from_pretrained(self.model_name,
                                                              device_map="auto", 
                                                             ).to(self.device)

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            kwargs
        Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        # await hf.visualQuestionAnswering({
        #       model: 'dandelin/vilt-b32-finetuned-vqa',
        #       inputs: {
        #         question: 'How many cats are lying down?',
        #         image: await (await fetch('https://placekitten.com/300/300')).blob()
        #       }
        #     })
        
        inputs = data.get("inputs")
        imageBase64 = inputs.get("image")
        question = inputs.get("question")

        if ('http:' in imageBase64) or ('https:' in imageBase64): 
            image = Image.open(requests.get(imageBase64, stream=True).raw)
        else:
            image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[0].encode())))

        prompt = f"Question: {question}, Answer:"
        processed = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device)

        with torch.no_grad():
            out = self.model.generate(**processed, 
                                     max_new_tokens=20,
                                     # temperature = 0.5,
                                     # do_sample=True,
                                     # top_k=50,
                                     # top_p=0.9,
                                     repetition_penalty=1.2  
                                     ).to(self.device)
        
        result = {}
        text_output = self.processor.decode(out[0], skip_special_tokens=True)
        result["text_output"] = text_output
        score = 0
        
        return [{"answer":text_output,"score":score}]