waleko commited on
Commit
77f4da6
1 Parent(s): 7ab5a30

fix generation

Browse files
Files changed (2) hide show
  1. infer.py +2 -15
  2. webui.py +41 -39
infer.py CHANGED
@@ -196,34 +196,23 @@ class TikzGenerator:
196
  self.expand_to_square = expand_to_square
197
  self.clean_up_output = clean_up_output
198
  self.pipeline = pipe
199
- # self.pipeline.model = torch.compile(self.pipeline.model)
200
 
201
  self.default_kwargs = dict(
202
  temperature=temperature,
203
  top_p=top_p,
204
  top_k=top_k,
205
- num_return_sequences=1,
206
- # max_length=self.pipeline.tokenizer.model_max_length, # type: ignore
207
  do_sample=True,
208
- return_full_text=False,
209
- streamer=TextStreamer(self.pipeline.tokenizer, # type: ignore
210
- skip_prompt=True,
211
- skip_special_tokens=True
212
- ),
213
  max_new_tokens=1024,
214
  )
215
 
216
- if not stream:
217
- self.default_kwargs.pop("streamer")
218
 
219
  def generate(self, image: Image.Image, **generate_kwargs):
220
  prompt = "Assistant helps to write down the TikZ code for the user's image. USER: <image>\nWrite down the TikZ code to draw the diagram shown in the lol. ASSISTANT:"
221
  tokenizer = self.pipeline.tokenizer
222
- print('starting generation')
223
  text = self.pipeline(image, prompt=prompt, generate_kwargs=(self.default_kwargs | generate_kwargs))[0]["generated_text"] # type: ignore
224
 
225
- print('text generated: ', text) # TODO: remove
226
-
227
  if self.clean_up_output:
228
  for token in reversed(tokenizer.tokenize(prompt)): # type: ignore
229
  # remove leading characters because skip_special_tokens in pipeline
@@ -240,8 +229,6 @@ class TikzGenerator:
240
  for artifact, replacement in artifacts.items():
241
  text = sub(artifact, replacement, text) # type: ignore
242
 
243
- print('cleaned text: ', text)
244
-
245
  return TikzDocument(text)
246
 
247
 
 
196
  self.expand_to_square = expand_to_square
197
  self.clean_up_output = clean_up_output
198
  self.pipeline = pipe
 
199
 
200
  self.default_kwargs = dict(
201
  temperature=temperature,
202
  top_p=top_p,
203
  top_k=top_k,
 
 
204
  do_sample=True,
 
 
 
 
 
205
  max_new_tokens=1024,
206
  )
207
 
208
+ # if not stream:
209
+ # self.default_kwargs.pop("streamer")
210
 
211
  def generate(self, image: Image.Image, **generate_kwargs):
212
  prompt = "Assistant helps to write down the TikZ code for the user's image. USER: <image>\nWrite down the TikZ code to draw the diagram shown in the lol. ASSISTANT:"
213
  tokenizer = self.pipeline.tokenizer
 
214
  text = self.pipeline(image, prompt=prompt, generate_kwargs=(self.default_kwargs | generate_kwargs))[0]["generated_text"] # type: ignore
215
 
 
 
216
  if self.clean_up_output:
217
  for token in reversed(tokenizer.tokenize(prompt)): # type: ignore
218
  # remove leading characters because skip_special_tokens in pipeline
 
229
  for artifact, replacement in artifacts.items():
230
  text = sub(artifact, replacement, text) # type: ignore
231
 
 
 
232
  return TikzDocument(text)
233
 
234
 
webui.py CHANGED
@@ -15,22 +15,22 @@ import fitz
15
  import gradio as gr
16
  from transformers import TextIteratorStreamer, pipeline, ImageToTextPipeline, AutoModelForPreTraining, AutoProcessor
17
 
18
- from infer import TikzDocument, TikzGenerator
19
 
20
  # assets = files(__package__) / "assets" if __package__ else files("assets") / "."
21
  models = {
22
- "Fine-tuned Llava": "waleko/TikZ-llava-1.5-7b"
23
  }
24
 
25
 
26
- def is_8bit(model_name):
27
  return "waleko/TikZ-llava" in model_name
28
 
29
 
30
  @lru_cache(maxsize=1)
