{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "accelerator": "GPU", "colab": { "name": "animegan_v2_for_videos.ipynb", "provenance": [], "collapsed_sections": [], "authorship_tag": "ABX9TyP/bydrfrVmE0CzRt9JBw+x", "include_colab_link": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "" ] }, { "cell_type": "code", "metadata": { "id": "dufmM-T1Helt" }, "source": [ "%%capture\n", "! pip install gradio encoded-video" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "9CY3n8A0Lvdi" }, "source": [ "import gc\n", "import math\n", "import tempfile\n", "from PIL import Image\n", "from io import BytesIO\n", "\n", "import torch\n", "import gradio as gr\n", "import numpy as np\n", "from encoded_video import EncodedVideo, write_video\n", "from torchvision.transforms.functional import to_tensor, center_crop" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "YxdCnrTzLw5V" }, "source": [ "model = torch.hub.load(\n", " \"AK391/animegan2-pytorch:main\",\n", " \"generator\",\n", " pretrained=True,\n", " device=\"cuda\",\n", " progress=True,\n", ")" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "TYAyXUP1UeOd" }, "source": [ "! curl https://upload.wikimedia.org/wikipedia/commons/transcoded/2/29/2017-01-07_President_Obama%27s_Weekly_Address.webm/2017-01-07_President_Obama%27s_Weekly_Address.webm.360p.vp9.webm -o obama.webm" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "TxT45Nlc88tD" }, "source": [ "def face2paint(model: torch.nn.Module, img: Image.Image, size: int = 512, device: str = 'cuda'):\n", " w, h = img.size\n", " s = min(w, h)\n", " img = img.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))\n", " img = img.resize((size, size), Image.LANCZOS)\n", "\n", " with torch.no_grad():\n", " input = to_tensor(img).unsqueeze(0) * 2 - 1\n", " output = model(input.to(device)).cpu()[0]\n", "\n", " output = (output * 0.5 + 0.5).clip(0, 1) * 255.\n", "\n", " return output\n", "\n", "# This function is taken from pytorchvideo!\n", "def uniform_temporal_subsample(x: torch.Tensor, num_samples: int, temporal_dim: int = -3) -> torch.Tensor:\n", " \"\"\"\n", " Uniformly subsamples num_samples indices from the temporal dimension of the video.\n", " When num_samples is larger than the size of temporal dimension of the video, it\n", " will sample frames based on nearest neighbor interpolation.\n", " Args:\n", " x (torch.Tensor): A video tensor with dimension larger than one with torch\n", " tensor type includes int, long, float, complex, etc.\n", " num_samples (int): The number of equispaced samples to be selected\n", " temporal_dim (int): dimension of temporal to perform temporal subsample.\n", " Returns:\n", " An x-like Tensor with subsampled temporal dimension.\n", " \"\"\"\n", " t = x.shape[temporal_dim]\n", " assert num_samples > 0 and t > 0\n", " # Sample by nearest neighbor interpolation if num_samples > t.\n", " indices = torch.linspace(0, t - 1, num_samples)\n", " indices = torch.clamp(indices, 0, t - 1).long()\n", " return torch.index_select(x, temporal_dim, indices)\n", "\n", "\n", "def short_side_scale(\n", " x: torch.Tensor,\n", " size: int,\n", " interpolation: str = \"bilinear\",\n", ") -> torch.Tensor:\n", " \"\"\"\n", " Determines the shorter spatial dim of the video (i.e. width or height) and scales\n", " it to the given size. To maintain aspect ratio, the longer side is then scaled\n", " accordingly.\n", " Args:\n", " x (torch.Tensor): A video tensor of shape (C, T, H, W) and type torch.float32.\n", " size (int): The size the shorter side is scaled to.\n", " interpolation (str): Algorithm used for upsampling,\n", " options: nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area'\n", " Returns:\n", " An x-like Tensor with scaled spatial dims.\n", " \"\"\"\n", " assert len(x.shape) == 4\n", " assert x.dtype == torch.float32\n", " c, t, h, w = x.shape\n", " if w < h:\n", " new_h = int(math.floor((float(h) / w) * size))\n", " new_w = size\n", " else:\n", " new_h = size\n", " new_w = int(math.floor((float(w) / h) * size))\n", "\n", " return torch.nn.functional.interpolate(\n", " x, size=(new_h, new_w), mode=interpolation, align_corners=False\n", " )\n", "\n", "def inference_step(vid, start_sec, duration, out_fps):\n", " clip = vid.get_clip(start_sec, start_sec + duration)\n", " video_arr = torch.from_numpy(clip['video']).permute(3, 0, 1, 2)\n", " audio_arr = np.expand_dims(clip['audio'], 0)\n", " audio_fps = None if not vid._has_audio else vid._container.streams.audio[0].sample_rate\n", "\n", " x = uniform_temporal_subsample(video_arr, duration * out_fps)\n", " x = center_crop(short_side_scale(x, 512), 512)\n", " x /= 255.\n", " x = x.permute(1, 0, 2, 3)\n", " with torch.no_grad():\n", " output = model(x.to('cuda')).detach().cpu()\n", " output = (output * 0.5 + 0.5).clip(0, 1) * 255.\n", " output_video = output.permute(0, 2, 3, 1).numpy()\n", " \n", " return output_video, audio_arr, out_fps, audio_fps\n", "\n", "def predict_fn(filepath, start_sec, duration, out_fps):\n", " # out_fps=12\n", " vid = EncodedVideo.from_path(filepath)\n", " for i in range(duration):\n", " video, audio, fps, audio_fps = inference_step(\n", " vid = vid,\n", " start_sec = i + start_sec,\n", " duration = 1,\n", " out_fps = out_fps\n", " )\n", " gc.collect()\n", " if i == 0:\n", " video_all = video\n", " audio_all = audio\n", " else:\n", " video_all = np.concatenate((video_all, video))\n", " audio_all = np.hstack((audio_all, audio))\n", "\n", " write_video(\n", " 'out.mp4',\n", " video_all,\n", " fps=fps,\n", " audio_array=audio_all,\n", " audio_fps=audio_fps,\n", " audio_codec='aac'\n", " )\n", "\n", " del video_all\n", " del audio_all\n", " \n", " return 'out.mp4'\n", "\n", "article = \"\"\"\n", "
\n", " Github Repo Pytorch\n", "
\n", "\"\"\"\n", "\n", "gr.Interface(\n", " predict_fn,\n", " inputs=[gr.inputs.Video(), gr.inputs.Slider(minimum=0, maximum=300, step=1, default=0), gr.inputs.Slider(minimum=1, maximum=10, step=1, default=2), gr.inputs.Slider(minimum=12, maximum=30, step=6, default=24)],\n", " outputs=gr.outputs.Video(),\n", " title='AnimeGANV2 On Videos',\n", " description=\"Applying AnimeGAN-V2 to frame from video clips\",\n", " article = article,\n", " enable_queue=True,\n", " examples=[\n", " ['obama.webm', 23, 10, 30],\n", " ],\n", " allow_flagging=False\n", ").launch(debug=True)" ], "execution_count": null, "outputs": [] } ] }