waleko commited on
Commit
53f2284
1 Parent(s): b54f700
Files changed (2) hide show
  1. requirements.txt +2 -0
  2. webui.py +11 -3
requirements.txt CHANGED
@@ -5,3 +5,5 @@ PyMuPDF~=1.22
5
  peft>=0.2.0
6
  transformers
7
  gradio
 
 
 
5
  peft>=0.2.0
6
  transformers
7
  gradio
8
+ accelerate
9
+ bitsandbytes
webui.py CHANGED
@@ -13,7 +13,7 @@ from typing import Optional
13
  from PIL import Image
14
  import fitz
15
  import gradio as gr
16
- from transformers import TextIteratorStreamer, pipeline, ImageToTextPipeline
17
 
18
  from infer import TikzDocument, TikzGenerator
19
 
@@ -23,11 +23,19 @@ models = {
23
  }
24
 
25
 
 
 
 
 
26
  @lru_cache(maxsize=1)
27
  def cached_load(model_name, **kwargs) -> ImageToTextPipeline:
28
  gr.Info("Instantiating model. Could take a while...") # type: ignore
29
- # noinspection PyTypeChecker
30
- return pipeline("image-to-text", model=model_name, **kwargs)
 
 
 
 
31
 
32
 
33
  def convert_to_svg(pdf):
 
13
  from PIL import Image
14
  import fitz
15
  import gradio as gr
16
+ from transformers import TextIteratorStreamer, pipeline, ImageToTextPipeline, AutoModelForPreTraining, AutoProcessor
17
 
18
  from infer import TikzDocument, TikzGenerator
19
 
 
23
  }
24
 
25
 
26
+ def is_8bit(model_name):
27
+ return "waleko/TikZ-llava" in model_name
28
+
29
+
30
  @lru_cache(maxsize=1)
31
  def cached_load(model_name, **kwargs) -> ImageToTextPipeline:
32
  gr.Info("Instantiating model. Could take a while...") # type: ignore
33
+ if not is_8bit(model_name):
34
+ return pipeline("image-to-text", model=model_name, **kwargs)
35
+ else:
36
+ model = AutoModelForPreTraining.from_pretrained(model_name, load_in_8bit=True, **kwargs)
37
+ processor = AutoProcessor.from_pretrained(model_name)
38
+ return pipeline(task="image-to-text", model=model, tokenizer=processor.tokenizer, image_processor=processor.image_processor)
39
 
40
 
41
  def convert_to_svg(pdf):