MMGP_demo / app.py
fabiencasenave's picture
initial commit
f4d7da3
raw
history blame contribute delete
No virus
8.01 kB
import gradio as gr
import pickle
from datasets import load_dataset
from plaid.containers.sample import Sample
import numpy as np
import pyrender
from trimesh import Trimesh
import matplotlib as mpl
import matplotlib.cm as cm
from utils_inference import infer
import os
# switch to "osmesa" or "egl" before loading pyrender
os.environ["PYOPENGL_PLATFORM"] = "egl"
hf_dataset = load_dataset("PLAID-datasets/AirfRANS_remeshed", split="all_samples")
file = open('training_data.pkl', 'rb')
training_data = pickle.load(file)
file.close()
train_ids = hf_dataset.description['split']['ML4PhySim_Challenge_train']
out_fields_names = hf_dataset.description['out_fields_names']
in_scalars_names = hf_dataset.description['in_scalars_names']
out_scalars_names = hf_dataset.description['out_scalars_names']
nb_samples = len(hf_dataset)
# <h2><b><a href='https://arxiv.org/abs/2305.12871' target='_blank'><b>MMGP</b> demo on the <a href='https://huggingface.co/datasets/PLAID-datasets/AirfRANS_remeshed' target='_blank'><b>AirfRANS_remeshed dataset</b></b></h2>
# <a href='https://arxiv.org/abs/2305.12871' target='_blank'><b>MMGP paper</b>,
_HEADER_ = '''
<h2><b>MMGP demo on the <a href='https://huggingface.co/datasets/PLAID-datasets/AirfRANS_remeshed' target='_blank'><b>AirfRANS_remeshed dataset</b></b></h2>
'''
_HEADER_2 = '''
The model is already trained. The morphing is the same as the one used in the [MMGP paper](https://arxiv.org/abs/2305.12871),
but is much less involved than the one used in the winning entry of the [ML4PhySim competition](https://www.codabench.org/competitions/1534/).
The training set has 103 samples and is the one used in this competition (some evaluations are out-of-distribution).
The inference takes approx 5 seconds, and is done from scratch (no precomputation is used during the inference when evaluating samples).
This means that the morphing and the finite element interpolations are re-done at each evaluation.
After choosing a sample id, please change the field name in the dropdown menu to update the visualization.
'''
def round_num(num)->str:
return '%s' % float('%.3g' % num)
def compute_inference(sample_id_str):
sample_id = int(sample_id_str)
sample_ = hf_dataset[sample_id]["sample"]
plaid_sample = Sample.model_validate(pickle.loads(sample_))
prediction = infer(hf_dataset, sample_id, training_data)
reference = {fieldn:plaid_sample.get_field(fieldn) for fieldn in out_fields_names}
nodes = plaid_sample.get_nodes()
if nodes.shape[1] == 2:
nodes__ = np.zeros((nodes.shape[0],nodes.shape[1]+1))
nodes__[:,:-1] = nodes
nodes = nodes__
triangles = plaid_sample.get_elements()['TRI_3']
trimesh = Trimesh(vertices = nodes, faces = triangles)
file = open('computed_inference.pkl', 'wb')
pickle.dump([trimesh, reference, prediction], file)
file.close()
str__ = f"Training sample {sample_id_str}"
if sample_id in train_ids:
str__ += " (in the training set)\n\n"
else:
str__ += " (not in the training set)\n\n"
str__ += str(plaid_sample)+"\n"
if len(hf_dataset.description['in_scalars_names'])>0:
str__ += "\nInput scalars:\n"
for sname in hf_dataset.description['in_scalars_names']:
str__ += f"- {sname}: {round_num(plaid_sample.get_scalar(sname))}\n"
str__ += f"\nNumber of nodes in the mesh: {nodes.shape[0]}"
return str__
def show_pred(fieldn):
file = open('computed_inference.pkl', 'rb')
data = pickle.load(file)
file.close()
trimesh, reference, prediction = data[0], data[1], data[2]
ref = reference[fieldn]
pred = prediction[fieldn]
norm = mpl.colors.Normalize(vmin=np.min(ref), vmax=np.max(ref))
cmap = cm.seismic#cm.coolwarm
m = cm.ScalarMappable(norm=norm, cmap=cmap)
vertex_colors = m.to_rgba(pred)[:,:3]
trimesh.visual.vertex_colors = vertex_colors
mesh = pyrender.Mesh.from_trimesh(trimesh, smooth=False)
# compose scene
scene = pyrender.Scene(ambient_light=[.1, .1, .3], bg_color=[0, 0, 0])
camera = pyrender.PerspectiveCamera( yfov=np.pi / 3.0)
light = pyrender.DirectionalLight(color=[1,1,1], intensity=1000.)
scene.add(mesh, pose= np.eye(4))
scene.add(light, pose= np.eye(4))
scene.add(camera, pose=[[ 1, 0, 0, 1],
[ 0, 1, 0, 0],
[ 0, 0, 1, 6],
[ 0, 0, 0, 1]])
# render scene
r = pyrender.OffscreenRenderer(1024, 1024)
color, _ = r.render(scene)
return color
def show_ref(fieldn):
file = open('computed_inference.pkl', 'rb')
data = pickle.load(file)
file.close()
trimesh, reference, prediction = data[0], data[1], data[2]
ref = reference[fieldn]
norm = mpl.colors.Normalize(vmin=np.min(ref), vmax=np.max(ref))
cmap = cm.seismic#cm.coolwarm
m = cm.ScalarMappable(norm=norm, cmap=cmap)
vertex_colors = m.to_rgba(ref)[:,:3]
trimesh.visual.vertex_colors = vertex_colors
mesh = pyrender.Mesh.from_trimesh(trimesh, smooth=False)
# compose scene
scene = pyrender.Scene(ambient_light=[.1, .1, .3], bg_color=[0, 0, 0])
camera = pyrender.PerspectiveCamera( yfov=np.pi / 3.0)
light = pyrender.DirectionalLight(color=[1,1,1], intensity=1000.)
scene.add(mesh, pose= np.eye(4))
scene.add(light, pose= np.eye(4))
scene.add(camera, pose=[[ 1, 0, 0, 1],
[ 0, 1, 0, 0],
[ 0, 0, 1, 6],
[ 0, 0, 0, 1]])
# render scene
r = pyrender.OffscreenRenderer(1024, 1024)
color, _ = r.render(scene)
return color
def show_err(fieldn):
file = open('computed_inference.pkl', 'rb')
data = pickle.load(file)
file.close()
trimesh, reference, prediction = data[0], data[1], data[2]
ref = reference[fieldn]
pred = prediction[fieldn]
norm = mpl.colors.Normalize(vmin=np.min(ref), vmax=np.max(ref))
cmap = cm.seismic#cm.coolwarm
m = cm.ScalarMappable(norm=norm, cmap=cmap)
vertex_colors = m.to_rgba(pred-ref)[:,:3]
trimesh.visual.vertex_colors = vertex_colors
mesh = pyrender.Mesh.from_trimesh(trimesh, smooth=False)
# compose scene
scene = pyrender.Scene(ambient_light=[.1, .1, .3], bg_color=[0, 0, 0])
camera = pyrender.PerspectiveCamera( yfov=np.pi / 3.0)
light = pyrender.DirectionalLight(color=[1,1,1], intensity=1000.)
scene.add(mesh, pose= np.eye(4))
scene.add(light, pose= np.eye(4))
scene.add(camera, pose=[[ 1, 0, 0, 1],
[ 0, 1, 0, 0],
[ 0, 0, 1, 6],
[ 0, 0, 0, 1]])
# render scene
r = pyrender.OffscreenRenderer(1024, 1024)
color, _ = r.render(scene)
return color
if __name__ == "__main__":
with gr.Blocks() as demo:
# trimesh, reference, prediction = compute_inference(0)
gr.Markdown(_HEADER_)
gr.Markdown(_HEADER_2)
with gr.Row(variant="panel"):
with gr.Column():
d1 = gr.Slider(0, nb_samples-1, value=0, label="Training sample id", info="Choose between 0 and "+str(nb_samples-1))
# output1 = gr.Text(label="Inference status")
output4 = gr.Text(label="Information on sample")
output5 = gr.Image(label="Error")
with gr.Column():
d2 = gr.Dropdown(out_fields_names, value=out_fields_names[0], label="Field name")
output2 = gr.Image(label="Reference")
output3 = gr.Image(label="MMGP prediction")
# d1.input(compute_inference, [d1], [output1, output4])
d1.input(compute_inference, [d1], [output4])
d2.input(show_ref, [d2], [output2])
d2.input(show_pred, [d2], [output3])
d2.input(show_err, [d2], [output5])
demo.launch()