Spaces:
Running
Running
import matplotlib.pyplot as plt | |
import numpy as np | |
from PIL import Image, ImageEnhance, ImageDraw | |
import torch | |
import streamlit as st | |
from model.inference_cpu import inference_case | |
initial_rectangle = { | |
"version": "4.4.0", | |
'objects': [ | |
{ | |
"type": "rect", | |
"version": "4.4.0", | |
"originX": "left", | |
"originY": "top", | |
"left": 50, | |
"top": 50, | |
"width": 100, | |
"height": 100, | |
'fill': 'rgba(255, 165, 0, 0.3)', | |
'stroke': '#2909F1', | |
'strokeWidth': 3, | |
'strokeDashArray': None, | |
'strokeLineCap': 'butt', | |
'strokeDashOffset': 0, | |
'strokeLineJoin': 'miter', | |
'strokeUniform': True, | |
'strokeMiterLimit': 4, | |
'scaleX': 1, | |
'scaleY': 1, | |
'angle': 0, | |
'flipX': False, | |
'flipY': False, | |
'opacity': 1, | |
'shadow': None, | |
'visible': True, | |
'backgroundColor': '', | |
'fillRule': | |
'nonzero', | |
'paintFirst': | |
'fill', | |
'globalCompositeOperation': 'source-over', | |
'skewX': 0, | |
'skewY': 0, | |
'rx': 0, | |
'ry': 0 | |
} | |
] | |
} | |
def run(): | |
image = st.session_state.data_item["image"].float() | |
image_zoom_out = st.session_state.data_item["zoom_out_image"].float() | |
text_prompt = None | |
point_prompt = None | |
box_prompt = None | |
if st.session_state.use_text_prompt: | |
text_prompt = st.session_state.text_prompt | |
if st.session_state.use_point_prompt and len(st.session_state.points) > 0: | |
point_prompt = reflect_points_into_model(st.session_state.points) | |
if st.session_state.use_box_prompt: | |
box_prompt = reflect_box_into_model(st.session_state.rectangle_3Dbox) | |
inference_case.clear() | |
st.session_state.preds_3D, st.session_state.preds_3D_ori = inference_case(image, image_zoom_out, | |
text_prompt=text_prompt, | |
_point_prompt=point_prompt, | |
_box_prompt=box_prompt) | |
def reflect_box_into_model(box_3d): | |
z1, y1, x1, z2, y2, x2 = box_3d | |
x1_prompt = int(x1 * 256.0 / 325.0) | |
y1_prompt = int(y1 * 256.0 / 325.0) | |
z1_prompt = int(z1 * 32.0 / 325.0) | |
x2_prompt = int(x2 * 256.0 / 325.0) | |
y2_prompt = int(y2 * 256.0 / 325.0) | |
z2_prompt = int(z2 * 32.0 / 325.0) | |
return torch.tensor(np.array([z1_prompt, y1_prompt, x1_prompt, z2_prompt, y2_prompt, x2_prompt])) | |
def reflect_json_data_to_3D_box(json_data, view): | |
if view == 'xy': | |
st.session_state.rectangle_3Dbox[1] = json_data['objects'][0]['top'] | |
st.session_state.rectangle_3Dbox[2] = json_data['objects'][0]['left'] | |
st.session_state.rectangle_3Dbox[4] = json_data['objects'][0]['top'] + json_data['objects'][0]['height'] * json_data['objects'][0]['scaleY'] | |
st.session_state.rectangle_3Dbox[5] = json_data['objects'][0]['left'] + json_data['objects'][0]['width'] * json_data['objects'][0]['scaleX'] | |
print(st.session_state.rectangle_3Dbox) | |
def reflect_points_into_model(points): | |
points_prompt_list = [] | |
for point in points: | |
z, y, x = point | |
x_prompt = int(x * 256.0 / 325.0) | |
y_prompt = int(y * 256.0 / 325.0) | |
z_prompt = int(z * 32.0 / 325.0) | |
points_prompt_list.append([z_prompt, y_prompt, x_prompt]) | |
points_prompt = np.array(points_prompt_list) | |
points_label = np.ones(points_prompt.shape[0]) | |
print(points_prompt, points_label) | |
return (torch.tensor(points_prompt), torch.tensor(points_label)) | |
def show_points(points_ax, points_label, ax): | |
color = 'red' if points_label == 0 else 'blue' | |
ax.scatter(points_ax[0], points_ax[1], c=color, marker='o', s=200) | |
def make_fig(image, preds, point_axs=None, current_idx=None, view=None): | |
# Convert A to an image | |
image = Image.fromarray((image * 255).astype(np.uint8)).convert("RGB") | |
enhancer = ImageEnhance.Contrast(image) | |
image = enhancer.enhance(2.0) | |
# Create a yellow mask from B | |
if preds is not None: | |
mask = np.where(preds == 1, 255, 0).astype(np.uint8) | |
mask = Image.merge("RGB", | |
(Image.fromarray(mask), | |
Image.fromarray(mask), | |
Image.fromarray(np.zeros_like(mask, dtype=np.uint8)))) | |
# Overlay the mask on the image | |
image = Image.blend(image.convert("RGB"), mask, alpha=st.session_state.transparency) | |
if point_axs is not None: | |
draw = ImageDraw.Draw(image) | |
radius = 5 | |
for point in point_axs: | |
z, y, x = point | |
if view == 'xy' and z == current_idx: | |
draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill="blue") | |
elif view == 'xz'and y == current_idx: | |
draw.ellipse((x-radius, z-radius, x+radius, z+radius), fill="blue") | |
return image |