File size: 8,871 Bytes
673cd4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53f2284
673cd4d
 
 
 
 
 
 
 
 
53f2284
 
 
 
673cd4d
 
 
53f2284
 
 
 
 
 
673cd4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c4ca38
 
 
 
 
 
 
 
 
 
673cd4d
 
 
 
 
 
 
 
 
 
 
 
 
 
b54f700
 
673cd4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24dcd1f
673cd4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209

#!/usr/bin/env python

from argparse import ArgumentParser
from functools import lru_cache
from importlib.resources import files
from inspect import signature
from multiprocessing.pool import ThreadPool
from tempfile import NamedTemporaryFile
from textwrap import dedent
from typing import Optional

from PIL import Image
import fitz
import gradio as gr
from transformers import TextIteratorStreamer, pipeline, ImageToTextPipeline, AutoModelForPreTraining, AutoProcessor

from infer import TikzDocument, TikzGenerator

# assets = files(__package__) / "assets" if __package__ else files("assets") / "."
models = {
    "Fine-tuned Llava": "waleko/TikZ-llava-1.5-7b"
}


def is_8bit(model_name):
    return "waleko/TikZ-llava" in model_name


@lru_cache(maxsize=1)
def cached_load(model_name, **kwargs) -> ImageToTextPipeline:
    gr.Info("Instantiating model. Could take a while...") # type: ignore
    if not is_8bit(model_name):
        return pipeline("image-to-text", model=model_name, **kwargs)
    else:
        model = AutoModelForPreTraining.from_pretrained(model_name, load_in_8bit=True, **kwargs)
        processor = AutoProcessor.from_pretrained(model_name)
        return pipeline(task="image-to-text", model=model, tokenizer=processor.tokenizer, image_processor=processor.image_processor)


def convert_to_svg(pdf):
    doc = fitz.open("pdf", pdf.raw) # type: ignore
    return doc[0].get_svg_image()


def inference(
    model_name: str,
    image: Image.Image,
    temperature: float,
    top_p: float,
    top_k: int,
    expand_to_square: bool,
):
    generate = TikzGenerator(
        cached_load(model_name, device_map="auto"),
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        expand_to_square=expand_to_square,
    )
    streamer = TextIteratorStreamer(
        generate.pipeline.tokenizer, # type: ignore
        skip_prompt=True,
        skip_special_tokens=True
    )

    thread = ThreadPool(processes=1)
    async_result = thread.apply_async(generate, kwds=dict(image=image, streamer=streamer))

    generated_text = ""
    for new_text in streamer:
        generated_text += new_text
        yield generated_text, None, False
    yield async_result.get().code, None, True

def tex_compile(
    code: str,
    timeout: int,
    rasterize: bool
):
    tikzdoc = TikzDocument(code, timeout=timeout)
    if not tikzdoc.has_content:
        if tikzdoc.compiled_with_errors:
            raise gr.Error("TikZ code did not compile!") # type: ignore
        else:
            gr.Warning("TikZ code compiled to an empty image!") # type: ignore
    elif tikzdoc.compiled_with_errors:
        gr.Warning("TikZ code compiled with errors!") # type: ignore

    if rasterize:
        yield tikzdoc.rasterize()
    else:
        with NamedTemporaryFile(suffix=".svg", buffering=0) as tmpfile:
            if pdf:=tikzdoc.pdf:
                tmpfile.write(convert_to_svg(pdf).encode())
            yield tmpfile.name if pdf else None

def check_inputs(image: Image.Image):
    if not image:
        raise gr.Error("Image is required")

def get_banner():
    return dedent('''\
    # AutomaTi*k*Z: Text-Guided Synthesis of Scientific Vector Graphics with Ti*k*Z

    <p>
      <a style="display:inline-block" href="https://github.com/potamides/AutomaTikZ">
        <img src="https://img.shields.io/badge/View%20on%20GitHub-green?logo=github&labelColor=gray" alt="View on GitHub">
      </a>
      <a style="display:inline-block" href="https://arxiv.org/abs/2310.00367">
        <img src="https://img.shields.io/badge/View%20on%20arXiv-B31B1B?logo=arxiv&labelColor=gray" alt="View on arXiv">
      </a>
      <a style="display:inline-block" href="https://colab.research.google.com/drive/14S22x_8VohMr9pbnlkB4FqtF4n81khIh">
        <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab">
      </a>
      <a style="display:inline-block" href="https://huggingface.co/spaces/nllg/AutomaTikZ">
        <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg" alt="Open in HF Spaces">
      </a>
    </p>
    ''')

