File size: 8,340 Bytes
260d870
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# -*- coding: utf-8 -*-
"""
End-to-End Referring Video Object Segmentation with Multimodal Transformers

This notebook provides a (limited) hands-on demonstration of MTTR.

Given a text query and a short clip based on a YouTube video, we demonstrate how MTTR can be used to segment the referred object instance throughout the video.


### Disclaimer
This is a **limited** demonstration of MTTR's performance. The model used here was trained **exclusively** on Refer-YouTube-VOS with window size `w=12` (as described in our paper). No additional training data was used whatsoever. 
Hence, the model's performance may be limited, especially on instances from unseen categories.

Additionally, slow processing times may be encountered, depending on the input clip length and/or resolution, and due to Colab's limited computational resources.

Finally, we emphasize that this demonstration is intended to be used for academic purposes only. We do not take any responsibility for how the created content is used or distributed, and discourage the users from copyright infringment of YouTube videos. <br><br>

And now, with all formalities aside, let's begin!

"""

import gradio as gr
import torch
import torchvision
import torchvision.transforms.functional as F
from einops import rearrange
import numpy as np
from PIL import Image, ImageDraw, ImageOps, ImageFont
from moviepy.editor import VideoFileClip, AudioFileClip, ImageSequenceClip
from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
from tqdm import trange, tqdm

class NestedTensor(object):
    def __init__(self, tensors, mask):
        self.tensors = tensors
        self.mask = mask

def nested_tensor_from_videos_list(videos_list):
    def _max_by_axis(the_list):
      maxes = the_list[0]
      for sublist in the_list[1:]:
          for index, item in enumerate(sublist):
              maxes[index] = max(maxes[index], item)
      return maxes

    max_size = _max_by_axis([list(img.shape) for img in videos_list])
    padded_batch_shape = [len(videos_list)] + max_size
    b, t, c, h, w = padded_batch_shape
    dtype = videos_list[0].dtype
    device = videos_list[0].device
    padded_videos = torch.zeros(padded_batch_shape, dtype=dtype, device=device)
    videos_pad_masks = torch.ones((b, t, h, w), dtype=torch.bool, device=device)
    for vid_frames, pad_vid_frames, vid_pad_m in zip(videos_list, padded_videos, videos_pad_masks):
        pad_vid_frames[:vid_frames.shape[0], :, :vid_frames.shape[2], :vid_frames.shape[3]].copy_(vid_frames)
        vid_pad_m[:vid_frames.shape[0], :vid_frames.shape[2], :vid_frames.shape[3]] = False
    return NestedTensor(padded_videos.transpose(0, 1), videos_pad_masks.transpose(0, 1))

def apply_mask(image, mask, color, transparency=0.7):
    mask = mask[..., np.newaxis].repeat(repeats=3, axis=2)
    mask = mask * transparency
    color_matrix = np.ones(image.shape, dtype=np.float) * color
    out_image = color_matrix * mask + image * (1.0 - mask)
    return out_image

