# import subprocess # import re # from typing import List, Tuple, Optional # command = ["python", "setup.py", "build_ext", "--inplace"] # result = subprocess.run(command, capture_output=True, text=True) # print("Output:\n", result.stdout) # print("Errors:\n", result.stderr) # if result.returncode == 0: # print("Command executed successfully.") # else: # print("Command failed with return code:", result.returncode) import datetime import gc import hashlib import math import multiprocessing as mp import os import threading import time os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1" import shutil import ffmpeg from moviepy.editor import ImageSequenceClip import zipfile # import gradio as gr import torch import numpy as np import matplotlib.pyplot as plt from PIL import Image from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor from sam2.build_sam import build_sam2_video_predictor import cv2 import uuid user_processes = {} PROCESS_TIMEOUT = datetime.timedelta(minutes=4) def reset(seg_tracker): if seg_tracker is not None: predictor, inference_state, image_predictor = seg_tracker predictor.reset_state(inference_state) del predictor del inference_state del image_predictor del seg_tracker gc.collect() torch.cuda.empty_cache() return None, ({}, {}), None, None, 0, None, None, None, 0 def extract_video_info(input_video): if input_video is None: return 4, 4, None, None, None, None, None cap = cv2.VideoCapture(input_video) fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() return fps, total_frames, None, None, None, None, None def get_meta_from_video(session_id, input_video, scale_slider, checkpoint): output_dir = f'/tmp/output_frames/{session_id}' output_masks_dir = f'/tmp/output_masks/{session_id}' output_combined_dir = f'/tmp/output_combined/{session_id}' clear_folder(output_dir) clear_folder(output_masks_dir) clear_folder(output_combined_dir) if input_video is None: return None, ({}, {}), None, None, (4, 1, 4), None, None, None, 0 cap = cv2.VideoCapture(input_video) fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() frame_interval = max(1, int(fps // scale_slider)) print(f"frame_interval: {frame_interval}") try: ffmpeg.input(input_video, hwaccel='cuda').output( os.path.join(output_dir, '%07d.jpg'), q=2, start_number=0, vf=rf'select=not(mod(n\,{frame_interval}))', vsync='vfr' ).run() except: print(f"ffmpeg cuda err") ffmpeg.input(input_video).output( os.path.join(output_dir, '%07d.jpg'), q=2, start_number=0, vf=rf'select=not(mod(n\,{frame_interval}))', vsync='vfr' ).run() first_frame_path = os.path.join(output_dir, '0000000.jpg') first_frame = cv2.imread(first_frame_path) first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() if torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True sam2_checkpoint = "segment-anything-2/checkpoints/sam2_hiera_tiny.pt" model_cfg = "sam2_hiera_t.yaml" if checkpoint == "samll": sam2_checkpoint = "segment-anything-2/checkpoints/sam2_hiera_small.pt" model_cfg = "sam2_hiera_s.yaml" elif checkpoint == "base-plus": sam2_checkpoint = "segment-anything-2/checkpoints/sam2_hiera_base_plus.pt" model_cfg = "sam2_hiera_b+.yaml" elif checkpoint == "large": sam2_checkpoint = "segment-anything-2/checkpoints/sam2_hiera_large.pt" model_cfg = "sam2_hiera_l.yaml" predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda") sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") image_predictor = SAM2ImagePredictor(sam2_model) inference_state = predictor.init_state(video_path=output_dir) predictor.reset_state(inference_state) return (predictor, inference_state, image_predictor), ({}, {}), first_frame_rgb, first_frame_rgb, (fps, frame_interval, total_frames), None, None, None, 0 def mask2bbox(mask): if len(np.where(mask > 0)[0]) == 0: print(f'not mask') return np.array([0, 0, 0, 0]).astype(np.int64), False x_ = np.sum(mask, axis=0) y_ = np.sum(mask, axis=1) x0 = np.min(np.nonzero(x_)[0]) x1 = np.max(np.nonzero(x_)[0]) y0 = np.min(np.nonzero(y_)[0]) y1 = np.max(np.nonzero(y_)[0]) return np.array([x0, y0, x1, y1]).astype(np.int64), True def sam_stroke(session_id, seg_tracker, drawing_board, last_draw, frame_num, ann_obj_id): predictor, inference_state, image_predictor = seg_tracker image_path = f'/tmp/output_frames/{session_id}/{frame_num:07d}.jpg' image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) display_image = drawing_board["image"] image_predictor.set_image(image) input_mask = drawing_board["mask"] input_mask[input_mask != 0] = 255 if last_draw is not None: diff_mask = cv2.absdiff(input_mask, last_draw) input_mask = diff_mask bbox, hasMask = mask2bbox(input_mask[:, :, 0]) if not hasMask : return seg_tracker, display_image, display_image, None masks, scores, logits = image_predictor.predict( point_coords=None, point_labels=None, box=bbox[None, :], multimask_output=False,) mask = masks > 0.0 masked_frame = show_mask(mask, display_image, ann_obj_id) masked_with_rect = draw_rect(masked_frame, bbox, ann_obj_id) frame_idx, object_ids, masks = predictor.add_new_mask(inference_state, frame_idx=frame_num, obj_id=ann_obj_id, mask=mask[0]) last_draw = drawing_board["mask"] return seg_tracker, masked_with_rect, masked_with_rect, last_draw def draw_rect(image, bbox, obj_id): cmap = plt.get_cmap("tab10") color = np.array(cmap(obj_id)[:3]) rgb_color = tuple(map(int, (color[:3] * 255).astype(np.uint8))) inv_color = tuple(map(int, (255 - color[:3] * 255).astype(np.uint8))) x0, y0, x1, y1 = bbox image_with_rect = cv2.rectangle(image.copy(), (x0, y0), (x1, y1), rgb_color, thickness=2) return image_with_rect def sam_click(session_id, seg_tracker, frame_num, point_mode, click_stack, ann_obj_id, point): points_dict, labels_dict = click_stack predictor, inference_state, image_predictor = seg_tracker ann_frame_idx = frame_num # the frame index we interact with print(f'ann_frame_idx: {ann_frame_idx}') if point_mode == "Positive": label = np.array([1], np.int32) else: label = np.array([0], np.int32) if ann_frame_idx not in points_dict: points_dict[ann_frame_idx] = {} if ann_frame_idx not in labels_dict: labels_dict[ann_frame_idx] = {} if ann_obj_id not in points_dict[ann_frame_idx]: points_dict[ann_frame_idx][ann_obj_id] = np.empty((0, 2), dtype=np.float32) if ann_obj_id not in labels_dict[ann_frame_idx]: labels_dict[ann_frame_idx][ann_obj_id] = np.empty((0,), dtype=np.int32) points_dict[ann_frame_idx][ann_obj_id] = np.append(points_dict[ann_frame_idx][ann_obj_id], point, axis=0) labels_dict[ann_frame_idx][ann_obj_id] = np.append(labels_dict[ann_frame_idx][ann_obj_id], label, axis=0) click_stack = (points_dict, labels_dict) frame_idx, out_obj_ids, out_mask_logits = predictor.add_new_points( inference_state=inference_state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, points=points_dict[ann_frame_idx][ann_obj_id], labels=labels_dict[ann_frame_idx][ann_obj_id], ) image_path = f'/tmp/output_frames/{session_id}/{ann_frame_idx:07d}.jpg' image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) masked_frame = image.copy() for i, obj_id in enumerate(out_obj_ids): mask = (out_mask_logits[i] > 0.0).cpu().numpy() masked_frame = show_mask(mask, image=masked_frame, obj_id=obj_id) masked_frame_with_markers = draw_markers(masked_frame, points_dict[ann_frame_idx], labels_dict[ann_frame_idx]) return seg_tracker, masked_frame_with_markers, masked_frame_with_markers, click_stack def draw_markers(image, points_dict, labels_dict): cmap = plt.get_cmap("tab10") image_h, image_w = image.shape[:2] marker_size = max(1, int(min(image_h, image_w) * 0.05)) for obj_id in points_dict: color = np.array(cmap(obj_id)[:3]) rgb_color = tuple(map(int, (color[:3] * 255).astype(np.uint8))) inv_color = tuple(map(int, (255 - color[:3] * 255).astype(np.uint8))) for point, label in zip(points_dict[obj_id], labels_dict[obj_id]): x, y = int(point[0]), int(point[1]) if label == 1: cv2.drawMarker(image, (x, y), inv_color, markerType=cv2.MARKER_CROSS, markerSize=marker_size, thickness=2) else: cv2.drawMarker(image, (x, y), inv_color, markerType=cv2.MARKER_TILTED_CROSS, markerSize=int(marker_size / np.sqrt(2)), thickness=2) return image def show_mask(mask, image=None, obj_id=None): cmap = plt.get_cmap("tab10") cmap_idx = 0 if obj_id is None else obj_id color = np.array([*cmap(cmap_idx)[:3], 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) mask_image = (mask_image * 255).astype(np.uint8) if image is not None: image_h, image_w = image.shape[:2] if (image_h, image_w) != (h, w): raise ValueError(f"Image dimensions ({image_h}, {image_w}) and mask dimensions ({h}, {w}) do not match") colored_mask = np.zeros_like(image, dtype=np.uint8) for c in range(3): colored_mask[..., c] = mask_image[..., c] alpha_mask = mask_image[..., 3] / 255.0 for c in range(3): image[..., c] = np.where(alpha_mask > 0, (1 - alpha_mask) * image[..., c] + alpha_mask * colored_mask[..., c], image[..., c]) return image return mask_image def show_res_by_slider(session_id, frame_per, click_stack): image_path = f'/tmp/output_frames/{session_id}' output_combined_dir = f'/tmp/output_combined/{session_id}' combined_frames = sorted([os.path.join(output_combined_dir, img_name) for img_name in os.listdir(output_combined_dir)]) if combined_frames: output_masked_frame_path = combined_frames else: original_frames = sorted([os.path.join(image_path, img_name) for img_name in os.listdir(image_path)]) output_masked_frame_path = original_frames total_frames_num = len(output_masked_frame_path) if total_frames_num == 0: print("No output results found") return None, None, 0 else: frame_num = math.floor(total_frames_num * frame_per / 100) if frame_per == 100: frame_num = frame_num - 1 chosen_frame_path = output_masked_frame_path[frame_num] print(f"{chosen_frame_path}") chosen_frame_show = cv2.imread(chosen_frame_path) chosen_frame_show = cv2.cvtColor(chosen_frame_show, cv2.COLOR_BGR2RGB) points_dict, labels_dict = click_stack if frame_num in points_dict and frame_num in labels_dict: chosen_frame_show = draw_markers(chosen_frame_show, points_dict[frame_num], labels_dict[frame_num]) return chosen_frame_show, chosen_frame_show, frame_num def clear_folder(folder_path): if os.path.exists(folder_path): shutil.rmtree(folder_path) os.makedirs(folder_path) def zip_folder(folder_path, output_zip_path): with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_STORED) as zipf: for root, _, files in os.walk(folder_path): for file in files: file_path = os.path.join(root, file) zipf.write(file_path, os.path.relpath(file_path, folder_path)) def tracking_objects(session_id, seg_tracker, frame_num, input_video): output_dir = f'/tmp/output_frames/{session_id}' output_masks_dir = f'/tmp/output_masks/{session_id}' output_combined_dir = f'/tmp/output_combined/{session_id}' output_files_dir = f'/tmp/output_files/{session_id}' output_video_path = f'{output_files_dir}/output_video.mp4' output_zip_path = f'{output_files_dir}/output_masks.zip' clear_folder(output_masks_dir) clear_folder(output_combined_dir) clear_folder(output_files_dir) video_segments = {} predictor, inference_state, image_predictor = seg_tracker for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state): video_segments[out_frame_idx] = { out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) } frame_files = sorted([f for f in os.listdir(output_dir) if f.endswith('.jpg')]) # for frame_idx in sorted(video_segments.keys()): for frame_file in frame_files: frame_idx = int(os.path.splitext(frame_file)[0]) frame_path = os.path.join(output_dir, frame_file) image = cv2.imread(frame_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) masked_frame = image.copy() if frame_idx in video_segments: for obj_id, mask in video_segments[frame_idx].items(): masked_frame = show_mask(mask, image=masked_frame, obj_id=obj_id) mask_output_path = os.path.join(output_masks_dir, f'{obj_id}_{frame_idx:07d}.png') cv2.imwrite(mask_output_path, show_mask(mask)) combined_output_path = os.path.join(output_combined_dir, f'{frame_idx:07d}.png') combined_image_bgr = cv2.cvtColor(masked_frame, cv2.COLOR_RGB2BGR) cv2.imwrite(combined_output_path, combined_image_bgr) if frame_idx == frame_num: final_masked_frame = masked_frame cap = cv2.VideoCapture(input_video) fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) cap.release() # output_frames = int(total_frames * scale_slider) output_frames = len([name for name in os.listdir(output_combined_dir) if os.path.isfile(os.path.join(output_combined_dir, name)) and name.endswith('.png')]) out_fps = fps * output_frames / total_frames # ffmpeg.input(os.path.join(output_combined_dir, '%07d.png'), framerate=out_fps).output(output_video_path, vcodec='h264_nvenc', pix_fmt='yuv420p').run() # fourcc = cv2.VideoWriter_fourcc(*"mp4v") # out = cv2.VideoWriter(output_video_path, fourcc, out_fps, (frame_width, frame_height)) # for i in range(output_frames): # frame_path = os.path.join(output_combined_dir, f'{i:07d}.png') # frame = cv2.imread(frame_path) # out.write(frame) # out.release() image_files = [os.path.join(output_combined_dir, f'{i:07d}.png') for i in range(output_frames)] clip = ImageSequenceClip(image_files, fps=out_fps) clip.write_videofile(output_video_path, codec="libx264", fps=out_fps) zip_folder(output_masks_dir, output_zip_path) print("done") return final_masked_frame, final_masked_frame, output_video_path, output_video_path, output_zip_path def increment_ann_obj_id(ann_obj_id): ann_obj_id += 1 return ann_obj_id def drawing_board_get_input_first_frame(input_first_frame): return input_first_frame def process_video(queue, result_queue, session_id): seg_tracker = None click_stack = ({}, {}) frame_num = int(0) ann_obj_id =int(0) last_draw = None while True: task = queue.get() if task["command"] == "exit": print(f"Process for {session_id} exiting.") break elif task["command"] == "extract_video_info": input_video = task["input_video"] fps, total_frames, input_first_frame, drawing_board, output_video, output_mp4, output_mask = extract_video_info(input_video) result_queue.put({"fps": fps, "total_frames": total_frames, "input_first_frame": input_first_frame, "drawing_board": drawing_board, "output_video": output_video, "output_mp4": output_mp4, "output_mask": output_mask}) elif task["command"] == "get_meta_from_video": input_video = task["input_video"] scale_slider = task["scale_slider"] checkpoint = task["checkpoint"] seg_tracker, click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id = get_meta_from_video(session_id, input_video, scale_slider, checkpoint) result_queue.put({"input_first_frame": input_first_frame, "drawing_board": drawing_board, "frame_per": frame_per, "output_video": output_video, "output_mp4": output_mp4, "output_mask": output_mask, "ann_obj_id": ann_obj_id}) elif task["command"] == "sam_stroke": drawing_board = task["drawing_board"] last_draw = task["last_draw"] frame_num = task["frame_num"] ann_obj_id = task["ann_obj_id"] seg_tracker, input_first_frame, drawing_board, last_draw = sam_stroke(session_id, seg_tracker, drawing_board, last_draw, frame_num, ann_obj_id) result_queue.put({"input_first_frame": input_first_frame, "drawing_board": drawing_board, "last_draw": last_draw}) elif task["command"] == "sam_click": frame_num = task["frame_num"] point_mode = task["point_mode"] click_stack = task["click_stack"] ann_obj_id = task["ann_obj_id"] point = task["point"] seg_tracker, input_first_frame, drawing_board, last_draw = sam_click(session_id, seg_tracker, frame_num, point_mode, click_stack, ann_obj_id, point) result_queue.put({"input_first_frame": input_first_frame, "drawing_board": drawing_board, "last_draw": last_draw}) elif task["command"] == "increment_ann_obj_id": ann_obj_id = task["ann_obj_id"] ann_obj_id = increment_ann_obj_id(ann_obj_id) result_queue.put({"ann_obj_id": ann_obj_id}) elif task["command"] == "drawing_board_get_input_first_frame": input_first_frame = task["input_first_frame"] input_first_frame = drawing_board_get_input_first_frame(input_first_frame) result_queue.put({"input_first_frame": input_first_frame}) elif task["command"] == "reset": seg_tracker, click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id = reset(seg_tracker) result_queue.put({"click_stack": click_stack, "input_first_frame": input_first_frame, "drawing_board": drawing_board, "frame_per": frame_per, "output_video": output_video, "output_mp4": output_mp4, "output_mask": output_mask, "ann_obj_id": ann_obj_id}) elif task["command"] == "show_res_by_slider": frame_per = task["frame_per"] click_stack = task["click_stack"] input_first_frame, drawing_board, frame_num = show_res_by_slider(session_id, frame_per, click_stack) result_queue.put({"input_first_frame": input_first_frame, "drawing_board": drawing_board, "frame_num": frame_num}) elif task["command"] == "tracking_objects": frame_num = task["frame_num"] input_video = task["input_video"] input_first_frame, drawing_board, output_video, output_mp4, output_mask = tracking_objects(session_id, seg_tracker, frame_num, input_video) result_queue.put({"input_first_frame": input_first_frame, "drawing_board": drawing_board, "output_video": output_video, "output_mp4": output_mp4, "output_mask": output_mask}) else: print(f"Unknown command {task['command']} for {session_id}") result_queue.put("Unknown command") def start_process(session_id): if session_id not in user_processes: queue = mp.Queue() result_queue = mp.Queue() process = mp.Process(target=process_video, args=(queue, result_queue, session_id)) process.start() user_processes[session_id] = { "process": process, "queue": queue, "result_queue": result_queue, "last_active": datetime.datetime.now() } else: user_processes[session_id]["last_active"] = datetime.datetime.now() return user_processes[session_id]["queue"] # def clean_up_processes(session_id, init_clean = False): # now = datetime.datetime.now() # to_remove = [] # for s_id, process_info in user_processes.items(): # if (now - process_info["last_active"] > PROCESS_TIMEOUT) or (s_id == session_id and init_clean): # process_info["queue"].put({"command": "exit"}) # process_info["process"].terminate() # process_info["process"].join() # to_remove.append(s_id) # for s_id in to_remove: # del user_processes[s_id] # print(f"Cleaned up process for session {s_id}.") def monitor_and_cleanup_processes(): while True: now = datetime.datetime.now() to_remove = [] for session_id, process_info in user_processes.items(): if now - process_info["last_active"] > PROCESS_TIMEOUT: process_info["queue"].put({"command": "exit"}) process_info["process"].terminate() process_info["process"].join() to_remove.append(session_id) for session_id in to_remove: del user_processes[session_id] print(f"Automatically cleaned up process for session {session_id}.") time.sleep(10) def seg_track_app(): import gradio as gr def extract_session_id_from_request(request: gr.Request): session_id = hashlib.sha256(f'{request.client.host}:{request.client.port}'.encode('utf-8')).hexdigest() # cookies = request.kwargs["headers"].get('cookie', '') # session_id = None # if '_gid=' in cookies: # session_id = cookies.split('_gid=')[1].split(';')[0] # else: # session_id = str(uuid.uuid4()) print(f"session_id {session_id}") return session_id def handle_extract_video_info(session_id, input_video): # clean_up_processes(session_id, init_clean=True) if input_video == None: return 0, 0, None, None, None, None, None queue = start_process(session_id) result_queue = user_processes[session_id]["result_queue"] queue.put({"command": "extract_video_info", "input_video": input_video}) result = result_queue.get() fps = result.get("fps") total_frames = result.get("total_frames") input_first_frame = result.get("input_first_frame") drawing_board = result.get("drawing_board") output_video = result.get("output_video") output_mp4 = result.get("output_mp4") output_mask = result.get("output_mask") scale_slider = gr.Slider.update(minimum=1.0, maximum=fps, step=1.0, value=fps,) frame_per = gr.Slider.update(minimum= 0.0, maximum= total_frames / fps, step=1.0/fps, value=0.0,) return scale_slider, frame_per, input_first_frame, drawing_board, output_video, output_mp4, output_mask def handle_get_meta_from_video(session_id, input_video, scale_slider, checkpoint): # clean_up_processes(session_id) queue = start_process(session_id) result_queue = user_processes[session_id]["result_queue"] queue.put({"command": "get_meta_from_video", "input_video": input_video, "scale_slider": scale_slider, "checkpoint": checkpoint}) result = result_queue.get() input_first_frame = result.get("input_first_frame") drawing_board = result.get("drawing_board") (fps, frame_interval, total_frames) = result.get("frame_per") output_video = result.get("output_video") output_mp4 = result.get("output_mp4") output_mask = result.get("output_mask") ann_obj_id = result.get("ann_obj_id") frame_per = gr.Slider.update(minimum= 0.0, maximum= total_frames / fps, step=frame_interval / fps, value=0.0,) return input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id def handle_sam_stroke(session_id, drawing_board, last_draw, frame_num, ann_obj_id): # clean_up_processes(session_id) queue = start_process(session_id) result_queue = user_processes[session_id]["result_queue"] queue.put({"command": "sam_stroke", "drawing_board": drawing_board, "last_draw": last_draw, "frame_num": frame_num, "ann_obj_id": ann_obj_id}) result = result_queue.get() input_first_frame = result.get("input_first_frame") drawing_board = result.get("drawing_board") last_draw = result.get("last_draw") return input_first_frame, drawing_board, last_draw def handle_sam_click(session_id, frame_num, point_mode, click_stack, ann_obj_id, evt: gr.SelectData): # clean_up_processes(session_id) queue = start_process(session_id) result_queue = user_processes[session_id]["result_queue"] point = np.array([[evt.index[0], evt.index[1]]], dtype=np.float32) queue.put({"command": "sam_click", "frame_num": frame_num, "point_mode": point_mode, "click_stack": click_stack, "ann_obj_id": ann_obj_id, "point": point}) result = result_queue.get() input_first_frame = result.get("input_first_frame") drawing_board = result.get("drawing_board") last_draw = result.get("last_draw") return input_first_frame, drawing_board, last_draw def handle_increment_ann_obj_id(session_id, ann_obj_id): # clean_up_processes(session_id) queue = start_process(session_id) result_queue = user_processes[session_id]["result_queue"] queue.put({"command": "increment_ann_obj_id", "ann_obj_id": ann_obj_id}) result = result_queue.get() ann_obj_id = result.get("ann_obj_id") return ann_obj_id def handle_drawing_board_get_input_first_frame(session_id, input_first_frame): # clean_up_processes(session_id) queue = start_process(session_id) result_queue = user_processes[session_id]["result_queue"] queue.put({"command": "drawing_board_get_input_first_frame", "input_first_frame": input_first_frame}) result = result_queue.get() input_first_frame = result.get("input_first_frame") return input_first_frame def handle_reset(session_id): # clean_up_processes(session_id) queue = start_process(session_id) result_queue = user_processes[session_id]["result_queue"] queue.put({"command": "reset"}) result = result_queue.get() click_stack = result.get("click_stack") input_first_frame = result.get("input_first_frame") drawing_board = result.get("drawing_board") frame_per = result.get("frame_per") output_video = result.get("output_video") output_mp4 = result.get("output_mp4") output_mask = result.get("output_mask") ann_obj_id = result.get("ann_obj_id") return click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id def handle_show_res_by_slider(session_id, frame_per, click_stack): # clean_up_processes(session_id) queue = start_process(session_id) result_queue = user_processes[session_id]["result_queue"] queue.put({"command": "show_res_by_slider", "frame_per": frame_per, "click_stack": click_stack}) result = result_queue.get() input_first_frame = result.get("input_first_frame") drawing_board = result.get("drawing_board") frame_num = result.get("frame_num") return input_first_frame, drawing_board, frame_num def handle_tracking_objects(session_id, frame_num, input_video): # clean_up_processes(session_id) queue = start_process(session_id) result_queue = user_processes[session_id]["result_queue"] queue.put({"command": "tracking_objects", "frame_num": frame_num, "input_video": input_video}) result = result_queue.get() input_first_frame = result.get("input_first_frame") drawing_board = result.get("drawing_board") output_video = result.get("output_video") output_mp4 = result.get("output_mp4") output_mask = result.get("output_mask") return input_first_frame, drawing_board, output_video, output_mp4, output_mask ########################################################## ###################### Front-end ######################## ########################################################## css = """ #input_output_video video { max-height: 550px; max-width: 100%; height: auto; } """ app = gr.Blocks(css=css) with app: session_id = gr.State() app.load(extract_session_id_from_request, None, session_id) gr.Markdown( '''