File size: 16,506 Bytes
e3a6a57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
import os
import click
from tqdm import tqdm

try:
    import ffmpeg
except ImportError:
    raise ImportError('ffmpeg-python not found! Install it via "pip install ffmpeg-python"')

try:
    import skvideo.io
except ImportError:
    raise ImportError('scikit-video not found! Install it via "pip install scikit-video"')

import PIL
from PIL import Image
import scipy.ndimage as nd

from fractions import Fraction
import numpy as np
import torch
from torchvision import transforms

from typing import Union, Tuple

import dnnlib
import legacy
from torch_utils import gen_utils

from network_features import VGG16FeaturesNVIDIA, DiscriminatorFeatures

os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = 'hide'
import moviepy.editor


# ----------------------------------------------------------------------------


def normalize_image(image: Union[PIL.Image.Image, np.ndarray]) -> np.ndarray:
    """Change dynamic range of an image from [0, 255] to [-1, 1]"""
    image = np.array(image, dtype=np.float32)
    image = image / 127.5 - 1.0
    return image


def get_video_information(mp4_filename: Union[str, os.PathLike],
                          max_length_seconds: float = None,
                          starting_second: float = 0.0) -> Tuple[int, float, int, int, int, int]:
    """Take a mp4 file and return a list containing each frame as a NumPy array"""
    metadata = skvideo.io.ffprobe(mp4_filename)
    # Get video properties
    fps = int(np.rint(float(Fraction(metadata['video']['@avg_frame_rate']))))  # Possible error here if we get 0
    total_video_num_frames = int(metadata['video']['@nb_frames'])
    video_duration = float(metadata['video']['@duration'])
    video_width = int(metadata['video']['@width'])
    video_height = int(metadata['video']['@height'])
    # Maximum number of frames to return (if not provided, return the full video)
    if max_length_seconds is None:
        print('Considering the full video...')
        max_length_seconds = video_duration
    if starting_second != 0.0:
        print('Using part of the video...')
        starting_second = min(starting_second, video_duration)
        max_length_seconds = min(video_duration - starting_second, max_length_seconds)
    max_num_frames = int(np.rint(max_length_seconds * fps))
    max_frames = min(total_video_num_frames, max_num_frames)
    returned_duration = min(video_duration, max_length_seconds)
    # Frame to start from
    starting_frame = int(np.rint(starting_second * fps))

    return fps, returned_duration, starting_frame, max_frames, video_width, video_height


