File size: 1,957 Bytes
143c05b
19dfe9f
 
 
7b0ea0f
4fd791e
 
7b0ea0f
 
143c05b
19dfe9f
 
 
7b0ea0f
 
19dfe9f
143c05b
e8ba5e8
19dfe9f
db924e3
 
 
e8ba5e8
7b0ea0f
 
 
 
a1cce6c
7b0ea0f
3927f0d
19dfe9f
7b0ea0f
 
 
3c28bcd
7b0ea0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19dfe9f
 
7b0ea0f
 
e8ba5e8
19dfe9f
3927f0d
 
3c28bcd
fb3f4a6
7b0ea0f
 
19dfe9f
 
 
 
4fd791e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
import torch
from PIL import Image
import json
import vl_convert as vlc  
from io import BytesIO

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the processor and model
processor = AutoProcessor.from_pretrained("google/matcha-base")
processor.image_processor.is_vqa = False

model = Pix2StructForConditionalGeneration.from_pretrained("martinsinnona/visdecode_B").to(device)
model.eval()

def generate(image):

    inputs = processor(images=image, return_tensors="pt", max_patches=1024).to(device)
    generated_ids = model.generate(flattened_patches=inputs.flattened_patches, attention_mask=inputs.attention_mask, max_length=600)
    generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    # Generate the Vega image
    vega = string_to_vega(generated_caption)
    vega_image = draw_vega(vega)
    
    return generated_caption, vega_image

def draw_vega(vega, scale=3):

    spec = json.dumps(vega, indent=4)
    png_data = vlc.vegalite_to_png(vl_spec=spec, scale=scale)

    return Image.open(BytesIO(png_data))

def string_to_vega(string):

    string = string.replace("'", "\"")
    vega = json.loads(string)

    for axis in ["x", "y"]:
        field = vega["encoding"][axis]["field"]
        if field == "":
            vega["encoding"][axis]["field"] = axis
            vega["encoding"][axis]["title"] = ""
        else:
            for entry in vega["data"]["values"]:
                entry[field] = entry.pop(axis)
    return vega

# Create the Gradio interface
iface = gr.Interface(

    fn=generate,
    inputs=gr.Image(type="pil"),

    outputs=[gr.Textbox(),
             gr.Image(type="pil", label="Generated Vega Image")],

    title="Image to Vega-Lite",
    description="Upload an image to generate vega-lite"
)

# Launch the interface
if __name__ == "__main__":
    iface.launch(share=True)