31
  def cached_load(model_name, **kwargs) -> ImageToTextPipeline:
32
  gr.Info("Instantiating model. Could take a while...") # type: ignore
33
- if not is_8bit(model_name):
34
  return pipeline("image-to-text", model=model_name, **kwargs)
35
  else:
36
  model = AutoModelForPreTraining.from_pretrained(model_name, load_in_8bit=True, **kwargs)
@@ -45,33 +45,35 @@ def convert_to_svg(pdf):
45
 
46
  def inference(
47
  model_name: str,
48
- image: Image.Image,
49
  temperature: float,
50
  top_p: float,
51
  top_k: int,
52
  expand_to_square: bool,
53
  ):
54
- generate = TikzGenerator(
55
- cached_load(model_name, device_map="auto"),
56
- temperature=temperature,
57
- top_p=top_p,
58
- top_k=top_k,
59
- expand_to_square=expand_to_square,
60
- )
61
- streamer = TextIteratorStreamer(
62
- generate.pipeline.tokenizer, # type: ignore
63
- skip_prompt=True,
64
- skip_special_tokens=True
65
- )
66
-
67
- thread = ThreadPool(processes=1)
68
- async_result = thread.apply_async(generate, kwds=dict(image=image, streamer=streamer))
69
-
70
- generated_text = ""
71
- for new_text in streamer:
72
- generated_text += new_text
73
- yield generated_text, None, False
74
- yield async_result.get().code, None, True
 
 
75
 