def remove_darkness(stylable):
    """
    Patch gradio to only contain light mode colors.
    """
    pass # TODO: remove dark mode colors from the theme
    # if isinstance(stylable, gr.themes.Base): # remove dark variants from the entire theme
    #     params = signature(stylable.set).parameters
    #     colors = {color: getattr(stylable, color.removesuffix("_dark")) for color in dir(stylable) if color in params}
    #     return stylable.set(**colors)
    # elif isinstance(stylable, gr.Blocks): # also handle components which do not use the theme (e.g. modals)
    #     stylable.load(_js="() => document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'))")
    #     return stylable
    # else:
    #     raise ValueError

def build_ui(model=list(models)[0], lock=False, rasterize=False, force_light=False, lock_reason="locked", timeout=120):
    theme = remove_darkness(gr.themes.Soft()) if force_light else gr.themes.Soft()
    with gr.Blocks(theme=theme, title="AutomaTikZ") as demo: # type: ignore
        if force_light: remove_darkness(demo)
        gr.Markdown(get_banner())
        with gr.Row(variant="panel"):
            with gr.Column():
                info = (
                    "Describe what you want to generate. "
                    "Scientific graphics benefit from captions with at least 30 tokens (see examples below), "
                    "while simple objects work best with shorter, 2-3 word captions."
                )
                # caption = gr.Textbox(label="Caption", info=info, placeholder="Type a caption...")
                # image = gr.Image(label="Image Input", type="pil")
                image = gr.ImageEditor(label="Image Input", type="pil")
                label = "Model" + (f" ({lock_reason})" if lock else "")
                model = gr.Dropdown(label=label, choices=list(models.items()), value=models[model], interactive=not lock) # type: ignore
                with gr.Accordion(label="Advanced Options", open=False):
                    temperature = gr.Slider(minimum=0, maximum=2, step=0.05, value=0.8, label="Temperature")
                    top_p = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.95, label="Top-P")
                    top_k = gr.Slider(minimum=0, maximum=100, step=10, value=0, label="Top-K")
                    expand_to_square = gr.Checkbox(value=True, label="Expand image to square")
                with gr.Row():
                    run_btn = gr.Button("Run", variant="primary")
                    stop_btn = gr.Button("Stop")
                    clear_btn = gr.ClearButton([image])
            with gr.Column():
                with gr.Tabs() as tabs:
                    with gr.TabItem(label:="TikZ Code", id=0):
                        info = "Source code of the generated image."
                        tikz_code = gr.Code(label=label, show_label=False, interactive=False)
                    with gr.TabItem(label:="Compiled Image", id=1):
                        result_image = gr.Image(label=label, show_label=False, show_share_button=rasterize)
                    clear_btn.add([tikz_code, result_image])
        # TODO: gr.Examples(examples=str(assets), inputs=[image, tikz_code, result_image])

        events = list()
        finished = gr.Textbox(visible=False) # hack to cancel compile on canceled inference
        for listener in [run_btn.click]:
            generate_event = listener(
                check_inputs,
                inputs=[image],
                queue=False
            ).success(
                lambda: gr.Tabs(selected=0),
                outputs=tabs, # type: ignore
                queue=False
            ).then(
                inference,
                inputs=[model, image, temperature, top_p, top_k, expand_to_square],
                outputs=[tikz_code, result_image, finished]
            )

            def tex_compile_if_finished(finished, *args):
                yield from (tex_compile(*args, timeout=timeout, rasterize=rasterize) if finished == "True" else [])

            compile_event = generate_event.then(
                lambda finished: gr.Tabs(selected=1) if finished == "True" else gr.Tabs(),
                inputs=finished,
                outputs=tabs, # type: ignore
                queue=False
            ).then(
                tex_compile_if_finished,
                inputs=[finished, tikz_code],
                outputs=result_image
            )
            events.extend([generate_event, compile_event])

        # model.select(lambda model_name: gr.Image(visible="clima" in model_name), inputs=model, outputs=image, queue=False)
        for btn in [clear_btn, stop_btn]:
            btn.click(fn=None, cancels=events, queue=False)
        return demo