Slep commited on
Commit
7fe54e9
1 Parent(s): 014beea

Clean and add model link

Browse files
Files changed (2) hide show
  1. app.py +33 -12
  2. src/custom_js.py +0 -27
app.py CHANGED
@@ -37,9 +37,11 @@ examples = [
37
  ["examples/1811.jpg", "Bags"],
38
  ]
39
 
 
40
  @torch.inference_mode()
41
  def retrieval(image, category):
42
- if image is None or category is None: return
 
43
 
44
  q_emb = m(tfs(image).unsqueeze(0), torch.tensor([category]))
45
 
@@ -48,42 +50,61 @@ def retrieval(image, category):
48
  imgs = [process_img(idx, gal_imgs) for idx in r[1][0]]
49
 
50
  html = [make_img_html(i) for i in imgs]
51
- html += ["<p></p>"] # Avoid Gradio's last-child{margin-bottom:0!important;}
52
 
53
  return "\n".join(html)
54
 
55
 
56
-
57
  JavaScriptLoader("src/custom_functions.js")
58
  with gr.Blocks(css="src/style.css") as demo:
59
  with gr.Column():
60
- gr.Markdown("""
 
61
  # Conditional ViT Demo
62
  [[`Paper`](https://arxiv.org/abs/2306.02928)]
63
  [[`Code`](https://github.com/Simon-Lepage/CondViT-LRVSF)]
64
  [[`Dataset`](https://huggingface.co/datasets/Slep/LAION-RVS-Fashion)]
 
 
65
 
66
  *Running on 2 vCPU, 16Go RAM.*
67
 
68
  - **Model :** Categorical CondViT-B/16
69
  - **Gallery :** 93K images.
70
- """)
 
71
 
72
  # Input section
73
  with gr.Row():
74
  img = gr.Image(label="Query Image", type="pil", elem_id="query_img")
75
  with gr.Column():
76
- cat = gr.Dropdown(choices = categories, label="Category", value="Upper Body", type='index', elem_id="dropdown")
 
 
 
 
 
 
77
  submit = gr.Button("Submit")
78
-
79
  # Examples
80
- gr.Examples(examples, inputs=[img, cat], fn=retrieval, elem_id = "preset_examples", examples_per_page=100)
81
- gr.HTML(value=ExamplesHandler(examples).to_html(), label = "examples", elem_id = "html_examples")
82
-
 
 
 
 
 
 
 
 
 
 
83
  # Outputs
84
  gr.Markdown("# Retrieved Items")
85
- out = gr.HTML(label="Results", elem_id = "html_output")
86
 
87
  submit.click(fn=retrieval, inputs=[img, cat], outputs=out)
88
 
89
- demo.launch()
 
37
  ["examples/1811.jpg", "Bags"],
38
  ]
39
 
40
+
41
  @torch.inference_mode()
42
  def retrieval(image, category):
43
+ if image is None or category is None:
44
+ return
45
 
46
  q_emb = m(tfs(image).unsqueeze(0), torch.tensor([category]))
47
 
 
50
  imgs = [process_img(idx, gal_imgs) for idx in r[1][0]]
51
 
52
  html = [make_img_html(i) for i in imgs]
53
+ html += ["<p></p>"] # Avoid Gradio's last-child{margin-bottom:0!important;}
54
 
55
  return "\n".join(html)
56
 
57
 
 
58
  JavaScriptLoader("src/custom_functions.js")
59
  with gr.Blocks(css="src/style.css") as demo:
60
  with gr.Column():
61
+ gr.Markdown(
62
+ """
63
  # Conditional ViT Demo
64
  [[`Paper`](https://arxiv.org/abs/2306.02928)]
65
  [[`Code`](https://github.com/Simon-Lepage/CondViT-LRVSF)]
66
  [[`Dataset`](https://huggingface.co/datasets/Slep/LAION-RVS-Fashion)]
67
+ [[`Model`](https://huggingface.co/Slep/CondViT-B16-cat)]
68
+
69
 
70
  *Running on 2 vCPU, 16Go RAM.*
71
 
72
  - **Model :** Categorical CondViT-B/16
73
  - **Gallery :** 93K images.
74
+ """
75
+ )
76
 
77
  # Input section
78
  with gr.Row():
79
  img = gr.Image(label="Query Image", type="pil", elem_id="query_img")
80
  with gr.Column():
81
+ cat = gr.Dropdown(
82
+ choices=categories,
83
+ label="Category",
84
+ value="Upper Body",
85
+ type="index",
86
+ elem_id="dropdown",
87
+ )
88
  submit = gr.Button("Submit")
89
+
90
  # Examples
91
+ gr.Examples(
92
+ examples,
93
+ inputs=[img, cat],
94
+ fn=retrieval,
95
+ elem_id="preset_examples",
96
+ examples_per_page=100,
97
+ )
98
+ gr.HTML(
99
+ value=ExamplesHandler(examples).to_html(),
100
+ label="examples",
101
+ elem_id="html_examples",
102
+ )
103
+
104
  # Outputs
105
  gr.Markdown("# Retrieved Items")
106
+ out = gr.HTML(label="Results", elem_id="html_output")
107
 
108
  submit.click(fn=retrieval, inputs=[img, cat], outputs=out)
109
 
110
+ demo.launch()
src/custom_js.py DELETED
@@ -1,27 +0,0 @@
1
- import gradio
2
-
3
- # Adapted from https://github.com/gradio-app/gradio/discussions/2932
4
-
5
- class JavaScriptLoader:
6
- def __init__(self, target):
7
- #Copy the template response
8
- self.original_template = gradio.routes.templates.TemplateResponse
9
- #Prep the js files
10
- self.load_js(target)
11
- #reassign the template response to your method, so gradio calls your method instead
12
- gradio.routes.templates.TemplateResponse = self.template_response
13
-
14
- def load_js(self, target):
15
- with open(target, 'r', encoding="utf-8") as file:
16
- self.loaded_script = f"<script>\n{file.read()}\n</script>"
17
-
18
- def template_response(self, *args, **kwargs):
19
- """Once gradio calls your method, you call the original, you modify it to include
20
- your scripts and you return the modified version
21
- """
22
- response = self.original_template(*args, **kwargs)
23
- response.body = response.body.replace(
24
- '</head>'.encode('utf-8'), self.loaded_script + "\n</head>".encode("utf-8")
25
- )
26
- response.init_headers()
27
- return response