76
  def tex_compile(
77
  code: str,
@@ -85,7 +87,8 @@ def tex_compile(
85
  else:
86
  gr.Warning("TikZ code compiled to an empty image!") # type: ignore
87
  elif tikzdoc.compiled_with_errors:
88
- gr.Warning("TikZ code compiled with errors!") # type: ignore
 
89
 
90
  if rasterize:
91
  yield tikzdoc.rasterize()
@@ -123,16 +126,15 @@ def remove_darkness(stylable):
123
  """
124
  Patch gradio to only contain light mode colors.
125
  """
126
- pass # TODO: remove dark mode colors from the theme
127
- # if isinstance(stylable, gr.themes.Base): # remove dark variants from the entire theme
128
- # params = signature(stylable.set).parameters
129
- # colors = {color: getattr(stylable, color.removesuffix("_dark")) for color in dir(stylable) if color in params}
130
- # return stylable.set(**colors)
131
- # elif isinstance(stylable, gr.Blocks): # also handle components which do not use the theme (e.g. modals)
132
- # stylable.load(_js="() => document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'))")
133
- # return stylable
134
- # else:
135
- # raise ValueError
136
 
137
  def build_ui(model=list(models)[0], lock=False, rasterize=False, force_light=False, lock_reason="locked", timeout=120):
138
  theme = remove_darkness(gr.themes.Soft()) if force_light else gr.themes.Soft()
@@ -148,7 +150,7 @@ def build_ui(model=list(models)[0], lock=False, rasterize=False, force_light=Fal
148
  )
149
  # caption = gr.Textbox(label="Caption", info=info, placeholder="Type a caption...")
150
  # image = gr.Image(label="Image Input", type="pil")
151
- image = gr.ImageEditor(label="Image Input", type="pil")
152
  label = "Model" + (f" ({lock_reason})" if lock else "")
153
  model = gr.Dropdown(label=label, choices=list(models.items()), value=models[model], interactive=not lock) # type: ignore
154
  with gr.Accordion(label="Advanced Options", open=False):
@@ -168,7 +170,7 @@ def build_ui(model=list(models)[0], lock=False, rasterize=False, force_light=Fal
168
  with gr.TabItem(label:="Compiled Image", id=1):
169
  result_image = gr.Image(label=label, show_label=False, show_share_button=rasterize)
170
  clear_btn.add([tikz_code, result_image])
171
- # TODO: gr.Examples(examples=str(assets), inputs=[image, tikz_code, result_image])
172
 
173
  events = list()
174
  finished = gr.Textbox(visible=False) # hack to cancel compile on canceled inference
 
15
  import gradio as gr
16
  from transformers import TextIteratorStreamer, pipeline, ImageToTextPipeline, AutoModelForPreTraining, AutoProcessor
17
 
18
+ from .infer import TikzDocument, TikzGenerator
19
 
20
  # assets = files(__package__) / "assets" if __package__ else files("assets") / "."
21
  models = {
22
+ "llava-1.5-7b-hf": "waleko/TikZ-llava-1.5-7b"
23
  }
24
 
25
 
26
+ def is_quantization(model_name):
27
  return "waleko/TikZ-llava" in model_name
28
 
29
 
30
  @lru_cache(maxsize=1)
31
  def cached_load(model_name, **kwargs) -> ImageToTextPipeline:
32
  gr.Info("Instantiating model. Could take a while...") # type: ignore
33
+ if not is_quantization(model_name):
34
  return pipeline("image-to-text", model=model_name, **kwargs)
35
  else:
36
  model = AutoModelForPreTraining.from_pretrained(model_name, load_in_8bit=True, **kwargs)
 
45
 
46
  def inference(
47
  model_name: str,
48
+ image_dict: dict,
49
  temperature: float,
50
  top_p: float,
51
  top_k: int,
52
  expand_to_square: bool,
53
  ):
54
+ try:
55
+ generate = TikzGenerator(
56
+ cached_load(model_name, device_map="auto"),
57
+ temperature=temperature,
58
+ top_p=top_p,
59
+ top_k=top_k,
60
+ expand_to_square=expand_to_square,
61
+ )
62
+ streamer = TextIteratorStreamer(
63
+ generate.pipeline.tokenizer, # type: ignore
64
+ skip_prompt=True,
65
+ skip_special_tokens=True
66
+ )
67
+
68
+ thread = ThreadPool(processes=1)
69
+ async_result = thread.apply_async(generate, kwds=dict(image=image_dict['composite'], streamer=streamer))
70
+ generated_text = ""
71
+ for new_text in streamer:
72
+ generated_text += new_text
73
+ yield generated_text, None, False
74
+ yield async_result.get().code, None, True
75
+ except Exception as e:
76
+ raise gr.Error(f"Internal Error! {e}")
77
 
78
  def tex_compile(
79
  code: str,
 
87
  else:
88
  gr.Warning("TikZ code compiled to an empty image!") # type: ignore
89
  elif tikzdoc.compiled_with_errors:
90
+ # gr.Warning("TikZ code compiled with errors!") # type: ignore
91
+ print("TikZ code compiled with errors!")
92
 
93
  if rasterize:
94
  yield tikzdoc.rasterize()
 
126
  """
127
  Patch gradio to only contain light mode colors.
128
  """
129
+ if isinstance(stylable, gr.themes.Base): # remove dark variants from the entire theme
130
+ params = signature(stylable.set).parameters
131
+ colors = {color: getattr(stylable, color.removesuffix("_dark")) for color in dir(stylable) if color in params}
132
+ return stylable.set(**colors)
133
+ elif isinstance(stylable, gr.Blocks): # also handle components which do not use the theme (e.g. modals)
134
+ stylable.load(js="() => document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'))")
135
+ return stylable
136
+ else:
137
+ raise ValueError
 
138
 
139
  def build_ui(model=list(models)[0], lock=False, rasterize=False, force_light=False, lock_reason="locked", timeout=120):
140
  theme = remove_darkness(gr.themes.Soft()) if force_light else gr.themes.Soft()
 
150
  )
151
  # caption = gr.Textbox(label="Caption", info=info, placeholder="Type a caption...")
152
  # image = gr.Image(label="Image Input", type="pil")
153
+ image = gr.ImageEditor(label="Image Input", type="pil", sources=['upload', 'clipboard'], value=Image.new('RGB', (336, 336), (255, 255, 255)))
154
  label = "Model" + (f" ({lock_reason})" if lock else "")
155
  model = gr.Dropdown(label=label, choices=list(models.items()), value=models[model], interactive=not lock) # type: ignore
156
  with gr.Accordion(label="Advanced Options", open=False):
 
170
  with gr.TabItem(label:="Compiled Image", id=1):
171
  result_image = gr.Image(label=label, show_label=False, show_share_button=rasterize)
172
  clear_btn.add([tikz_code, result_image])
173
+ gr.Examples(examples=[["https://waleko.github.io/data/image.jpg"]], inputs=[image])
174
 
175
  events = list()
176
  finished = gr.Textbox(visible=False) # hack to cancel compile on canceled inference