waleko commited on
Commit
755b6ea
1 Parent(s): 6786a25

change branding

Browse files
pix2tikz/config.yaml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ backbone_layers:
2
+ - 2
3
+ - 3
4
+ - 7
5
+ batchsize: 12
6
+ betas:
7
+ - 0.9
8
+ - 0.999
9
+ bos_token: 1
10
+ channels: 1
11
+ config: colab.yaml
12
+ data: dataset/data/simple_train.pkl
13
+ debug: false
14
+ decoder_args:
15
+ attn_on_attn: true
16
+ cross_attend: true
17
+ ff_glu: true
18
+ rel_pos_bias: false
19
+ use_scalenorm: false
20
+ device: cuda:0
21
+ dim: 256
22
+ encoder_depth: 4
23
+ encoder_structure: hybrid
24
+ eos_token: 2
25
+ epoch: 429
26
+ epochs: 500
27
+ gamma: 0.9995
28
+ gpu_devices:
29
+ - 0
30
+ heads: 8
31
+ id: v9h46w6a
32
+ load_chkpt: /home/coder/project/LaTeX-OCR/weights.pth
33
+ lr: 0.001
34
+ lr_step: 30
35
+ max_dimensions:
36
+ - 336
37
+ - 336
38
+ max_height: 336
39
+ max_seq_len: 2048
40
+ max_width: 336
41
+ min_dimensions:
42
+ - 32
43
+ - 32
44
+ min_height: 32
45
+ min_width: 32
46
+ model_path: simple_checkpoints
47
+ name: mixed
48
+ no_cuda: false
49
+ num_layers: 4
50
+ num_tokens: 8000
51
+ optimizer: Adam
52
+ output_path: simple_outputs
53
+ pad: false
54
+ pad_token: 0
55
+ patch_size: 16
56
+ resume: false
57
+ sample_freq: 201
58
+ save_freq: 1
59
+ scheduler: StepLR
60
+ seed: 42
61
+ temperature: 0.2
62
+ test_samples: 5
63
+ testbatchsize: 8
64
+ tokenizer: dataset/tokenizer.json
65
+ valbatches: 1
66
+ valdata: dataset/data/simple_val.pkl
67
+ wandb: true
pix2tikz/mixed_e362_step201.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2eaae4b58f528da4eb5090f6addc717f1b135509f7e763100b90468ad9bfe1a8
3
+ size 103619970
requirements.txt CHANGED
@@ -7,3 +7,5 @@ transformers
7
  gradio
8
  accelerate
9
  bitsandbytes
 
 
 
7
  gradio
8
  accelerate
9
  bitsandbytes
10
+ altair<5
11
+ pix2tex[api]
webui.py CHANGED
@@ -1,5 +1,6 @@
1
 
2
  #!/usr/bin/env python
 
3
 
4
  from argparse import ArgumentParser
5
  from functools import lru_cache
@@ -15,10 +16,16 @@ import fitz
15
  import gradio as gr
16
  from transformers import TextIteratorStreamer, pipeline, ImageToTextPipeline, AutoModelForPreTraining, AutoProcessor
17
 
 
 
 
 
 
18
  from infer import TikzDocument, TikzGenerator
19
 
20
  # assets = files(__package__) / "assets" if __package__ else files("assets") / "."
21
  models = {
 
22
  "llava-1.5-7b-hf": "waleko/TikZ-llava-1.5-7b"
23
  }
24
 
