{ "cells": [ { "cell_type": "markdown", "source": [ "# Clone official repo" ], "metadata": { "id": "P1rhi9xgJR-x" }, "id": "P1rhi9xgJR-x" }, { "cell_type": "code", "source": [ "! git clone https://github.com/xuebinqin/DIS\n", "\n", "%cd ./DIS/IS-Net\n", "\n", "!pip install gdown" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "wlRB0Pq0JIvF", "outputId": "3fbaf6b9-f65d-48b7-fdf0-8e3d0e8c5614" }, "id": "wlRB0Pq0JIvF", "execution_count": 1, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Cloning into 'DIS'...\n", "remote: Enumerating objects: 355, done.\u001b[K\n", "remote: Counting objects: 100% (121/121), done.\u001b[K\n", "remote: Compressing objects: 100% (45/45), done.\u001b[K\n", "remote: Total 355 (delta 96), reused 85 (delta 76), pack-reused 234\u001b[K\n", "Receiving objects: 100% (355/355), 49.60 MiB | 27.31 MiB/s, done.\n", "Resolving deltas: 100% (172/172), done.\n", "/content/DIS/IS-Net\n", "Requirement already satisfied: gdown in /usr/local/lib/python3.10/dist-packages (4.6.6)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from gdown) (3.13.1)\n", "Requirement already satisfied: requests[socks] in /usr/local/lib/python3.10/dist-packages (from gdown) (2.31.0)\n", "Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from gdown) (1.16.0)\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from gdown) (4.66.1)\n", "Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.10/dist-packages (from gdown) (4.11.2)\n", "Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.10/dist-packages (from beautifulsoup4->gdown) (2.5)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests[socks]->gdown) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests[socks]->gdown) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests[socks]->gdown) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests[socks]->gdown) (2023.11.17)\n", "Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /usr/local/lib/python3.10/dist-packages (from requests[socks]->gdown) (1.7.1)\n" ] } ] }, { "cell_type": "code", "source": [ "!mkdir ./saved_models" ], "metadata": { "id": "FRrZ5Z_krzr7" }, "id": "FRrZ5Z_krzr7", "execution_count": 2, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Imports" ], "metadata": { "id": "RO0DY6O3Jqe9" }, "id": "RO0DY6O3Jqe9" }, { "cell_type": "code", "source": [ "import numpy as np\n", "from PIL import Image\n", "import torch\n", "from torch.autograd import Variable\n", "from torchvision import transforms\n", "import torch.nn.functional as F\n", "import gdown\n", "import os\n", "\n", "import requests\n", "import matplotlib.pyplot as plt\n", "from io import BytesIO\n", "\n", "# project imports\n", "from data_loader_cache import normalize, im_reader, im_preprocess\n", "from models import *\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "9fFNd2X_Js0e", "outputId": "039dd0c1-a69b-4f13-e3bf-913681f6e87d" }, "id": "9fFNd2X_Js0e", "execution_count": 3, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.\n", " warnings.warn(warning.format(ret))\n" ] } ] }, { "cell_type": "markdown", "source": [ "# Helpers" ], "metadata": { "id": "h1C9zSdkJgtF" }, "id": "h1C9zSdkJgtF" }, { "cell_type": "code", "source": [ "drive_link = \"https://drive.google.com/uc?id=1XHIzgTzY5BQHw140EDIgwIb53K659ENH\"\n", "\n", "# Specify the local path and filename\n", "local_path = \"/content/DIS/IS-Net/saved_models/isnet.pth\"\n", "\n", "# Download the file\n", "gdown.download(drive_link, local_path, quiet=False)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 105 }, "id": "D239BlYfo2cl", "outputId": "a4990760-18d6-4df6-9741-93f2e6e367d9" }, "id": "D239BlYfo2cl", "execution_count": 4, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "Downloading...\n", "From: https://drive.google.com/uc?id=1XHIzgTzY5BQHw140EDIgwIb53K659ENH\n", "To: /content/DIS/IS-Net/saved_models/isnet.pth\n", "100%|██████████| 177M/177M [00:03<00:00, 57.9MB/s]\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "'/content/DIS/IS-Net/saved_models/isnet.pth'" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 4 } ] }, { "cell_type": "code", "source": [ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "\n", "# Download official weights\n", "\n", "\n", "\n", "class GOSNormalize(object):\n", " '''\n", " Normalize the Image using torch.transforms\n", " '''\n", " def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):\n", " self.mean = mean\n", " self.std = std\n", "\n", " def __call__(self,image):\n", " image = normalize(image,self.mean,self.std)\n", " return image\n", "\n", "\n", "transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])\n", "\n", "def load_image(im_path, hypar):\n", " if im_path.startswith(\"http\"):\n", " im_path = BytesIO(requests.get(im_path).content)\n", "\n", " im = im_reader(im_path)\n", " im, im_shp = im_preprocess(im, hypar[\"cache_size\"])\n", " im = torch.divide(im,255.0)\n", " shape = torch.from_numpy(np.array(im_shp))\n", " return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape\n", "\n", "\n", "def build_model(hypar,device):\n", " net = hypar[\"model\"]#GOSNETINC(3,1)\n", "\n", " # convert to half precision\n", " if(hypar[\"model_digit\"]==\"half\"):\n", " net.half()\n", " for layer in net.modules():\n", " if isinstance(layer, nn.BatchNorm2d):\n", " layer.float()\n", "\n", " net.to(device)\n", "\n", " if(hypar[\"restore_model\"]!=\"\"):\n", " net.load_state_dict(torch.load(hypar[\"model_path\"]+\"/\"+hypar[\"restore_model\"],map_location=device))\n", " net.to(device)\n", " net.eval()\n", " return net\n", "\n", "\n", "def predict(net, inputs_val, shapes_val, hypar, device):\n", " '''\n", " Given an Image, predict the mask\n", " '''\n", " net.eval()\n", "\n", " if(hypar[\"model_digit\"]==\"full\"):\n", " inputs_val = inputs_val.type(torch.FloatTensor)\n", " else:\n", " inputs_val = inputs_val.type(torch.HalfTensor)\n", "\n", "\n", " inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable\n", "\n", " ds_val = net(inputs_val_v)[0] # list of 6 results\n", "\n", " pred_val = ds_val[0][0,:,:,:] # B x 1 x H x W # we want the first one which is the most accurate prediction\n", "\n", " ## recover the prediction spatial size to the orignal image size\n", " pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))\n", "\n", " ma = torch.max(pred_val)\n", " mi = torch.min(pred_val)\n", " pred_val = (pred_val-mi)/(ma-mi) # max = 1\n", "\n", " if device == 'cuda': torch.cuda.empty_cache()\n", " return (pred_val.detach().cpu().numpy()*255).astype(np.uint8) # it is the mask we need" ], "metadata": { "id": "BFVvxhZQJkEy" }, "id": "BFVvxhZQJkEy", "execution_count": 5, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Set Parameters" ], "metadata": { "id": "H7OQxVqaOgtk" }, "id": "H7OQxVqaOgtk" }, { "cell_type": "code", "execution_count": 6, "id": "189b719a-c9a2-4048-8620-0501fd5653ec", "metadata": { "id": "189b719a-c9a2-4048-8620-0501fd5653ec" }, "outputs": [], "source": [ "hypar = {} # paramters for inferencing\n", "\n", "\n", "hypar[\"model_path\"] =\"./saved_models\" ## load trained weights from this path\n", "hypar[\"restore_model\"] = \"isnet.pth\" ## name of the to-be-loaded weights\n", "hypar[\"interm_sup\"] = False ## indicate if activate intermediate feature supervision\n", "\n", "## choose floating point accuracy --\n", "hypar[\"model_digit\"] = \"full\" ## indicates \"half\" or \"full\" accuracy of float number\n", "hypar[\"seed\"] = 0\n", "\n", "hypar[\"cache_size\"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size\n", "\n", "## data augmentation parameters ---\n", "hypar[\"input_size\"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar[\"cache_size\"], which means we don't further resize the images\n", "hypar[\"crop_size\"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar[\"cache_size\"], e.g., [920,920] for data augmentation\n", "\n", "hypar[\"model\"] = ISNetDIS()" ] }, { "cell_type": "markdown", "id": "0af5269e-26a6-4370-8863-92b7381ee90f", "metadata": { "tags": [], "id": "0af5269e-26a6-4370-8863-92b7381ee90f" }, "source": [ "# Build Model" ] }, { "cell_type": "code", "execution_count": 7, "id": "b23ea487-1f64-4443-95b4-7998b5345310", "metadata": { "id": "b23ea487-1f64-4443-95b4-7998b5345310" }, "outputs": [], "source": [ "net = build_model(hypar, device)" ] }, { "cell_type": "markdown", "id": "8beb1f62-0345-4c82-a2e3-9a4db55a55a2", "metadata": { "id": "8beb1f62-0345-4c82-a2e3-9a4db55a55a2" }, "source": [ "# Predict Mask" ] }, { "cell_type": "code", "source": [ "gsheetid = \"1n9kk7IHyBzkw5e08wpjjt-Ry5aE_thqGrJ97rMeN-K4\"\n", "sheet_name = \"sarvm\"" ], "metadata": { "id": "8g-9kgrFm4nW" }, "id": "8g-9kgrFm4nW", "execution_count": 8, "outputs": [] }, { "cell_type": "code", "source": [ "gsheet_url = \"https://docs.google.com/spreadsheets/d/{}/gviz/tq?tqx=out:csv&sheet={}\".format(gsheetid, sheet_name)" ], "metadata": { "id": "cPw0Wk86nFD2" }, "id": "cPw0Wk86nFD2", "execution_count": 9, "outputs": [] }, { "cell_type": "code", "source": [ "gsheet_url" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 36 }, "id": "vJuRGNA4nI0o", "outputId": "43d83b79-7e9b-4826-8dc9-6d94607f1e63" }, "id": "vJuRGNA4nI0o", "execution_count": 10, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'https://docs.google.com/spreadsheets/d/1n9kk7IHyBzkw5e08wpjjt-Ry5aE_thqGrJ97rMeN-K4/gviz/tq?tqx=out:csv&sheet=sarvm'" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 10 } ] }, { "cell_type": "code", "source": [ "import pandas as pd\n", "df = pd.read_csv(gsheet_url)" ], "metadata": { "id": "qbuIkaiRlYMt" }, "id": "qbuIkaiRlYMt", "execution_count": 11, "outputs": [] }, { "cell_type": "code", "source": [ "image_path = df.iloc[-1]['Image']" ], "metadata": { "id": "NJSO2jQNoIwb" }, "id": "NJSO2jQNoIwb", "execution_count": 13, "outputs": [] }, { "cell_type": "code", "source": [ "drive_link = \"https://drive.google.com/uc?id=132iFIWDU6NSzZy4oEUurGplQ2Z3tGGKb\"\n", "\n", "# Specify the local path and filename\n", "local_path = \"/content/DIS/IS-Net/saved_models/input2.jpg\"\n", "\n", "# Download the file\n", "gdown.download(drive_link, local_path, quiet=False)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 105 }, "id": "TbqbGK-ZtQ67", "outputId": "ea4b8748-d999-464e-9c2e-3057c965a4f8" }, "id": "TbqbGK-ZtQ67", "execution_count": 15, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "Downloading...\n", "From: https://drive.google.com/uc?id=132iFIWDU6NSzZy4oEUurGplQ2Z3tGGKb\n", "To: /content/DIS/IS-Net/saved_models/input2.jpg\n", "100%|██████████| 27.7k/27.7k [00:00<00:00, 42.7MB/s]\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "'/content/DIS/IS-Net/saved_models/input2.jpg'" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 15 } ] }, { "cell_type": "code", "execution_count": 17, "id": "036b21e8-556b-43dd-b9fb-1ea085f7f0f1", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 708 }, "id": "036b21e8-556b-43dd-b9fb-1ea085f7f0f1", "outputId": "e5c6ae23-a03b-4c37-a96c-87439ff81aa4" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py:3769: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.\n", " warnings.warn(\"nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.\")\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "\n" }, "metadata": {} } ], "source": [ "from google.colab.patches import cv2_imshow\n", "from PIL import Image\n", "image_path = \"/content/DIS/IS-Net/saved_models/input2.jpg\"\n", "# image_bytes = BytesIO(requests.get(image_path).content)\n", "# print(image_bytes)\n", "image_tensor, orig_size = load_image(image_path, hypar)\n", "mask = predict(net,image_tensor,orig_size, hypar, device)\n", "image = Image.open(image_path)\n", "\n", "f, ax = plt.subplots(1,2, figsize = (35,20))\n", "\n", "# ax[0].imshow(np.array(Image.open(image_bytes))) # Original image\n", "# cv2_imshow(image_path)\n", "\n", "ax[0].imshow(mask, cmap = 'gray') # retouched image\n", "\n", "# ax[0].set_title(\"Original Image\")\n", "ax[0].set_title(\"Mask\")\n", "\n", "plt.show()" ] }, { "cell_type": "code", "source": [ "import cv2\n", "image = cv2.imread(image_path)\n", "h, w , _ = image.shape\n", "# print(h)\n", "# print(w)\n", "# print(_)\n", "# print(image)\n", "h, w , _ = image.shape\n", "# print(h)\n", "# print(w)\n", "# print(_)\n", "# new_image = np.zeros_like(image)\n", "# new_image[mask] = image[mask]\n", "new_image = cv2.bitwise_and(image, image, mask=mask)\n", "transparent_bg = np.zeros((new_image.shape[0],new_image.shape[1], new_image.shape[2]+1) , dtype=np.uint8)\n", "\n", "# Apply the mask to the transparent background\n", "transparent_bg[:, :, :3] = new_image\n", "\n", "# Set the alpha channel using the mask\n", "transparent_bg[:, :, 3] = mask\n", "\n", "# Save the new image with a transparent background\n", "output_path = \"/content/output.png\"\n", "cv2.imwrite(output_path, transparent_bg)\n", "# Save the new image\n", "# output_path = \"/content/output.jpg\"\n", "# cv2.imwrite(output_path, new_image)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "xnLQbsHS7IF7", "outputId": "5c08d485-e617-4a7c-f714-6ec04c0543f5" }, "id": "xnLQbsHS7IF7", "execution_count": 18, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "True" ] }, "metadata": {}, "execution_count": 18 } ] }, { "cell_type": "code", "source": [], "metadata": { "id": "rJoaRMr8T5um" }, "id": "rJoaRMr8T5um", "execution_count": null, "outputs": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.11" }, "colab": { "provenance": [] } }, "nbformat": 4, "nbformat_minor": 5 }