{ "cells": [ { "cell_type": "markdown", "id": "1c073d83-8e73-407a-a669-3a837a90aac6", "metadata": {}, "source": [ "# Compute DINO features" ] }, { "cell_type": "code", "execution_count": 2, "id": "901240c3-5111-4733-97ee-69891e4e7184", "metadata": {}, "outputs": [], "source": [ "import argparse\n", "import math\n", "import os\n", "\n", "import torch\n", "import torchpq\n", "from omegaconf import OmegaConf\n", "from torch.utils.data import DataLoader\n", "from sklearn.decomposition import PCA\n", "from torchvision.transforms import transforms\n", "from tqdm import tqdm\n", "from transformers.utils import constants\n", "\n", "from dreamcreature.dino import DINO\n", "from dreamcreature.dataset import ImageDataset\n", "\n", "MEAN = constants.IMAGENET_DEFAULT_MEAN\n", "STD = constants.IMAGENET_DEFAULT_STD" ] }, { "cell_type": "code", "execution_count": 4, "id": "2f83c248-c3d5-4fe2-a111-ecdeda648214", "metadata": {}, "outputs": [], "source": [ "dataset_name = 'cub200_2011'\n", "# dataset_name = 'dogs'\n", "\n", "rootdir = f'data/{dataset_name}'\n", "resize = 256\n", "crop = 224\n", "\n", "dataset = ImageDataset(rootdir,\n", " 'train.txt',\n", " transform=transforms.Compose([\n", " transforms.Resize(resize, interpolation=transforms.InterpolationMode.BICUBIC),\n", " transforms.CenterCrop(crop),\n", " transforms.ToTensor(),\n", " transforms.Normalize(MEAN, STD)\n", " ]))" ] }, { "cell_type": "code", "execution_count": null, "id": "946a88e5-b368-4cc3-a625-fd952650036b", "metadata": {}, "outputs": [], "source": [ "dataloader = DataLoader(dataset, 32, shuffle=False, drop_last=False, num_workers=4)\n", "model = DINO()\n", "model.eval()\n", "\n", "device = torch.device('cuda')\n", "model = model.to(device)" ] }, { "cell_type": "code", "execution_count": null, "id": "f526a02a-745d-43df-94af-1a50ed438fda", "metadata": {}, "outputs": [], "source": [ "os.makedirs(config.rootdir + '/dinov2', exist_ok=True)\n", "\n", "image_feats = []\n", "with tqdm(dataloader, bar_format='{l_bar}{bar:10}{r_bar}') as tepoch:\n", " for i, (image, label, index) in enumerate(tepoch):\n", " image = image.to(device)\n", "\n", " with torch.no_grad():\n", " output = model.get_feat_maps(image) # (B, C, H, W)\n", "\n", " B, C, H, W = output.size()\n", " output = output.reshape(B, C, H * W)\n", " image_feats.append(output.cpu())\n", "\n", "image_feats = torch.cat(image_feats, dim=0) # (N, C, H*W)\n", "torch.save(image_feats, rootdir + '/dinov2_image_feats.pth')" ] }, { "cell_type": "markdown", "id": "50f9ed32-5231-47d0-864e-cfbcd8b6d732", "metadata": {}, "source": [ "# Train Kmeans Segmentation" ] }, { "cell_type": "code", "execution_count": null, "id": "da8b9bde-2c1a-40d5-a32c-dda98334fe17", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import random\n", "import numpy as np\n", "\n", "dataset_name = 'cub200_2011'\n", "# dataset_name = 'dogs'\n", "\n", "sd = torch.load(f'data/{dataset_name}/dinov2_image_feats.pth', map_location='cpu')\n", "sd.size()" ] }, { "cell_type": "code", "execution_count": null, "id": "67b0f5ba-e272-4aea-b17d-05a80e2ce025", "metadata": {}, "outputs": [], "source": [ "from dataset import code_to_int, int_to_caption\n", "from dataset import ImageDataset\n", "from torchvision.transforms import transforms\n", "\n", "ds = ImageDataset(f'data/{dataset_name}', transform=transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224)]))\n", "train_lines = open(f'data/{dataset_name}/train.txt').readlines()" ] }, { "cell_type": "code", "execution_count": null, "id": "46c57aa3-de67-4d1a-b9a7-584687e437f5", "metadata": {}, "outputs": [], "source": [ "def set_seed(seed):\n", " random.seed(seed)\n", " np.random.seed(seed)\n", " torch.manual_seed(seed)\n", " torch.cuda.manual_seed(seed)\n", "\n", "set_seed(42)\n", " \n", "n = 100 # use small training sample to avoid OOM\n", "randidx = torch.randperm(len(sd))[:n]\n", "randsd = sd[randidx].permute(0, 2, 1) # (N, HW, C)\n", "randsd.size()" ] }, { "cell_type": "code", "execution_count": null, "id": "fab9e741-ff9a-4b6a-8ee3-41bc1d43cb52", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torchpq\n", "import torch.nn.functional as F\n", "import matplotlib.pyplot as plt\n", "import random\n", "from sklearn.decomposition import PCA\n", "\n", "set_seed(42)\n", "\n", "fg_kmeans = torchpq.clustering.KMeans(n_clusters=2,\n", " distance=\"cosine\",\n", " verbose=1,\n", " n_redo=5,\n", " max_iter=1000)\n", "fg_labels = fg_kmeans.fit(randsd.reshape(-1, 768).t().contiguous().cuda()).cpu().reshape(n, -1)" ] }, { "cell_type": "code", "execution_count": null, "id": "9a12550b-1c66-48a0-9bd6-a39b928cf57d", "metadata": {}, "outputs": [], "source": [ "torch.unique(fg_labels, return_counts=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "b9a3b9f0-fc73-4e44-bca8-a2a2da16df38", "metadata": {}, "outputs": [], "source": [ "for i in range(100):\n", " plt.subplot(10, 10, i+1)\n", " plt.imshow(fg_labels[i].reshape(16, 16))\n", " plt.axis('off')" ] }, { "cell_type": "code", "execution_count": null, "id": "6f8e6121-74ee-4539-bf61-9d0ba2198ef5", "metadata": {}, "outputs": [], "source": [ "fg_idx = 0 # this have to do manual inspection, based on the visualization above\n", "bg_idx = 1 - fg_idx\n", "\n", "randsd_bgnorm = []\n", "randsd_nobg = []\n", "randsd_bgmean = []\n", "\n", "for i in range(n):\n", " bgnorm_mean = randsd[i][fg_labels[i] == bg_idx].mean(dim=0, keepdim=True)\n", " \n", " if fg_idx == 0:\n", " bg_mask = fg_labels[i]\n", " else:\n", " bg_mask = 1 - fg_labels[i]\n", " \n", " bg_mask = bg_mask.unsqueeze(1)\n", " bgnorm = (randsd[i] * (1 - bg_mask)) + (bgnorm_mean * bg_mask)\n", " \n", " randsd_bgnorm.append(bgnorm)\n", " randsd_nobg.append(randsd[i] * (1 - bg_mask) + (-1 * bg_mask))\n", " randsd_bgmean.append(bgnorm_mean)\n", " \n", "randsd_bgnorm = torch.stack(randsd_bgnorm, dim=0)\n", "randsd_nobg = torch.stack(randsd_nobg, dim=0)\n", "randsd_bgmean = torch.cat(randsd_bgmean, dim=0)" ] }, { "cell_type": "code", "execution_count": null, "id": "f90785f8-d1e4-4dc9-9e6d-c242058639a5", "metadata": {}, "outputs": [], "source": [ "set_seed(42)\n", "M = 8\n", "\n", "coarse_kmeans = torchpq.clustering.KMeans(n_clusters=M,\n", " distance=\"cosine\",\n", " verbose=1,\n", " n_redo=5,\n", " max_iter=1000)\n", "coarse_labels = coarse_kmeans.fit(randsd_nobg.reshape(-1, 768).t().contiguous().cuda()).cpu().reshape(n, -1)" ] }, { "cell_type": "code", "execution_count": null, "id": "d4930216-de8a-420b-b7cd-45260299b7b9", "metadata": {}, "outputs": [], "source": [ "for i in range(100):\n", " plt.subplot(10, 10, i+1)\n", " plt.imshow(coarse_labels[i].reshape(16, 16))\n", " plt.axis('off')" ] }, { "cell_type": "code", "execution_count": null, "id": "d5e72d5e-df24-485d-85aa-de17ba381d84", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "disp = coarse_labels[0].reshape(16, 16)\n", "\n", "plt.imshow(disp)\n", "plt.axis('off')" ] }, { "cell_type": "code", "execution_count": null, "id": "b508e11d-082a-40a6-83db-4d306c0f9f00", "metadata": {}, "outputs": [], "source": [ "torch.unique(coarse_labels, return_counts=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "8f4bf418-2297-4dc5-8bdd-15af7cf44c7a", "metadata": {}, "outputs": [], "source": [ "sd_bgnorm = []\n", "sd_nobg = []\n", "sd_bgmean = []\n", "\n", "inp = sd.permute(0, 2, 1)\n", "N = inp.size(0)\n", "\n", "sd_fg_labels = []\n", "bs = 1000\n", "for bidx in range(N // bs + 1):\n", " if bidx * bs >= N:\n", " break\n", " \n", " start_bidx = bidx*bs\n", " end_bidx = min((bidx+1)*bs, N)\n", " \n", " sd_fg_labels.append(fg_kmeans.predict(inp[start_bidx:end_bidx].reshape(-1, 768).t().contiguous().cuda()).cpu().reshape(end_bidx - start_bidx, -1))\n", " \n", "sd_fg_labels = torch.cat(sd_fg_labels, dim=0)\n", "\n", "for i in range(N):\n", " bgnorm_mean = inp[i][sd_fg_labels[i] == bg_idx].mean(dim=0, keepdim=True)\n", " \n", " if fg_idx == 0:\n", " bg_mask = sd_fg_labels[i]\n", " else:\n", " bg_mask = 1 - sd_fg_labels[i]\n", " \n", " bg_mask = bg_mask.unsqueeze(1)\n", " bgnorm = (inp[i] * (1 - bg_mask)) + (bgnorm_mean * bg_mask)\n", " \n", " sd_bgnorm.append(bgnorm)\n", " sd_nobg.append(inp[i] * (1 - bg_mask) + (-1 * bg_mask))\n", " sd_bgmean.append(bgnorm_mean)\n", " print(i, end='\\r')\n", " \n", "sd_bgnorm = torch.stack(sd_bgnorm, dim=0)\n", "sd_nobg = torch.stack(sd_nobg, dim=0)\n", "sd_bgmean = torch.cat(sd_bgmean, dim=0)" ] }, { "cell_type": "code", "execution_count": null, "id": "d8026046-2ac5-4d62-9d18-4cbf7e2ebdb0", "metadata": {}, "outputs": [], "source": [ "sd_coarse_labels = []\n", "bs = 1000\n", "for bidx in range(N // bs + 1):\n", " if bidx * bs >= N:\n", " break\n", " \n", " start_bidx = bidx*bs\n", " end_bidx = min((bidx+1)*bs, N)\n", " \n", " sd_coarse_labels.append(coarse_kmeans.predict(sd_nobg[start_bidx:end_bidx].reshape(-1, 768).t().contiguous().cuda()).cpu().reshape(end_bidx - start_bidx, -1))\n", " \n", "sd_coarse_labels = torch.cat(sd_coarse_labels, dim=0)" ] }, { "cell_type": "code", "execution_count": null, "id": "80a46c2d-6cd4-4344-b56f-4ff1fa7235d3", "metadata": {}, "outputs": [], "source": [ "for i in range(100):\n", " plt.subplot(10, 10, i+1)\n", " coarse_mask = sd_coarse_labels[i].reshape(16, 16)\n", " plt.imshow(coarse_mask)\n", " plt.axis('off')" ] }, { "cell_type": "code", "execution_count": null, "id": "e41a997b-cccc-40c2-a983-c6092ffe69be", "metadata": {}, "outputs": [], "source": [ "torch.save(sd_coarse_labels.reshape(N, 16, 16).long().cpu(), f'data/{dataset_name}/coarse_mask_m8.pth')" ] }, { "cell_type": "code", "execution_count": null, "id": "cc918d47-35ca-4177-b63b-a12f4f4a3d5a", "metadata": {}, "outputs": [], "source": [ "torch.unique(sd_coarse_labels, return_counts=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "9d4beb2d-e385-4ecc-9f90-287d7bc13c0d", "metadata": {}, "outputs": [], "source": [ "from tqdm.auto import tqdm\n", "\n", "sd_fgmean = []\n", "\n", "inp = sd.permute(0, 2, 1)\n", "N = inp.size(0)\n", "M = 8\n", "\n", "for i in tqdm(range(N)):\n", " mean_feats = []\n", " for m in range(M):\n", " coarse_mask = sd_coarse_labels[i] == m\n", " if coarse_mask.sum().item() == 0:\n", " m_mean_feats = torch.zeros(1, 768)\n", " else:\n", " m_mean_feats = inp[i][coarse_mask].mean(dim=0, keepdim=True)\n", " \n", " mean_feats.append(m_mean_feats)\n", " \n", " mean_feats = torch.cat(mean_feats, dim=0)\n", " sd_fgmean.append(mean_feats)\n", " print(i, end='\\r')\n", " \n", "sd_fgmean = torch.stack(sd_fgmean, dim=0)" ] }, { "cell_type": "code", "execution_count": null, "id": "5034a91a-ae6d-468b-a738-f9ec0d019d72", "metadata": {}, "outputs": [], "source": [ "N = inp.size(0)\n", "M = 8\n", "K = 256\n", "bgm = {'cub200_2011': 7, 'dogs': 1}[dataset_name] # 7 for cub, 1 for dog, this means which index is background\n", "\n", "final_labels = torch.ones(N, M) * K\n", "\n", "set_seed(42)\n", "\n", "zero_mean_idxs = []\n", "fine_feats = []\n", "fine_kmeans_trained = []\n", "\n", "for m in range(M):\n", " fine_kmeans = torchpq.clustering.KMeans(n_clusters=K,\n", " distance=\"cosine\",\n", " verbose=1,\n", " n_redo=5,\n", " max_iter=1000)\n", " \n", " if m == bgm:\n", " fine_labels = fine_kmeans.fit(sd_bgmean.t().contiguous().cuda()).cpu()\n", " final_labels[:, m] = fine_labels\n", " else:\n", " fine_inp = sd_fgmean[:, m].reshape(-1, 768)\n", " fine_labels = fine_kmeans.fit(fine_inp.t().contiguous().cuda()).cpu()\n", " \n", " final_labels[:, m] = fine_labels\n", " \n", " fine_kmeans_trained.append(fine_kmeans)\n", " \n", " fine_feats.append(fine_kmeans.centroizds.cpu().t()[fine_labels])\n", " \n", " print('zero mean', torch.arange(K)[fine_kmeans.centroids.t().sum(dim=-1).cpu() == 0].tolist())\n", " zero_mean_idxs.append(torch.arange(K)[fine_kmeans.centroids.t().sum(dim=-1).cpu() == 0].tolist())\n", " \n", "fine_feats = torch.cat(fine_feats, dim=1)\n", "print(fine_feats.size())" ] }, { "cell_type": "code", "execution_count": null, "id": "450bc5c1-cc41-44b3-ab6a-c0c65eb1210a", "metadata": {}, "outputs": [], "source": [ "torch.save({\n", " 'foreground_background': fg_kmeans,\n", " 'coarse_kmeans': coarse_kmeans,\n", " 'fine_kmeans': fine_kmeans_trained,\n", "}, f'data/{dataset_name}/pretrained_kmeans.pth')" ] }, { "cell_type": "code", "execution_count": null, "id": "9ad2f095-f99d-489e-b645-8933b6f66372", "metadata": {}, "outputs": [], "source": [ "from tqdm.auto import tqdm\n", "\n", "final_code_captions = []\n", "counts = [[0 for _ in range(K)] for _ in range(M)]\n", "\n", "for i in tqdm(range(N)):\n", " m_labels = final_labels[i] # M\n", " \n", " line = []\n", " for m in range(M):\n", " k = m_labels[m].long().item()\n", " \n", " if k not in zero_mean_idxs[m]:\n", " line.append(f'{m}:{k}')\n", " counts[m][k] += 1\n", " \n", " assert len(line) != 0, f'error at {i}'\n", " final_code_captions.append(' '.join(line))" ] }, { "cell_type": "code", "execution_count": null, "id": "1be7233d-d9fe-4e0a-bda9-93fc45f8e982", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "for m in range(M):\n", " if max(counts[m]) == 0:\n", " continue\n", " \n", " plt.scatter(range(K), counts[m])\n", " print(m, min(counts[m]), max(counts[m]), np.mean(counts[m]))" ] }, { "cell_type": "code", "execution_count": null, "id": "e344ab48-fba6-4c17-a88c-e2528c0ac5cf", "metadata": {}, "outputs": [], "source": [ "with open(f'data/{dataset_name}/train_caps_better_m{M}_k{K}.txt', 'w+') as f:\n", " for line in final_code_captions:\n", " f.write(line + '\\n')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" } }, "nbformat": 4, "nbformat_minor": 5 }