Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pickle | |
from datasets import load_dataset | |
from plaid.containers.sample import Sample | |
import numpy as np | |
import pyrender | |
from trimesh import Trimesh | |
import matplotlib as mpl | |
import matplotlib.cm as cm | |
from utils_inference import infer | |
import os | |
# switch to "osmesa" or "egl" before loading pyrender | |
os.environ["PYOPENGL_PLATFORM"] = "egl" | |
hf_dataset = load_dataset("PLAID-datasets/AirfRANS_remeshed", split="all_samples") | |
file = open('training_data.pkl', 'rb') | |
training_data = pickle.load(file) | |
file.close() | |
train_ids = hf_dataset.description['split']['ML4PhySim_Challenge_train'] | |
out_fields_names = hf_dataset.description['out_fields_names'] | |
in_scalars_names = hf_dataset.description['in_scalars_names'] | |
out_scalars_names = hf_dataset.description['out_scalars_names'] | |
nb_samples = len(hf_dataset) | |
# <h2><b><a href='https://arxiv.org/abs/2305.12871' target='_blank'><b>MMGP</b> demo on the <a href='https://huggingface.co/datasets/PLAID-datasets/AirfRANS_remeshed' target='_blank'><b>AirfRANS_remeshed dataset</b></b></h2> | |
# <a href='https://arxiv.org/abs/2305.12871' target='_blank'><b>MMGP paper</b>, | |
_HEADER_ = ''' | |
<h2><b>MMGP demo on the <a href='https://huggingface.co/datasets/PLAID-datasets/AirfRANS_remeshed' target='_blank'><b>AirfRANS_remeshed dataset</b></b></h2> | |
''' | |
_HEADER_2 = ''' | |
The model is already trained. The morphing is the same as the one used in the [MMGP paper](https://arxiv.org/abs/2305.12871), | |
but is much less involved than the one used in the winning entry of the [ML4PhySim competition](https://www.codabench.org/competitions/1534/). | |
The training set has 103 samples and is the one used in this competition (some evaluations are out-of-distribution). | |
The inference takes approx 5 seconds, and is done from scratch (no precomputation is used during the inference when evaluating samples). | |
This means that the morphing and the finite element interpolations are re-done at each evaluation. | |
After choosing a sample id, please change the field name in the dropdown menu to update the visualization. | |
''' | |
def round_num(num)->str: | |
return '%s' % float('%.3g' % num) | |
def compute_inference(sample_id_str): | |
sample_id = int(sample_id_str) | |
sample_ = hf_dataset[sample_id]["sample"] | |
plaid_sample = Sample.model_validate(pickle.loads(sample_)) | |
prediction = infer(hf_dataset, sample_id, training_data) | |
reference = {fieldn:plaid_sample.get_field(fieldn) for fieldn in out_fields_names} | |
nodes = plaid_sample.get_nodes() | |
if nodes.shape[1] == 2: | |
nodes__ = np.zeros((nodes.shape[0],nodes.shape[1]+1)) | |
nodes__[:,:-1] = nodes | |
nodes = nodes__ | |
triangles = plaid_sample.get_elements()['TRI_3'] | |
trimesh = Trimesh(vertices = nodes, faces = triangles) | |
file = open('computed_inference.pkl', 'wb') | |
pickle.dump([trimesh, reference, prediction], file) | |
file.close() | |
str__ = f"Training sample {sample_id_str}" | |
if sample_id in train_ids: | |
str__ += " (in the training set)\n\n" | |
else: | |
str__ += " (not in the training set)\n\n" | |
str__ += str(plaid_sample)+"\n" | |
if len(hf_dataset.description['in_scalars_names'])>0: | |
str__ += "\nInput scalars:\n" | |
for sname in hf_dataset.description['in_scalars_names']: | |
str__ += f"- {sname}: {round_num(plaid_sample.get_scalar(sname))}\n" | |
str__ += f"\nNumber of nodes in the mesh: {nodes.shape[0]}" | |
return str__ | |
def show_pred(fieldn): | |
file = open('computed_inference.pkl', 'rb') | |
data = pickle.load(file) | |
file.close() | |
trimesh, reference, prediction = data[0], data[1], data[2] | |
ref = reference[fieldn] | |
pred = prediction[fieldn] | |
norm = mpl.colors.Normalize(vmin=np.min(ref), vmax=np.max(ref)) | |
cmap = cm.seismic#cm.coolwarm | |
m = cm.ScalarMappable(norm=norm, cmap=cmap) | |
vertex_colors = m.to_rgba(pred)[:,:3] | |
trimesh.visual.vertex_colors = vertex_colors | |
mesh = pyrender.Mesh.from_trimesh(trimesh, smooth=False) | |
# compose scene | |
scene = pyrender.Scene(ambient_light=[.1, .1, .3], bg_color=[0, 0, 0]) | |
camera = pyrender.PerspectiveCamera( yfov=np.pi / 3.0) | |
light = pyrender.DirectionalLight(color=[1,1,1], intensity=1000.) | |
scene.add(mesh, pose= np.eye(4)) | |
scene.add(light, pose= np.eye(4)) | |
scene.add(camera, pose=[[ 1, 0, 0, 1], | |
[ 0, 1, 0, 0], | |
[ 0, 0, 1, 6], | |
[ 0, 0, 0, 1]]) | |
# render scene | |
r = pyrender.OffscreenRenderer(1024, 1024) | |
color, _ = r.render(scene) | |
return color | |
def show_ref(fieldn): | |
file = open('computed_inference.pkl', 'rb') | |
data = pickle.load(file) | |
file.close() | |
trimesh, reference, prediction = data[0], data[1], data[2] | |
ref = reference[fieldn] | |
norm = mpl.colors.Normalize(vmin=np.min(ref), vmax=np.max(ref)) | |
cmap = cm.seismic#cm.coolwarm | |
m = cm.ScalarMappable(norm=norm, cmap=cmap) | |
vertex_colors = m.to_rgba(ref)[:,:3] | |
trimesh.visual.vertex_colors = vertex_colors | |
mesh = pyrender.Mesh.from_trimesh(trimesh, smooth=False) | |
# compose scene | |
scene = pyrender.Scene(ambient_light=[.1, .1, .3], bg_color=[0, 0, 0]) | |
camera = pyrender.PerspectiveCamera( yfov=np.pi / 3.0) | |
light = pyrender.DirectionalLight(color=[1,1,1], intensity=1000.) | |
scene.add(mesh, pose= np.eye(4)) | |
scene.add(light, pose= np.eye(4)) | |
scene.add(camera, pose=[[ 1, 0, 0, 1], | |
[ 0, 1, 0, 0], | |
[ 0, 0, 1, 6], | |
[ 0, 0, 0, 1]]) | |
# render scene | |
r = pyrender.OffscreenRenderer(1024, 1024) | |
color, _ = r.render(scene) | |
return color | |
def show_err(fieldn): | |
file = open('computed_inference.pkl', 'rb') | |
data = pickle.load(file) | |
file.close() | |
trimesh, reference, prediction = data[0], data[1], data[2] | |
ref = reference[fieldn] | |
pred = prediction[fieldn] | |
norm = mpl.colors.Normalize(vmin=np.min(ref), vmax=np.max(ref)) | |
cmap = cm.seismic#cm.coolwarm | |
m = cm.ScalarMappable(norm=norm, cmap=cmap) | |
vertex_colors = m.to_rgba(pred-ref)[:,:3] | |
trimesh.visual.vertex_colors = vertex_colors | |
mesh = pyrender.Mesh.from_trimesh(trimesh, smooth=False) | |
# compose scene | |
scene = pyrender.Scene(ambient_light=[.1, .1, .3], bg_color=[0, 0, 0]) | |
camera = pyrender.PerspectiveCamera( yfov=np.pi / 3.0) | |
light = pyrender.DirectionalLight(color=[1,1,1], intensity=1000.) | |
scene.add(mesh, pose= np.eye(4)) | |
scene.add(light, pose= np.eye(4)) | |
scene.add(camera, pose=[[ 1, 0, 0, 1], | |
[ 0, 1, 0, 0], | |
[ 0, 0, 1, 6], | |
[ 0, 0, 0, 1]]) | |
# render scene | |
r = pyrender.OffscreenRenderer(1024, 1024) | |
color, _ = r.render(scene) | |
return color | |
if __name__ == "__main__": | |
with gr.Blocks() as demo: | |
# trimesh, reference, prediction = compute_inference(0) | |
gr.Markdown(_HEADER_) | |
gr.Markdown(_HEADER_2) | |
with gr.Row(variant="panel"): | |
with gr.Column(): | |
d1 = gr.Slider(0, nb_samples-1, value=0, label="Training sample id", info="Choose between 0 and "+str(nb_samples-1)) | |
# output1 = gr.Text(label="Inference status") | |
output4 = gr.Text(label="Information on sample") | |
output5 = gr.Image(label="Error") | |
with gr.Column(): | |
d2 = gr.Dropdown(out_fields_names, value=out_fields_names[0], label="Field name") | |
output2 = gr.Image(label="Reference") | |
output3 = gr.Image(label="MMGP prediction") | |
# d1.input(compute_inference, [d1], [output1, output4]) | |
d1.input(compute_inference, [d1], [output4]) | |
d2.input(show_ref, [d2], [output2]) | |
d2.input(show_pred, [d2], [output3]) | |
d2.input(show_err, [d2], [output5]) | |
demo.launch() | |