@@ -43,6 +50,24 @@ def convert_to_svg(pdf):
43
  return doc[0].get_svg_image()
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def inference(
47
  model_name: str,
48
  image_dict: dict,
@@ -52,6 +77,11 @@ def inference(
52
  expand_to_square: bool,
53
  ):
54
  try:
 
 
 
 
 
55
  generate = TikzGenerator(
56
  cached_load(model_name, device_map="auto"),
57
  temperature=temperature,
@@ -66,7 +96,7 @@ def inference(
66
  )
67
 
68
  thread = ThreadPool(processes=1)
69
- async_result = thread.apply_async(generate, kwds=dict(image=image_dict['composite'], streamer=streamer))
70
  generated_text = ""
71
  for new_text in streamer:
72
  generated_text += new_text
@@ -171,7 +201,13 @@ def build_ui(model=list(models)[0], lock=False, rasterize=False, force_light=Fal
171
  with gr.TabItem(label:="Compiled Image", id=1):
172
  result_image = gr.Image(label=label, show_label=False, show_share_button=rasterize)
173
  clear_btn.add([tikz_code, result_image])
174
- gr.Examples(examples=[["https://waleko.github.io/data/image.jpg"]], inputs=[image])
 
 
 
 
 
 
175
 
176
  events = list()
177
  finished = gr.Textbox(visible=False) # hack to cancel compile on canceled inference
 
1
 
2
  #!/usr/bin/env python
3
+ import re
4
 
5
  from argparse import ArgumentParser
6
  from functools import lru_cache
 
16
  import gradio as gr
17
  from transformers import TextIteratorStreamer, pipeline, ImageToTextPipeline, AutoModelForPreTraining, AutoProcessor
18
 
19
+ import os
20
+ from pix2tex.cli import LatexOCR
21
+ from munch import Munch
22
+
23
+
24
  from infer import TikzDocument, TikzGenerator
25
 
26
  # assets = files(__package__) / "assets" if __package__ else files("assets") / "."
27
  models = {
28
+ "pix2tikz": "pix2tikz/mixed_e362_step201.pth",
29
  "llava-1.5-7b-hf": "waleko/TikZ-llava-1.5-7b"
30
  }
31
 
 
50
  return doc[0].get_svg_image()
51
 
52
 
53
+ def pix2tikz(
54
+ checkpoint: str,
55
+ image: Image.Image,
56
+ temperature: float,
57
+ _: float,
58
+ __: int,
59
+ ___: bool,
60
+ ):
61
+ args = Munch({'config': os.path.realpath(os.path.join(os.path.dirname(__file__), 'pix2tikz/config.yaml')),
62
+ 'checkpoint': os.path.realpath(os.path.join(os.path.dirname(__file__), checkpoint)),
63
+ 'no_resize': False,
64
+ 'temperature': temperature})
65
+ model = LatexOCR(args)
66
+ res = model(image)
67
+ text = re.sub(r'\\n(?=\W)', '\n', res)
68
+ return text, None, True
69
+
70
+
71
  def inference(
72
  model_name: str,
73
  image_dict: dict,
 
77
  expand_to_square: bool,
78
  ):
79
  try:
80
+ image = image_dict['composite']
81
+ if model_name == "pix2tikz":
82
+ yield pix2tikz(model_name, image, temperature, top_p, top_k, expand_to_square)
83
+ return
84
+
85
  generate = TikzGenerator(
86
  cached_load(model_name, device_map="auto"),
87
  temperature=temperature,
 
96
  )
97
 
98
  thread = ThreadPool(processes=1)
99
+ async_result = thread.apply_async(generate, kwds=dict(image=image, streamer=streamer))
100
  generated_text = ""
101
  for new_text in streamer:
102
  generated_text += new_text
 
201
  with gr.TabItem(label:="Compiled Image", id=1):
202
  result_image = gr.Image(label=label, show_label=False, show_share_button=rasterize)
203
  clear_btn.add([tikz_code, result_image])
204
+ gr.Examples(examples=[
205
+ ["https://waleko.github.io/data/image.jpg",
206
+ "https://waleko.github.io/data/image2.jpg",
207
+ "https://waleko.github.io/data/image3.jpg"
208
+ "https://waleko.github.io/data/image4.jpg",
209
+ ]
210
+ ], inputs=[image])
211
 
212
  events = list()
213
  finished = gr.Textbox(visible=False) # hack to cancel compile on canceled inference