shikunl commited on
Commit
818a4f8
1 Parent(s): 7617596
Files changed (3) hide show
  1. app.py +13 -7
  2. app_caption.py +3 -13
  3. prismer_model.py +8 -32
app.py CHANGED
@@ -11,25 +11,31 @@ import gradio as gr
11
  if os.getenv('SYSTEM') == 'spaces':
12
  with open('patch') as f:
13
  subprocess.run('patch -p1'.split(), cwd='prismer', stdin=f)
14
- shutil.copytree('prismer/helpers/images',
15
- 'prismer/images',
16
- dirs_exist_ok=True)
17
 
18
  from app_caption import create_demo as create_demo_caption
19
  from prismer_model import build_deformable_conv, download_models
20
 
 
21
  download_models()
22
  build_deformable_conv()
23
 
24
- DESCRIPTION = '# [Prismer](https://github.com/nvlabs/prismer)'
 
 
 
 
 
 
25
 
26
  if (SPACE_ID := os.getenv('SPACE_ID')) is not None:
27
- DESCRIPTION += f'<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
 
28
 
29
  with gr.Blocks(css='style.css') as demo:
30
- gr.Markdown(DESCRIPTION)
31
  with gr.Tabs():
32
- with gr.TabItem('Caption'):
33
  create_demo_caption()
34
 
35
  demo.queue(api_open=False).launch()
 
11
  if os.getenv('SYSTEM') == 'spaces':
12
  with open('patch') as f:
13
  subprocess.run('patch -p1'.split(), cwd='prismer', stdin=f)
14
+ shutil.copytree('prismer/helpers/images', 'prismer/images', dirs_exist_ok=True)
 
 
15
 
16
  from app_caption import create_demo as create_demo_caption
17
  from prismer_model import build_deformable_conv, download_models
18
 
19
+ # Prepare model checkpoints
20
  download_models()
21
  build_deformable_conv()
22
 
23
+
24
+ # Demo file here
25
+ description = """
26
+ # Prismer
27
+ The official demo for **Prismer: A Vision-Language Model with An Ensemble of Experts**.
28
+ Please refer to our [project page](https://shikun.io/projects/prismer) or [github](https://github.com/NVlabs/prismer) for more details.
29
+ """
30
 
31
  if (SPACE_ID := os.getenv('SPACE_ID')) is not None:
32
+ description += f'For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a>'
33
+
34
 
35
  with gr.Blocks(css='style.css') as demo:
36
+ gr.Markdown(description)
37
  with gr.Tabs():
38
+ with gr.TabItem('Zero-shot Image Captioning'):
39
  create_demo_caption()
40
 
41
  demo.queue(api_open=False).launch()
app_caption.py CHANGED
@@ -15,10 +15,8 @@ def create_demo():
15
 
16
  with gr.Row():
17
  with gr.Column():
18
- image = gr.Image(label='Input', type='filepath')
19
- model_name = gr.Dropdown(label='Model',
20
- choices=['prismer_base'],
21
- value='prismer_base')
22
  run_button = gr.Button('Run')
23
  with gr.Column(scale=1.5):
24
  caption = gr.Text(label='Caption')
@@ -32,15 +30,7 @@ def create_demo():
32
  ocr = gr.Image(label='OCR Detection')
33
 
34
  inputs = [image, model_name]
35
- outputs = [
36
- caption,
37
- depth,
38
- edge,
39
- normals,
40
- segmentation,
41
- object_detection,
42
- ocr,
43
- ]
44
 
45
  paths = sorted(pathlib.Path('prismer/images').glob('*'))
46
  examples = [[path.as_posix(), 'prismer_base'] for path in paths]
 
15
 
16
  with gr.Row():
17
  with gr.Column():
18
+ image = gr.Image(label='Input Image', type='filepath')
19
+ model_name = gr.Dropdown(label='Model Size', choices=['prismer_base'], value='prismer_base')
 
 
20
  run_button = gr.Button('Run')
21
  with gr.Column(scale=1.5):
22
  caption = gr.Text(label='Caption')
 
30
  ocr = gr.Image(label='OCR Detection')
31
 
32
  inputs = [image, model_name]
33
+ outputs = [caption, depth, edge, normals, segmentation, object_detection, ocr]
 
 
 
 
 
 
 
 
34
 