def process(text_query, full_video_path):
    start_pt, end_pt = 0, 10
    input_clip_path = '/tmp/input.mp4'
    # extract the relevant subclip:
    with VideoFileClip(full_video_path) as video:
        subclip = video.subclip(start_pt, end_pt)
        subclip.write_videofile(input_clip_path)
        
    checkpoint_path ='./refer-youtube-vos_window-12.pth.tar'
    model, postprocessor = torch.hub.load('Randl/MTTR:main','mttr_refer_youtube_vos', get_weights=False)
    
    model_state_dict = torch.load(checkpoint_path, map_location='cpu')
    if 'model_state_dict' in model_state_dict.keys():
        model_state_dict = model_state_dict['model_state_dict']
    model.load_state_dict(model_state_dict, strict=True)


    text_queries= [text_query]
    window_length = 24  # length of window during inference
    window_overlap = 6  # overlap (in frames) between consecutive windows

    with torch.inference_mode():
      # read and preprocess the video clip:
      video, audio, meta = torchvision.io.read_video(filename=input_clip_path)
      video = rearrange(video, 't h w c -> t c h w')
      input_video = F.resize(video, size=360, max_size=640)
      input_video = input_video.to(torch.float).div_(255)
      input_video = F.normalize(input_video, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      video_metadata = {'resized_frame_size': input_video.shape[-2:], 'original_frame_size': video.shape[-2:]}
      
      # partition the clip into overlapping windows of frames:
      windows = [input_video[i:i+window_length] for i in range(0, len(input_video), window_length - window_overlap)]
      # clean up the text queries:
      text_queries = [" ".join(q.lower().split()) for q in text_queries]

      pred_masks_per_query = []
      t, _, h, w = video.shape
      for text_query in tqdm(text_queries, desc='text queries'):
        pred_masks = torch.zeros(size=(t, 1, h, w))
        for i, window in enumerate(tqdm(windows, desc='windows')):
          window = nested_tensor_from_videos_list([window])
          valid_indices = torch.arange(len(window.tensors))
          outputs = model(window, valid_indices, [text_query])
          window_masks = postprocessor(outputs, [video_metadata], window.tensors.shape[-2:])[0]['pred_masks']
          win_start_idx = i*(window_length-window_overlap)
          pred_masks[win_start_idx:win_start_idx + window_length] = window_masks
        pred_masks_per_query.append(pred_masks)

    """Finally, we apply the generated instance masks and their corresponding text queries on the input clip for visualization:"""

    # RGB colors for instance masks:
    light_blue = (41, 171, 226)
    purple = (237, 30, 121)
    dark_green = (35, 161, 90)
    orange = (255, 148, 59)
    colors = np.array([light_blue, purple, dark_green, orange])

    # width (in pixels) of the black strip above the video on which the text queries will be displayed:
    text_border_height_per_query = 40

    video_np = rearrange(video, 't c h w -> t h w c').numpy() / 255.0
    # del video
    pred_masks_per_frame = rearrange(torch.stack(pred_masks_per_query), 'q t 1 h w -> t q h w').numpy()
    masked_video = []
    for vid_frame, frame_masks in tqdm(zip(video_np, pred_masks_per_frame), total=len(video_np), desc='applying masks...'):
      # apply the masks:
      for inst_mask, color in zip(frame_masks, colors):
        vid_frame = apply_mask(vid_frame, inst_mask, color / 255.0)
      vid_frame = Image.fromarray((vid_frame * 255).astype(np.uint8))
      # visualize the text queries:
      vid_frame = ImageOps.expand(vid_frame, border=(0, len(text_queries)*text_border_height_per_query, 0, 0))
      W, H = vid_frame.size
      draw = ImageDraw.Draw(vid_frame)
      font = ImageFont.truetype(font='LiberationSans-Regular.ttf', size=30)
      for i, (text_query, color) in enumerate(zip(text_queries, colors), start=1):
          w, h = draw.textsize(text_query, font=font)
          draw.text(((W - w) / 2, (text_border_height_per_query * i) - h - 8),
                    text_query, fill=tuple(color) + (255,), font=font)
      masked_video.append(np.array(vid_frame))

    # generate and save the output clip:
    output_clip_path = '/tmp/output_clip.mp4'
    clip = ImageSequenceClip(sequence=masked_video, fps=meta['video_fps'])
    clip = clip.set_audio(AudioFileClip(input_clip_path))
    clip.write_videofile(output_clip_path, fps=meta['video_fps'], audio=True)
    del masked_video
    

    return output_clip_path
    
    

title = "Interactive demo: MTTR"

description = "To use it, upload a video file. Right now we only suggest using .mp4 files."

article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.14821'>End-to-End Referring Video Object Segmentation with Multimodal Transformers</a> | <a href='https://github.com/mttr2021/MTTR'>Github Repo</a></p>"

iface = gr.Interface(fn=process, 
                     inputs=[gr.inputs.Textbox(label="text query"), gr.inputs.Video(label="Input video. First 10 seconds of the video are used.")],
                     outputs='video',
                     title=title,
                     description=description,
                     enable_queue=True,
                     # examples=[[420, 'skate_jump.mp4']],  # Not working for some reason...
                     article=article)

iface.launch(debug=True)