def get_video_frames(mp4_filename: Union[str, os.PathLike],
                     run_dir: Union[str, os.PathLike],
                     starting_frame: int,
                     max_frames: int,
                     center_crop: bool = False,
                     save_selected_frames: bool = False) -> np.ndarray:
    """Get all the frames of a video as a np.ndarray"""
    # DEPRECATED
    print('Getting video frames...')
    frames = skvideo.io.vread(mp4_filename)  # TODO: crazy things with scikit-video
    frames = frames[starting_frame:min(starting_frame + max_frames, len(frames)), :, :, :]
    frames = np.transpose(frames, (0, 3, 2, 1))  # NHWC => NCWH
    if center_crop:
        frame_width, frame_height = frames.shape[2], frames.shape[3]
        min_side = min(frame_width, frame_height)
        frames = frames[:, :, (frame_width - min_side) // 2:(frame_width + min_side) // 2, (frame_height - min_side) // 2:(frame_height + min_side) // 2]

    if save_selected_frames:
        skvideo.io.vwrite(os.path.join(run_dir, 'selected_frames.mp4'), np.transpose(frames, (0, 3, 2, 1)))
    return frames


# ----------------------------------------------------------------------------


@click.command()
@click.pass_context
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
# Encoder options
@click.option('--encoder', type=click.Choice(['discriminator', 'vgg16', 'clip']), help='Choose the model to encode each frame into the latent space Z.', default='discriminator', show_default=True)
@click.option('--vgg16-layer', type=click.Choice(['conv4_1', 'conv4_2', 'conv4_3', 'conv5_1', 'conv5_2', 'conv5_3', 'adavgpool', 'fc1', 'fc2']), help='Choose the layer to use from VGG16 (if used as encoder)', default='adavgpool', show_default=True)
# Source video options
@click.option('--source-video', '-video', 'video_file', type=click.Path(exists=True, dir_okay=False), help='Path to video file', required=True)
@click.option('--max-video-length', type=click.FloatRange(min=0.0, min_open=True), help='How many seconds of the video to take (from the starting second)', default=None, show_default=True)
@click.option('--starting-second', type=click.FloatRange(min=0.0), help='Second to start the video from', default=0.0, show_default=True)
@click.option('--frame-transform', type=click.Choice(['none', 'center-crop', 'resize']), help='TODO: Transform to apply to the individual frame.')
@click.option('--center-crop', is_flag=True, help='Center-crop each frame of the video')
@click.option('--save-selected-frames', is_flag=True, help='Save the selected frames of the input video after the selected transform')
# Synthesis options
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
@click.option('--new-center', type=gen_utils.parse_new_center, help='New center for the W latent space; a seed (int) or a path to a dlatent (.npy/.npz)', default=None)
@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
@click.option('--anchor-latent-space', '-anchor', is_flag=True, help='Anchor the latent space to w_avg to stabilize the video')
# Video options
@click.option('--compress', is_flag=True, help='Add flag to compress the final mp4 file with ffmpeg-python (same resolution, lower file size)')
# Extra parameters for saving the results
@click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out','visual-reactive'), show_default=True, metavar='DIR')
@click.option('--description', '-desc', type=str, help='Description name for the directory path to save results', default='', show_default=True)
def visual_reactive_interpolation(
        ctx: click.Context,
        network_pkl: Union[str, os.PathLike],
        encoder: str,
        vgg16_layer: str,
        video_file: Union[str, os.PathLike],
        max_video_length: float,
        starting_second: float,
        frame_transform: str,
        center_crop: bool,
        save_selected_frames: bool,
        truncation_psi: float,
        new_center: Tuple[str, Union[int, np.ndarray]],
        noise_mode: str,
        anchor_latent_space: bool,
        outdir: Union[str, os.PathLike],
        description: str,
        compress: bool,
        smoothing_sec: float = 0.1  # For Gaussian blur; the lower, the faster the reaction; higher leads to more generated frames being the same
):
    print(f'Loading networks from "{network_pkl}"...')

    # Define the model (load both D, G, and the features of D)
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    if encoder == 'discriminator':
        print('Loading Discriminator and its features...')
        with dnnlib.util.open_url(network_pkl) as f:
            D = legacy.load_network_pkl(f)['D'].eval().requires_grad_(False).to(device)  # type: ignore

        D_features = DiscriminatorFeatures(D).requires_grad_(False).to(device)
        del D
    elif encoder == 'vgg16':
        print('Loading VGG16 and its features...')
        url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
        with dnnlib.util.open_url(url) as f:
            vgg16 = torch.jit.load(f).eval().to(device)

        vgg16_features = VGG16FeaturesNVIDIA(vgg16).requires_grad_(False).to(device)
        del vgg16

    elif encoder == 'clip':
        print('Loading CLIP model...')
        try:
            import clip
        except ImportError:
            raise ImportError('clip not installed! Install it via "pip install git+https://github.com/openai/CLIP.git"')
        model, preprocess = clip.load('ViT-B/32', device=device)
        model = model.requires_grad_(False)  # Otherwise OOM

    print('Loading Generator...')
    with dnnlib.util.open_url(network_pkl) as f:
        G = legacy.load_network_pkl(f)['G_ema'].eval().requires_grad_(False).to(device)  # type: ignore

    if anchor_latent_space:
        gen_utils.anchor_latent_space(G)

    if new_center is None:
        # Stick to the tracked center of W during training
        w_avg = G.mapping.w_avg
    else:
        new_center, new_center_value = new_center
        # We get the new center using the int (a seed) or recovered dlatent (an np.ndarray)
        if isinstance(new_center_value, int):
            new_center = f'seed_{new_center}'
            w_avg = gen_utils.get_w_from_seed(G, device, new_center_value, truncation_psi=1.0)  # We want the pure dlatent
        elif isinstance(new_center_value, np.ndarray):
            w_avg = torch.from_numpy(new_center_value).to(device)
        else:
            ctx.fail('Error: New center has strange format! Only an int (seed) or a file (.npy/.npz) are accepted!')

    # Create the run dir with the given name description; add slowdown if different than the default (1)
    description = 'visual-reactive' if len(description) == 0 else description
    run_dir = gen_utils.make_run_dir(outdir, description)
    # Name of the video
    video_name, _ = os.path.splitext(video_file)
    video_name = video_name.split(os.sep)[-1]  # Get the actual name of the video
    mp4_name = f'visual-reactive_{video_name}'

    # Get all the frames of the video and its properties
    # TODO: resize the frames to the size of the network (G.img_resolution)
    fps, max_video_length, starting_frame, max_frames, width, height = get_video_information(video_file,
                                                                                             max_video_length,
                                                                                             starting_second)

    videogen = skvideo.io.vreader(video_file)
    fake_dlatents = list()
    if save_selected_frames:
        # skvideo.io.vwrite sets FPS=25, so we have to manually enter it via FFmpeg
        # TODO: use only ffmpeg-python
        writer = skvideo.io.FFmpegWriter(os.path.join(run_dir, f'selected-frames_{video_name}.mp4'),
                                         inputdict={'-r': str(fps)})

    for idx, frame in enumerate(tqdm(videogen, desc=f'Getting frames+latents of "{video_name}"', unit='frames')):
        # Only save the frames that the user has selected
        if idx < starting_frame:
            continue
        if idx > starting_frame + max_frames:
            break

        if center_crop:
            frame_width, frame_height = frame.shape[1], frame.shape[0]
            min_side = min(frame_width, frame_height)
            frame = frame[(frame_height - min_side) // 2:(frame_height + min_side) // 2, (frame_width - min_side) // 2:(frame_width + min_side) // 2, :]

        if save_selected_frames:
            writer.writeFrame(frame)

        # Get fake latents
        if encoder == 'discriminator':
            frame = normalize_image(frame)  # [0, 255] => [-1, 1]
            frame = torch.from_numpy(np.transpose(frame, (2, 1, 0))).unsqueeze(0).to(device)  # HWC => CWH => NCWH, N=1
            fake_z = D_features.get_layers_features(frame, layers=['fc'])[0]

        elif encoder == 'vgg16':
            preprocess = transforms.Compose([transforms.ToTensor(),
                                             transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                                  std=[0.229, 0.224, 0.225])])
            frame = preprocess(frame).unsqueeze(0).to(device)
            fake_z = vgg16_features.get_layers_features(frame, layers=[vgg16_layer])[0]
            fake_z = fake_z.view(1, 512, -1).mean(2)  # [1, C, H, W] => [1, C]; can be used in any layer

        elif encoder == 'clip':
            frame = Image.fromarray(frame)  # [0, 255]
            frame = preprocess(frame).unsqueeze(0).to(device)
            fake_z = model.encode_image(frame)

        # Normalize the latent so that it's ~N(0, 1), or divide by its .max()
        # fake_z = fake_z / fake_z.max()
        fake_z = (fake_z - fake_z.mean()) / fake_z.std()

        # Get dlatent
        fake_w = G.mapping(fake_z, None)
        # Truncation trick
        fake_w = w_avg + (fake_w - w_avg) * truncation_psi
        fake_dlatents.append(fake_w)

    if save_selected_frames:
        # Close the video writer
        writer.close()

    # Set the fake_dlatents as a torch tensor; we can't just do torch.tensor(fake_dlatents) as with NumPy :(
    fake_dlatents = torch.cat(fake_dlatents, 0)
    # Smooth out so larger changes in the scene are the ones that affect the generation
    fake_dlatents = torch.from_numpy(nd.gaussian_filter(fake_dlatents.cpu(),
                                                        sigma=[smoothing_sec * fps, 0, 0])).to(device)

    # Auxiliary function for moviepy
    def make_frame(t):
        # Get the frame, dlatent, and respective image
        frame_idx = int(np.clip(np.round(t * fps), 0, len(fake_dlatents) - 1))
        fake_w = fake_dlatents[frame_idx]
        image = gen_utils.w_to_img(G, fake_w, noise_mode)
        # Create grid for this timestamp
        grid = gen_utils.create_image_grid(image, (1, 1))
        # Grayscale => RGB
        if grid.shape[2] == 1:
            grid = grid.repeat(3, 2)
        return grid

    # Generate video using the respective make_frame function
    videoclip = moviepy.editor.VideoClip(make_frame, duration=max_video_length)
    videoclip.set_duration(max_video_length)

    # Change the video parameters (codec, bitrate) if you so desire
    final_video = os.path.join(run_dir, f'{mp4_name}.mp4')
    videoclip.write_videofile(final_video, fps=fps, codec='libx264', bitrate='16M')

    # Compress the video (lower file size, same resolution, if successful)
    if compress:
        gen_utils.compress_video(original_video=final_video, original_video_name=mp4_name, outdir=run_dir, ctx=ctx)

    # TODO: merge the videos side by side, but we will need them be the same height
    if save_selected_frames:
        # GUIDE: https://github.com/kkroening/ffmpeg-python/issues/150
        min_height = min(height, G.img_resolution)

        input0 = ffmpeg.input(os.path.join(run_dir, f'selected-frames_{video_name}.mp4'))
        input1 = ffmpeg.input(os.path.join(run_dir, f'{mp4_name}-compressed.mp4' if compress else f'{mp4_name}.mp4'))
        out = ffmpeg.filter([input0, input1], 'hstack').output(os.path.join(run_dir, 'side-by-side.mp4'))

    # Save the configuration used
    new_center = 'w_avg' if new_center is None else new_center
    ctx.obj = {
        'network_pkl': network_pkl,
        'encoder_options': {
            'encoder': encoder,
            'vgg16_layer': vgg16_layer,
        },
        'source_video_options': {
            'source_video': video_file,
            'sorce_video_params': {
                'fps': fps,
                'height': height,
                'width': width,
                'length': max_video_length,
                'starting_frame': starting_frame,
                'total_frames': max_frames
            },
            'max_video_length': max_video_length,
            'starting_second': starting_second,
            'frame_transform': frame_transform,
            'center_crop': center_crop,
            'save_selected_frames': save_selected_frames
        },
        'synthesis_options': {
            'truncation_psi': truncation_psi,
            'new_center': new_center,
            'noise_mode': noise_mode,
            'smoothing_sec': smoothing_sec
        },
        'video_options': {
            'compress': compress
        },
        'extra_parameters': {
            'outdir': run_dir,
            'description': description
        }
    }

    gen_utils.save_config(ctx=ctx, run_dir=run_dir)


# ----------------------------------------------------------------------------


if __name__ == '__main__':
    visual_reactive_interpolation()


# ----------------------------------------------------------------------------