File size: 10,669 Bytes
673cd4d
 
755b6ea
673cd4d
 
 
 
 
 
 
 
 
 
 
 
 
53f2284
673cd4d
755b6ea
ea58d06
755b6ea
77f8033
755b6ea
 
6ae92ca
673cd4d
 
 
ea58d06
cd9a590
 
673cd4d
 
 
77f4da6
53f2284
 
 
673cd4d
cd9a590
 
 
 
 
 
673cd4d
77f4da6
cd9a590
53f2284
cd9a590
53f2284
 
673cd4d
 
 
 
 
 
 
ea58d06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755b6ea
 
673cd4d
cd9a590
77f4da6
673cd4d
 
 
 
 
77f4da6
755b6ea
1ea18f2
ea58d06
755b6ea
 
77f4da6
cd9a590
77f4da6
 
 
 
 
 
 
 
 
 
 
 
77f8033
77f4da6
 
 
 
 
 
 
673cd4d
 
 
 
 
 
 
 
 
 
 
 
 
77f4da6
 
673cd4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dc09ef
673cd4d
 
6786a25
 
 
 
 
 
 
 
 
 
4dc09ef
673cd4d
 
 
 
 
 
 
 
 
77f4da6
 
 
 
 
 
 
 
 
673cd4d
 
 
4dc09ef
673cd4d
 
 
 
 
 
 
 
 
 
b54f700
77f4da6
673cd4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24dcd1f
673cd4d
 
 
755b6ea
7bb8428
 
 
 
755b6ea
673cd4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261

#!/usr/bin/env python
import re

from argparse import ArgumentParser
from functools import lru_cache
from importlib.resources import files
from inspect import signature
from multiprocessing.pool import ThreadPool
from tempfile import NamedTemporaryFile
from textwrap import dedent
from typing import Optional

from PIL import Image
import fitz
import gradio as gr
from transformers import TextIteratorStreamer, pipeline, ImageToTextPipeline, AutoModelForPreTraining, AutoProcessor

import os
# from pix2tex.cli import LatexOCR
from munch import Munch
import spaces


from infer import TikzDocument, TikzGenerator

# assets = files(__package__) / "assets" if __package__ else files("assets") / "."
models = {
    # "pix2tikz": "pix2tikz/mixed_e362_step201.pth",
    "llava-1.5-7b-hf": "waleko/TikZ-llava-1.5-7b",
    "new llava-1.5-7b-hf": "waleko/TikZ-llava-1.5-7b v2"
}


def is_quantization(model_name):
    return "waleko/TikZ-llava" in model_name


@lru_cache(maxsize=1)
def cached_load(model_name, **kwargs) -> ImageToTextPipeline:
    # split
    model_dict = model_name.split(" ")
    revision = "main"
    if len(model_dict) > 1:
        model_name, revision = model_dict
    gr.Info("Instantiating model. Could take a while...") # type: ignore
    if not is_quantization(model_name):
        return pipeline("image-to-text", model=model_name, revision=revision, **kwargs)
    else:
        model = AutoModelForPreTraining.from_pretrained(model_name, load_in_4bit=True, revision=revision, **kwargs)
        processor = AutoProcessor.from_pretrained(model_name)
        return pipeline(task="image-to-text", model=model, tokenizer=processor.tokenizer, image_processor=processor.image_processor)


def convert_to_svg(pdf):
    doc = fitz.open("pdf", pdf.raw) # type: ignore
    return doc[0].get_svg_image()


# def pix2tikz(
#     checkpoint: str,
#     image: Image.Image,
#     temperature: float,
#     _: float,
#     __: int,
#     ___: bool,
# ):
#     cur_pwd = os.path.dirname(os.path.abspath(__file__))
#     config_path = os.path.join(cur_pwd, 'pix2tikz/config.yaml')
#     model_path = os.path.join(cur_pwd, checkpoint)
#
#     print(cur_pwd, config_path, model_path, os.path.exists(config_path), os.path.exists(model_path))
#
#     args = Munch({'config': config_path,
#                   'checkpoint': model_path,
#                   'no_resize': False,
#                   'no_cuda': False,
#                   'temperature': temperature})
#     model = LatexOCR(args)
#     res = model(image)
#     text = re.sub(r'\\n(?=\W)', '\n', res)
#     return text, None, True


def inference(
    model_name: str,
    image_dict: dict,
    temperature: float,
    top_p: float,
    top_k: int,
    expand_to_square: bool,
):
    try:
        image = image_dict['composite']
        if "pix2tikz" in model_name:
            # yield pix2tikz(model_name, image, temperature, top_p, top_k, expand_to_square)
            return

        generate = TikzGenerator(
            cached_load(model_name, device_map="auto"),
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            expand_to_square=expand_to_square,
        )
        streamer = TextIteratorStreamer(
            generate.pipeline.tokenizer, # type: ignore
            skip_prompt=True,
            skip_special_tokens=True
        )

        thread = ThreadPool(processes=1)
        async_result = thread.apply_async(spaces.GPU(generate), kwds=dict(image=image, streamer=streamer))
        generated_text = ""
        for new_text in streamer:
            generated_text += new_text
            yield generated_text, None, False
        yield async_result.get().code, None, True
    except Exception as e:
        raise gr.Error(f"Internal Error! {e}")

def tex_compile(
    code: str,
    timeout: int,
    rasterize: bool
):
    tikzdoc = TikzDocument(code, timeout=timeout)
    if not tikzdoc.has_content:
        if tikzdoc.compiled_with_errors:
            raise gr.Error("TikZ code did not compile!") # type: ignore
        else:
            gr.Warning("TikZ code compiled to an empty image!") # type: ignore
    elif tikzdoc.compiled_with_errors:
        # gr.Warning("TikZ code compiled with errors!") # type: ignore
        print("TikZ code compiled with errors!")

    if rasterize:
        yield tikzdoc.rasterize()
    else:
        with NamedTemporaryFile(suffix=".svg", buffering=0) as tmpfile:
            if pdf:=tikzdoc.pdf:
                tmpfile.write(convert_to_svg(pdf).encode())
            yield tmpfile.name if pdf else None

