aZhaoT's picture
add total collection period display
17bf45d
data_path = "./data/"
import pandas as pd
import datasets
# load the csv into motion_capture_data
import streamlit as st
dataset_names = ['Fold_towels', 'Pipette', 'Take_the_item', 'Twist_the_tube']
def load_data():
print("Loading data")
# load the motion capture data
all_datasets = {}
for name in dataset_names:
print("Loading dataset: ", name)
all_datasets[name] = pd.DataFrame(datasets.load_dataset("cyberorigin/"+name)['train'])
total_period = 0
for dataset in all_datasets.values():
# dataset["timestamp"] = dataset["timestamp"].astype(float)
traj_period = dataset["timestamp"].iloc[-1] - dataset["timestamp"].iloc[0]
total_period += traj_period
return all_datasets, total_period
@st.fragment
def visualize(data):
dataset_option = st.selectbox(
'Select a dataset:',
dataset_names
)
# create a streamlit app that displays the motion capture data
# and the video data
st.video("https://huggingface.co/datasets/cyberorigin/"+dataset_option+"/resolve/main/Video/video.mp4")
motion_capture_data = data[dataset_option]
body_part_names = ['Left Shoulder',
'Right Upper Arm',
'Left Lower Leg',
'Spine1',
'Right Upper Leg',
'Spine3',
'Right Lower Arm',
'Left Foot',
'Right Lower Leg',
'Right Shoulder',
'Left Hand',
'Left Upper Leg',
'Right Foot',
'Spine',
'Spine2',
'Left Lower Arm',
'Left Toe',
'Neck',
'Right Hand',
'Right Toe',
'Head',
'Left Upper Arm',
'Hips',]
motion_capture_x = motion_capture_data[[body_part_name+"_x" for body_part_name in body_part_names]]
motion_capture_y = motion_capture_data[[body_part_name+"_y" for body_part_name in body_part_names]]
motion_capture_z = motion_capture_data[[body_part_name+"_z" for body_part_name in body_part_names]]
import plotly.graph_objects as go
import numpy as np
# Sample Data Preparation
data = []
times = motion_capture_data["timestamp"]
frames = [go.Frame(
data=[
go.Scatter3d(
x=motion_capture_x.iloc[k],
y=motion_capture_y.iloc[k],
z=motion_capture_z.iloc[k],
mode='markers',
marker=dict(size=5, color='blue')
)
],
name=str(k)
) for k in range(len(times))]
# Create the initial scatter plot
initial_scatter = go.Scatter3d(
x=motion_capture_x.iloc[0],
y=motion_capture_y.iloc[0],
z=motion_capture_z.iloc[0],
mode='markers',
marker=dict(size=5, color='blue')
)
# Create the layout with slider
layout = go.Layout(
title='Motion Capture Visualization',
updatemenus=[{
'buttons': [
{
'args': [None, {'frame': {'duration': 1, 'redraw': True}, 'fromcurrent': True}],
'label': 'Play',
'method': 'animate'
},
{
'args': [[None], {'frame': {'duration': 0, 'redraw': True}, 'mode': 'immediate', 'transition': {'duration': 0}}],
'label': 'Pause',
'method': 'animate'
}
],
'direction': 'left',
'pad': {'r': 10, 't': 87},
'showactive': True,
'type': 'buttons',
'x': 0.1,
'xanchor': 'right',
'y': 0,
'yanchor': 'top'
}],
sliders=[{
'active': 0,
'steps': [{
'label': str(k),
'method': 'animate',
'args': [
[str(k)],
{'mode': 'immediate', 'frame': {'duration': 300, 'redraw': True}, 'transition': {'duration': len(times)/30}}
]
} for k in range(len(times))],
'currentvalue': {
'prefix': 'Time: ',
'visible': True,
'xanchor': 'right'
},
'pad': {'b': 10},
'len': 0.9,
'x': 0.1,
'y': 0,
}]
)
# Create the figure
fig = go.Figure(data=[initial_scatter], frames=frames, layout=layout)
# Display the figure in the streamlit app
st.plotly_chart(fig)
st.title("CyberOrigin Data Visualization")
data, period = load_data()
# display the total period of the data up to 2 decimal places
st.write("Total period of data: ", round(period, 2), " seconds")
visualize(data)