35
  paths = sorted(pathlib.Path('prismer/images').glob('*'))
36
  examples = [[path.as_posix(), 'prismer_base'] for path in paths]
prismer_model.py CHANGED
@@ -20,32 +20,22 @@ from model.prismer_caption import PrismerCaption
20
 
21
  def download_models() -> None:
22
  if not pathlib.Path('prismer/experts/expert_weights/').exists():
23
- subprocess.run(shlex.split(
24
- 'python download_checkpoints.py --download_experts=True'),
25
- cwd='prismer')
26
  model_names = [
27
- 'vqa_prismer_base',
28
- 'vqa_prismer_large',
29
- 'vqa_prismerz_base',
30
- 'vqa_prismerz_large',
31
- 'caption_prismerz_base',
32
- 'caption_prismerz_large',
33
  'caption_prismer_base',
34
  'caption_prismer_large',
35
  ]
36
  for model_name in model_names:
37
  if pathlib.Path(f'prismer/logging/{model_name}').exists():
38
  continue
39
- subprocess.run(shlex.split(
40
- f'python download_checkpoints.py --download_models={model_name}'),
41
- cwd='prismer')
42
 
43
 
44
  def build_deformable_conv() -> None:
45
- subprocess.run(
46
- shlex.split('sh make.sh'),
47
- cwd=
48
- 'prismer/experts/segmentation/mask2former/modeling/pixel_decoder/ops')
49
 
50
 
51
  def run_experts(image_path: str) -> tuple[str | None, ...]:
@@ -56,14 +46,7 @@ def run_experts(image_path: str) -> tuple[str | None, ...]:
56
  out_path = image_dir / 'image.jpg'
57
  cv2.imwrite(out_path.as_posix(), cv2.imread(image_path))
58
 
59
- expert_names = [
60
- 'depth',
61
- 'edge',
62
- 'normal',
63
- 'objdet',
64
- 'ocrdet',
65
- 'segmentation',
66
- ]
67
  for expert_name in expert_names:
68
  env = os.environ.copy()
69
  if 'PYTHONPATH' in env:
@@ -76,14 +59,7 @@ def run_experts(image_path: str) -> tuple[str | None, ...]:
76
  env=env,
77
  check=True)
78
 
79
- keys = [
80
- 'depth',
81
- 'edge',
82
- 'normal',
83
- 'seg_coco',
84
- 'obj_detection',
85
- 'ocr_detection',
86
- ]
87
  results = [
88
  pathlib.Path('prismer/helpers/labels') / key /
89
  'helpers/images/image.png' for key in keys
 
20
 
21
  def download_models() -> None:
22
  if not pathlib.Path('prismer/experts/expert_weights/').exists():
23
+ subprocess.run(shlex.split('python download_checkpoints.py --download_experts=True'), cwd='prismer')
24
+
 
25
  model_names = [
26
+ # 'vqa_prismer_base',
27
+ # 'vqa_prismer_large',
 
 
 
 
28
  'caption_prismer_base',
29
  'caption_prismer_large',
30
  ]
31
  for model_name in model_names:
32
  if pathlib.Path(f'prismer/logging/{model_name}').exists():
33
  continue
34
+ subprocess.run(shlex.split(f'python download_checkpoints.py --download_models={model_name}'), cwd='prismer')
 
 
35
 
36
 
37
  def build_deformable_conv() -> None:
38
+ subprocess.run(shlex.split('sh make.sh'), cwd='prismer/experts/segmentation/mask2former/modeling/pixel_decoder/ops')
 
 
 
39
 
40
 
41
  def run_experts(image_path: str) -> tuple[str | None, ...]:
 
46
  out_path = image_dir / 'image.jpg'
47
  cv2.imwrite(out_path.as_posix(), cv2.imread(image_path))
48
 
49
+ expert_names = ['depth', 'edge', 'normal', 'objdet', 'ocrdet', 'segmentation']
 
 
 
 
 
 
 
50
  for expert_name in expert_names:
51
  env = os.environ.copy()
52
  if 'PYTHONPATH' in env:
 
59
  env=env,
60
  check=True)
61
 
62
+ keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection']
 
 
 
 
 
 
 
63
  results = [
64
  pathlib.Path('prismer/helpers/labels') / key /
65
  'helpers/images/image.png' for key in keys