sentis-blaze-pose / Assets /Scripts /PoseDetection.cs
UnityGiles's picture
add project
58c78a9
using System;
using Unity.Mathematics;
using Unity.Sentis;
using UnityEngine;
public class PoseDetection : MonoBehaviour
{
public PosePreview posePreview;
public ImagePreview imagePreview;
public Texture2D imageTexture;
public ModelAsset poseDetector;
public ModelAsset poseLandmarker;
public TextAsset anchorsCSV;
public float scoreThreshold = 0.75f;
const int k_NumAnchors = 2254;
float[,] m_Anchors;
const int k_NumKeypoints = 33;
const int detectorInputSize = 224;
const int landmarkerInputSize = 256;
Worker m_PoseDetectorWorker;
Worker m_PoseLandmarkerWorker;
Tensor<float> m_DetectorInput;
Tensor<float> m_LandmarkerInput;
Awaitable m_DetectAwaitable;
float m_TextureWidth;
float m_TextureHeight;
public async void Start()
{
m_Anchors = BlazeUtils.LoadAnchors(anchorsCSV.text, k_NumAnchors);
var poseDetectorModel = ModelLoader.Load(poseDetector);
// post process the model to filter scores + argmax select the best pose
var graph = new FunctionalGraph();
var input = graph.AddInput(poseDetectorModel, 0);
var outputs = Functional.Forward(poseDetectorModel, input);
var boxes = outputs[0]; // (1, 2254, 12)
var scores = outputs[1]; // (1, 2254, 1)
var idx_scores_boxes = BlazeUtils.ArgMaxFiltering(boxes, scores);
poseDetectorModel = graph.Compile(idx_scores_boxes.Item1, idx_scores_boxes.Item2, idx_scores_boxes.Item3);
m_PoseDetectorWorker = new Worker(poseDetectorModel, BackendType.GPUCompute);
var poseLandmarkerModel = ModelLoader.Load(poseLandmarker);
m_PoseLandmarkerWorker = new Worker(poseLandmarkerModel, BackendType.GPUCompute);
m_DetectorInput = new Tensor<float>(new TensorShape(1, detectorInputSize, detectorInputSize, 3));
m_LandmarkerInput = new Tensor<float>(new TensorShape(1, landmarkerInputSize, landmarkerInputSize, 3));
while (true)
{
try
{
m_DetectAwaitable = Detect(imageTexture);
await m_DetectAwaitable;
}
catch (OperationCanceledException)
{
break;
}
}
m_PoseDetectorWorker.Dispose();
m_PoseLandmarkerWorker.Dispose();
m_DetectorInput.Dispose();
m_LandmarkerInput.Dispose();
}
Vector3 ImageToWorld(Vector2 position)
{
return (position - 0.5f * new Vector2(m_TextureWidth, m_TextureHeight)) / m_TextureHeight;
}
async Awaitable Detect(Texture texture)
{
m_TextureWidth = texture.width;
m_TextureHeight = texture.height;
imagePreview.SetTexture(texture);
var size = Mathf.Max(texture.width, texture.height);
// The affine transformation matrix to go from tensor coordinates to image coordinates
var scale = size / (float)detectorInputSize;
var M = BlazeUtils.mul(BlazeUtils.TranslationMatrix(0.5f * (new Vector2(texture.width, texture.height) + new Vector2(-size, size))), BlazeUtils.ScaleMatrix(new Vector2(scale, -scale)));
BlazeUtils.SampleImageAffine(texture, m_DetectorInput, M);
m_PoseDetectorWorker.Schedule(m_DetectorInput);
var outputIdxAwaitable = (m_PoseDetectorWorker.PeekOutput(0) as Tensor<int>).ReadbackAndCloneAsync();
var outputScoreAwaitable = (m_PoseDetectorWorker.PeekOutput(1) as Tensor<float>).ReadbackAndCloneAsync();
var outputBoxAwaitable = (m_PoseDetectorWorker.PeekOutput(2) as Tensor<float>).ReadbackAndCloneAsync();
using var outputIdx = await outputIdxAwaitable;
using var outputScore = await outputScoreAwaitable;
using var outputBox = await outputBoxAwaitable;
var scorePassesThreshold = outputScore[0] >= scoreThreshold;
posePreview.SetActive(scorePassesThreshold);
if (!scorePassesThreshold)
return;
var idx = outputIdx[0];
var anchorPosition = detectorInputSize * new float2(m_Anchors[idx, 0], m_Anchors[idx, 1]);
var face_ImageSpace = BlazeUtils.mul(M, anchorPosition + new float2(outputBox[0, 0, 0], outputBox[0, 0, 1]));
var faceTopRight_ImageSpace = BlazeUtils.mul(M, anchorPosition + new float2(outputBox[0, 0, 0] + 0.5f * outputBox[0, 0, 2], outputBox[0, 0, 1] + 0.5f * outputBox[0, 0, 3]));
var kp1_ImageSpace = BlazeUtils.mul(M, anchorPosition + new float2(outputBox[0, 0, 4 + 2 * 0 + 0], outputBox[0, 0, 4 + 2 * 0 + 1]));
var kp2_ImageSpace = BlazeUtils.mul(M, anchorPosition + new float2(outputBox[0, 0, 4 + 2 * 1 + 0], outputBox[0, 0, 4 + 2 * 1 + 1]));
var delta_ImageSpace = kp2_ImageSpace - kp1_ImageSpace;
var dscale = 1.25f;
var radius = dscale * math.length(delta_ImageSpace);
var theta = math.atan2(delta_ImageSpace.y, delta_ImageSpace.x);
var origin2 = new float2(0.5f * landmarkerInputSize, 0.5f * landmarkerInputSize);
var scale2 = radius / (0.5f * landmarkerInputSize);
var M2 = BlazeUtils.mul(BlazeUtils.mul(BlazeUtils.mul(BlazeUtils.TranslationMatrix(kp1_ImageSpace), BlazeUtils.ScaleMatrix(new float2(scale2, -scale2))), BlazeUtils.RotationMatrix(0.5f * Mathf.PI - theta)), BlazeUtils.TranslationMatrix(-origin2));
BlazeUtils.SampleImageAffine(texture, m_LandmarkerInput, M2);
var boxSize = 2f * (faceTopRight_ImageSpace - face_ImageSpace);
posePreview.SetBoundingBox(true, ImageToWorld(face_ImageSpace), boxSize / m_TextureHeight);
posePreview.SetBoundingCircle(true, ImageToWorld(kp1_ImageSpace), radius / m_TextureHeight);
m_PoseLandmarkerWorker.Schedule(m_LandmarkerInput);
var landmarksAwaitable = (m_PoseLandmarkerWorker.PeekOutput("Identity") as Tensor<float>).ReadbackAndCloneAsync();
using var landmarks = await landmarksAwaitable; // (1,195)
for (var i = 0; i < k_NumKeypoints; i++)
{
// https://arxiv.org/pdf/2006.10204
var position_ImageSpace = BlazeUtils.mul(M2, new float2(landmarks[5 * i + 0], landmarks[5 * i + 1]));
var visibility = landmarks[5 * i + 3];
var presence = landmarks[5 * i + 4];
// z-position is in unit cube centered on hips
Vector3 position_WorldSpace = ImageToWorld(position_ImageSpace) + new Vector3(0, 0, landmarks[5 * i + 2] / m_TextureHeight);
posePreview.SetKeypoint(i, visibility > 0.5f && presence > 0.5f, position_WorldSpace);
}
}
void OnDestroy()
{
m_DetectAwaitable.Cancel();
}
}