def check_inputs(image: Image.Image):
    if not image:
        raise gr.Error("Image is required")

def get_banner():
    return dedent('''\
    # Ti*k*Z Assistant: Sketches to Vector Graphics with Ti*k*Z

    <p>
    
      <!--<a style="display:inline-block" href="https://github.com/potamides/AutomaTikZ">
        <img src="https://img.shields.io/badge/View%20on%20GitHub-green?logo=github&labelColor=gray" alt="View on GitHub">
      </a>
      <a style="display:inline-block" href="https://arxiv.org/abs/2310.00367">
        <img src="https://img.shields.io/badge/View%20on%20arXiv-B31B1B?logo=arxiv&labelColor=gray" alt="View on arXiv">
      </a>
      <a style="display:inline-block" href="https://colab.research.google.com/drive/14S22x_8VohMr9pbnlkB4FqtF4n81khIh">
        <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab">
      </a>-->
      <a style="display:inline-block" href="https://huggingface.co/spaces/waleko/TikZ-Assistant">
        <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg" alt="Open in HF Spaces">
      </a>
    </p>
    ''')

def remove_darkness(stylable):
    """
    Patch gradio to only contain light mode colors.
    """
    if isinstance(stylable, gr.themes.Base): # remove dark variants from the entire theme
        params = signature(stylable.set).parameters
        colors = {color: getattr(stylable, color.removesuffix("_dark")) for color in dir(stylable) if color in params}
        return stylable.set(**colors)
    elif isinstance(stylable, gr.Blocks): # also handle components which do not use the theme (e.g. modals)
        stylable.load(js="() => document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'))")
        return stylable
    else:
        raise ValueError

def build_ui(model=list(models)[0], lock=False, rasterize=False, force_light=False, lock_reason="locked", timeout=120):
    theme = remove_darkness(gr.themes.Soft()) if force_light else gr.themes.Soft()
    with gr.Blocks(theme=theme, title="TikZ Assistant") as demo: # type: ignore
        if force_light: remove_darkness(demo)
        gr.Markdown(get_banner())
        with gr.Row(variant="panel"):
            with gr.Column():
                info = (
                    "Describe what you want to generate. "
                    "Scientific graphics benefit from captions with at least 30 tokens (see examples below), "
                    "while simple objects work best with shorter, 2-3 word captions."
                )
                # caption = gr.Textbox(label="Caption", info=info, placeholder="Type a caption...")
                # image = gr.Image(label="Image Input", type="pil")
                image = gr.ImageEditor(label="Image Input", type="pil", sources=['upload', 'clipboard'], value=Image.new('RGB', (336, 336), (255, 255, 255)))
                label = "Model" + (f" ({lock_reason})" if lock else "")
                model = gr.Dropdown(label=label, choices=list(models.items()), value=models[model], interactive=not lock) # type: ignore
                with gr.Accordion(label="Advanced Options", open=False):
                    temperature = gr.Slider(minimum=0, maximum=2, step=0.05, value=0.8, label="Temperature")
                    top_p = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.95, label="Top-P")
                    top_k = gr.Slider(minimum=0, maximum=100, step=10, value=0, label="Top-K")
                    expand_to_square = gr.Checkbox(value=True, label="Expand image to square")
                with gr.Row():
                    run_btn = gr.Button("Run", variant="primary")
                    stop_btn = gr.Button("Stop")
                    clear_btn = gr.ClearButton([image])
            with gr.Column():
                with gr.Tabs() as tabs:
                    with gr.TabItem(label:="TikZ Code", id=0):
                        info = "Source code of the generated image."
                        tikz_code = gr.Code(label=label, show_label=False, interactive=False)
                    with gr.TabItem(label:="Compiled Image", id=1):
                        result_image = gr.Image(label=label, show_label=False, show_share_button=rasterize)
                    clear_btn.add([tikz_code, result_image])
        gr.Examples(examples=[
            ["https://waleko.github.io/data/image.jpg"],
            ["https://waleko.github.io/data/image2.jpg"],
            ["https://waleko.github.io/data/image3.jpg"],
            ["https://waleko.github.io/data/image4.jpg"],
        ], inputs=[image])

        events = list()
        finished = gr.Textbox(visible=False) # hack to cancel compile on canceled inference
        for listener in [run_btn.click]:
            generate_event = listener(
                check_inputs,
                inputs=[image],
                queue=False
            ).success(
                lambda: gr.Tabs(selected=0),
                outputs=tabs, # type: ignore
                queue=False
            ).then(
                inference,
                inputs=[model, image, temperature, top_p, top_k, expand_to_square],
                outputs=[tikz_code, result_image, finished]
            )

            def tex_compile_if_finished(finished, *args):
                yield from (tex_compile(*args, timeout=timeout, rasterize=rasterize) if finished == "True" else [])

            compile_event = generate_event.then(
                lambda finished: gr.Tabs(selected=1) if finished == "True" else gr.Tabs(),
                inputs=finished,
                outputs=tabs, # type: ignore
                queue=False
            ).then(
                tex_compile_if_finished,
                inputs=[finished, tikz_code],
                outputs=result_image
            )
            events.extend([generate_event, compile_event])

        # model.select(lambda model_name: gr.Image(visible="clima" in model_name), inputs=model, outputs=image, queue=False)
        for btn in [clear_btn, stop_btn]:
            btn.click(fn=None, cancels=events, queue=False)
        return demo