jmercat commited on
Commit
9acc98b
β€’
1 Parent(s): 0bc5996

correction, sdk version, plot size

Browse files
README.md CHANGED
@@ -4,6 +4,7 @@ emoji: πŸš™
4
  colorFrom: red
5
  colorTo: gray
6
  sdk: gradio
 
7
  app_file: app.py
8
  pinned: false
9
  language:
 
4
  colorFrom: red
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 3.8
8
  app_file: app.py
9
  pinned: false
10
  language:
scripts/scripts_utils/plotly_interface.py CHANGED
@@ -107,6 +107,8 @@ def build_data(
107
  mask_pred = numpy_data["mask_pred"][0]
108
  map_data = numpy_data["map_data"][0]
109
  mask_map = numpy_data["mask_map"][0]
 
 
110
 
111
  data_x = get_scatter_data(
112
  x,
@@ -119,14 +121,14 @@ def build_data(
119
  x=x[0:1, -1:],
120
  mask_x=mask_x[0:1, -1:],
121
  mode="markers",
122
- marker=dict(color="blue", size=20, opacity=0.5),
123
  name="Ego",
124
  )
125
  agent_present = get_scatter_data(
126
  x=x[1:2, -1:],
127
  mask_x=mask_x[1:2, -1:],
128
  mode="markers",
129
- marker=dict(color="green", size=20, opacity=0.5),
130
  name="Agent",
131
  )
132
 
@@ -161,7 +163,7 @@ def build_data(
161
  pred[:, i, -1:],
162
  mask_pred[:, -1:],
163
  mode="markers",
164
- marker=dict(color="red", size=10, opacity=0.5, symbol="x"),
165
  name="Forecast end",
166
  )
167
  forecasts_end += forecast_end
@@ -177,7 +179,7 @@ def build_data(
177
  y=x[mask_x[:, k], k, 1],
178
  mode="markers",
179
  opacity=animation_opacity,
180
- marker=dict(color="black", size=15),
181
  showlegend=False,
182
  ),
183
  go.Scatter(
@@ -185,7 +187,7 @@ def build_data(
185
  y=x[0:1, k, 1],
186
  mode="markers",
187
  opacity=animation_opacity,
188
- marker=dict(color="blue", size=15),
189
  showlegend=False,
190
  ),
191
  ]
@@ -200,14 +202,14 @@ def build_data(
200
  y=y[1:2][mask_y[1:2, k], k, 1],
201
  mode="markers",
202
  opacity=animation_opacity,
203
- marker=dict(color="green", size=15),
204
  )
205
  cur_gt_future_data = go.Scatter(
206
  x=y[2:][mask_y[2:, k], k, 0],
207
  y=y[2:][mask_y[2:, k], k, 1],
208
  mode="markers",
209
  opacity=animation_opacity,
210
- marker=dict(color="black", size=15),
211
  )
212
  cur_pred_data = []
213
  for i in range(n_samples):
@@ -217,7 +219,7 @@ def build_data(
217
  y=pred[mask_pred[:, k], i, k, 1],
218
  mode="markers",
219
  opacity=animation_opacity,
220
- marker=dict(color="red", size=15),
221
  showlegend=False,
222
  )
223
  )
@@ -226,7 +228,7 @@ def build_data(
226
  y=y[0:1, k, 1],
227
  mode="markers",
228
  opacity=animation_opacity,
229
- marker=dict(color="blue", size=15),
230
  )
231
  cur_data = [cur_gt_agent_data, cur_gt_future_data, *cur_pred_data, cur_ego_data]
232
  frame = go.Frame(data=cur_data)
@@ -262,8 +264,8 @@ def prediction_plot(
262
  ),
263
  title_text="Road Scene",
264
  hovermode="closest",
265
- width=600,
266
- height=300,
267
  updatemenus=[
268
  dict(
269
  type="buttons",
@@ -274,7 +276,6 @@ def prediction_plot(
274
  args=[
275
  None,
276
  dict(
277
- transition=dict(duration=100),
278
  frame=dict(duration=100, redraw=False),
279
  mode="immediate",
280
  fromcurrent=True,
 
107
  mask_pred = numpy_data["mask_pred"][0]
108
  map_data = numpy_data["map_data"][0]
109
  mask_map = numpy_data["mask_map"][0]
110
+
111
+ marker_size = 12
112
 
113
  data_x = get_scatter_data(
114
  x,
 
121
  x=x[0:1, -1:],
122
  mask_x=mask_x[0:1, -1:],
123
  mode="markers",
124
+ marker=dict(color="blue", size=marker_size, opacity=0.5),
125
  name="Ego",
126
  )
127
  agent_present = get_scatter_data(
128
  x=x[1:2, -1:],
129
  mask_x=mask_x[1:2, -1:],
130
  mode="markers",
131
+ marker=dict(color="green", size=marker_size, opacity=0.5),
132
  name="Agent",
133
  )
134
 
 
163
  pred[:, i, -1:],
164
  mask_pred[:, -1:],
165
  mode="markers",
166
+ marker=dict(color="red", size=marker_size/2, opacity=0.5, symbol="x"),
167
  name="Forecast end",
168
  )
169
  forecasts_end += forecast_end
 
179
  y=x[mask_x[:, k], k, 1],
180
  mode="markers",
181
  opacity=animation_opacity,
182
+ marker=dict(color="black", size=marker_size),
183
  showlegend=False,
184
  ),
185
  go.Scatter(
 
187
  y=x[0:1, k, 1],
188
  mode="markers",
189
  opacity=animation_opacity,
190
+ marker=dict(color="blue", size=marker_size),
191
  showlegend=False,
192
  ),
193
  ]
 
202
  y=y[1:2][mask_y[1:2, k], k, 1],
203
  mode="markers",
204
  opacity=animation_opacity,
205
+ marker=dict(color="green", size=marker_size),
206
  )
207
  cur_gt_future_data = go.Scatter(
208
  x=y[2:][mask_y[2:, k], k, 0],
209
  y=y[2:][mask_y[2:, k], k, 1],
210
  mode="markers",
211
  opacity=animation_opacity,
212
+ marker=dict(color="black", size=marker_size),
213
  )
214
  cur_pred_data = []
215
  for i in range(n_samples):
 
219
  y=pred[mask_pred[:, k], i, k, 1],
220
  mode="markers",
221
  opacity=animation_opacity,
222
+ marker=dict(color="red", size=marker_size),
223
  showlegend=False,
224
  )
225
  )
 
228
  y=y[0:1, k, 1],
229
  mode="markers",
230
  opacity=animation_opacity,
231
+ marker=dict(color="blue", size=marker_size),
232
  )
233
  cur_data = [cur_gt_agent_data, cur_gt_future_data, *cur_pred_data, cur_ego_data]
234
  frame = go.Frame(data=cur_data)
 
264
  ),
265
  title_text="Road Scene",
266
  hovermode="closest",
267
+ width=800,
268
+ height=400,
269
  updatemenus=[
270
  dict(
271
  type="buttons",
 
276
  args=[
277
  None,
278
  dict(
 
279
  frame=dict(duration=100, redraw=False),
280
  mode="immediate",
281
  fromcurrent=True,