|
import os |
|
import gradio as gr |
|
import numpy as np |
|
import glob |
|
import warnings |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
|
|
from utils import OrthogonalRegularizer |
|
from huggingface_hub.keras_mixin import from_pretrained_keras |
|
|
|
|
|
model = from_pretrained_keras( |
|
"keras-io/pointnet_segmentation", custom_objects={"OrthogonalRegularizer": OrthogonalRegularizer} |
|
) |
|
|
|
|
|
samples = [] |
|
input_images = glob.glob("asset/source/*.csv") |
|
examples = [[im] for im in input_images] |
|
LABELS = ["wing", "body", "tail", "engine"] |
|
COLORS = ["blue", "green", "red", "pink"] |
|
|
|
|
|
def visualize_data(point_cloud, labels, output_path=None): |
|
df = pd.DataFrame( |
|
data={ |
|
"x": point_cloud[:, 0], |
|
"y": point_cloud[:, 1], |
|
"z": point_cloud[:, 2], |
|
"label": labels, |
|
} |
|
) |
|
fig = plt.figure(figsize=(15, 10)) |
|
ax = plt.axes(projection="3d") |
|
for index, label in enumerate(LABELS): |
|
c_df = df[df["label"] == label] |
|
try: |
|
ax.scatter(c_df["x"], c_df["y"], c_df["z"], label=label, alpha=0.5, c=COLORS[index]) |
|
except IndexError: |
|
pass |
|
ax.legend() |
|
if output_path: |
|
os.makedirs(os.path.dirname(output_path), exist_ok=True) |
|
plt.savefig(output_path) |
|
|
|
|
|
def inference( |
|
csv_file, |
|
output_path="asset/output", |
|
cpu=False, |
|
): |
|
|
|
csv_path = csv_file.name |
|
im_name = csv_path.split("/")[-1].split(".")[0] |
|
|
|
if os.path.exists(csv_path): |
|
df = pd.read_csv(csv_path, index_col=None) |
|
inputs = df[["x", "y", "z"]].values |
|
y_test = df.iloc[:, 3:].values |
|
else: |
|
warnings.warn(f"{csv_path} not found for {im_path}") |
|
return |
|
|
|
preds = model.predict(np.expand_dims(inputs, 0))[0] |
|
label_map = LABELS + ["none"] |
|
visualize_data(inputs, [label_map[np.argmax(label)] for label in preds], f"{output_path}/{im_name}.png") |
|
return f"{output_path}/{im_name}.png" |
|
|
|
|
|
article = "<div style='text-align: center;'><a href='https://nouamanetazi.me/' target='_blank'>Space by Nouamane Tazi</a><br><a href='https://keras.io/examples/vision/pointnet_segmentation' target='_blank'>Keras example by Soumik Rakshit, Sayak Paul</a></div>" |
|
|
|
iface = gr.Interface( |
|
inference, |
|
inputs=[ |
|
"file", |
|
], |
|
outputs=[ |
|
gr.outputs.Image(label="result"), |
|
], |
|
title="Point cloud segmentation with PointNet", |
|
article=article, |
|
examples=examples, cache_examples=True |
|
).launch(enable_queue=True) |
|
|