Spaces:
Paused
Paused
import cv2, math | |
import json, os, torch | |
import numpy as np | |
from sklearn.preprocessing import Normalizer | |
from align import align_filter | |
def merge_intervals_with_breaks(time_intervals, errors, max_break=1.5): | |
print(f"时间区间: {time_intervals}") | |
print(f"错误: {errors}") | |
if not time_intervals: | |
return [] | |
# Sort intervals based on starting times (not necessary here as input is sorted but good practice) | |
sorted_intervals = sorted(zip(time_intervals, errors), key=lambda x: x[0][0]) | |
merged_intervals = [] | |
current_interval, current_error = sorted_intervals[0] | |
for (start, end), error in sorted_intervals[1:]: | |
# Check if the current interval error is the same and the break between intervals is <= 1.5 seconds | |
if error == current_error and start - current_interval[1] <= max_break: | |
# Merge intervals | |
current_interval = (round(current_interval[0]), round(max(current_interval[1], end))) | |
else: | |
# Save the completed interval | |
merged_intervals.append(((round(current_interval[0]), round(current_interval[1])), current_error)) | |
# merged_intervals.append((current_interval, current_error)) | |
# Start a new interval | |
current_interval, current_error = (round(start), round(end)), error | |
# Add the last interval | |
merged_intervals.append((current_interval, current_error)) | |
return merged_intervals | |
def findcos_single(k1, k2): | |
u1 = np.array(k1).reshape(-1, 1) | |
u2 = np.array(k2).reshape(-1, 1) | |
source_representation, test_representation = u1, u2 | |
a = np.matmul(np.transpose(source_representation), test_representation) | |
b = np.sum(np.multiply(source_representation, source_representation)) | |
c = np.sum(np.multiply(test_representation, test_representation)) | |
# return 1 - (a / (np.sqrt(b) * np.sqrt(c))) | |
cosine_similarity = a / (np.sqrt(b) * np.sqrt(c)) | |
return 100 * (1 - (1 - cosine_similarity) / 2), 0 | |
def align_hstack(frame_00, frame_01, keypoints_01=None): | |
height_00 = frame_00.shape[0] | |
height_01 = frame_01.shape[0] | |
if height_01 != height_00: | |
# 计算缩放比例,确保高度与 frame_00 一致 | |
scale_factor = height_00 / height_01 | |
new_width = int(frame_01.shape[1] * scale_factor) | |
# 使用 OpenCV 的 resize 函数按比例缩放 frame_01 | |
frame_01_resized = cv2.resize(frame_01, (new_width, height_00)) | |
else: | |
frame_01_resized = frame_01 | |
# 现在可以水平拼接两个数组 | |
combined_frame_ori = np.hstack((frame_00, frame_01_resized)) | |
if keypoints_01 == None: return combined_frame_ori, None | |
scale_factor = frame_00.shape[0] / frame_01.shape[0] # 根据高度的缩放比例 | |
# 对 frame_01 的关键点进行缩放 | |
keypoints_01_scaled = [] | |
for point in keypoints_01: | |
scaled_point = [point[0] * scale_factor, point[1] * scale_factor] # 仅对 x 和 y 坐标进行缩放 | |
keypoints_01_scaled.append(scaled_point) | |
return combined_frame_ori, keypoints_01_scaled | |
def findCosineSimilarity_1(keypoints1, keypoints2): | |
# transformer = Normalizer().fit(keypoints1) | |
# keypoints1 = transformer.transform(keypoints1) | |
user1 = np.concatenate((keypoints1[5:13], keypoints1[91:133]), axis=0).reshape(-1, 1) | |
# transformer = Normalizer().fit(keypoints2) | |
# keypoints2 = transformer.transform(keypoints2) | |
user2 = np.concatenate((keypoints2[5:13], keypoints2[91:133]), axis=0).reshape(-1, 1) | |
####ZIYU | |
source_representation, test_representation = user1, user2 | |
a = np.matmul(np.transpose(source_representation), test_representation) | |
b = np.sum(np.multiply(source_representation, source_representation)) | |
c = np.sum(np.multiply(test_representation, test_representation)) | |
# return 1 - (a / (np.sqrt(b) * np.sqrt(c))) | |
cosine_similarity = a / (np.sqrt(b) * np.sqrt(c)) | |
return 100 * (1 - (1 - cosine_similarity) / 2), 0 | |
def load_json(path): | |
with open(path, 'r') as file: | |
return json.load(file) | |
def eval(test, standard, tmpdir): | |
test_p = tmpdir + "/user.mp4" | |
standard_p = tmpdir + "/standard.mp4" | |
os.system('python inferencer_demo.py ' + test_p + ' --pred-out-dir ' + tmpdir) # produce user.json | |
scores = [] | |
align_filter(tmpdir + '/standard', tmpdir + '/user', tmpdir) # 帧对齐 produce aligned vedios | |
data_00 = load_json(tmpdir + '/standard.json') | |
data_01 = load_json(tmpdir + '/user.json') | |
cap_00 = cv2.VideoCapture(standard_p) | |
cap_01 = cv2.VideoCapture(test_p) | |
# Define keypoint connections for both videos (example indices, you'll need to customize) | |
connections1 = [(9,11), (7,9), (6,7), (6,8), (8,10), (7,13), (6,12), (12,13)] | |
connections2 = [(130,133), (126,129), (122,125), (118,121), (114,117), (93,96), (97,100), (101,104), (105,108), (109,112)] | |
# Determine the minimum length of JSON data to use | |
min_length = min(len(data_00), len(data_01)) | |
frame_width = int(cap_00.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
frame_height = int(cap_00.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
out = cv2.VideoWriter(tmpdir + '/output.mp4', cv2.VideoWriter_fourcc(*'XVID'), 5, (frame_width*2, frame_height*2)) | |
cap_00.set(cv2.CAP_PROP_POS_FRAMES, 0) # 初始化视频从头开始读取 | |
cap_01.set(cv2.CAP_PROP_POS_FRAMES, 0) | |
comments = -1 | |
error_dict = {} | |
cnt = 0 | |
line_width = 1 if frame_width // 300 == 0 else frame_width // 300 | |
# 开始逐帧处理两个视频 | |
while True: | |
ret_00, frame_00 = cap_00.read() # 逐帧读取标准视频和用户视频的当前帧 | |
ret_01, frame_01 = cap_01.read() | |
if not ret_00 and ret_01: | |
comments = 0 #.append("请尝试加快手势的完成速度,并确保每个动作都清晰可见。") | |
break # Stop if either video runs out of frames | |
elif ret_00 and not ret_01: | |
comments = 1 #.append("请尝试放慢手势的完成速度,确保每个动作都清晰可见。") | |
break # Stop if either video runs out of frames | |
elif not ret_00 and not ret_01: | |
comments = 2 | |
break | |
# combined_frame_ori = np.hstack((frame_00, frame_01)) | |
# 获取视频当前的帧号 | |
frame_id_00 = int(cap_00.get(cv2.CAP_PROP_POS_FRAMES)) | |
frame_id_01 = int(cap_01.get(cv2.CAP_PROP_POS_FRAMES)) | |
if frame_id_01 < min_length: | |
combined_frame_ori, keypoints_01_scaled = align_hstack(frame_00, frame_01, data_01[frame_id_01]["instances"][0]["keypoints"]) | |
else: | |
combined_frame_ori, _ = align_hstack(frame_00, frame_01) | |
# 处理标准视频中的关键点,并绘制关键点连接 | |
if frame_id_00 < min_length: | |
keypoints_00 = data_00[frame_id_00]["instances"][0]["keypoints"] | |
for (start, end) in connections1: | |
start = start - 1 | |
end = end - 1 | |
if start < len(keypoints_00) and end < len(keypoints_00): | |
start_point = (int(keypoints_00[start][0]), int(keypoints_00[start][1])) | |
end_point = (int(keypoints_00[end][0]), int(keypoints_00[end][1])) | |
cv2.line(frame_00, start_point, end_point, (255, 0, 0), line_width) # (BGR) Blue line | |
for (start, end) in connections2: | |
start = start - 1 | |
end = end - 1 | |
for i in range(start, end): | |
if i < len(keypoints_00) and i + 1 < len(keypoints_00): | |
start_point = (int(keypoints_00[i][0]), int(keypoints_00[i][1])) | |
end_point = (int(keypoints_00[i + 1][0]), int(keypoints_00[i + 1][1])) | |
cv2.line(frame_00, start_point, end_point, (255, 0, 0), line_width) # Blue line | |
# 将部分关键点保存并绘制圆点,标记关键位置 | |
keypoints_00_ori = keypoints_00 | |
keypoints_00 = keypoints_00[5:13] + keypoints_00[91:133] | |
for point in keypoints_00: | |
cv2.circle(frame_00, (int(point[0]), int(point[1])), 1, (0, 210, 0), -1) | |
# 处理用户视频中的关键点,并进行误差分析 | |
if frame_id_01 < min_length: | |
error = [] | |
bigerror = [] | |
keypoints_01 = keypoints_01_scaled #data_01[frame_id_01]["instances"][0]["keypoints"] | |
for (start, end) in connections1: | |
start = start - 1 | |
end = end - 1 | |
if start < len(keypoints_01) and end < len(keypoints_01): | |
start_point = (int(keypoints_01[start][0]), int(keypoints_01[start][1])) | |
end_point = (int(keypoints_01[end][0]), int(keypoints_01[end][1])) | |
cur_score = findcos_single([[int(keypoints_01[start][0]), int(keypoints_01[start][1])], [int(keypoints_01[end][0]), int(keypoints_01[end][1])]], [[int(keypoints_00_ori[start][0]), int(keypoints_00_ori[start][1])], [int(keypoints_00_ori[end][0]), int(keypoints_00_ori[end][1])]]) | |
# 如果当前相似度小于 99.3,认为有误差,并记录下来 | |
if float(cur_score[0]) < 98.8 and start != 5: | |
error.append(start) | |
cv2.line(frame_01, start_point, end_point, (0, 0, 255), 2) # Red line | |
# 如果相似度低于 98,记录为大误差 | |
if float(cur_score[0]) < 97.8: | |
bigerror.append(start) | |
else: | |
cv2.line(frame_01, start_point, end_point, (255, 0, 0), line_width) # Blue line | |
for (start, end) in connections2: | |
start = start - 1 | |
end = end - 1 | |
for i in range(start, end): | |
if i < len(keypoints_01) and i + 1 < len(keypoints_01): | |
start_point = (int(keypoints_01[i][0]), int(keypoints_01[i][1])) | |
end_point = (int(keypoints_01[i + 1][0]), int(keypoints_01[i + 1][1])) | |
cur_score = findcos_single([[int(keypoints_01[i][0]), int(keypoints_01[i][1])], [int(keypoints_01[i + 1][0]), int(keypoints_01[i + 1][1])]], [[int(keypoints_00_ori[i][0]), int(keypoints_00_ori[i][1])], [int(keypoints_00_ori[i + 1][0]), int(keypoints_00_ori[i + 1][1])]]) | |
if float(cur_score[0]) < 98.8: | |
error.append(start) | |
cv2.line(frame_01, start_point, end_point, (0, 0, 255), 2) # Red line | |
if float(cur_score[0]) < 97.8: | |
bigerror.append(start) | |
else: | |
cv2.line(frame_01, start_point, end_point, (255, 0, 0), line_width) # Blue line | |
# 将用户视频的关键点绘制为圆点 | |
keypoints_01 = keypoints_01[5:13] + keypoints_01[91:133] | |
for point in keypoints_01: | |
cv2.circle(frame_01, (int(point[0]), int(point[1])), 1, (0, 210, 0), -1) | |
# Concatenate the images horizontally to display side by side | |
# combined_frame = np.hstack((frame_00, frame_01)) | |
combined_frame, _ = align_hstack(frame_00, frame_01) | |
if frame_id_00 < min_length and frame_id_01 < min_length: | |
min_cos, min_idx = findCosineSimilarity_1(data_00[frame_id_00]["instances"][0]["keypoints"], data_01[frame_id_01]["instances"][0]["keypoints"]) | |
# 如果存在误差,将误差部分对应的人体部位加入内容列表 | |
if error != []: | |
# print(error) | |
content = [] | |
for i in error: | |
if i in [5,7]: content.append('Left Arm') | |
if i in [6,8]: content.append('Right Arm') | |
if i > 90 and i < 112: content.append('Left Hand') | |
if i >= 112: content.append('Right Hand') | |
part = "" | |
# 在视频帧上显示检测到的误差部位 | |
# cv2.putText(combined_frame, "Please check: ", (int(frame_width*1.75), int(frame_height*0.2)), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 2) | |
start_x = int(frame_width*1.75) + 10 #435 # 起始的 x 坐标 | |
start_y = int(frame_height*0.2) + 50 # 45 | |
line_height = 50 # 每一行文字的高度 | |
# 将每一个部位的内容绘制到帧上 | |
for i, item in enumerate(list(set(content))): | |
text = "- " + item | |
y_position = start_y + i * line_height | |
# cv2.putText(combined_frame, text, (start_x, y_position), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 2) | |
# big | |
if bigerror != []: | |
bigcontent = [] | |
for i in bigerror: | |
if i in [5,7]: bigcontent.append('Left Arm') | |
if i in [6,8]: bigcontent.append('Right Arm') | |
if i > 90 and i < 112: bigcontent.append('Left Hand') | |
if i >= 112: bigcontent.append('Right Hand') | |
# 记录当前帧的严重误差部位,存入 error_dict 中 | |
error_dict[cnt] = list(set(bigcontent)) | |
cnt += 1 | |
combined_frame = np.vstack((combined_frame_ori, combined_frame)) | |
out.write(combined_frame) | |
scores.append(float(min_cos)) # 记录每一帧的相似度得分 | |
fps = 5 # Frames per second | |
frame_numbers = list(error_dict.keys()) # List of frame numbers 获取含有严重误差的帧号列表 | |
time_intervals = [(frame / fps, (frame + 1) / fps) for frame in frame_numbers] # 将帧号转换为时间区间(秒) | |
errors = [error_dict[frame] for frame in frame_numbers] # 每一帧对应的严重误差部位 | |
final_merged_intervals = merge_intervals_with_breaks(time_intervals, errors) # 合并相邻或相近的时间区间,并记录对应的误差部位 | |
out.release() | |
# 返回三个结果: | |
# 1. scores 的平均值,作为整体手势相似度的评分 | |
# 2. final_merged_intervals,合并后的误差时间区间及其对应的误差信息 | |
# 3. comments,用于给用户的速度建议(加快或放慢手势) | |
return sum(scores) / len(scores), final_merged_intervals, comments | |
def install(): | |
import subprocess | |
subprocess.run(["pip", "uninstall", "-y", "numpy"], check=True) | |
subprocess.run(["pip", "install", "numpy<2"], check=True) | |
os.system('mim install mmengine') | |
os.system('mim install mmcv==2.2.0') | |
os.system('git clone https://github.com/open-mmlab/mmpose.git') | |
os.chdir('mmpose') | |
os.system('pip install -r requirements.txt') | |
os.system('pip install -v -e .') | |
os.chdir('../') | |
os.system('git clone https://github.com/open-mmlab/mmdetection.git') | |
os.system('cp ./test.py ./mmdetection/mmdet/__init__.py') | |
os.chdir('mmdetection') | |
os.system('pip install -v -e .') | |
os.chdir('../') | |
os.system('apt-get install ffmpeg imagemagick') | |