Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from mpl_toolkits.mplot3d import Axes3D | |
from gradio.components import Slider, Number, Image, Dataframe, Textbox, Checkbox, Radio, Label, Plot | |
import plotly.graph_objects as go | |
from aaaaa import get_datXYZ, plot_3d_points, genS | |
from scipy.stats import multivariate_normal | |
import PIL.Image | |
data_path = "reference_data/" | |
image_path = "images_configs/" | |
# To generate distorsed covariance matrixes | |
def genS(sds, R, distor=1): | |
# sds: standard deviations | |
# R: correlation matrix | |
# distor: level of distorsion | |
sdsdis = sds * distor | |
S = R * np.outer(sdsdis, sdsdis) | |
np.fill_diagonal(S, sdsdis ** 2) | |
return S | |
def show_hand_3d(gesto, attempt, distortion, mean, std): | |
datXYZ = get_datXYZ(gesto, data_path) | |
print(datXYZ.shape) | |
# print(datXYZ) | |
obsmean = np.mean(datXYZ, axis=0) | |
print(obsmean.shape) | |
# change to 63 shape | |
obsmean = obsmean.reshape(63) | |
obssd = np.std(datXYZ, axis=0) | |
print(obssd.shape) | |
obssd = obssd.reshape(63) | |
datXYZ_reshaped = datXYZ.reshape(datXYZ.shape[0], -1) | |
R = np.corrcoef(datXYZ_reshaped, rowvar=False) | |
Sberria = genS(obssd, R, distor=distortion) | |
z = multivariate_normal.rvs(mean=obsmean, cov=Sberria, size=1) | |
print(z.shape) | |
# print(z) | |
# add here the code to transform z into an array eith the same shape as datXYZ | |
first_row = datXYZ[attempt] # Extract the attempt row of the array | |
reshaped_array = first_row[:63].reshape(21, 3) # Reshape the first row to (21, 3) | |
# create noise from mean and std | |
noise = np.random.normal(mean, std, reshaped_array.shape) | |
reshaped_array_with_noise = reshaped_array + noise | |
distorted_hand = z.reshape(reshaped_array.shape) | |
hand_plot = plot_3d_points(reshaped_array, "Recorded Hand Configuration") | |
distorted_hand_plot = plot_3d_points(distorted_hand, "Distorted Hand Configuration (global)") | |
reshaped_array_with_noise_plot = plot_3d_points(reshaped_array_with_noise, "Distorted Hand Configuration (individual))") | |
image_filename = image_path + str(int(gesto)) + ".png" | |
print(image_filename) | |
image = PIL.Image.open(image_filename) | |
# image = Image(image_filename) # Load the corresponding image | |
return image, hand_plot, distorted_hand_plot, reshaped_array_with_noise_plot | |
inputs = [ | |
Number(value=42, precision=0, label="Number of Kind of Hand Configuration"), | |
Number(value=0, precision=0, label="Number of Concrete Hand Attempt"), | |
Number(value=1.0, label="Distortion Level"), | |
Number(value=0.0, label="Mean Distortion Level (one hand)"), | |
Number(value=0.006, label="Standard Deviation Distortion Level (one hand)") | |
] | |
outputs = [Image(type="pil"), Plot(), Plot(), Plot()] | |
title = "Hand 3D Visualization" | |
description = "Enter the gesture number (gesto) and the data path (data_path) to visualize the hand in 3D." | |
iface = gr.Interface(fn=show_hand_3d, | |
inputs=inputs, | |
outputs=outputs, | |
title=title, | |
description=description, | |
allow_flagging="never") | |
iface.launch() |