# Compute DINO features

In [2]:
import argparse
import math
import os

import torch
import torchpq
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from sklearn.decomposition import PCA
from torchvision.transforms import transforms
from tqdm import tqdm
from transformers.utils import constants

from dreamcreature.dino import DINO
from dreamcreature.dataset import ImageDataset

MEAN = constants.IMAGENET_DEFAULT_MEAN
STD = constants.IMAGENET_DEFAULT_STD

In [4]:
dataset_name = 'cub200_2011'
# dataset_name = 'dogs'

rootdir = f'data/{dataset_name}'
resize = 256
crop = 224

dataset = ImageDataset(rootdir,
 'train.txt',
 transform=transforms.Compose([
 transforms.Resize(resize, interpolation=transforms.InterpolationMode.BICUBIC),
 transforms.CenterCrop(crop),
 transforms.ToTensor(),
 transforms.Normalize(MEAN, STD)
 ]))

In [None]:
dataloader = DataLoader(dataset, 32, shuffle=False, drop_last=False, num_workers=4)
model = DINO()
model.eval()

device = torch.device('cuda')
model = model.to(device)

In [None]:
os.makedirs(config.rootdir + '/dinov2', exist_ok=True)

image_feats = []
with tqdm(dataloader, bar_format='{l_bar}{bar:10}{r_bar}') as tepoch:
 for i, (image, label, index) in enumerate(tepoch):
 image = image.to(device)

 with torch.no_grad():
 output = model.get_feat_maps(image) # (B, C, H, W)

 B, C, H, W = output.size()
 output = output.reshape(B, C, H * W)
 image_feats.append(output.cpu())

image_feats = torch.cat(image_feats, dim=0) # (N, C, H*W)
torch.save(image_feats, rootdir + '/dinov2_image_feats.pth')

# Train Kmeans Segmentation

In [None]:
import torch
import random
import numpy as np

dataset_name = 'cub200_2011'
# dataset_name = 'dogs'

sd = torch.load(f'data/{dataset_name}/dinov2_image_feats.pth', map_location='cpu')
sd.size()

In [None]:
from dataset import code_to_int, int_to_caption
from dataset import ImageDataset
from torchvision.transforms import transforms

ds = ImageDataset(f'data/{dataset_name}', transform=transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224)]))
train_lines = open(f'data/{dataset_name}/train.txt').readlines()

In [None]:
def set_seed(seed):
 random.seed(seed)
 np.random.seed(seed)
 torch.manual_seed(seed)
 torch.cuda.manual_seed(seed)

set_seed(42)
 
n = 100 # use small training sample to avoid OOM
randidx = torch.randperm(len(sd))[:n]
randsd = sd[randidx].permute(0, 2, 1) # (N, HW, C)
randsd.size()

In [None]:
import numpy as np
import torchpq
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
from sklearn.decomposition import PCA

set_seed(42)

fg_kmeans = torchpq.clustering.KMeans(n_clusters=2,
 distance="cosine",
 verbose=1,
 n_redo=5,
 max_iter=1000)
fg_labels = fg_kmeans.fit(randsd.reshape(-1, 768).t().contiguous().cuda()).cpu().reshape(n, -1)

In [None]:
torch.unique(fg_labels, return_counts=True)

In [None]:
for i in range(100):
 plt.subplot(10, 10, i+1)
 plt.imshow(fg_labels[i].reshape(16, 16))
 plt.axis('off')

In [None]:
fg_idx = 0 # this have to do manual inspection, based on the visualization above
bg_idx = 1 - fg_idx

randsd_bgnorm = []
randsd_nobg = []
randsd_bgmean = []

for i in range(n):
 bgnorm_mean = randsd[i][fg_labels[i] == bg_idx].mean(dim=0, keepdim=True)
 
 if fg_idx == 0:
 bg_mask = fg_labels[i]
 else:
 bg_mask = 1 - fg_labels[i]
 
 bg_mask = bg_mask.unsqueeze(1)
 bgnorm = (randsd[i] * (1 - bg_mask)) + (bgnorm_mean * bg_mask)
 
 randsd_bgnorm.append(bgnorm)
 randsd_nobg.append(randsd[i] * (1 - bg_mask) + (-1 * bg_mask))
 randsd_bgmean.append(bgnorm_mean)
 
randsd_bgnorm = torch.stack(randsd_bgnorm, dim=0)
randsd_nobg = torch.stack(randsd_nobg, dim=0)
randsd_bgmean = torch.cat(randsd_bgmean, dim=0)

In [None]:
set_seed(42)
M = 8

coarse_kmeans = torchpq.clustering.KMeans(n_clusters=M,
 distance="cosine",
 verbose=1,
 n_redo=5,
 max_iter=1000)
coarse_labels = coarse_kmeans.fit(randsd_nobg.reshape(-1, 768).t().contiguous().cuda()).cpu().reshape(n, -1)

In [None]:
for i in range(100):
 plt.subplot(10, 10, i+1)
 plt.imshow(coarse_labels[i].reshape(16, 16))
 plt.axis('off')

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

disp = coarse_labels[0].reshape(16, 16)

plt.imshow(disp)
plt.axis('off')

In [None]:
torch.unique(coarse_labels, return_counts=True)

In [None]:
sd_bgnorm = []
sd_nobg = []
sd_bgmean = []

inp = sd.permute(0, 2, 1)
N = inp.size(0)

