File size: 2,994 Bytes
c3a1897
 
 
 
 
 
 
 
 
 
 
51f8a02
c3a1897
 
 
 
 
 
 
 
 
eb902b3
c3a1897
eb902b3
c3a1897
 
 
 
 
eb902b3
 
 
 
 
c3a1897
eb902b3
8381241
b510b75
 
eb902b3
c3a1897
 
 
 
 
51f8a02
c3a1897
eb902b3
 
 
 
 
 
 
 
 
 
 
 
 
c3a1897
44a0c32
c3a1897
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from models.blip2_model import ImageCaptioning
from models.grit_model import DenseCaptioning
from models.gpt_model import ImageToText
from models.controlnet_model import TextToImage
from models.region_semantic import RegionSemantic
from utils.util import read_image_width_height, display_images_and_text
import argparse
from PIL import Image
import base64
from io import BytesIO
import os
from utils.util import resize_long_edge

def pil_image_to_base64(image):
    buffered = BytesIO()
    image.save(buffered, format="JPEG")
    img_str = base64.b64encode(buffered.getvalue()).decode()
    return img_str


class ImageTextTransformation:
    def __init__(self, args):
        # Load your big model here
        self.args = args
        self.init_models()
        self.ref_image = None
    
    def init_models(self):
        openai_key = os.environ['OPENAI_KEY']
        print('\033[1;34m' + "Welcome to the Image2Paragraph toolbox...".center(50, '-') + '\033[0m')
        print('\033[1;33m' + "Initializing models...".center(50, '-') + '\033[0m')
        print('\033[1;31m' + "This is time-consuming, please wait...".center(50, '-') + '\033[0m')
        self.image_caption_model = ImageCaptioning(device=self.args.image_caption_device)
        self.dense_caption_model = DenseCaptioning(device=self.args.dense_caption_device)
        self.gpt_model = ImageToText(openai_key)
        self.controlnet_model = TextToImage(device=self.args.contolnet_device)
        # time-conusimg on CPU, run on local
        if self.args.semantic_segment:
            self.region_semantic_model = RegionSemantic(device=self.args.semantic_segment_device)
        print('\033[1;32m' + "Model initialization finished!".center(50, '-') + '\033[0m')

    
    def image_to_text(self, img_src):
        # the information to generate paragraph based on the context
        self.ref_image = Image.open(img_src)
        self.ref_image = resize_long_edge(self.ref_image)
        width, height = read_image_width_height(img_src)
        print(self.args)
        if self.args.image_caption:
            image_caption = self.image_caption_model.image_caption(img_src)
        else:
            image_caption = " "
        if self.args.dense_caption:
            dense_caption = self.dense_caption_model.image_dense_caption(img_src)
        else:
            dense_caption = " "
        if self.args.semantic_segment:
            region_semantic = self.region_semantic_model.region_semantic(img_src)
        else:
            region_semantic = " "
        generated_text = self.gpt_model.paragraph_summary_with_gpt(image_caption, dense_caption, region_semantic, width, height)
        return image_caption, dense_caption, region_semantic, generated_text

    def text_to_image(self, text):
        generated_image = self.controlnet_model.text_to_image(text, self.ref_image)
        return generated_image

    def text_to_image_retrieval(self, text):
        pass

    def image_to_text_retrieval(self, image):
        pass