{ "cells": [ { "cell_type": "markdown", "id": "a2220df6", "metadata": {}, "source": [ "# Import Libraries" ] }, { "cell_type": "code", "execution_count": 4, "id": "7249bea4", "metadata": {}, "outputs": [], "source": [ "import gradio as gr\n", "import torch\n", "import torch.nn.functional as F\n", "from facenet_pytorch import MTCNN, InceptionResnetV1\n", "import numpy as np\n", "from PIL import Image\n", "import cv2\n", "from pytorch_grad_cam import GradCAM\n", "from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget\n", "from pytorch_grad_cam.utils.image import show_cam_on_image\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "62f0492b-aad6-4464-ab96-1365b7f3a44e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: gradio in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (3.39.0)\n", "Collecting gradio\n", " Downloading gradio-4.19.1-py3-none-any.whl.metadata (15 kB)\n", "Requirement already satisfied: aiofiles<24.0,>=22.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (23.2.1)\n", "Requirement already satisfied: altair<6.0,>=4.2.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (5.2.0)\n", "Requirement already satisfied: fastapi in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (0.109.2)\n", "Requirement already satisfied: ffmpy in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (0.3.2)\n", "Requirement already satisfied: gradio-client==0.10.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (0.10.0)\n", "Requirement already satisfied: httpx in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (0.26.0)\n", "Requirement already satisfied: huggingface-hub>=0.19.3 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (0.20.3)\n", "Requirement already satisfied: importlib-resources<7.0,>=1.3 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (6.1.1)\n", "Requirement already satisfied: jinja2<4.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (3.1.3)\n", "Requirement already satisfied: markupsafe~=2.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (2.1.5)\n", "Requirement already satisfied: matplotlib~=3.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (3.8.3)\n", "Requirement already satisfied: numpy~=1.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (1.26.4)\n", "Requirement already satisfied: orjson~=3.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (3.9.14)\n", "Requirement already satisfied: packaging in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (23.2)\n", "Requirement already satisfied: pandas<3.0,>=1.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (2.2.0)\n", "Requirement already satisfied: pillow<11.0,>=8.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (9.4.0)\n", "Requirement already satisfied: pydantic>=2.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (2.6.1)\n", "Requirement already satisfied: pydub in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (0.25.1)\n", "Requirement already satisfied: python-multipart>=0.0.9 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (0.0.9)\n", "Requirement already satisfied: pyyaml<7.0,>=5.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (6.0.1)\n", "Collecting ruff>=0.1.7 (from gradio)\n", " Downloading ruff-0.2.1-py3-none-win_amd64.whl.metadata (23 kB)\n", "Requirement already satisfied: semantic-version~=2.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (2.10.0)\n", "Collecting tomlkit==0.12.0 (from gradio)\n", " Downloading tomlkit-0.12.0-py3-none-any.whl.metadata (2.7 kB)\n", "Collecting typer<1.0,>=0.9 (from typer[all]<1.0,>=0.9->gradio)\n", " Downloading typer-0.9.0-py3-none-any.whl.metadata (14 kB)\n", "Requirement already satisfied: typing-extensions~=4.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (4.9.0)\n", "Requirement already satisfied: uvicorn>=0.14.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (0.27.1)\n", "Requirement already satisfied: fsspec in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio-client==0.10.0->gradio) (2024.2.0)\n", "Requirement already satisfied: websockets<12.0,>=10.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio-client==0.10.0->gradio) (11.0.3)\n", "Requirement already satisfied: jsonschema>=3.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from altair<6.0,>=4.2.0->gradio) (4.21.1)\n", "Requirement already satisfied: toolz in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from altair<6.0,>=4.2.0->gradio) (0.12.1)\n", "Requirement already satisfied: filelock in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from huggingface-hub>=0.19.3->gradio) (3.13.1)\n", "Requirement already satisfied: requests in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from huggingface-hub>=0.19.3->gradio) (2.31.0)\n", "Requirement already satisfied: tqdm>=4.42.1 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from huggingface-hub>=0.19.3->gradio) (4.66.2)\n", "Requirement already satisfied: zipp>=3.1.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from importlib-resources<7.0,>=1.3->gradio) (3.17.0)\n", "Requirement already satisfied: contourpy>=1.0.1 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from matplotlib~=3.0->gradio) (1.2.0)\n", "Requirement already satisfied: cycler>=0.10 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from matplotlib~=3.0->gradio) (0.12.1)\n", "Requirement already satisfied: fonttools>=4.22.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from matplotlib~=3.0->gradio) (4.49.0)\n", "Requirement already satisfied: kiwisolver>=1.3.1 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from matplotlib~=3.0->gradio) (1.4.5)\n", "Requirement already satisfied: pyparsing>=2.3.1 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from matplotlib~=3.0->gradio) (3.1.1)\n", "Requirement already satisfied: python-dateutil>=2.7 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from matplotlib~=3.0->gradio) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from pandas<3.0,>=1.0->gradio) (2024.1)\n", "Requirement already satisfied: tzdata>=2022.7 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from pandas<3.0,>=1.0->gradio) (2024.1)\n", "Requirement already satisfied: annotated-types>=0.4.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from pydantic>=2.0->gradio) (0.6.0)\n", "Requirement already satisfied: pydantic-core==2.16.2 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from pydantic>=2.0->gradio) (2.16.2)\n", "Requirement already satisfied: click<9.0.0,>=7.1.1 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from typer<1.0,>=0.9->typer[all]<1.0,>=0.9->gradio) (8.1.7)\n", "Requirement already satisfied: colorama<0.5.0,>=0.4.3 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from typer[all]<1.0,>=0.9->gradio) (0.4.6)\n", "Collecting shellingham<2.0.0,>=1.3.0 (from typer[all]<1.0,>=0.9->gradio)\n", " Downloading shellingham-1.5.4-py2.py3-none-any.whl.metadata (3.5 kB)\n", "Collecting rich<14.0.0,>=10.11.0 (from typer[all]<1.0,>=0.9->gradio)\n", " Downloading rich-13.7.0-py3-none-any.whl.metadata (18 kB)\n", "Requirement already satisfied: h11>=0.8 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from uvicorn>=0.14.0->gradio) (0.14.0)\n", "Requirement already satisfied: starlette<0.37.0,>=0.36.3 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from fastapi->gradio) (0.36.3)\n", "Requirement already satisfied: anyio in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from httpx->gradio) (4.2.0)\n", "Requirement already satisfied: certifi in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from httpx->gradio) (2024.2.2)\n", "Requirement already satisfied: httpcore==1.* in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from httpx->gradio) (1.0.3)\n", "Requirement already satisfied: idna in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from httpx->gradio) (3.6)\n", "Requirement already satisfied: sniffio in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from httpx->gradio) (1.3.0)\n", "Requirement already satisfied: attrs>=22.2.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (23.2.0)\n", "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (2023.12.1)\n", "Requirement already satisfied: referencing>=0.28.4 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (0.33.0)\n", "Requirement already satisfied: rpds-py>=0.7.1 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (0.18.0)\n", "Requirement already satisfied: six>=1.5 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from python-dateutil>=2.7->matplotlib~=3.0->gradio) (1.16.0)\n", "Requirement already satisfied: markdown-it-py>=2.2.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from rich<14.0.0,>=10.11.0->typer[all]<1.0,>=0.9->gradio) (2.2.0)\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from rich<14.0.0,>=10.11.0->typer[all]<1.0,>=0.9->gradio) (2.17.2)\n", "Requirement already satisfied: exceptiongroup>=1.0.2 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from anyio->httpx->gradio) (1.2.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from requests->huggingface-hub>=0.19.3->gradio) (3.3.2)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from requests->huggingface-hub>=0.19.3->gradio) (2.2.0)\n", "Requirement already satisfied: mdurl~=0.1 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from markdown-it-py>=2.2.0->rich<14.0.0,>=10.11.0->typer[all]<1.0,>=0.9->gradio) (0.1.2)\n", "Downloading gradio-4.19.1-py3-none-any.whl (16.9 MB)\n", " ---------------------------------------- 0.0/16.9 MB ? eta -:--:--\n", " ---------------------------------------- 0.0/16.9 MB 1.9 MB/s eta 0:00:09\n", " ---------------------------------------- 0.2/16.9 MB 2.4 MB/s eta 0:00:08\n", " - -------------------------------------- 0.5/16.9 MB 3.9 MB/s eta 0:00:05\n", " -- ------------------------------------- 1.1/16.9 MB 6.1 MB/s eta 0:00:03\n", " ----- ---------------------------------- 2.2/16.9 MB 10.2 MB/s eta 0:00:02\n", " -------- ------------------------------- 3.7/16.9 MB 13.9 MB/s eta 0:00:01\n", " ----------- ---------------------------- 5.1/16.9 MB 15.4 MB/s eta 0:00:01\n", " --------------- ------------------------ 6.6/16.9 MB 18.4 MB/s eta 0:00:01\n", " ------------------- -------------------- 8.4/16.9 MB 20.6 MB/s eta 0:00:01\n", " --------------------- ------------------ 9.2/16.9 MB 21.1 MB/s eta 0:00:01\n", " --------------------- ------------------ 9.2/16.9 MB 21.1 MB/s eta 0:00:01\n", " --------------------- ------------------ 9.2/16.9 MB 21.1 MB/s eta 0:00:01\n", " --------------------- ------------------ 9.3/16.9 MB 16.5 MB/s eta 0:00:01\n", " ------------------------ --------------- 10.3/16.9 MB 17.2 MB/s eta 0:00:01\n", " ------------------------ --------------- 10.5/16.9 MB 17.3 MB/s eta 0:00:01\n", " -------------------------- ------------- 11.1/16.9 MB 18.7 MB/s eta 0:00:01\n", " ---------------------------- ----------- 12.1/16.9 MB 17.7 MB/s eta 0:00:01\n", " -------------------------------- ------- 13.6/16.9 MB 18.2 MB/s eta 0:00:01\n", " ----------------------------------- ---- 14.9/16.9 MB 18.2 MB/s eta 0:00:01\n", " --------------------------------------- 16.6/16.9 MB 18.2 MB/s eta 0:00:01\n", " ---------------------------------------- 16.9/16.9 MB 16.8 MB/s eta 0:00:00\n", "Downloading tomlkit-0.12.0-py3-none-any.whl (37 kB)\n", "Downloading ruff-0.2.1-py3-none-win_amd64.whl (7.4 MB)\n", " ---------------------------------------- 0.0/7.4 MB ? eta -:--:--\n", " -------- ------------------------------- 1.6/7.4 MB 51.9 MB/s eta 0:00:01\n", " ---------------- ----------------------- 3.1/7.4 MB 40.2 MB/s eta 0:00:01\n", " ------------------------- -------------- 4.8/7.4 MB 37.9 MB/s eta 0:00:01\n", " --------------------------------- ------ 6.3/7.4 MB 36.5 MB/s eta 0:00:01\n", " --------------------------------------- 7.4/7.4 MB 33.7 MB/s eta 0:00:01\n", " ---------------------------------------- 7.4/7.4 MB 31.6 MB/s eta 0:00:00\n", "Downloading typer-0.9.0-py3-none-any.whl (45 kB)\n", " ---------------------------------------- 0.0/45.9 kB ? eta -:--:--\n", " ---------------------------------------- 45.9/45.9 kB ? eta 0:00:00\n", "Downloading rich-13.7.0-py3-none-any.whl (240 kB)\n", " ---------------------------------------- 0.0/240.6 kB ? eta -:--:--\n", " --------------------------------------- 240.6/240.6 kB 14.4 MB/s eta 0:00:00\n", "Downloading shellingham-1.5.4-py2.py3-none-any.whl (9.8 kB)\n", "Installing collected packages: tomlkit, shellingham, ruff, typer, rich, gradio\n", " Attempting uninstall: gradio\n", " Found existing installation: gradio 3.39.0\n", " Uninstalling gradio-3.39.0:\n", " Successfully uninstalled gradio-3.39.0\n", "Successfully installed gradio-4.19.1 rich-13.7.0 ruff-0.2.1 shellingham-1.5.4 tomlkit-0.12.0 typer-0.9.0\n" ] } ], "source": [ "!pip install -U gradio" ] }, { "cell_type": "markdown", "id": "d25e1c5d", "metadata": {}, "source": [ "# Download and Load Model" ] }, { "cell_type": "code", "execution_count": 5, "id": "237fbf44", "metadata": {}, "outputs": [], "source": [ "DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n", "\n", "mtcnn = MTCNN(\n", " select_largest=False,\n", " post_process=False,\n", " device=DEVICE\n", ").to(DEVICE).eval()" ] }, { "cell_type": "code", "execution_count": 6, "id": "f3ef2b4f", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "43131e0cdbdf44beb6f775f854ebbf07", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0.00/107M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "InceptionResnetV1(\n", " (conv2d_1a): BasicConv2d(\n", " (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (conv2d_2a): BasicConv2d(\n", " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (conv2d_2b): BasicConv2d(\n", " (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (maxpool_3a): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (conv2d_3b): BasicConv2d(\n", " (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (conv2d_4a): BasicConv2d(\n", " (conv): Conv2d(80, 192, kernel_size=(3, 3), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (conv2d_4b): BasicConv2d(\n", " (conv): Conv2d(192, 256, kernel_size=(3, 3), stride=(2, 2), bias=False)\n", " (bn): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (repeat_1): Sequential(\n", " (0): Block35(\n", " (branch0): BasicConv2d(\n", " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (branch1): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (branch2): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (2): BasicConv2d(\n", " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (conv2d): Conv2d(96, 256, kernel_size=(1, 1), stride=(1, 1))\n", " (relu): ReLU()\n", " )\n", " (1): Block35(\n", " (branch0): BasicConv2d(\n", " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (branch1): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (branch2): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (2): BasicConv2d(\n", " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (conv2d): Conv2d(96, 256, kernel_size=(1, 1), stride=(1, 1))\n", " (relu): ReLU()\n", " )\n", " (2): Block35(\n", " (branch0): BasicConv2d(\n", " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (branch1): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (branch2): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (2): BasicConv2d(\n", " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (conv2d): Conv2d(96, 256, kernel_size=(1, 1), stride=(1, 1))\n", " (relu): ReLU()\n", " )\n", " (3): Block35(\n", " (branch0): BasicConv2d(\n", " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (branch1): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (branch2): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (2): BasicConv2d(\n", " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (conv2d): Conv2d(96, 256, kernel_size=(1, 1), stride=(1, 1))\n", " (relu): ReLU()\n", " )\n", " (4): Block35(\n", " (branch0): BasicConv2d(\n", " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (branch1): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (branch2): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (2): BasicConv2d(\n", " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (conv2d): Conv2d(96, 256, kernel_size=(1, 1), stride=(1, 1))\n", " (relu): ReLU()\n", " )\n", " )\n", " (mixed_6a): Mixed_6a(\n", " (branch0): BasicConv2d(\n", " (conv): Conv2d(256, 384, kernel_size=(3, 3), stride=(2, 2), bias=False)\n", " (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (branch1): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(256, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (2): BasicConv2d(\n", " (conv): Conv2d(192, 256, kernel_size=(3, 3), stride=(2, 2), bias=False)\n", " (bn): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (branch2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " )\n", " (repeat_2): Sequential(\n", " (0): Block17(\n", " (branch0): BasicConv2d(\n", " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (branch1): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (2): BasicConv2d(\n", " (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (conv2d): Conv2d(256, 896, kernel_size=(1, 1), stride=(1, 1))\n", " (relu): ReLU()\n", " )\n", " (1): Block17(\n", " (branch0): BasicConv2d(\n", " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (branch1): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (2): BasicConv2d(\n", " (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (conv2d): Conv2d(256, 896, kernel_size=(1, 1), stride=(1, 1))\n", " (relu): ReLU()\n", " )\n", " (2): Block17(\n", " (branch0): BasicConv2d(\n", " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (branch1): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (2): BasicConv2d(\n", " (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (conv2d): Conv2d(256, 896, kernel_size=(1, 1), stride=(1, 1))\n", " (relu): ReLU()\n", " )\n", " (3): Block17(\n", " (branch0): BasicConv2d(\n", " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (branch1): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (2): BasicConv2d(\n", " (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (conv2d): Conv2d(256, 896, kernel_size=(1, 1), stride=(1, 1))\n", " (relu): ReLU()\n", " )\n", " (4): Block17(\n", " (branch0): BasicConv2d(\n", " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (branch1): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (2): BasicConv2d(\n", " (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (conv2d): Conv2d(256, 896, kernel_size=(1, 1), stride=(1, 1))\n", " (relu): ReLU()\n", " )\n", " (5): Block17(\n", " (branch0): BasicConv2d(\n", " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (branch1): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (2): BasicConv2d(\n", " (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (conv2d): Conv2d(256, 896, kernel_size=(1, 1), stride=(1, 1))\n", " (relu): ReLU()\n", " )\n", " (6): Block17(\n", " (branch0): BasicConv2d(\n", " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (branch1): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (2): BasicConv2d(\n", " (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (conv2d): Conv2d(256, 896, kernel_size=(1, 1), stride=(1, 1))\n", " (relu): ReLU()\n", " )\n", " (7): Block17(\n", " (branch0): BasicConv2d(\n", " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (branch1): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (2): BasicConv2d(\n", " (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (conv2d): Conv2d(256, 896, kernel_size=(1, 1), stride=(1, 1))\n", " (relu): ReLU()\n", " )\n", " (8): Block17(\n", " (branch0): BasicConv2d(\n", " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (branch1): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (2): BasicConv2d(\n", " (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (conv2d): Conv2d(256, 896, kernel_size=(1, 1), stride=(1, 1))\n", " (relu): ReLU()\n", " )\n", " (9): Block17(\n", " (branch0): BasicConv2d(\n", " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (branch1): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (2): BasicConv2d(\n", " (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)\n", " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (conv2d): Conv2d(256, 896, kernel_size=(1, 1), stride=(1, 1))\n", " (relu): ReLU()\n", " )\n", " )\n", " (mixed_7a): Mixed_7a(\n", " (branch0): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(896, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(256, 384, kernel_size=(3, 3), stride=(2, 2), bias=False)\n", " (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (branch1): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(896, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), bias=False)\n", " (bn): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (branch2): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(896, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (2): BasicConv2d(\n", " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), bias=False)\n", " (bn): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (branch3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " )\n", " (repeat_3): Sequential(\n", " (0): Block8(\n", " (branch0): BasicConv2d(\n", " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (branch1): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(192, 192, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (2): BasicConv2d(\n", " (conv): Conv2d(192, 192, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (conv2d): Conv2d(384, 1792, kernel_size=(1, 1), stride=(1, 1))\n", " (relu): ReLU()\n", " )\n", " (1): Block8(\n", " (branch0): BasicConv2d(\n", " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (branch1): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(192, 192, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (2): BasicConv2d(\n", " (conv): Conv2d(192, 192, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (conv2d): Conv2d(384, 1792, kernel_size=(1, 1), stride=(1, 1))\n", " (relu): ReLU()\n", " )\n", " (2): Block8(\n", " (branch0): BasicConv2d(\n", " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (branch1): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(192, 192, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (2): BasicConv2d(\n", " (conv): Conv2d(192, 192, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (conv2d): Conv2d(384, 1792, kernel_size=(1, 1), stride=(1, 1))\n", " (relu): ReLU()\n", " )\n", " (3): Block8(\n", " (branch0): BasicConv2d(\n", " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (branch1): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(192, 192, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (2): BasicConv2d(\n", " (conv): Conv2d(192, 192, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (conv2d): Conv2d(384, 1792, kernel_size=(1, 1), stride=(1, 1))\n", " (relu): ReLU()\n", " )\n", " (4): Block8(\n", " (branch0): BasicConv2d(\n", " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (branch1): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(192, 192, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (2): BasicConv2d(\n", " (conv): Conv2d(192, 192, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (conv2d): Conv2d(384, 1792, kernel_size=(1, 1), stride=(1, 1))\n", " (relu): ReLU()\n", " )\n", " )\n", " (block8): Block8(\n", " (branch0): BasicConv2d(\n", " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (branch1): Sequential(\n", " (0): BasicConv2d(\n", " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (1): BasicConv2d(\n", " (conv): Conv2d(192, 192, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " (2): BasicConv2d(\n", " (conv): Conv2d(192, 192, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)\n", " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU()\n", " )\n", " )\n", " (conv2d): Conv2d(384, 1792, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (avgpool_1a): AdaptiveAvgPool2d(output_size=1)\n", " (dropout): Dropout(p=0.6, inplace=False)\n", " (last_linear): Linear(in_features=1792, out_features=512, bias=False)\n", " (last_bn): BatchNorm1d(512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", " (logits): Linear(in_features=512, out_features=1, bias=True)\n", ")" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = InceptionResnetV1(\n", " pretrained=\"vggface2\",\n", " classify=True,\n", " num_classes=1,\n", " device=DEVICE\n", ")\n", "\n", "checkpoint = torch.load(\"resnetinceptionv1_epoch_32.pth\", map_location=torch.device('cpu'))\n", "model.load_state_dict(checkpoint['model_state_dict'])\n", "model.to(DEVICE)\n", "model.eval()" ] }, { "cell_type": "markdown", "id": "a499194a", "metadata": {}, "source": [ "# Model Inference " ] }, { "cell_type": "code", "execution_count": 8, "id": "376e6cd6", "metadata": {}, "outputs": [], "source": [ "def predict(input_image:Image.Image):\n", " \"\"\"Predict the label of the input_image\"\"\"\n", " face = mtcnn(input_image)\n", " if face is None:\n", " raise Exception('No face detected')\n", " face = face.unsqueeze(0) # add the batch dimension\n", " face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)\n", " \n", " # convert the face into a numpy array to be able to plot it\n", " prev_face = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()\n", " prev_face = prev_face.astype('uint8')\n", "\n", " face = face.to(DEVICE)\n", " face = face.to(torch.float32)\n", " face = face / 255.0\n", " face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()\n", "\n", " target_layers=[model.block8.branch1[-1]]\n", " use_cuda = True if torch.cuda.is_available() else False\n", " cam = GradCAM(model=model, target_layers=target_layers, use_cuda=use_cuda)\n", " targets = [ClassifierOutputTarget(0)]\n", "\n", " grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)\n", " grayscale_cam = grayscale_cam[0, :]\n", " visualization = show_cam_on_image(face_image_to_plot, grayscale_cam, use_rgb=True)\n", " face_with_mask = cv2.addWeighted(prev_face, 1, visualization, 0.5, 0)\n", "\n", " with torch.no_grad():\n", " output = torch.sigmoid(model(face).squeeze(0))\n", " prediction = \"real\" if output.item() < 0.5 else \"fake\"\n", " \n", " real_prediction = 1 - output.item()\n", " fake_prediction = output.item()\n", " \n", " confidences = {\n", " 'real': real_prediction,\n", " 'fake': fake_prediction\n", " }\n", " return confidences, face_with_mask\n" ] }, { "cell_type": "markdown", "id": "14f47b5a", "metadata": {}, "source": [ "# Gradio Interface" ] }, { "cell_type": "code", "execution_count": 9, "id": "d62177b5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7860\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "