waleko commited on
Commit
cd9a590
1 Parent(s): ff6815d

fix selector

Browse files
Files changed (1) hide show
  1. webui.py +13 -10
webui.py CHANGED
@@ -25,9 +25,9 @@ from infer import TikzDocument, TikzGenerator
25
 
26
  # assets = files(__package__) / "assets" if __package__ else files("assets") / "."
27
  models = {
28
- "pix2tikz": {"model": "pix2tikz/mixed_e362_step201.pth"},
29
- "llava-1.5-7b-hf": {"model": "waleko/TikZ-llava-1.5-7b"},
30
- "new llava-1.5-7b-hf": {"model": "waleko/TikZ-llava-1.5-7b", "revision": "v2"},
31
  }
32
 
33
 
@@ -36,13 +36,17 @@ def is_quantization(model_name):
36
 
37
 
38
  @lru_cache(maxsize=1)
39
- def cached_load(model_dict, **kwargs) -> ImageToTextPipeline:
40
- model_name = model_dict["model"]
 
 
 
 
41
  gr.Info("Instantiating model. Could take a while...") # type: ignore
42
  if not is_quantization(model_name):
43
- return pipeline("image-to-text", **model_dict, **kwargs)
44
  else:
45
- model = AutoModelForPreTraining.from_pretrained(model_name, load_in_4bit=True, revision=model_dict.get("revision", "main"), **kwargs)
46
  processor = AutoProcessor.from_pretrained(model_name)
47
  return pipeline(task="image-to-text", model=model, tokenizer=processor.tokenizer, image_processor=processor.image_processor)
48
 
@@ -78,7 +82,7 @@ def pix2tikz(
78
 
79
 
80
  def inference(
81
- model_dict: dict,
82
  image_dict: dict,
83
  temperature: float,
84
  top_p: float,
@@ -87,13 +91,12 @@ def inference(
87
  ):
88
  try:
89
  image = image_dict['composite']
90
- model_name = model_dict["model"]
91
  if "pix2tikz" in model_name:
92
  yield pix2tikz(model_name, image, temperature, top_p, top_k, expand_to_square)
93
  return
94
 
95
  generate = TikzGenerator(
96
- cached_load(model_dict, device_map="auto"),
97
  temperature=temperature,
98
  top_p=top_p,
99
  top_k=top_k,
 
25
 
26
  # assets = files(__package__) / "assets" if __package__ else files("assets") / "."
27
  models = {
28
+ "pix2tikz": "pix2tikz/mixed_e362_step201.pth",
29
+ "llava-1.5-7b-hf": "waleko/TikZ-llava-1.5-7b",
30
+ "new llava-1.5-7b-hf": "waleko/TikZ-llava-1.5-7b v2"
31
  }
32
 
33
 
 
36
 
37
 
38
  @lru_cache(maxsize=1)
39
+ def cached_load(model_name, **kwargs) -> ImageToTextPipeline:
40
+ # split
41
+ model_dict = model_name.split(" ")
42
+ revision = "main"
43
+ if len(model_dict) > 1:
44
+ model_name, revision = model_dict
45
  gr.Info("Instantiating model. Could take a while...") # type: ignore
46
  if not is_quantization(model_name):
47
+ return pipeline("image-to-text", model=model_name, revision=revision, **kwargs)
48
  else:
49
+ model = AutoModelForPreTraining.from_pretrained(model_name, load_in_4bit=True, revision=revision, **kwargs)
50
  processor = AutoProcessor.from_pretrained(model_name)
51
  return pipeline(task="image-to-text", model=model, tokenizer=processor.tokenizer, image_processor=processor.image_processor)
52
 
 
82
 
83
 
84
  def inference(
85
+ model_name: str,
86
  image_dict: dict,
87
  temperature: float,
88
  top_p: float,
 
91
  ):
92
  try:
93
  image = image_dict['composite']
 
94
  if "pix2tikz" in model_name:
95
  yield pix2tikz(model_name, image, temperature, top_p, top_k, expand_to_square)
96
  return
97
 
98
  generate = TikzGenerator(
99
+ cached_load(model_name, device_map="auto"),
100
  temperature=temperature,
101
  top_p=top_p,
102
  top_k=top_k,