Spaces:
Running
Running
update localization
Browse files- localization.py +8 -8
localization.py
CHANGED
@@ -127,22 +127,22 @@ def get_heatmap(image_url, text, pixel_size=10, iterations=3):
|
|
127 |
# if full_sim is None:
|
128 |
# full_sim = sim
|
129 |
# sim = jax.nn.relu(sim - full_sim)
|
130 |
-
emb = jnp.expand_dims(e, axis=0)
|
131 |
|
132 |
if v:
|
133 |
vm = jnp.any(m, axis=0)
|
134 |
-
vertical_scores = vertical_scores + (emb * vm)
|
135 |
-
vertical_masks = vertical_masks + vm
|
136 |
if h:
|
137 |
hm = jnp.any(m, axis=1)
|
138 |
-
horizontal_scores = horizontal_scores + (emb * hm)
|
139 |
-
horizontal_masks = horizontal_masks + hm
|
140 |
|
141 |
|
142 |
embs_1 = jnp.expand_dims((vertical_scores), axis=0) * jnp.expand_dims(jnp.abs(horizontal_scores), axis=1)
|
143 |
embs_2 = jnp.expand_dims(jnp.abs(vertical_scores), axis=0) * jnp.expand_dims((horizontal_scores), axis=1)
|
144 |
full_embs = jnp.minimum(embs_1, embs_2)
|
145 |
-
mask_sum = jnp.expand_dims(vertical_masks, axis=0) * jnp.expand_dims(horizontal_masks, axis=1)
|
146 |
full_embs = (full_embs / mask_sum)
|
147 |
|
148 |
orig_shape = full_embs.shape
|
@@ -196,9 +196,9 @@ def app():
|
|
196 |
col1, col2 = st.columns([0.75, 0.25])
|
197 |
|
198 |
with col2:
|
199 |
-
pixel_size = st.selectbox("Pixel Size", options=range(
|
200 |
|
201 |
-
iterations = st.selectbox("Refinement Steps", options=range(
|
202 |
|
203 |
compute = st.button("LOCATE")
|
204 |
|
|
|
127 |
# if full_sim is None:
|
128 |
# full_sim = sim
|
129 |
# sim = jax.nn.relu(sim - full_sim)
|
130 |
+
emb = jnp.expand_dims(e, axis=0) #* n
|
131 |
|
132 |
if v:
|
133 |
vm = jnp.any(m, axis=0)
|
134 |
+
vertical_scores = vertical_scores + (emb * vm) / jnp.mean(vm)
|
135 |
+
vertical_masks = vertical_masks + vm / jnp.mean(vm)
|
136 |
if h:
|
137 |
hm = jnp.any(m, axis=1)
|
138 |
+
horizontal_scores = horizontal_scores + (emb * hm) / jnp.mean(hm)
|
139 |
+
horizontal_masks = horizontal_masks + hm / jnp.mean(hm)
|
140 |
|
141 |
|
142 |
embs_1 = jnp.expand_dims((vertical_scores), axis=0) * jnp.expand_dims(jnp.abs(horizontal_scores), axis=1)
|
143 |
embs_2 = jnp.expand_dims(jnp.abs(vertical_scores), axis=0) * jnp.expand_dims((horizontal_scores), axis=1)
|
144 |
full_embs = jnp.minimum(embs_1, embs_2)
|
145 |
+
mask_sum = jnp.expand_dims(vertical_masks + 1, axis=0) * jnp.expand_dims(horizontal_masks + 1, axis=1)
|
146 |
full_embs = (full_embs / mask_sum)
|
147 |
|
148 |
orig_shape = full_embs.shape
|
|
|
196 |
col1, col2 = st.columns([0.75, 0.25])
|
197 |
|
198 |
with col2:
|
199 |
+
pixel_size = st.selectbox("Pixel Size", options=range(5, 26, 5), index=3)
|
200 |
|
201 |
+
iterations = st.selectbox("Refinement Steps", options=range(0, 6, 1), index=0)
|
202 |
|
203 |
compute = st.button("LOCATE")
|
204 |
|