Samuel CHAINEAU
commited on
Commit
•
382e94b
1
Parent(s):
ce48a3d
QBGPT
Browse files- .streamlit/config.toml +2 -0
- pages.py +3 -18
- tools.py +78 -0
.streamlit/config.toml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[theme]
|
2 |
+
base="dark"
|
pages.py
CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
|
|
2 |
import plotly.express as px
|
3 |
import pandas as pd
|
4 |
import numpy as np
|
5 |
-
from tools import generator
|
6 |
from PIL import Image
|
7 |
|
8 |
def set_app_title_and_logo():
|
@@ -63,25 +63,10 @@ def qb_gpt_page(ref_df, ref, tokenizer, model):
|
|
63 |
step1_true = QB_gen.prepare_for_plot(decoded_true)
|
64 |
plot_true = pd.DataFrame(step1_true)
|
65 |
|
66 |
-
fig_gen =
|
67 |
-
text="position_ids", title="Generated players' trajectories Over Time", line_shape="linear",
|
68 |
-
range_x=[0, 140], range_y=[0, 60], # Set X and Y axis ranges
|
69 |
-
render_mode="svg") # Render mode for smoother lines
|
70 |
-
|
71 |
-
# Customize the appearance of the plot
|
72 |
-
fig_gen.update_traces(marker=dict(size=10), selector=dict(mode='lines'))
|
73 |
-
fig_gen.update_layout(width=800, height=600)
|
74 |
st.plotly_chart(fig_gen)
|
75 |
|
76 |
-
fig_true =
|
77 |
-
text="position_ids", title="True players' trajectories Over Time",
|
78 |
-
range_x=[0, 140], range_y=[0, 60], # Set X and Y axis ranges
|
79 |
-
line_shape="linear", # Draw lines connecting points
|
80 |
-
render_mode="svg") # Render mode for smoother lines
|
81 |
-
|
82 |
-
# Customize the appearance of the plot
|
83 |
-
fig_true.update_traces(marker=dict(size=10), selector=dict(mode='lines'))
|
84 |
-
fig_true.update_layout(width=800, height=600)
|
85 |
st.plotly_chart(fig_true)
|
86 |
|
87 |
|
|
|
2 |
import plotly.express as px
|
3 |
import pandas as pd
|
4 |
import numpy as np
|
5 |
+
from tools import generator, get_plot
|
6 |
from PIL import Image
|
7 |
|
8 |
def set_app_title_and_logo():
|
|
|
63 |
step1_true = QB_gen.prepare_for_plot(decoded_true)
|
64 |
plot_true = pd.DataFrame(step1_true)
|
65 |
|
66 |
+
fig_gen = get_plot(plot, frames, "Generated")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
st.plotly_chart(fig_gen)
|
68 |
|
69 |
+
fig_true = get_plot(plot_true, frames, "True")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
st.plotly_chart(fig_true)
|
71 |
|
72 |
|
tools.py
CHANGED
@@ -2,6 +2,8 @@ import polars as pl
|
|
2 |
import numpy as np
|
3 |
import tensorflow as tf
|
4 |
import pandas as pd
|
|
|
|
|
5 |
|
6 |
class tokenizer:
|
7 |
def __init__(self,
|
@@ -373,3 +375,79 @@ class generator:
|
|
373 |
cutted["ids"] = [[i for e in range(len(cutted["input_ids"][i]))] for i in range(len(cutted["input_ids"]))]
|
374 |
merged = self.merge_cuts(cutted)
|
375 |
return merged
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import numpy as np
|
3 |
import tensorflow as tf
|
4 |
import pandas as pd
|
5 |
+
import plotly.graph_objects as go
|
6 |
+
|
7 |
|
8 |
class tokenizer:
|
9 |
def __init__(self,
|
|
|
375 |
cutted["ids"] = [[i for e in range(len(cutted["input_ids"][i]))] for i in range(len(cutted["input_ids"]))]
|
376 |
merged = self.merge_cuts(cutted)
|
377 |
return merged
|
378 |
+
|
379 |
+
def get_plot(df, n_frames, name):
|
380 |
+
fig = go.Figure(
|
381 |
+
layout=go.Layout(
|
382 |
+
updatemenus=[dict(type="buttons", direction="right", x=0.9, y=1.16), ],
|
383 |
+
xaxis=dict(range=[0, 120],
|
384 |
+
autorange=False, tickwidth=2,
|
385 |
+
title_text="X"),
|
386 |
+
yaxis=dict(range=[0, 60],
|
387 |
+
autorange=False,
|
388 |
+
title_text="Y")
|
389 |
+
))
|
390 |
+
|
391 |
+
# Add traces
|
392 |
+
i = 1
|
393 |
+
frames = {i: [] for i in df["pos_ids"].unique() if i !=0}
|
394 |
+
|
395 |
+
for id in df["ids"].unique():
|
396 |
+
spec = df[df["ids"] == id].reset_index(drop = True)
|
397 |
+
fig.add_trace(
|
398 |
+
go.Scatter(x=spec.input_ids_x[:i],
|
399 |
+
y=spec.input_ids_y[:i],
|
400 |
+
name= spec.position_ids.unique()[0],
|
401 |
+
text= spec.position_ids.unique()[0],
|
402 |
+
visible=True,
|
403 |
+
line=dict(color="#f47738", dash="solid")))
|
404 |
+
|
405 |
+
for k in range(i, spec.shape[0]):
|
406 |
+
current_frame = spec["pos_ids"][k]
|
407 |
+
frames[current_frame].append(go.Scatter(x=spec.input_ids_x[:k], y=spec.input_ids_y[:k]))
|
408 |
+
|
409 |
+
frames = list(frames.values())
|
410 |
+
frames = [go.Frame(data = v) for v in frames]
|
411 |
+
|
412 |
+
|
413 |
+
# Animation
|
414 |
+
fig.update(frames=frames)
|
415 |
+
|
416 |
+
fig.update_xaxes(ticks="outside", tickwidth=2, tickcolor='white', ticklen=10)
|
417 |
+
fig.update_yaxes(ticks="outside", tickwidth=2, tickcolor='white', ticklen=1)
|
418 |
+
fig.update_layout(yaxis_tickformat=',')
|
419 |
+
fig.update_layout(legend=dict(x=0, y=1.1), legend_orientation="h")
|
420 |
+
|
421 |
+
# Buttons
|
422 |
+
fig.update_layout(title=f"{name} play",
|
423 |
+
xaxis_title="X",
|
424 |
+
yaxis_title="Y",
|
425 |
+
legend_title="Legend Title",
|
426 |
+
showlegend=False,
|
427 |
+
font=dict(
|
428 |
+
family="Arial",
|
429 |
+
size=14
|
430 |
+
),
|
431 |
+
hovermode="x",
|
432 |
+
updatemenus=[
|
433 |
+
dict(
|
434 |
+
buttons=list(
|
435 |
+
[
|
436 |
+
dict(label="Play",
|
437 |
+
method="animate",
|
438 |
+
args=[None, {"frame": {"duration": n_frames}}])
|
439 |
+
]
|
440 |
+
),
|
441 |
+
type = "buttons",
|
442 |
+
direction="right",
|
443 |
+
pad={"r": 50, "t": 50},
|
444 |
+
showactive=False,
|
445 |
+
x=0.5,
|
446 |
+
yanchor="top")
|
447 |
+
])
|
448 |
+
|
449 |
+
fig.update_layout(template='plotly_dark'
|
450 |
+
)
|
451 |
+
|
452 |
+
fig.update_layout(width=1200, height=600)
|
453 |
+
return fig
|