sd_fg_labels = []
bs = 1000
for bidx in range(N // bs + 1):
 if bidx * bs >= N:
 break
 
 start_bidx = bidx*bs
 end_bidx = min((bidx+1)*bs, N)
 
 sd_fg_labels.append(fg_kmeans.predict(inp[start_bidx:end_bidx].reshape(-1, 768).t().contiguous().cuda()).cpu().reshape(end_bidx - start_bidx, -1))
 
sd_fg_labels = torch.cat(sd_fg_labels, dim=0)

for i in range(N):
 bgnorm_mean = inp[i][sd_fg_labels[i] == bg_idx].mean(dim=0, keepdim=True)
 
 if fg_idx == 0:
 bg_mask = sd_fg_labels[i]
 else:
 bg_mask = 1 - sd_fg_labels[i]
 
 bg_mask = bg_mask.unsqueeze(1)
 bgnorm = (inp[i] * (1 - bg_mask)) + (bgnorm_mean * bg_mask)
 
 sd_bgnorm.append(bgnorm)
 sd_nobg.append(inp[i] * (1 - bg_mask) + (-1 * bg_mask))
 sd_bgmean.append(bgnorm_mean)
 print(i, end='\r')
 
sd_bgnorm = torch.stack(sd_bgnorm, dim=0)
sd_nobg = torch.stack(sd_nobg, dim=0)
sd_bgmean = torch.cat(sd_bgmean, dim=0)

In [None]:
sd_coarse_labels = []
bs = 1000
for bidx in range(N // bs + 1):
 if bidx * bs >= N:
 break
 
 start_bidx = bidx*bs
 end_bidx = min((bidx+1)*bs, N)
 
 sd_coarse_labels.append(coarse_kmeans.predict(sd_nobg[start_bidx:end_bidx].reshape(-1, 768).t().contiguous().cuda()).cpu().reshape(end_bidx - start_bidx, -1))
 
sd_coarse_labels = torch.cat(sd_coarse_labels, dim=0)

In [None]:
for i in range(100):
 plt.subplot(10, 10, i+1)
 coarse_mask = sd_coarse_labels[i].reshape(16, 16)
 plt.imshow(coarse_mask)
 plt.axis('off')

In [None]:
torch.save(sd_coarse_labels.reshape(N, 16, 16).long().cpu(), f'data/{dataset_name}/coarse_mask_m8.pth')

In [None]:
torch.unique(sd_coarse_labels, return_counts=True)

In [None]:
from tqdm.auto import tqdm

sd_fgmean = []

inp = sd.permute(0, 2, 1)
N = inp.size(0)
M = 8

for i in tqdm(range(N)):
 mean_feats = []
 for m in range(M):
 coarse_mask = sd_coarse_labels[i] == m
 if coarse_mask.sum().item() == 0:
 m_mean_feats = torch.zeros(1, 768)
 else:
 m_mean_feats = inp[i][coarse_mask].mean(dim=0, keepdim=True)
 
 mean_feats.append(m_mean_feats)
 
 mean_feats = torch.cat(mean_feats, dim=0)
 sd_fgmean.append(mean_feats)
 print(i, end='\r')
 
sd_fgmean = torch.stack(sd_fgmean, dim=0)

In [None]:
N = inp.size(0)
M = 8
K = 256
bgm = {'cub200_2011': 7, 'dogs': 1}[dataset_name] # 7 for cub, 1 for dog, this means which index is background

final_labels = torch.ones(N, M) * K

set_seed(42)

zero_mean_idxs = []
fine_feats = []
fine_kmeans_trained = []

for m in range(M):
 fine_kmeans = torchpq.clustering.KMeans(n_clusters=K,
 distance="cosine",
 verbose=1,
 n_redo=5,
 max_iter=1000)
 
 if m == bgm:
 fine_labels = fine_kmeans.fit(sd_bgmean.t().contiguous().cuda()).cpu()
 final_labels[:, m] = fine_labels
 else:
 fine_inp = sd_fgmean[:, m].reshape(-1, 768)
 fine_labels = fine_kmeans.fit(fine_inp.t().contiguous().cuda()).cpu()
 
 final_labels[:, m] = fine_labels
 
 fine_kmeans_trained.append(fine_kmeans)
 
 fine_feats.append(fine_kmeans.centroizds.cpu().t()[fine_labels])
 
 print('zero mean', torch.arange(K)[fine_kmeans.centroids.t().sum(dim=-1).cpu() == 0].tolist())
 zero_mean_idxs.append(torch.arange(K)[fine_kmeans.centroids.t().sum(dim=-1).cpu() == 0].tolist())
 
fine_feats = torch.cat(fine_feats, dim=1)
print(fine_feats.size())

In [None]:
torch.save({
 'foreground_background': fg_kmeans,
 'coarse_kmeans': coarse_kmeans,
 'fine_kmeans': fine_kmeans_trained,
}, f'data/{dataset_name}/pretrained_kmeans.pth')

In [None]:
from tqdm.auto import tqdm

final_code_captions = []
counts = [[0 for _ in range(K)] for _ in range(M)]

for i in tqdm(range(N)):
 m_labels = final_labels[i] # M
 
 line = []
 for m in range(M):
 k = m_labels[m].long().item()
 
 if k not in zero_mean_idxs[m]:
 line.append(f'{m}:{k}')
 counts[m][k] += 1
 
 assert len(line) != 0, f'error at {i}'
 final_code_captions.append(' '.join(line))

In [None]:
import matplotlib.pyplot as plt

for m in range(M):
 if max(counts[m]) == 0:
 continue
 
 plt.scatter(range(K), counts[m])
 print(m, min(counts[m]), max(counts[m]), np.mean(counts[m]))

In [None]:
with open(f'data/{dataset_name}/train_caps_better_m{M}_k{K}.txt', 'w+') as f:
 for line in final_code_captions:
 f.write(line + '\n')