diff --git a/featup/__init__.py b/featup/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dba658fd603bac1e6b5fce7d7189a508c6adc50d --- /dev/null +++ b/featup/__init__.py @@ -0,0 +1 @@ +from featup.upsamplers import JBULearnedRange \ No newline at end of file diff --git a/featup/adaptive_conv_cuda/__init__.py b/featup/adaptive_conv_cuda/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/featup/adaptive_conv_cuda/adaptive_conv.cpp b/featup/adaptive_conv_cuda/adaptive_conv.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ce3e6e1455d97f75009183b94ff60ac59ed6baf9 --- /dev/null +++ b/featup/adaptive_conv_cuda/adaptive_conv.cpp @@ -0,0 +1,142 @@ +#include + +#include +using torch::Tensor; + +Tensor adaptive_conv_forward(Tensor input, Tensor filters) { + + assert(input.dtype() == filters.dtype()); + + auto B = input.sizes()[0]; + auto C_in = input.sizes()[1]; + auto H_in = input.sizes()[2]; + auto W_in = input.sizes()[3]; + + assert(filters.sizes()[0] == B); + auto H_out = filters.sizes()[1]; + auto W_out = filters.sizes()[2]; + auto I = filters.sizes()[3]; + auto J = filters.sizes()[4]; + + assert(I == J); + assert(H_out + I - 1 == H_in); + assert(W_out + J - 1 == W_in); + + auto out = torch::zeros({ B, C_in, H_out, W_out }, input.dtype()); + + // output stationary + for (uint32_t b = 0; b < B; b++) { + for (uint32_t c = 0; c < C_in; c++) { + for (uint32_t h = 0; h < H_out; h++) { + for (uint32_t w = 0; w < W_out; w++) { + // produce output pixel b, h, w, c + for (uint32_t i = 0; i < I; i++) { + for (uint32_t j = 0; j < J; j++) { + auto weight = filters[b][h][w][i][j]; + assert(h+i < H_in); + assert(w+j < W_in); + auto input_val = input[b][c][h+i][w+j]; + out[b][c][h][w] += weight * input_val; + } + } + } + } + } + } + return out; +} + +Tensor adaptive_conv_grad_input(Tensor grad_output, Tensor filters) { + + auto B = grad_output.sizes()[0]; + auto C = grad_output.sizes()[1]; + auto H_out = grad_output.sizes()[2]; + auto W_out = grad_output.sizes()[3]; + + assert(filters.sizes()[0] == B); + assert(filters.sizes()[1] == H_out); + assert(filters.sizes()[2] == W_out); + auto I = filters.sizes()[3]; + auto J = filters.sizes()[4]; + assert(I == J); + + auto H_in = H_out + I - 1; + auto W_in = W_out + J - 1; + + assert(grad_output.dtype() == filters.dtype()); + + auto out = torch::zeros({ B, C, H_in, W_in }, grad_output.dtype()); + + for (int32_t b = 0; b < B; b++) { + for (int32_t c = 0; c < C; c++) { + for (int32_t h = 0; h < H_in; h++) { + for (int32_t w = 0; w < W_in; w++) { + for (int32_t i = 0; i < I; i++) { + for (int32_t j = 0; j < J; j++) { + + int32_t h_out = h - i; + int32_t w_out = w - j; + + if ((h_out >= 0) && (w_out >= 0) && (h_out < H_out) && (w_out < W_out)) { + auto grad = grad_output[b][c][h_out][w_out]; + auto weight = filters[b][h_out][w_out][i][j]; + + out[b][c][h][w] += grad * weight; + } + } + } + } + } + } + } + return out; +} + +Tensor adaptive_conv_grad_filters(Tensor grad_output, Tensor input) { + + auto B = grad_output.sizes()[0]; + auto C = grad_output.sizes()[1]; + auto H_out = grad_output.sizes()[2]; + auto W_out = grad_output.sizes()[3]; + + assert(input.sizes()[0] == B); + assert(input.sizes()[1] == C); + auto H_in = input.sizes()[2]; + auto W_in = input.sizes()[3]; + + assert(H_in > H_out); + assert(W_in > W_out); + + auto I = W_in - W_out + 1; + auto J = H_in - H_out + 1; + + assert(grad_output.dtype() == input.dtype()); + + auto out = torch::zeros({ B, H_out, W_out, I, J }, grad_output.dtype()); + + for (uint32_t b = 0; b < B; b++) { + for (uint32_t h = 0; h < H_out; h++) { + for (uint32_t w = 0; w < W_out; w++) { + for (uint32_t i = 0; i < I; i++) { + for (uint32_t j = 0; j < J; j++) { + for (uint32_t c = 0; c < C; c++) { + auto grad = grad_output[b][c][h][w]; + assert(h + i < H_in); + assert(w + j < W_in); + auto input_val = input[b][c][h+i][w+j]; + out[b][h][w][i][j] += grad * input_val; + } + } + } + } + } + } + + return out; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &adaptive_conv_forward, "adaptive_conv forward"); + m.def("grad_input", &adaptive_conv_grad_input, "adaptive_conv grad_input"); + m.def("grad_filters", &adaptive_conv_grad_filters, "adaptive_conv grad_filters"); +} diff --git a/featup/adaptive_conv_cuda/adaptive_conv.py b/featup/adaptive_conv_cuda/adaptive_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..b5c370ac92bceac72dfc50bd88ee05986674a90a --- /dev/null +++ b/featup/adaptive_conv_cuda/adaptive_conv.py @@ -0,0 +1,47 @@ +from torch.autograd import Function +import torch + +import adaptive_conv_cuda_impl as cuda_impl +import adaptive_conv_cpp_impl as cpp_impl + +torch.manual_seed(42) + + +class AdaptiveConv(Function): + + @staticmethod + def forward(ctx, input, filters): + ctx.save_for_backward(filters, input) + b, h2, w2, f1, f2 = filters.shape + assert f1 == f2 + + if input.is_cuda: + assert filters.is_cuda + result = cuda_impl.forward(input, filters) + else: + result = cpp_impl.forward(input, filters) + + return result + + @staticmethod + def backward(ctx, grad_output): + filters, input = ctx.saved_tensors + grad_input = grad_filters = None + b, h2, w2, f1, f2 = filters.shape + assert f1 == f2 + + grad_output = grad_output.contiguous() + if grad_output.is_cuda: + assert input.is_cuda + assert filters.is_cuda + if ctx.needs_input_grad[0]: + grad_input = cuda_impl.grad_input(grad_output, filters) + if ctx.needs_input_grad[1]: + grad_filters = cuda_impl.grad_filters(grad_output, input) + else: + if ctx.needs_input_grad[0]: + grad_input = cpp_impl.grad_input(grad_output, filters) + if ctx.needs_input_grad[1]: + grad_filters = cpp_impl.grad_filters(grad_output, input) + + return grad_input, grad_filters diff --git a/featup/adaptive_conv_cuda/adaptive_conv_cuda.cpp b/featup/adaptive_conv_cuda/adaptive_conv_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ed88682bb2d0ae01369b15eaf3ab44acd21292f7 --- /dev/null +++ b/featup/adaptive_conv_cuda/adaptive_conv_cuda.cpp @@ -0,0 +1,39 @@ +#include +using torch::Tensor; + +// CUDA forward declarations + +Tensor adaptive_conv_cuda_forward(Tensor input, Tensor filters); +Tensor adaptive_conv_cuda_grad_input(Tensor grad_output, Tensor filters); +Tensor adaptive_conv_cuda_grad_filters(Tensor grad_output, Tensor input); + +// C++ interface + +// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. +#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +Tensor adaptive_conv_forward(Tensor input, Tensor filters) { + //CHECK_INPUT(input); + //CHECK_INPUT(filters); + return adaptive_conv_cuda_forward(input, filters); +} + +Tensor adaptive_conv_grad_input(Tensor grad_output, Tensor filters) { + //CHECK_INPUT(grad_output); + //CHECK_INPUT(filters); + return adaptive_conv_cuda_grad_input(grad_output, filters); +} + +Tensor adaptive_conv_grad_filters(Tensor grad_output, Tensor input) { + //CHECK_INPUT(grad_output); + //CHECK_INPUT(input); + return adaptive_conv_cuda_grad_filters(grad_output, input); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &adaptive_conv_forward, "adaptive_conv forward"); + m.def("grad_input", &adaptive_conv_grad_input, "adaptive_conv grad_input"); + m.def("grad_filters", &adaptive_conv_grad_filters, "adaptive_conv grad_filters"); +} diff --git a/featup/adaptive_conv_cuda/adaptive_conv_kernel.cu b/featup/adaptive_conv_cuda/adaptive_conv_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..e6126f6e75b903528df7297ddf88acb5813379cb --- /dev/null +++ b/featup/adaptive_conv_cuda/adaptive_conv_kernel.cu @@ -0,0 +1,285 @@ +#include + +#include +#include +#include + +constexpr uint32_t kernel_channel_depth = 2; + +using torch::Tensor; +using namespace at; + +template +__launch_bounds__(1024) __global__ void adaptive_conv_forward_kernel( + torch::PackedTensorAccessor64 out, + torch::PackedTensorAccessor64 input, + torch::PackedTensorAccessor64 filters, + uint32_t batch) { + + const auto w = blockIdx.x * blockDim.x + threadIdx.x; + const auto h = blockIdx.y * blockDim.y + threadIdx.y; + const auto c_lo = blockIdx.z * kernel_channel_depth; + const auto c_hi = min(c_lo + kernel_channel_depth, (uint32_t) input.size(1)); + + const uint32_t I = filters.size(3); + const uint32_t J = filters.size(4); + + if (w < out.size(3) && h < out.size(2)) { + for (uint32_t c = c_lo; c < c_hi; c++) { + scalar_t output_val = 0.0; + for (uint32_t i = 0; i < I; i++) { + for (uint32_t j = 0; j < J; j++) { + + auto weight = filters[batch][h][w][i][j]; + auto input_val = input[batch][c][h+i][w+j]; + + output_val += (weight * input_val); + } + } + out[batch][c][h][w] = output_val; + } + } +} + +template +__launch_bounds__(1024) __global__ void adaptive_conv_grad_input_kernel( + torch::PackedTensorAccessor64 out, + torch::PackedTensorAccessor64 grad_output, + torch::PackedTensorAccessor64 filters, + uint32_t batch) { + + const int32_t w = blockIdx.x * blockDim.x + threadIdx.x; + const int32_t h = blockIdx.y * blockDim.y + threadIdx.y; + + const int32_t H_out = out.size(2); + const int32_t W_out = out.size(3); + + // thread's output index is outside output tensor + if (w >= W_out || h >= H_out) return; + + const int32_t c_lo = blockIdx.z * kernel_channel_depth; + const int32_t c_hi = min(c_lo + kernel_channel_depth, (int32_t) out.size(1)); + + const int32_t I = filters.size(3); + const int32_t J = filters.size(4); + + const int32_t H_grad = grad_output.size(2); + const int32_t W_grad = grad_output.size(3); + + for (int32_t c = c_lo; c < c_hi; c++) { + + scalar_t output_val = 0.0; + + for (int32_t i = 0; i < I; i++) { + for (int32_t j = 0; j < J; j++) { + const int32_t h_grad = h - i; + const int32_t w_grad = w - j; + + if (h_grad >= 0 && w_grad >= 0 && h_grad < H_grad && w_grad < W_grad) { + output_val += grad_output[batch][c][h_grad][w_grad] * filters[batch][h_grad][w_grad][i][j]; + } + } + } + out[batch][c][h][w] = output_val; + } +} + + +template +__launch_bounds__(1024) __global__ void adaptive_conv_grad_filters_kernel( + torch::PackedTensorAccessor64 out, + torch::PackedTensorAccessor64 grad_output, + torch::PackedTensorAccessor64 input, + uint32_t batch) { + + const uint32_t w = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t h = blockIdx.y * blockDim.y + threadIdx.y; + const uint32_t f = blockIdx.z * blockIdx.z + threadIdx.z; + + const uint32_t H = out.size(1); + const uint32_t W = out.size(2); + const uint32_t I = out.size(3); + const uint32_t J = out.size(4); + + assert(I == J); + + const uint32_t C = input.size(1); + + if (h >= H || w >= W || f >= (I * J)) return; + + const uint32_t i = f / I; + const uint32_t j = f % I; + + scalar_t output_val = 0.0; + for (uint32_t c = 0; c < C; c++) { + auto grad = grad_output[batch][c][h][w]; + auto input_val = input[batch][c][h+i][w+j]; + output_val += grad * input_val; + } + out[batch][h][w][i][j] = output_val; +} + + +template +T div_round_up(T a, T b) { + return (a + b - 1) / b; +} + +Tensor adaptive_conv_cuda_forward(Tensor input, Tensor filters) { + at::cuda::set_device(input.device().index()); + + // Check for error in the input tensors + TORCH_CHECK(input.dim() == 4, "input must have 4 dimensions"); + TORCH_CHECK(filters.dim() == 5, "filters must have 5 dimensions"); + TORCH_CHECK(input.dtype() == filters.dtype(), "input and filters must have the same data type"); + + const uint32_t B = input.size(0); + const uint32_t C = input.size(1); + const uint32_t H_in = input.size(2); + const uint32_t W_in = input.size(3); + + TORCH_CHECK(filters.size(0) == B, "Inconsistent batch size between input and filters"); + const uint32_t H_out = filters.size(1); + const uint32_t W_out = filters.size(2); + const uint32_t I = filters.size(3); + const uint32_t J = filters.size(4); + + TORCH_CHECK(I == J, "filters dimension I and J must be equal"); + TORCH_CHECK(H_out + I - 1 == H_in, "Inconsistent height between input and filters"); + TORCH_CHECK(W_out + J - 1 == W_in, "Inconsistent width between input and filters"); + + auto options = torch::TensorOptions() + .dtype(input.dtype()) + .device(torch::kCUDA); + + auto out = torch::zeros({ B, C, H_out, W_out }, options); + + const dim3 tpb(32, 32); + const dim3 blocks(div_round_up(W_out, tpb.x), + div_round_up(H_out, tpb.y), + div_round_up(C, kernel_channel_depth)); + + for (uint32_t b = 0; b < B; b++) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(out.scalar_type(), "adaptive_conv_forward_cuda", ([&] { + adaptive_conv_forward_kernel<<>>( + out.packed_accessor64(), + input.packed_accessor64(), + filters.packed_accessor64(), + b); + })); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("Error in adaptive_conv_forward_kernel: %s\n", cudaGetErrorString(err)); + } + } + return out; +} + + +Tensor adaptive_conv_cuda_grad_input(Tensor grad_output, Tensor filters) { + at::cuda::set_device(grad_output.device().index()); + + // Check for error in the input tensors + TORCH_CHECK(grad_output.dim() == 4, "grad_output must have 4 dimensions"); + TORCH_CHECK(filters.dim() == 5, "filters must have 5 dimensions"); + + const uint32_t B = grad_output.size(0); + const uint32_t C = grad_output.size(1); + const uint32_t H_out = grad_output.size(2); + const uint32_t W_out = grad_output.size(3); + + TORCH_CHECK(filters.size(0) == B, "Inconsistent batch size between filters and grad_output"); + TORCH_CHECK(filters.size(1) == H_out, "Inconsistent height between filters and grad_output"); + TORCH_CHECK(filters.size(2) == W_out, "Inconsistent width between filters and grad_output"); + + const uint32_t I = filters.size(3); + const uint32_t J = filters.size(4); + TORCH_CHECK(I == J, "filters dimension I and J must be equal"); + + const uint32_t H_in = H_out + I - 1; + const uint32_t W_in = W_out + J - 1; + + TORCH_CHECK(grad_output.dtype() == filters.dtype(), "grad_output and filters must have the same data type"); + + auto options = torch::TensorOptions() + .dtype(filters.dtype()) + .device(torch::kCUDA); + + auto out = torch::zeros({ B, C, H_in, W_in }, options); + + const dim3 tpb(32, 32); + const dim3 blocks(div_round_up(W_in, tpb.x), + div_round_up(H_in, tpb.y), + div_round_up(C, kernel_channel_depth)); + + for (uint32_t b = 0; b < B; b++) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(out.scalar_type(), "adaptive_conv_grad_input_cuda", ([&] { + adaptive_conv_grad_input_kernel<<>>( + out.packed_accessor64(), + grad_output.packed_accessor64(), + filters.packed_accessor64(), + b); + })); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("Error in adaptive_conv_grad_input_kernel: %s\n", cudaGetErrorString(err)); + } + } + return out; +} + +Tensor adaptive_conv_cuda_grad_filters(Tensor grad_output, Tensor input) { + at::cuda::set_device(grad_output.device().index()); + + // Check for error in the input tensors + TORCH_CHECK(grad_output.dim() == 4, "grad_output must have 4 dimensions"); + TORCH_CHECK(input.dim() == 4, "input must have 4 dimensions"); + + const uint32_t B = grad_output.size(0); + const uint32_t C = grad_output.size(1); + const uint32_t H_out = grad_output.size(2); + const uint32_t W_out = grad_output.size(3); + + TORCH_CHECK(input.size(0) == B, "Inconsistent batch size between input and grad_output"); + TORCH_CHECK(input.size(1) == C, "Inconsistent number of channels between input and grad_output"); + + const uint32_t H_in = input.size(2); + const uint32_t W_in = input.size(3); + + TORCH_CHECK(H_in > H_out, "Input height must be greater than grad_output height"); + TORCH_CHECK(W_in > W_out, "Input width must be greater than grad_output width"); + + const uint32_t I = W_in - W_out + 1; + const uint32_t J = H_in - H_out + 1; + + TORCH_CHECK(grad_output.dtype() == input.dtype(), "grad_output and input must have the same data type"); + + auto options = torch::TensorOptions() + .dtype(input.dtype()) + .device(torch::kCUDA); + + auto out = torch::zeros({ B, H_out, W_out, I, J }, options); + + const dim3 tpb(32, 32, 1); + const dim3 blocks(div_round_up(W_out, tpb.x), + div_round_up(H_out, tpb.y), + div_round_up(I * J, tpb.z)); + + + + for (uint32_t b = 0; b < B; b++) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(out.scalar_type(), "adaptive_conv_grad_filters_cuda", ([&] { + adaptive_conv_grad_filters_kernel<<>>( + out.packed_accessor64(), + grad_output.packed_accessor64(), + input.packed_accessor64(), + b); + })); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("Error in adaptive_conv_grad_filters_kernel: %s\n", cudaGetErrorString(err)); + } + } + return out; +} + diff --git a/featup/configs/implicit_upsampler.yaml b/featup/configs/implicit_upsampler.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d8670a6fc896e0a647dcd68346145b22e944193b --- /dev/null +++ b/featup/configs/implicit_upsampler.yaml @@ -0,0 +1,44 @@ +# Environment Args +output_root: '../../' +pytorch_data_dir: '/pytorch-data' +submitting_to_aml: false +summarize: true +experiment_name: "exp1" + +# Dataset args +dataset: "sample" +split: "val" +partition: 0 +total_partitions: 1 + +# Model Args +model_type: "maskclip" +activation_type: "token" + +# Upsampler args +outlier_detection: True +downsampler_type: "attention" +blur_attn: True +mag_tv_weight: 0.05 +mag_weight: 0.001 +color_feats: true +pca_batch: 50 +proj_dim: 128 +max_pad: 30 +use_flips: true +max_zoom: 1.8 +blur_pin: 0.1 +n_freqs: 30 +param_type: "implicit" +use_norm: false + +# Training args +steps: 1200 +n_images: 3000 + +# No need to change +hydra: + run: + dir: "." + output_subdir: ~ + diff --git a/featup/configs/jbu_upsampler.yaml b/featup/configs/jbu_upsampler.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a03676b0b44de2d7674326b16b6f9205202107ed --- /dev/null +++ b/featup/configs/jbu_upsampler.yaml @@ -0,0 +1,39 @@ +# Environment Args +output_root: '../../' +pytorch_data_dir: '/pytorch-data' +submitting_to_aml: false + +# Dataset args +dataset: "cocostuff" + +# Model Args +model_type: "vit" +activation_type: "token" + +# Upsampling args +outlier_detection: True +upsampler_type: "jbu_stack" +downsampler_type: "attention" +max_pad: 20 +max_zoom: 2 +n_jitters: 5 +random_projection: 30 +crf_weight: 0.001 +filter_ent_weight: 0.0 +tv_weight: 0.0 + +implicit_sup_weight: 1.0 + +# Training args +batch_size: 4 +epochs: 1 +num_gpus: 1 +num_workers: 24 +lr: 1e-3 + +# No need to change +hydra: + run: + dir: "." + output_subdir: ~ + diff --git a/featup/configs/train_probe.yaml b/featup/configs/train_probe.yaml new file mode 100644 index 0000000000000000000000000000000000000000..db60c483a278d8e329c377192b2badefbcbad5d0 --- /dev/null +++ b/featup/configs/train_probe.yaml @@ -0,0 +1,38 @@ +# Environment Args +output_root: '../../' +pytorch_data_dir: '/pytorch-data' +submitting_to_aml: false + +# Dataset args +task: "seg" + +# Model Args +model_type: "vit" +activation_type: "token" + +# Upsampling args +outlier_detection: True +upsampler_type: "jbu_stack" +downsampler_type: "attention" +max_pad: 20 +max_zoom: 2 +n_jitters: 5 +random_projection: 30 +crf_weight: 0.001 +filter_ent_weight: 0.0 +tv_weight: 0.0 + +# Training args +batch_size: 2 +epochs: 200 +num_workers: 24 +lr: 1e-3 +dropout: .5 +wd: 0.0 + +# No need to change +hydra: + run: + dir: "." + output_subdir: ~ + diff --git a/featup/datasets/COCO.py b/featup/datasets/COCO.py new file mode 100644 index 0000000000000000000000000000000000000000..1ff4d18beb5ebf8e0c1e05d5fa614054cb552c52 --- /dev/null +++ b/featup/datasets/COCO.py @@ -0,0 +1,148 @@ +import random +from os.path import join + +import numpy as np +import torch +import torch.multiprocessing +from PIL import Image +from torch.utils.data import Dataset + + +def bit_get(val, idx): + """Gets the bit value. + Args: + val: Input value, int or numpy int array. + idx: Which bit of the input val. + Returns: + The "idx"-th bit of input val. + """ + return (val >> idx) & 1 + + +def create_pascal_label_colormap(): + """Creates a label colormap used in PASCAL VOC segmentation benchmark. + Returns: + A colormap for visualizing segmentation results. + """ + colormap = np.zeros((512, 3), dtype=int) + ind = np.arange(512, dtype=int) + + for shift in reversed(list(range(8))): + for channel in range(3): + colormap[:, channel] |= bit_get(ind, channel) << shift + ind >>= 3 + + return colormap + + +class Coco(Dataset): + def __init__(self, + root, + split, + transform, + target_transform, + include_labels=True, + coarse_labels=False, + exclude_things=False, + subset=None): + super(Coco, self).__init__() + self.split = split + self.root = join(root, "cocostuff") + self.coarse_labels = coarse_labels + self.transform = transform + self.label_transform = target_transform + self.subset = subset + self.exclude_things = exclude_things + self.include_labels = include_labels + + if self.subset is None: + self.image_list = "Coco164kFull_Stuff_Coarse.txt" + elif self.subset == 6: # IIC Coarse + self.image_list = "Coco164kFew_Stuff_6.txt" + elif self.subset == 7: # IIC Fine + self.image_list = "Coco164kFull_Stuff_Coarse_7.txt" + + assert self.split in ["train", "val", "train+val"] + split_dirs = { + "train": ["train2017"], + "val": ["val2017"], + "train+val": ["train2017", "val2017"] + } + + self.image_files = [] + self.label_files = [] + for split_dir in split_dirs[self.split]: + with open(join(self.root, "curated", split_dir, self.image_list), "r") as f: + img_ids = [fn.rstrip() for fn in f.readlines()] + for img_id in img_ids: + self.image_files.append(join(self.root, "images", split_dir, img_id + ".jpg")) + self.label_files.append(join(self.root, "annotations", split_dir, img_id + ".png")) + + self.fine_to_coarse = {0: 9, 1: 11, 2: 11, 3: 11, 4: 11, 5: 11, 6: 11, 7: 11, 8: 11, 9: 8, 10: 8, 11: 8, 12: 8, + 13: 8, 14: 8, 15: 7, 16: 7, 17: 7, 18: 7, 19: 7, 20: 7, 21: 7, 22: 7, 23: 7, 24: 7, + 25: 6, 26: 6, 27: 6, 28: 6, 29: 6, 30: 6, 31: 6, 32: 6, 33: 10, 34: 10, 35: 10, 36: 10, + 37: 10, 38: 10, 39: 10, 40: 10, 41: 10, 42: 10, 43: 5, 44: 5, 45: 5, 46: 5, 47: 5, 48: 5, + 49: 5, 50: 5, 51: 2, 52: 2, 53: 2, 54: 2, 55: 2, 56: 2, 57: 2, 58: 2, 59: 2, 60: 2, + 61: 3, 62: 3, 63: 3, 64: 3, 65: 3, 66: 3, 67: 3, 68: 3, 69: 3, 70: 3, 71: 0, 72: 0, + 73: 0, 74: 0, 75: 0, 76: 0, 77: 1, 78: 1, 79: 1, 80: 1, 81: 1, 82: 1, 83: 4, 84: 4, + 85: 4, 86: 4, 87: 4, 88: 4, 89: 4, 90: 4, 91: 17, 92: 17, 93: 22, 94: 20, 95: 20, 96: 22, + 97: 15, 98: 25, 99: 16, 100: 13, 101: 12, 102: 12, 103: 17, 104: 17, 105: 23, 106: 15, + 107: 15, 108: 17, 109: 15, 110: 21, 111: 15, 112: 25, 113: 13, 114: 13, 115: 13, 116: 13, + 117: 13, 118: 22, 119: 26, 120: 14, 121: 14, 122: 15, 123: 22, 124: 21, 125: 21, 126: 24, + 127: 20, 128: 22, 129: 15, 130: 17, 131: 16, 132: 15, 133: 22, 134: 24, 135: 21, 136: 17, + 137: 25, 138: 16, 139: 21, 140: 17, 141: 22, 142: 16, 143: 21, 144: 21, 145: 25, 146: 21, + 147: 26, 148: 21, 149: 24, 150: 20, 151: 17, 152: 14, 153: 21, 154: 26, 155: 15, 156: 23, + 157: 20, 158: 21, 159: 24, 160: 15, 161: 24, 162: 22, 163: 25, 164: 15, 165: 20, 166: 17, + 167: 17, 168: 22, 169: 14, 170: 18, 171: 18, 172: 18, 173: 18, 174: 18, 175: 18, 176: 18, + 177: 26, 178: 26, 179: 19, 180: 19, 181: 24} + + self._label_names = [ + "ground-stuff", + "plant-stuff", + "sky-stuff", + ] + self.cocostuff3_coarse_classes = [23, 22, 21] + self.first_stuff_index = 12 + + def __len__(self): + return len(self.image_files) + + def __getitem__(self, index): + image_path = self.image_files[index] + label_path = self.label_files[index] + seed = np.random.randint(2147483647) + batch = {} + + random.seed(seed) + torch.manual_seed(seed) + img = self.transform(Image.open(image_path).convert("RGB")) + batch["img"] = img + batch["img_path"] = image_path + + if self.include_labels: + random.seed(seed) + torch.manual_seed(seed) + label = self.label_transform(Image.open(label_path)).squeeze(0) + label[label == 255] = -1 # to be consistent with 10k + coarse_label = torch.zeros_like(label) + for fine, coarse in self.fine_to_coarse.items(): + coarse_label[label == fine] = coarse + coarse_label[label == -1] = -1 + + if self.coarse_labels: + coarser_labels = -torch.ones_like(label) + for i, c in enumerate(self.cocostuff3_coarse_classes): + coarser_labels[coarse_label == c] = i + batch["label"] = coarser_labels + else: + if self.exclude_things: + batch["label"] = coarse_label - self.first_stuff_index + else: + batch["label"] = coarse_label + + return batch + + @staticmethod + def colorize_label(label): + cmap = create_pascal_label_colormap() + return cmap[label.cpu()].astype(np.uint8) diff --git a/featup/datasets/DAVIS.py b/featup/datasets/DAVIS.py new file mode 100644 index 0000000000000000000000000000000000000000..1a5ecdb75328cd4eea1791cdcd27279533eb3641 --- /dev/null +++ b/featup/datasets/DAVIS.py @@ -0,0 +1,42 @@ +from torchvision import transforms +import os +from PIL import Image +from torch.utils.data import Dataset + + +class DAVIS(Dataset): + def __init__(self, root, video_name, transform=None): + """ + Args: + root (string): Directory with all the videos. + video_name (string): Name of the specific video. + transform (callable, optional): Optional transform to be applied on a sample. + """ + self.root_dir = os.path.join(root, "DAVIS/JPEGImages/480p/", video_name) + self.frames = os.listdir(self.root_dir) + self.transform = transform + + def __len__(self): + return len(self.frames) + + def __getitem__(self, idx): + img_path = os.path.join(self.root_dir, self.frames[idx]) + image = Image.open(img_path).convert("RGB") + + if self.transform: + image = self.transform(image) + + return {"img": image, "img_path": img_path} + + +if __name__ == "__main__": + transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor() + ]) + + davis_dataset = DAVIS(root='/pytorch-data', video_name="motocross-jump", transform=transform) + + frames = davis_dataset[0] + + print("here") diff --git a/featup/datasets/EmbeddingFile.py b/featup/datasets/EmbeddingFile.py new file mode 100644 index 0000000000000000000000000000000000000000..b2f5ab37561cbdfca588481a025d8f3e415f10db --- /dev/null +++ b/featup/datasets/EmbeddingFile.py @@ -0,0 +1,55 @@ +import numpy as np +from torch.utils.data import Dataset + + +class EmbeddingFile(Dataset): + """ + modified from: https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder + uses cached directory listing if available rather than walking directory + Attributes: + classes (list): List of the class names. + class_to_idx (dict): Dict with items (class_name, class_index). + samples (list): List of (sample path, class_index) tuples + targets (list): The class_index value for each image in the dataset + """ + + def __init__(self, file): + super(Dataset, self).__init__() + self.file = file + loaded = np.load(file) + self.feats = loaded["feats"] + self.labels = loaded["labels"] + + def dim(self): + return self.feats.shape[1] + + def num_classes(self): + return self.labels.max() + 1 + + def __getitem__(self, index): + return self.feats[index], self.labels[index] + + def __len__(self): + return len(self.labels) + + +class EmbeddingAndImage(Dataset): + def __init__(self, file, dataset): + super(Dataset, self).__init__() + self.file = file + loaded = np.load(file) + self.feats = loaded["feats"] + self.labels = loaded["labels"] + self.imgs = dataset + + def dim(self): + return self.feats.shape[1] + + def num_classes(self): + return self.labels.max() + 1 + + def __getitem__(self, index): + return self.feats[index], self.labels[index], self.imgs[index] + + def __len__(self): + return len(self.labels) diff --git a/featup/datasets/HighResEmbs.py b/featup/datasets/HighResEmbs.py new file mode 100644 index 0000000000000000000000000000000000000000..51f6828a95fa69cff4f638409605f26866d8eed0 --- /dev/null +++ b/featup/datasets/HighResEmbs.py @@ -0,0 +1,268 @@ +import collections +import sys +from os.path import join + +import featup.downsamplers +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms as T +from featup.featurizers.util import get_featurizer +from featup.layers import ChannelNorm +from featup.layers import ChannelNorm +from featup.util import norm +from sklearn.decomposition import PCA +from torch.utils.data import Dataset, DataLoader +from torch.utils.data import Subset +from torch.utils.data import default_collate +from tqdm import tqdm + +from util import get_dataset + +torch.multiprocessing.set_sharing_strategy('file_system') + + +def clamp_mag(t, min_mag, max_mag): + mags = mag(t) + clamped_above = t * (max_mag / mags.clamp_min(.000001)).clamp_max(1.0) + clamped_below = clamped_above * (min_mag / mags.clamp_min(.000001)).clamp_min(1.0) + return clamped_below + + +def pca(image_feats_list, dim=3, fit_pca=None): + device = image_feats_list[0].device + + def flatten(tensor, target_size=None): + if target_size is not None and fit_pca is None: + F.interpolate(tensor, (target_size, target_size), mode="bilinear") + B, C, H, W = tensor.shape + return feats.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu() + + if len(image_feats_list) > 1 and fit_pca is None: + target_size = image_feats_list[0].shape[2] + else: + target_size = None + + flattened_feats = [] + for feats in image_feats_list: + flattened_feats.append(flatten(feats, target_size)) + x = torch.cat(flattened_feats, dim=0) + + if fit_pca is None: + fit_pca = PCA(n_components=dim).fit(x) + + reduced_feats = [] + for feats in image_feats_list: + x_red = torch.from_numpy(fit_pca.transform(flatten(feats))) + x_red -= x_red.min(dim=0, keepdim=True).values + x_red /= x_red.max(dim=0, keepdim=True).values + B, C, H, W = feats.shape + reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device)) + + return reduced_feats, fit_pca + + +def mag(t): + return t.square().sum(1, keepdim=True).sqrt() + + +def model_collate(batch): + elem = batch[0] + elem_type = type(elem) + if isinstance(elem, torch.nn.Module): + return batch + elif isinstance(elem, collections.abc.Mapping): + try: + return elem_type({key: model_collate([d[key] for d in batch]) for key in elem}) + except TypeError: + # The mapping type may not support `__init__(iterable)`. + return {key: model_collate([d[key] for d in batch]) for key in elem} + else: + return default_collate(batch) + + +class HighResEmbHelper(Dataset): + def __init__(self, + root, + output_root, + dataset_name, + emb_name, + split, + model_type, + transform, + target_transform, + limit, + include_labels): + self.root = root + self.emb_dir = join(output_root, "feats", emb_name, dataset_name, split, model_type) + + self.dataset = get_dataset( + root, dataset_name, split, transform, target_transform, include_labels=include_labels) + + if split == 'train': + self.dataset = Subset(self.dataset, generate_subset(len(self.dataset), 5000)) + # TODO factor this limit out + + if limit is not None: + self.dataset = Subset(self.dataset, range(0, limit)) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, item): + batch = self.dataset[item] + output_location = join(self.emb_dir, "/".join(batch["img_path"].split("/")[-1:]).replace(".jpg", ".pth")) + state_dicts = torch.load(output_location, map_location="cpu") + from featup.train_implicit_upsampler import get_implicit_upsampler + from featup.util import PCAUnprojector + model = get_implicit_upsampler(**state_dicts["model_args"]) + model.load_state_dict(state_dicts["model"]) + unp_state_dict = state_dicts["unprojector"] + unprojector = PCAUnprojector( + None, + unp_state_dict["components_"].shape[0], + device="cpu", + original_dim=unp_state_dict["components_"].shape[1], + **unp_state_dict + ) + batch["model"] = {"model": model, "unprojector": unprojector} + return batch + + +def load_hr_emb(image, loaded_model, target_res): + image = image.cuda() + if isinstance(loaded_model["model"], list): + hr_model = loaded_model["model"][0].cuda().eval() + unprojector = loaded_model["unprojector"][0].eval() + else: + hr_model = loaded_model["model"].cuda().eval() + unprojector = loaded_model["unprojector"].eval() + + with torch.no_grad(): + original_image = F.interpolate( + image, size=(target_res, target_res), mode='bilinear', antialias=True) + hr_feats = hr_model(original_image) + return unprojector(hr_feats.detach().cpu()) + + +class HighResEmb(Dataset): + def __init__(self, + root, + dataset_name, + emb_name, + split, + output_root, + model_type, + transform, + target_transform, + target_res, + limit, + include_labels, + ): + self.root = root + self.dataset = HighResEmbHelper( + root=root, + output_root=output_root, + dataset_name=dataset_name, + emb_name=emb_name, + split=split, + model_type=model_type, + transform=transform, + target_transform=target_transform, + limit=limit, + include_labels=include_labels) + + self.all_hr_feats = [] + self.target_res = target_res + loader = DataLoader(self.dataset, shuffle=False, batch_size=1, num_workers=12, collate_fn=model_collate) + + for img_num, batch in enumerate(tqdm(loader, "Loading hr embeddings")): + with torch.no_grad(): + self.all_hr_feats.append(load_hr_emb(batch["img"], batch["model"], target_res)) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, item): + batch = self.dataset.dataset[item] + batch["hr_feat"] = self.all_hr_feats[item].squeeze(0) + return batch + + +def generate_subset(n, batch): + np.random.seed(0) + return np.random.permutation(n)[:batch] + + +def load_some_hr_feats(model_type, + activation_type, + dataset_name, + split, + emb_name, + root, + output_root, + input_size, + samples_per_batch, + num_batches, + num_workers + ): + transform = T.Compose([ + T.Resize(input_size), + T.CenterCrop(input_size), + T.ToTensor(), + norm + ]) + + shared_args = dict( + root=root, + dataset_name=dataset_name, + emb_name=emb_name, + output_root=output_root, + model_type=model_type, + transform=transform, + target_transform=None, + target_res=input_size, + include_labels=False, + limit=samples_per_batch * num_batches + ) + + def get_data(model, ds): + loader = DataLoader(ds, batch_size=samples_per_batch, num_workers=num_workers) + all_batches = [] + for batch in loader: + batch["lr_feat"] = model(batch["img"].cuda()).cpu() + all_batches.append(batch) + + big_batch = {} + for k, t in all_batches[0].items(): + if isinstance(t, torch.Tensor): + big_batch[k] = torch.cat([b[k] for b in all_batches], dim=0) + del loader + return big_batch + + with torch.no_grad(): + model, _, dim = get_featurizer(model_type, activation_type) + model = torch.nn.Sequential(model, ChannelNorm(dim)) + model = model.cuda() + batch = get_data(model, HighResEmb(split=split, **shared_args)) + del model + + return batch + + +if __name__ == "__main__": + loaded = load_some_hr_feats( + "vit", + "token", + "cocostuff", + "train", + "3_12_2024", + "/pytorch-data/", + "../../../", + 224, + 50, + 3, + 0 + ) + + print(loaded) diff --git a/featup/datasets/ImageNetSubset.py b/featup/datasets/ImageNetSubset.py new file mode 100644 index 0000000000000000000000000000000000000000..883df169cb50e636da602afd243d0417e7fbfd9c --- /dev/null +++ b/featup/datasets/ImageNetSubset.py @@ -0,0 +1,1093 @@ +from torchvision.datasets import folder +from torchvision.datasets.vision import VisionDataset +from glob import glob +from os.path import join + + +class ImageNetSubset(VisionDataset): + """ + modified from: https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder + uses cached directory listing if available rather than walking directory + Attributes: + classes (list): List of the class names. + class_to_idx (dict): Dict with items (class_name, class_index). + samples (list): List of (sample path, class_index) tuples + targets (list): The class_index value for each image in the dataset + """ + + def __init__(self, + root, + split, + transform=None, + target_transform=None, + subset=None, + include_labels=True, + loader=folder.default_loader): + super(ImageNetSubset, self).__init__(root, transform=transform, target_transform=target_transform) + self.root = join(root, "imagenet") + self.filenames = [] + self.targets = None + self.include_labels = include_labels + + if subset is not None: + self.targets = [] + with open(subset, "r") as f: + for line in f: + (path, idx) = line.strip().split(';') + self.filenames.append( + join(self.root, path)) + self.targets.append(int(idx)) + else: + if split == "train": + dirs = join(split, "*") + else: + dirs = split + self.filenames = sorted(list(glob(join(self.root, dirs, "*")))) + self.targets = None + + # cache = self.root.rstrip('/') + '/' + split + '.txt' + # cache = '../' + split + '.txt' + # print("Using directory list at: %s" % cache) + # with open(cache) as f: + # samples = [] + # for line in f: + # if ';' in line: + # (path, idx) = line.strip().split(';') + # else: + # path = line.strip() + # samples.append(os.path.join(self.root, path)) + # self.filenames = samples + + if len(self.filenames) == 0: + raise RuntimeError(f"Cache file contained no filenames") + self.loader = loader + + self.transform = transform + self.target_transform = target_transform + + def __getitem__(self, index): + image_path = self.filenames[index] + sample = self.loader(image_path) + + if self.transform is not None: + sample = self.transform(sample) + + batch = { + "img": sample, + "index": index, + "img_path": image_path + } + + if self.include_labels: + target = self.targets[index] + if self.target_transform is not None: + target = self.target_transform(target) + batch["label"] = target + + return batch + + def __len__(self): + return len(self.filenames) + + +class_labels = { + 0: 'tench, Tinca tinca', + 1: 'goldfish, Carassius auratus', + 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', + 3: 'tiger shark, Galeocerdo cuvieri', + 4: 'hammerhead, hammerhead shark', + 5: 'electric ray, crampfish, numbfish, torpedo', + 6: 'stingray', + 7: 'cock', + 8: 'hen', + 9: 'ostrich, Struthio camelus', + 10: 'brambling, Fringilla montifringilla', + 11: 'goldfinch, Carduelis carduelis', + 12: 'house finch, linnet, Carpodacus mexicanus', + 13: 'junco, snowbird', + 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea', + 15: 'robin, American robin, Turdus migratorius', + 16: 'bulbul', + 17: 'jay', + 18: 'magpie', + 19: 'chickadee', + 20: 'water ouzel, dipper', + 21: 'kite', + 22: 'bald eagle, American eagle, Haliaeetus leucocephalus', + 23: 'vulture', + 24: 'great grey owl, great gray owl, Strix nebulosa', + 25: 'European fire salamander, Salamandra salamandra', + 26: 'common newt, Triturus vulgaris', + 27: 'eft', + 28: 'spotted salamander, Ambystoma maculatum', + 29: 'axolotl, mud puppy, Ambystoma mexicanum', + 30: 'bullfrog, Rana catesbeiana', + 31: 'tree frog, tree-frog', + 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui', + 33: 'loggerhead, loggerhead turtle, Caretta caretta', + 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea', + 35: 'mud turtle', + 36: 'terrapin', + 37: 'box turtle, box tortoise', + 38: 'banded gecko', + 39: 'common iguana, iguana, Iguana iguana', + 40: 'American chameleon, anole, Anolis carolinensis', + 41: 'whiptail, whiptail lizard', + 42: 'agama', + 43: 'frilled lizard, Chlamydosaurus kingi', + 44: 'alligator lizard', + 45: 'Gila monster, Heloderma suspectum', + 46: 'green lizard, Lacerta viridis', + 47: 'African chameleon, Chamaeleo chamaeleon', + 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis', + 49: 'African crocodile, Nile crocodile, Crocodylus niloticus', + 50: 'American alligator, Alligator mississipiensis', + 51: 'triceratops', + 52: 'thunder snake, worm snake, Carphophis amoenus', + 53: 'ringneck snake, ring-necked snake, ring snake', + 54: 'hognose snake, puff adder, sand viper', + 55: 'green snake, grass snake', + 56: 'king snake, kingsnake', + 57: 'garter snake, grass snake', + 58: 'water snake', + 59: 'vine snake', + 60: 'night snake, Hypsiglena torquata', + 61: 'boa constrictor, Constrictor constrictor', + 62: 'rock python, rock snake, Python sebae', + 63: 'Indian cobra, Naja naja', + 64: 'green mamba', + 65: 'sea snake', + 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus', + 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus', + 68: 'sidewinder, horned rattlesnake, Crotalus cerastes', + 69: 'trilobite', + 70: 'harvestman, daddy longlegs, Phalangium opilio', + 71: 'scorpion', + 72: 'black and gold garden spider, Argiope aurantia', + 73: 'barn spider, Araneus cavaticus', + 74: 'garden spider, Aranea diademata', + 75: 'black widow, Latrodectus mactans', + 76: 'tarantula', + 77: 'wolf spider, hunting spider', + 78: 'tick', + 79: 'centipede', + 80: 'black grouse', + 81: 'ptarmigan', + 82: 'ruffed grouse, partridge, Bonasa umbellus', + 83: 'prairie chicken, prairie grouse, prairie fowl', + 84: 'peacock', + 85: 'quail', + 86: 'partridge', + 87: 'African grey, African gray, Psittacus erithacus', + 88: 'macaw', + 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita', + 90: 'lorikeet', + 91: 'coucal', + 92: 'bee eater', + 93: 'hornbill', + 94: 'hummingbird', + 95: 'jacamar', + 96: 'toucan', + 97: 'drake', + 98: 'red-breasted merganser, Mergus serrator', + 99: 'goose', + 100: 'black swan, Cygnus atratus', + 101: 'tusker', + 102: 'echidna, spiny anteater, anteater', + 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus', + 104: 'wallaby, brush kangaroo', + 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus', + 106: 'wombat', + 107: 'jellyfish', + 108: 'sea anemone, anemone', + 109: 'brain coral', + 110: 'flatworm, platyhelminth', + 111: 'nematode, nematode worm, roundworm', + 112: 'conch', + 113: 'snail', + 114: 'slug', + 115: 'sea slug, nudibranch', + 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore', + 117: 'chambered nautilus, pearly nautilus, nautilus', + 118: 'Dungeness crab, Cancer magister', + 119: 'rock crab, Cancer irroratus', + 120: 'fiddler crab', + 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica', + 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus', + 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish', + 124: 'crayfish, crawfish, crawdad, crawdaddy', + 125: 'hermit crab', + 126: 'isopod', + 127: 'white stork, Ciconia ciconia', + 128: 'black stork, Ciconia nigra', + 129: 'spoonbill', + 130: 'flamingo', + 131: 'little blue heron, Egretta caerulea', + 132: 'American egret, great white heron, Egretta albus', + 133: 'bittern', + 134: 'crane', + 135: 'limpkin, Aramus pictus', + 136: 'European gallinule, Porphyrio porphyrio', + 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana', + 138: 'bustard', + 139: 'ruddy turnstone, Arenaria interpres', + 140: 'red-backed sandpiper, dunlin, Erolia alpina', + 141: 'redshank, Tringa totanus', + 142: 'dowitcher', + 143: 'oystercatcher, oyster catcher', + 144: 'pelican', + 145: 'king penguin, Aptenodytes patagonica', + 146: 'albatross, mollymawk', + 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus', + 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca', + 149: 'dugong, Dugong dugon', + 150: 'sea lion', + 151: 'Chihuahua', + 152: 'Japanese spaniel', + 153: 'Maltese dog, Maltese terrier, Maltese', + 154: 'Pekinese, Pekingese, Peke', + 155: 'Shih-Tzu', + 156: 'Blenheim spaniel', + 157: 'papillon', + 158: 'toy terrier', + 159: 'Rhodesian ridgeback', + 160: 'Afghan hound, Afghan', + 161: 'basset, basset hound', + 162: 'beagle', + 163: 'bloodhound, sleuthhound', + 164: 'bluetick', + 165: 'black-and-tan coonhound', + 166: 'Walker hound, Walker foxhound', + 167: 'English foxhound', + 168: 'redbone', + 169: 'borzoi, Russian wolfhound', + 170: 'Irish wolfhound', + 171: 'Italian greyhound', + 172: 'whippet', + 173: 'Ibizan hound, Ibizan Podenco', + 174: 'Norwegian elkhound, elkhound', + 175: 'otterhound, otter hound', + 176: 'Saluki, gazelle hound', + 177: 'Scottish deerhound, deerhound', + 178: 'Weimaraner', + 179: 'Staffordshire bullterrier, Staffordshire bull terrier', + 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier', + 181: 'Bedlington terrier', + 182: 'Border terrier', + 183: 'Kerry blue terrier', + 184: 'Irish terrier', + 185: 'Norfolk terrier', + 186: 'Norwich terrier', + 187: 'Yorkshire terrier', + 188: 'wire-haired fox terrier', + 189: 'Lakeland terrier', + 190: 'Sealyham terrier, Sealyham', + 191: 'Airedale, Airedale terrier', + 192: 'cairn, cairn terrier', + 193: 'Australian terrier', + 194: 'Dandie Dinmont, Dandie Dinmont terrier', + 195: 'Boston bull, Boston terrier', + 196: 'miniature schnauzer', + 197: 'giant schnauzer', + 198: 'standard schnauzer', + 199: 'Scotch terrier, Scottish terrier, Scottie', + 200: 'Tibetan terrier, chrysanthemum dog', + 201: 'silky terrier, Sydney silky', + 202: 'soft-coated wheaten terrier', + 203: 'West Highland white terrier', + 204: 'Lhasa, Lhasa apso', + 205: 'flat-coated retriever', + 206: 'curly-coated retriever', + 207: 'golden retriever', + 208: 'Labrador retriever', + 209: 'Chesapeake Bay retriever', + 210: 'German short-haired pointer', + 211: 'vizsla, Hungarian pointer', + 212: 'English setter', + 213: 'Irish setter, red setter', + 214: 'Gordon setter', + 215: 'Brittany spaniel', + 216: 'clumber, clumber spaniel', + 217: 'English springer, English springer spaniel', + 218: 'Welsh springer spaniel', + 219: 'cocker spaniel, English cocker spaniel, cocker', + 220: 'Sussex spaniel', + 221: 'Irish water spaniel', + 222: 'kuvasz', + 223: 'schipperke', + 224: 'groenendael', + 225: 'malinois', + 226: 'briard', + 227: 'kelpie', + 228: 'komondor', + 229: 'Old English sheepdog, bobtail', + 230: 'Shetland sheepdog, Shetland sheep dog, Shetland', + 231: 'collie', + 232: 'Border collie', + 233: 'Bouvier des Flandres, Bouviers des Flandres', + 234: 'Rottweiler', + 235: 'German shepherd, German shepherd dog, German police dog, alsatian', + 236: 'Doberman, Doberman pinscher', + 237: 'miniature pinscher', + 238: 'Greater Swiss Mountain dog', + 239: 'Bernese mountain dog', + 240: 'Appenzeller', + 241: 'EntleBucher', + 242: 'boxer', + 243: 'bull mastiff', + 244: 'Tibetan mastiff', + 245: 'French bulldog', + 246: 'Great Dane', + 247: 'Saint Bernard, St Bernard', + 248: 'Eskimo dog, husky', + 249: 'malamute, malemute, Alaskan malamute', + 250: 'Siberian husky', + 251: 'dalmatian, coach dog, carriage dog', + 252: 'affenpinscher, monkey pinscher, monkey dog', + 253: 'basenji', + 254: 'pug, pug-dog', + 255: 'Leonberg', + 256: 'Newfoundland, Newfoundland dog', + 257: 'Great Pyrenees', + 258: 'Samoyed, Samoyede', + 259: 'Pomeranian', + 260: 'chow, chow chow', + 261: 'keeshond', + 262: 'Brabancon griffon', + 263: 'Pembroke, Pembroke Welsh corgi', + 264: 'Cardigan, Cardigan Welsh corgi', + 265: 'toy poodle', + 266: 'miniature poodle', + 267: 'standard poodle', + 268: 'Mexican hairless', + 269: 'timber wolf, grey wolf, gray wolf, Canis lupus', + 270: 'white wolf, Arctic wolf, Canis lupus tundrarum', + 271: 'red wolf, maned wolf, Canis rufus, Canis niger', + 272: 'coyote, prairie wolf, brush wolf, Canis latrans', + 273: 'dingo, warrigal, warragal, Canis dingo', + 274: 'dhole, Cuon alpinus', + 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus', + 276: 'hyena, hyaena', + 277: 'red fox, Vulpes vulpes', + 278: 'kit fox, Vulpes macrotis', + 279: 'Arctic fox, white fox, Alopex lagopus', + 280: 'grey fox, gray fox, Urocyon cinereoargenteus', + 281: 'tabby, tabby cat', + 282: 'tiger cat', + 283: 'Persian cat', + 284: 'Siamese cat, Siamese', + 285: 'Egyptian cat', + 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor', + 287: 'lynx, catamount', + 288: 'leopard, Panthera pardus', + 289: 'snow leopard, ounce, Panthera uncia', + 290: 'jaguar, panther, Panthera onca, Felis onca', + 291: 'lion, king of beasts, Panthera leo', + 292: 'tiger, Panthera tigris', + 293: 'cheetah, chetah, Acinonyx jubatus', + 294: 'brown bear, bruin, Ursus arctos', + 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus', + 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus', + 297: 'sloth bear, Melursus ursinus, Ursus ursinus', + 298: 'mongoose', + 299: 'meerkat, mierkat', + 300: 'tiger beetle', + 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle', + 302: 'ground beetle, carabid beetle', + 303: 'long-horned beetle, longicorn, longicorn beetle', + 304: 'leaf beetle, chrysomelid', + 305: 'dung beetle', + 306: 'rhinoceros beetle', + 307: 'weevil', + 308: 'fly', + 309: 'bee', + 310: 'ant, emmet, pismire', + 311: 'grasshopper, hopper', + 312: 'cricket', + 313: 'walking stick, walkingstick, stick insect', + 314: 'cockroach, roach', + 315: 'mantis, mantid', + 316: 'cicada, cicala', + 317: 'leafhopper', + 318: 'lacewing, lacewing fly', + 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", + 320: 'damselfly', + 321: 'admiral', + 322: 'ringlet, ringlet butterfly', + 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus', + 324: 'cabbage butterfly', + 325: 'sulphur butterfly, sulfur butterfly', + 326: 'lycaenid, lycaenid butterfly', + 327: 'starfish, sea star', + 328: 'sea urchin', + 329: 'sea cucumber, holothurian', + 330: 'wood rabbit, cottontail, cottontail rabbit', + 331: 'hare', + 332: 'Angora, Angora rabbit', + 333: 'hamster', + 334: 'porcupine, hedgehog', + 335: 'fox squirrel, eastern fox squirrel, Sciurus niger', + 336: 'marmot', + 337: 'beaver', + 338: 'guinea pig, Cavia cobaya', + 339: 'sorrel', + 340: 'zebra', + 341: 'hog, pig, grunter, squealer, Sus scrofa', + 342: 'wild boar, boar, Sus scrofa', + 343: 'warthog', + 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius', + 345: 'ox', + 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis', + 347: 'bison', + 348: 'ram, tup', + 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis', + 350: 'ibex, Capra ibex', + 351: 'hartebeest', + 352: 'impala, Aepyceros melampus', + 353: 'gazelle', + 354: 'Arabian camel, dromedary, Camelus dromedarius', + 355: 'llama', + 356: 'weasel', + 357: 'mink', + 358: 'polecat, fitch, foulmart, foumart, Mustela putorius', + 359: 'black-footed ferret, ferret, Mustela nigripes', + 360: 'otter', + 361: 'skunk, polecat, wood pussy', + 362: 'badger', + 363: 'armadillo', + 364: 'three-toed sloth, ai, Bradypus tridactylus', + 365: 'orangutan, orang, orangutang, Pongo pygmaeus', + 366: 'gorilla, Gorilla gorilla', + 367: 'chimpanzee, chimp, Pan troglodytes', + 368: 'gibbon, Hylobates lar', + 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus', + 370: 'guenon, guenon monkey', + 371: 'patas, hussar monkey, Erythrocebus patas', + 372: 'baboon', + 373: 'macaque', + 374: 'langur', + 375: 'colobus, colobus monkey', + 376: 'proboscis monkey, Nasalis larvatus', + 377: 'marmoset', + 378: 'capuchin, ringtail, Cebus capucinus', + 379: 'howler monkey, howler', + 380: 'titi, titi monkey', + 381: 'spider monkey, Ateles geoffroyi', + 382: 'squirrel monkey, Saimiri sciureus', + 383: 'Madagascar cat, ring-tailed lemur, Lemur catta', + 384: 'indri, indris, Indri indri, Indri brevicaudatus', + 385: 'Indian elephant, Elephas maximus', + 386: 'African elephant, Loxodonta africana', + 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens', + 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca', + 389: 'barracouta, snoek', + 390: 'eel', + 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch', + 392: 'rock beauty, Holocanthus tricolor', + 393: 'anemone fish', + 394: 'sturgeon', + 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus', + 396: 'lionfish', + 397: 'puffer, pufferfish, blowfish, globefish', + 398: 'abacus', + 399: 'abaya', + 400: "academic gown, academic robe, judge's robe", + 401: 'accordion, piano accordion, squeeze box', + 402: 'acoustic guitar', + 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier', + 404: 'airliner', + 405: 'airship, dirigible', + 406: 'altar', + 407: 'ambulance', + 408: 'amphibian, amphibious vehicle', + 409: 'analog clock', + 410: 'apiary, bee house', + 411: 'apron', + 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin', + 413: 'assault rifle, assault gun', + 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack', + 415: 'bakery, bakeshop, bakehouse', + 416: 'balance beam, beam', + 417: 'balloon', + 418: 'ballpoint, ballpoint pen, ballpen, Biro', + 419: 'Band Aid', + 420: 'banjo', + 421: 'bannister, banister, balustrade, balusters, handrail', + 422: 'barbell', + 423: 'barber chair', + 424: 'barbershop', + 425: 'barn', + 426: 'barometer', + 427: 'barrel, cask', + 428: 'barrow, garden cart, lawn cart, wheelbarrow', + 429: 'baseball', + 430: 'basketball', + 431: 'bassinet', + 432: 'bassoon', + 433: 'bathing cap, swimming cap', + 434: 'bath towel', + 435: 'bathtub, bathing tub, bath, tub', + 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon', + 437: 'beacon, lighthouse, beacon light, pharos', + 438: 'beaker', + 439: 'bearskin, busby, shako', + 440: 'beer bottle', + 441: 'beer glass', + 442: 'bell cote, bell cot', + 443: 'bib', + 444: 'bicycle-built-for-two, tandem bicycle, tandem', + 445: 'bikini, two-piece', + 446: 'binder, ring-binder', + 447: 'binoculars, field glasses, opera glasses', + 448: 'birdhouse', + 449: 'boathouse', + 450: 'bobsled, bobsleigh, bob', + 451: 'bolo tie, bolo, bola tie, bola', + 452: 'bonnet, poke bonnet', + 453: 'bookcase', + 454: 'bookshop, bookstore, bookstall', + 455: 'bottlecap', + 456: 'bow', + 457: 'bow tie, bow-tie, bowtie', + 458: 'brass, memorial tablet, plaque', + 459: 'brassiere, bra, bandeau', + 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty', + 461: 'breastplate, aegis, egis', + 462: 'broom', + 463: 'bucket, pail', + 464: 'buckle', + 465: 'bulletproof vest', + 466: 'bullet train, bullet', + 467: 'butcher shop, meat market', + 468: 'cab, hack, taxi, taxicab', + 469: 'caldron, cauldron', + 470: 'candle, taper, wax light', + 471: 'cannon', + 472: 'canoe', + 473: 'can opener, tin opener', + 474: 'cardigan', + 475: 'car mirror', + 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig', + 477: "carpenter's kit, tool kit", + 478: 'carton', + 479: 'car wheel', + 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM', + 481: 'cassette', + 482: 'cassette player', + 483: 'castle', + 484: 'catamaran', + 485: 'CD player', + 486: 'cello, violoncello', + 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone', + 488: 'chain', + 489: 'chainlink fence', + 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour', + 491: 'chain saw, chainsaw', + 492: 'chest', + 493: 'chiffonier, commode', + 494: 'chime, bell, gong', + 495: 'china cabinet, china closet', + 496: 'Christmas stocking', + 497: 'church, church building', + 498: 'cinema, movie theater, movie theatre, movie house, picture palace', + 499: 'cleaver, meat cleaver, chopper', + 500: 'cliff dwelling', + 501: 'cloak', + 502: 'clog, geta, patten, sabot', + 503: 'cocktail shaker', + 504: 'coffee mug', + 505: 'coffeepot', + 506: 'coil, spiral, volute, whorl, helix', + 507: 'combination lock', + 508: 'computer keyboard, keypad', + 509: 'confectionery, confectionary, candy store', + 510: 'container ship, containership, container vessel', + 511: 'convertible', + 512: 'corkscrew, bottle screw', + 513: 'cornet, horn, trumpet, trump', + 514: 'cowboy boot', + 515: 'cowboy hat, ten-gallon hat', + 516: 'cradle', + 517: 'crane', + 518: 'crash helmet', + 519: 'crate', + 520: 'crib, cot', + 521: 'Crock Pot', + 522: 'croquet ball', + 523: 'crutch', + 524: 'cuirass', + 525: 'dam, dike, dyke', + 526: 'desk', + 527: 'desktop computer', + 528: 'dial telephone, dial phone', + 529: 'diaper, nappy, napkin', + 530: 'digital clock', + 531: 'digital watch', + 532: 'dining table, board', + 533: 'dishrag, dishcloth', + 534: 'dishwasher, dish washer, dishwashing machine', + 535: 'disk brake, disc brake', + 536: 'dock, dockage, docking facility', + 537: 'dogsled, dog sled, dog sleigh', + 538: 'dome', + 539: 'doormat, welcome mat', + 540: 'drilling platform, offshore rig', + 541: 'drum, membranophone, tympan', + 542: 'drumstick', + 543: 'dumbbell', + 544: 'Dutch oven', + 545: 'electric fan, blower', + 546: 'electric guitar', + 547: 'electric locomotive', + 548: 'entertainment center', + 549: 'envelope', + 550: 'espresso maker', + 551: 'face powder', + 552: 'feather boa, boa', + 553: 'file, file cabinet, filing cabinet', + 554: 'fireboat', + 555: 'fire engine, fire truck', + 556: 'fire screen, fireguard', + 557: 'flagpole, flagstaff', + 558: 'flute, transverse flute', + 559: 'folding chair', + 560: 'football helmet', + 561: 'forklift', + 562: 'fountain', + 563: 'fountain pen', + 564: 'four-poster', + 565: 'freight car', + 566: 'French horn, horn', + 567: 'frying pan, frypan, skillet', + 568: 'fur coat', + 569: 'garbage truck, dustcart', + 570: 'gasmask, respirator, gas helmet', + 571: 'gas pump, gasoline pump, petrol pump, island dispenser', + 572: 'goblet', + 573: 'go-kart', + 574: 'golf ball', + 575: 'golfcart, golf cart', + 576: 'gondola', + 577: 'gong, tam-tam', + 578: 'gown', + 579: 'grand piano, grand', + 580: 'greenhouse, nursery, glasshouse', + 581: 'grille, radiator grille', + 582: 'grocery store, grocery, food market, market', + 583: 'guillotine', + 584: 'hair slide', + 585: 'hair spray', + 586: 'half track', + 587: 'hammer', + 588: 'hamper', + 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier', + 590: 'hand-held computer, hand-held microcomputer', + 591: 'handkerchief, hankie, hanky, hankey', + 592: 'hard disc, hard disk, fixed disk', + 593: 'harmonica, mouth organ, harp, mouth harp', + 594: 'harp', + 595: 'harvester, reaper', + 596: 'hatchet', + 597: 'holster', + 598: 'home theater, home theatre', + 599: 'honeycomb', + 600: 'hook, claw', + 601: 'hoopskirt, crinoline', + 602: 'horizontal bar, high bar', + 603: 'horse cart, horse-cart', + 604: 'hourglass', + 605: 'iPod', + 606: 'iron, smoothing iron', + 607: "jack-o'-lantern", + 608: 'jean, blue jean, denim', + 609: 'jeep, landrover', + 610: 'jersey, T-shirt, tee shirt', + 611: 'jigsaw puzzle', + 612: 'jinrikisha, ricksha, rickshaw', + 613: 'joystick', + 614: 'kimono', + 615: 'knee pad', + 616: 'knot', + 617: 'lab coat, laboratory coat', + 618: 'ladle', + 619: 'lampshade, lamp shade', + 620: 'laptop, laptop computer', + 621: 'lawn mower, mower', + 622: 'lens cap, lens cover', + 623: 'letter opener, paper knife, paperknife', + 624: 'library', + 625: 'lifeboat', + 626: 'lighter, light, igniter, ignitor', + 627: 'limousine, limo', + 628: 'liner, ocean liner', + 629: 'lipstick, lip rouge', + 630: 'Loafer', + 631: 'lotion', + 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system', + 633: "loupe, jeweler's loupe", + 634: 'lumbermill, sawmill', + 635: 'magnetic compass', + 636: 'mailbag, postbag', + 637: 'mailbox, letter box', + 638: 'maillot', + 639: 'maillot, tank suit', + 640: 'manhole cover', + 641: 'maraca', + 642: 'marimba, xylophone', + 643: 'mask', + 644: 'matchstick', + 645: 'maypole', + 646: 'maze, labyrinth', + 647: 'measuring cup', + 648: 'medicine chest, medicine cabinet', + 649: 'megalith, megalithic structure', + 650: 'microphone, mike', + 651: 'microwave, microwave oven', + 652: 'military uniform', + 653: 'milk can', + 654: 'minibus', + 655: 'miniskirt, mini', + 656: 'minivan', + 657: 'missile', + 658: 'mitten', + 659: 'mixing bowl', + 660: 'mobile home, manufactured home', + 661: 'Model T', + 662: 'modem', + 663: 'monastery', + 664: 'monitor', + 665: 'moped', + 666: 'mortar', + 667: 'mortarboard', + 668: 'mosque', + 669: 'mosquito net', + 670: 'motor scooter, scooter', + 671: 'mountain bike, all-terrain bike, off-roader', + 672: 'mountain tent', + 673: 'mouse, computer mouse', + 674: 'mousetrap', + 675: 'moving van', + 676: 'muzzle', + 677: 'nail', + 678: 'neck brace', + 679: 'necklace', + 680: 'nipple', + 681: 'notebook, notebook computer', + 682: 'obelisk', + 683: 'oboe, hautboy, hautbois', + 684: 'ocarina, sweet potato', + 685: 'odometer, hodometer, mileometer, milometer', + 686: 'oil filter', + 687: 'organ, pipe organ', + 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO', + 689: 'overskirt', + 690: 'oxcart', + 691: 'oxygen mask', + 692: 'packet', + 693: 'paddle, boat paddle', + 694: 'paddlewheel, paddle wheel', + 695: 'padlock', + 696: 'paintbrush', + 697: "pajama, pyjama, pj's, jammies", + 698: 'palace', + 699: 'panpipe, pandean pipe, syrinx', + 700: 'paper towel', + 701: 'parachute, chute', + 702: 'parallel bars, bars', + 703: 'park bench', + 704: 'parking meter', + 705: 'passenger car, coach, carriage', + 706: 'patio, terrace', + 707: 'pay-phone, pay-station', + 708: 'pedestal, plinth, footstall', + 709: 'pencil box, pencil case', + 710: 'pencil sharpener', + 711: 'perfume, essence', + 712: 'Petri dish', + 713: 'photocopier', + 714: 'pick, plectrum, plectron', + 715: 'pickelhaube', + 716: 'picket fence, paling', + 717: 'pickup, pickup truck', + 718: 'pier', + 719: 'piggy bank, penny bank', + 720: 'pill bottle', + 721: 'pillow', + 722: 'ping-pong ball', + 723: 'pinwheel', + 724: 'pirate, pirate ship', + 725: 'pitcher, ewer', + 726: "plane, carpenter's plane, woodworking plane", + 727: 'planetarium', + 728: 'plastic bag', + 729: 'plate rack', + 730: 'plow, plough', + 731: "plunger, plumber's helper", + 732: 'Polaroid camera, Polaroid Land camera', + 733: 'pole', + 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria', + 735: 'poncho', + 736: 'pool table, billiard table, snooker table', + 737: 'pop bottle, soda bottle', + 738: 'pot, flowerpot', + 739: "potter's wheel", + 740: 'power drill', + 741: 'prayer rug, prayer mat', + 742: 'printer', + 743: 'prison, prison house', + 744: 'projectile, missile', + 745: 'projector', + 746: 'puck, hockey puck', + 747: 'punching bag, punch bag, punching ball, punchball', + 748: 'purse', + 749: 'quill, quill pen', + 750: 'quilt, comforter, comfort, puff', + 751: 'racer, race car, racing car', + 752: 'racket, racquet', + 753: 'radiator', + 754: 'radio, wireless', + 755: 'radio telescope, radio reflector', + 756: 'rain barrel', + 757: 'recreational vehicle, RV, R.V.', + 758: 'reel', + 759: 'reflex camera', + 760: 'refrigerator, icebox', + 761: 'remote control, remote', + 762: 'restaurant, eating house, eating place, eatery', + 763: 'revolver, six-gun, six-shooter', + 764: 'rifle', + 765: 'rocking chair, rocker', + 766: 'rotisserie', + 767: 'rubber eraser, rubber, pencil eraser', + 768: 'rugby ball', + 769: 'rule, ruler', + 770: 'running shoe', + 771: 'safe', + 772: 'safety pin', + 773: 'saltshaker, salt shaker', + 774: 'sandal', + 775: 'sarong', + 776: 'sax, saxophone', + 777: 'scabbard', + 778: 'scale, weighing machine', + 779: 'school bus', + 780: 'schooner', + 781: 'scoreboard', + 782: 'screen, CRT screen', + 783: 'screw', + 784: 'screwdriver', + 785: 'seat belt, seatbelt', + 786: 'sewing machine', + 787: 'shield, buckler', + 788: 'shoe shop, shoe-shop, shoe store', + 789: 'shoji', + 790: 'shopping basket', + 791: 'shopping cart', + 792: 'shovel', + 793: 'shower cap', + 794: 'shower curtain', + 795: 'ski', + 796: 'ski mask', + 797: 'sleeping bag', + 798: 'slide rule, slipstick', + 799: 'sliding door', + 800: 'slot, one-armed bandit', + 801: 'snorkel', + 802: 'snowmobile', + 803: 'snowplow, snowplough', + 804: 'soap dispenser', + 805: 'soccer ball', + 806: 'sock', + 807: 'solar dish, solar collector, solar furnace', + 808: 'sombrero', + 809: 'soup bowl', + 810: 'space bar', + 811: 'space heater', + 812: 'space shuttle', + 813: 'spatula', + 814: 'speedboat', + 815: "spider web, spider's web", + 816: 'spindle', + 817: 'sports car, sport car', + 818: 'spotlight, spot', + 819: 'stage', + 820: 'steam locomotive', + 821: 'steel arch bridge', + 822: 'steel drum', + 823: 'stethoscope', + 824: 'stole', + 825: 'stone wall', + 826: 'stopwatch, stop watch', + 827: 'stove', + 828: 'strainer', + 829: 'streetcar, tram, tramcar, trolley, trolley car', + 830: 'stretcher', + 831: 'studio couch, day bed', + 832: 'stupa, tope', + 833: 'submarine, pigboat, sub, U-boat', + 834: 'suit, suit of clothes', + 835: 'sundial', + 836: 'sunglass', + 837: 'sunglasses, dark glasses, shades', + 838: 'sunscreen, sunblock, sun blocker', + 839: 'suspension bridge', + 840: 'swab, swob, mop', + 841: 'sweatshirt', + 842: 'swimming trunks, bathing trunks', + 843: 'swing', + 844: 'switch, electric switch, electrical switch', + 845: 'syringe', + 846: 'table lamp', + 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle', + 848: 'tape player', + 849: 'teapot', + 850: 'teddy, teddy bear', + 851: 'television, television system', + 852: 'tennis ball', + 853: 'thatch, thatched roof', + 854: 'theater curtain, theatre curtain', + 855: 'thimble', + 856: 'thresher, thrasher, threshing machine', + 857: 'throne', + 858: 'tile roof', + 859: 'toaster', + 860: 'tobacco shop, tobacconist shop, tobacconist', + 861: 'toilet seat', + 862: 'torch', + 863: 'totem pole', + 864: 'tow truck, tow car, wrecker', + 865: 'toyshop', + 866: 'tractor', + 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi', + 868: 'tray', + 869: 'trench coat', + 870: 'tricycle, trike, velocipede', + 871: 'trimaran', + 872: 'tripod', + 873: 'triumphal arch', + 874: 'trolleybus, trolley coach, trackless trolley', + 875: 'trombone', + 876: 'tub, vat', + 877: 'turnstile', + 878: 'typewriter keyboard', + 879: 'umbrella', + 880: 'unicycle, monocycle', + 881: 'upright, upright piano', + 882: 'vacuum, vacuum cleaner', + 883: 'vase', + 884: 'vault', + 885: 'velvet', + 886: 'vending machine', + 887: 'vestment', + 888: 'viaduct', + 889: 'violin, fiddle', + 890: 'volleyball', + 891: 'waffle iron', + 892: 'wall clock', + 893: 'wallet, billfold, notecase, pocketbook', + 894: 'wardrobe, closet, press', + 895: 'warplane, military plane', + 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin', + 897: 'washer, automatic washer, washing machine', + 898: 'water bottle', + 899: 'water jug', + 900: 'water tower', + 901: 'whiskey jug', + 902: 'whistle', + 903: 'wig', + 904: 'window screen', + 905: 'window shade', + 906: 'Windsor tie', + 907: 'wine bottle', + 908: 'wing', + 909: 'wok', + 910: 'wooden spoon', + 911: 'wool, woolen, woollen', + 912: 'worm fence, snake fence, snake-rail fence, Virginia fence', + 913: 'wreck', + 914: 'yawl', + 915: 'yurt', + 916: 'web site, website, internet site, site', + 917: 'comic book', + 918: 'crossword puzzle, crossword', + 919: 'street sign', + 920: 'traffic light, traffic signal, stoplight', + 921: 'book jacket, dust cover, dust jacket, dust wrapper', + 922: 'menu', + 923: 'plate', + 924: 'guacamole', + 925: 'consomme', + 926: 'hot pot, hotpot', + 927: 'trifle', + 928: 'ice cream, icecream', + 929: 'ice lolly, lolly, lollipop, popsicle', + 930: 'French loaf', + 931: 'bagel, beigel', + 932: 'pretzel', + 933: 'cheeseburger', + 934: 'hotdog, hot dog, red hot', + 935: 'mashed potato', + 936: 'head cabbage', + 937: 'broccoli', + 938: 'cauliflower', + 939: 'zucchini, courgette', + 940: 'spaghetti squash', + 941: 'acorn squash', + 942: 'butternut squash', + 943: 'cucumber, cuke', + 944: 'artichoke, globe artichoke', + 945: 'bell pepper', + 946: 'cardoon', + 947: 'mushroom', + 948: 'Granny Smith', + 949: 'strawberry', + 950: 'orange', + 951: 'lemon', + 952: 'fig', + 953: 'pineapple, ananas', + 954: 'banana', + 955: 'jackfruit, jak, jack', + 956: 'custard apple', + 957: 'pomegranate', + 958: 'hay', + 959: 'carbonara', + 960: 'chocolate sauce, chocolate syrup', + 961: 'dough', + 962: 'meat loaf, meatloaf', + 963: 'pizza, pizza pie', + 964: 'potpie', + 965: 'burrito', + 966: 'red wine', + 967: 'espresso', + 968: 'cup', + 969: 'eggnog', + 970: 'alp', + 971: 'bubble', + 972: 'cliff, drop, drop-off', + 973: 'coral reef', + 974: 'geyser', + 975: 'lakeside, lakeshore', + 976: 'promontory, headland, head, foreland', + 977: 'sandbar, sand bar', + 978: 'seashore, coast, seacoast, sea-coast', + 979: 'valley, vale', + 980: 'volcano', + 981: 'ballplayer, baseball player', + 982: 'groom, bridegroom', + 983: 'scuba diver', + 984: 'rapeseed', + 985: 'daisy', + 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", + 987: 'corn', + 988: 'acorn', + 989: 'hip, rose hip, rosehip', + 990: 'buckeye, horse chestnut, conker', + 991: 'coral fungus', + 992: 'agaric', + 993: 'gyromitra', + 994: 'stinkhorn, carrion fungus', + 995: 'earthstar', + 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa', + 997: 'bolete', + 998: 'ear, spike, capitulum', + 999: 'toilet tissue, toilet paper, bathroom tissue'} diff --git a/featup/datasets/JitteredImage.py b/featup/datasets/JitteredImage.py new file mode 100644 index 0000000000000000000000000000000000000000..34cc9acd0b8829c0e0ce4e07e2a0a9d49746ea84 --- /dev/null +++ b/featup/datasets/JitteredImage.py @@ -0,0 +1,69 @@ +import random + +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset + + +def apply_jitter(img, max_pad, transform_params): + h, w = img.shape[2:] + + padded = F.pad(img, [max_pad] * 4, mode="reflect") + + zoom = transform_params["zoom"].item() + x = transform_params["x"].item() + y = transform_params["y"].item() + flip = transform_params["flip"].item() + + if zoom > 1.0: + zoomed = F.interpolate(padded, scale_factor=zoom, mode="bilinear") + else: + zoomed = padded + + cropped = zoomed[:, :, x:h + x, y:w + y] + + if flip: + return torch.flip(cropped, [3]) + else: + return cropped + + +def sample_transform(use_flips, max_pad, max_zoom, h, w): + if use_flips: + flip = random.random() > .5 + else: + flip = False + + apply_zoom = random.random() > .5 + if apply_zoom: + zoom = random.random() * (max_zoom - 1) + 1 + else: + zoom = 1.0 + + valid_area_h = (int((h + max_pad * 2) * zoom) - h) + 1 + valid_area_w = (int((w + max_pad * 2) * zoom) - w) + 1 + + return { + "x": torch.tensor(torch.randint(0, valid_area_h, ()).item()), + "y": torch.tensor(torch.randint(0, valid_area_w, ()).item()), + "zoom": torch.tensor(zoom), + "flip": torch.tensor(flip) + } + + +class JitteredImage(Dataset): + + def __init__(self, img, length, use_flips, max_zoom, max_pad): + self.img = img + self.length = length + self.use_flips = use_flips + self.max_zoom = max_zoom + self.max_pad = max_pad + + def __len__(self): + return self.length + + def __getitem__(self, item): + h, w = self.img.shape[2:] + transform_params = sample_transform(self.use_flips, self.max_pad, self.max_zoom, h, w) + return apply_jitter(self.img, self.max_pad, transform_params).squeeze(0), transform_params diff --git a/featup/datasets/SampleImage.py b/featup/datasets/SampleImage.py new file mode 100644 index 0000000000000000000000000000000000000000..1baea41cbea1a7ec6275e633b9b8fc96c60c1aee --- /dev/null +++ b/featup/datasets/SampleImage.py @@ -0,0 +1,22 @@ +from PIL import Image +from torch.utils.data import Dataset + + +class SampleImage(Dataset): + def __init__(self, paths, transform, **kwargs): + self.paths = paths + self.transform = transform + + def __getitem__(self, idx): + image_path = self.paths[idx] + image = Image.open(image_path).convert('RGB') + if self.transform is not None: + image = self.transform(image) + batch = { + "img": image, + "img_path": image_path + } + return batch + + def __len__(self): + return len(self.paths) diff --git a/featup/datasets/__init__.py b/featup/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/featup/datasets/util.py b/featup/datasets/util.py new file mode 100644 index 0000000000000000000000000000000000000000..e406d4a5ad8c3052dec911ecb5403e145d0888b6 --- /dev/null +++ b/featup/datasets/util.py @@ -0,0 +1,58 @@ +from torch.utils.data import Dataset +from featup.datasets.ImageNetSubset import ImageNetSubset +from featup.datasets.COCO import Coco +from featup.datasets.DAVIS import DAVIS +from featup.datasets.SampleImage import SampleImage + + +class SlicedDataset(Dataset): + def __init__(self, ds, start, end): + self.ds = ds + self.start = max(0, start) + self.end = min(len(ds), end) + + def __getitem__(self, index): + if index >= self.__len__(): + raise StopIteration + + return self.ds[self.start + index] + + def __len__(self): + return self.end - self.start + + + +class SingleImageDataset(Dataset): + def __init__(self, i, ds, l=None): + self.ds = ds + self.i = i + self.l = len(self.ds) if l is None else l + + def __len__(self): + return self.l + + def __getitem__(self, item): + return self.ds[self.i] + + +def get_dataset(dataroot, name, split, transform, target_transform, include_labels): + if name == 'imagenet': + if split == 'val': + imagenet_subset = f'datalists/val_paths_vit.txt' + else: + imagenet_subset = None + + return ImageNetSubset(dataroot, split, transform, target_transform, + include_labels=include_labels, subset=imagenet_subset) + elif name == 'cocostuff': + return Coco(dataroot, split, transform, target_transform, include_labels=include_labels) + elif name.startswith('davis_'): + return DAVIS(dataroot, name.split("_")[-1], transform) + elif name == "sample": + return SampleImage( + paths=["../sample-images/bird_left.jpg", + "../sample-images/bird_right.jpg"], + transform=transform + ) + else: + raise ValueError(f"Unknown dataset {name}") diff --git a/featup/downsamplers.py b/featup/downsamplers.py new file mode 100644 index 0000000000000000000000000000000000000000..f5eae850b6bcf5f927b437819bbfdf9a722ac320 --- /dev/null +++ b/featup/downsamplers.py @@ -0,0 +1,79 @@ +import torch +import torch.nn.functional as F +from kornia.filters import gaussian_blur2d + + +class SimpleDownsampler(torch.nn.Module): + + def get_kernel(self): + k = self.kernel_params.unsqueeze(0).unsqueeze(0).abs() + k /= k.sum() + return k + + def __init__(self, kernel_size, final_size, *args, **kwargs): + super().__init__(*args, **kwargs) + self.kernel_size = kernel_size + self.final_size = final_size + self.kernel_params = torch.nn.Parameter(torch.ones(kernel_size, kernel_size)) + + def forward(self, imgs, guidance): + b, c, h, w = imgs.shape + input_imgs = imgs.reshape(b * c, 1, h, w) + stride = (h - self.kernel_size) // (self.final_size - 1) + + return F.conv2d( + input_imgs, + self.get_kernel(), + stride=stride + ).reshape(b, c, self.final_size, self.final_size) + + +class AttentionDownsampler(torch.nn.Module): + + def __init__(self, dim, kernel_size, final_size, blur_attn, *args, **kwargs): + super().__init__(*args, **kwargs) + self.kernel_size = kernel_size + self.final_size = final_size + self.in_dim = dim + self.attention_net = torch.nn.Sequential( + torch.nn.Dropout(p=.2), + torch.nn.Linear(self.in_dim, 1) + ) + self.w = torch.nn.Parameter(torch.ones(kernel_size, kernel_size).cuda() + + .01 * torch.randn(kernel_size, kernel_size).cuda()) + self.b = torch.nn.Parameter(torch.zeros(kernel_size, kernel_size).cuda() + + .01 * torch.randn(kernel_size, kernel_size).cuda()) + self.blur_attn = blur_attn + + def forward_attention(self, feats, guidance): + return self.attention_net(feats.permute(0, 2, 3, 1)).squeeze(-1).unsqueeze(1) + + def forward(self, hr_feats, guidance): + b, c, h, w = hr_feats.shape + + if self.blur_attn: + inputs = gaussian_blur2d(hr_feats, 5, (1.0, 1.0)) + else: + inputs = hr_feats + + stride = (h - self.kernel_size) // (self.final_size - 1) + + patches = torch.nn.Unfold(self.kernel_size, stride=stride)(inputs) \ + .reshape( + (b, self.in_dim, self.kernel_size * self.kernel_size, self.final_size, self.final_size * int(w / h))) \ + .permute(0, 3, 4, 2, 1) + + patch_logits = self.attention_net(patches).squeeze(-1) + + b, h, w, p = patch_logits.shape + dropout = torch.rand(b, h, w, 1, device=patch_logits.device) > 0.2 + + w = self.w.flatten().reshape(1, 1, 1, -1) + b = self.b.flatten().reshape(1, 1, 1, -1) + + patch_attn_logits = (patch_logits * dropout) * w + b + patch_attention = F.softmax(patch_attn_logits, dim=-1) + + downsampled = torch.einsum("bhwpc,bhwp->bchw", patches, patch_attention) + + return downsampled[:, :c, :, :] diff --git a/featup/featurizers/CLIP.py b/featup/featurizers/CLIP.py new file mode 100644 index 0000000000000000000000000000000000000000..7c94572fe50ef872ea55da88868432832bf8bc0f --- /dev/null +++ b/featup/featurizers/CLIP.py @@ -0,0 +1,44 @@ +import clip +import torch +from torch import nn +import os + +class CLIPFeaturizer(nn.Module): + + def __init__(self): + super().__init__() + self.model, self.preprocess = clip.load( + "ViT-B/16", + download_root=os.getenv('TORCH_HOME', os.path.join(os.path.expanduser('~'), '.cache', 'torch')) + ) + self.model.eval() + + def get_cls_token(self, img): + return self.model.encode_image(img).to(torch.float32) + + def forward(self, img): + features = self.model.get_visual_features(img, include_cls=False).to(torch.float32) + return features + + +if __name__ == "__main__": + import torchvision.transforms as T + from PIL import Image + from shared import norm, crop_to_divisor + + device = "cuda" if torch.cuda.is_available() else "cpu" + + image = Image.open("../samples/lex1.jpg") + load_size = 224 # * 3 + transform = T.Compose([ + T.Resize(load_size, Image.BILINEAR), + # T.CenterCrop(load_size), + T.ToTensor(), + lambda x: crop_to_divisor(x, 16), + norm]) + + model = CLIPFeaturizer().cuda() + + results = model(transform(image).cuda().unsqueeze(0)) + + print(clip.available_models()) diff --git a/featup/featurizers/DINO.py b/featup/featurizers/DINO.py new file mode 100644 index 0000000000000000000000000000000000000000..f5f526f61c24ac25925b606859ba825f76b66981 --- /dev/null +++ b/featup/featurizers/DINO.py @@ -0,0 +1,448 @@ +import math +import warnings +from functools import partial + +import timm +import torch +import torch.nn as nn + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, return_qkv=False): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn, qkv + + +class Block(nn.Module): + def __init__(self, dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, return_attention=False, return_qkv=False): + y, attn, qkv = self.attn(self.norm1(x)) + if return_attention: + return attn + x = x + self.drop_path(y) + x = x + self.drop_path(self.mlp(self.norm2(x))) + if return_qkv: + return x, attn, qkv + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + num_patches = (img_size // patch_size) * (img_size // patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x) + if x.shape[-2] % 2 == 1: + x = x[:, :, :-1, :-1] + return x.flatten(2).transpose(1, 2) + + +class VisionTransformer(nn.Module): + """ Vision Transformer """ + + def __init__(self, + img_size=[224], + patch_size=16, + in_chans=3, + num_classes=0, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer=nn.LayerNorm, + **kwargs): + super().__init__() + + self.num_features = self.embed_dim = embed_dim + + self.patch_embed = PatchEmbed( + img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Classifier head + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def interpolate_pos_encoding(self, x, w, h): + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + class_pos_embed = self.pos_embed[:, 0] + patch_pos_embed = self.pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_embed.patch_size + h0 = h // self.patch_embed.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode='bicubic', + ) + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def prepare_tokens(self, x): + B, nc, w, h = x.shape + x = self.patch_embed(x) # patch linear embedding + + # add the [CLS] token to the embed patch tokens + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # add positional encoding to each token + x = x + self.interpolate_pos_encoding(x, w, h) + + return self.pos_drop(x) + + def forward(self, x): + x = self.prepare_tokens(x) + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + return x[:, 0] + + def forward_feats(self, x): + x = self.prepare_tokens(x) + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + return x + + def get_intermediate_feat(self, x, n=1, norm=True): + x = self.prepare_tokens(x) + # we return the output tokens from the `n` last blocks + feat = [] + attns = [] + qkvs = [] + for i, blk in enumerate(self.blocks): + x, attn, qkv = blk(x, return_qkv=True) + if len(self.blocks) - i <= n: + if norm: + feat.append(self.norm(x)) + else: + feat.append(x) + qkvs.append(qkv) + attns.append(attn) + return feat, attns, qkvs + + def get_last_selfattention(self, x): + x = self.prepare_tokens(x) + for i, blk in enumerate(self.blocks): + if i < len(self.blocks) - 1: + x = blk(x) + else: + # return attention of the last block + return blk(x, return_attention=True) + + def get_intermediate_layers(self, x, n=1): + x = self.prepare_tokens(x) + # we return the output tokens from the `n` last blocks + output = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if len(self.blocks) - i <= n: + output.append(self.norm(x)) + return output + + +def vit_tiny(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_small(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_base(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +class DINOHead(nn.Module): + def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, + bottleneck_dim=256): + super().__init__() + nlayers = max(nlayers, 1) + if nlayers == 1: + self.mlp = nn.Linear(in_dim, bottleneck_dim) + else: + layers = [nn.Linear(in_dim, hidden_dim)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim)) + self.mlp = nn.Sequential(*layers) + self.apply(self._init_weights) + self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + if norm_last_layer: + self.last_layer.weight_g.requires_grad = False + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + x = nn.functional.normalize(x, dim=-1, p=2) + x = self.last_layer(x) + return x + + + +class DINOFeaturizer(nn.Module): + + def __init__(self, arch, patch_size, feat_type): + super().__init__() + self.arch = arch + self.patch_size = patch_size + self.feat_type = feat_type + + self.model = vit_small( + patch_size=patch_size, + num_classes=0) + + if "3d-dino" in arch: + state_dict = torch.load("../models/3d-dino-co3d.pth")["teacher"] + state_dict = {k.replace("module.", "").replace("backbone.", ""): v for k, v in state_dict.items()} + state_dict = {k: v for k, v in state_dict.items() if "head." not in k} + elif "iarpa-dino" in arch: + state_dict = torch.load("../models/dino_iarpa.pth")["teacher"] + state_dict = {k.replace("module.", "").replace("backbone.", ""): v for k, v in state_dict.items()} + state_dict = {k: v for k, v in state_dict.items() if "head." not in k} + elif "chk-dino" in arch: + state_dict = torch.load("../models/dino_deitsmall16_pretrain_full_checkpoint.pth")["teacher"] + state_dict = {k.replace("module.", "").replace("backbone.", ""): v for k, v in state_dict.items()} + state_dict = {k: v for k, v in state_dict.items() if "head." not in k} + elif "ft_dino" in arch: + arch = "_".join(arch.split("_")[:-1]) + state_dict = torch.load("../models/{}.pth".format(arch))["teacher"] + state_dict = {k.replace("module.", "").replace("backbone.", ""): v for k, v in state_dict.items()} + state_dict = {k: v for k, v in state_dict.items() if "head." not in k} + # elif "v2" in arch: + # state_dict = torch.hub.load('facebookresearch/dinov2:main', self.arch).state_dict() + elif "dino" in arch: + state_dict = torch.hub.load('facebookresearch/dino:main', self.arch).state_dict() + elif arch is not None: # model from timm -- load weights from timm to dino model (enables working on arbitrary size images). + temp_model = timm.create_model(self.arch, pretrained=True) + state_dict = temp_model.state_dict() + del state_dict['head.weight'] + del state_dict['head.bias'] + + if arch is not None: + self.model.load_state_dict(state_dict, strict=True) + + if arch == "vit_small": + self.n_feats = 384 + else: + self.n_feats = 768 + + def get_cls_token(self, img): + return self.model.forward(img) + + def forward(self, img, n=1, include_cls=False): + assert (img.shape[2] % self.patch_size == 0) + assert (img.shape[3] % self.patch_size == 0) + + feat, attn, qkv = self.model.get_intermediate_feat(img, n=n) + feat, attn, qkv = feat[0], attn[0], qkv[0] + + feat_h = img.shape[2] // self.patch_size + feat_w = img.shape[3] // self.patch_size + + if self.feat_type == "token": + image_feat = feat[:, 1:, :].reshape(feat.shape[0], feat_h, feat_w, -1).permute(0, 3, 1, 2) + elif self.feat_type == "key": + x = qkv[1, :, :, 1:, :] # remove cls token + desc = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1) + image_feat = desc.reshape(desc.shape[0], feat_h, feat_w, desc.shape[2]) \ + .permute(0, 3, 1, 2) + else: + raise ValueError("Unknown feat type:{}".format(self.feat_type)) + + if include_cls: + return image_feat, feat[:, 0, :] + + return image_feat + diff --git a/featup/featurizers/DINOv2.py b/featup/featurizers/DINOv2.py new file mode 100644 index 0000000000000000000000000000000000000000..483f0e4991c19f34ced9cf220076f6b00b76d8a9 --- /dev/null +++ b/featup/featurizers/DINOv2.py @@ -0,0 +1,436 @@ +import math +import warnings +from functools import partial + +import timm +import torch +import torch.nn as nn + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from featup.featurizers.dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=NestedTensorBlock, + ffn_layer="mlp", + block_chunks=1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + + def get_intermediate_feat(self, x, n=1, norm=True): + x = self.prepare_tokens_with_masks(x) + # we return the output tokens from the `n` last blocks + feat = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if len(self.blocks) - i <= n: + if norm: + feat.append(self.norm(x)) + else: + feat.append(x) + return feat + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode="bicubic", + ) + + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_patchtokens": x_norm[:, 1:], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_patchtokens": x_norm[:, 1:], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_base(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_large(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +class DINOv2Featurizer(nn.Module): + + def __init__(self, arch, patch_size, feat_type): + super().__init__() + self.arch = arch + self.patch_size = patch_size + self.feat_type = feat_type + + self.n_feats = 128 + self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') + + def get_cls_token(self, img): + return self.model.forward(img) + + def forward(self, img, n=1, include_cls=False): + h = img.shape[2] // self.patch_size + w = img.shape[3] // self.patch_size + return self.model.forward_features(img)["x_norm_patchtokens"].reshape(-1, h, w, 384).permute(0, 3, 1, 2) diff --git a/featup/featurizers/DeepLabV3.py b/featup/featurizers/DeepLabV3.py new file mode 100644 index 0000000000000000000000000000000000000000..20d80e82c214ff8c7620d7e340c26312689577c7 --- /dev/null +++ b/featup/featurizers/DeepLabV3.py @@ -0,0 +1,13 @@ +from torch import nn + + +class DeepLabV3Featurizer(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def get_cls_token(self, img): + return self.model.forward(img) + + def forward(self, img, layer_num=-1): + return self.model.backbone(img)['out'] diff --git a/featup/featurizers/MAE.py b/featup/featurizers/MAE.py new file mode 100644 index 0000000000000000000000000000000000000000..2bae45464923a05e79923445caa820e3e7b13c82 --- /dev/null +++ b/featup/featurizers/MAE.py @@ -0,0 +1,473 @@ +from functools import partial + +import numpy as np +import torch +import torch.nn as nn +import os +from timm.models.vision_transformer import Block +import torch.nn.functional as F + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + # assert H == self.img_size[0] and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float) + omega /= embed_dim / 2. + omega = 1. / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed + + +def sample(t: torch.Tensor, coords: torch.Tensor): + return F.grid_sample(t, coords.permute(0, 2, 1, 3), padding_mode='border', align_corners=True) + + +class MaskedAutoencoderViT(nn.Module): + """ Masked Autoencoder with VisionTransformer backbone + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, + embed_dim=1024, depth=24, num_heads=16, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, + mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): + super().__init__() + + # -------------------------------------------------------------------------- + # MAE encoder specifics + self.embed_dim = embed_dim + self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), + requires_grad=False) # fixed sin-cos embedding + + self.blocks = nn.ModuleList([ + Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + # -------------------------------------------------------------------------- + + # -------------------------------------------------------------------------- + # MAE decoder specifics + self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) + + self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), + requires_grad=False) # fixed sin-cos embedding + + self.decoder_blocks = nn.ModuleList([ + Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) + for i in range(decoder_depth)]) + + self.decoder_norm = norm_layer(decoder_embed_dim) + self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) # decoder to patch + # -------------------------------------------------------------------------- + + self.norm_pix_loss = norm_pix_loss + + self.initialize_weights() + + def initialize_weights(self): + # initialization + # initialize (and freeze) pos_embed by sin-cos embedding + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5), + cls_token=True) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], + int(self.patch_embed.num_patches ** .5), cls_token=True) + self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) + + # initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) + torch.nn.init.normal_(self.cls_token, std=.02) + torch.nn.init.normal_(self.mask_token, std=.02) + + # initialize nn.Linear and nn.LayerNorm + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def patchify(self, imgs): + """ + imgs: (N, 3, H, W) + x: (N, L, patch_size**2 *3) + """ + p = self.patch_embed.patch_size[0] + assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = torch.einsum('nchpwq->nhwpqc', x) + x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) + return x + + def unpatchify(self, x): + """ + x: (N, L, patch_size**2 *3) + imgs: (N, 3, H, W) + """ + p = self.patch_embed.patch_size[0] + h = w = int(x.shape[1] ** .5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + return imgs + + def random_masking(self, x, mask_ratio): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore + + def sample_pe(self, img, pe): + p = self.patch_embed.patch_size[0] + + H = img.shape[2] // p + W = img.shape[3] // p + + original_num_patches = 224 // p + embed_dim = pe.shape[-1] + + reshaped_pe = pe.squeeze(0)[1:] \ + .reshape(1, original_num_patches, original_num_patches, embed_dim) \ + .permute(0, 3, 1, 2) + + XX, YY = torch.meshgrid(torch.linspace(-1, 1, H, device=img.device, dtype=img.dtype), + torch.linspace(-1, 1, W, device=img.device, dtype=img.dtype)) + + coords = torch.cat([XX.unsqueeze(-1), YY.unsqueeze(-1)], dim=-1).unsqueeze(0) + + return sample(reshaped_pe, coords).reshape(embed_dim, H * W).permute(1, 0).unsqueeze(0) + + def featurize(self, img, n_decoder_blocks=None): + p = self.patch_embed.patch_size[0] + H = img.shape[2] // p + W = img.shape[3] // p + + # embed patches + x = self.patch_embed(img) + + # add pos embed w/o cls token + x = x + self.sample_pe(img, self.pos_embed) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + + + # embed tokens + #x = self.decoder_embed(x) + # + # # add pos embed + # cls_token = x[:, :1] + self.decoder_pos_embed[0, :1] + # x = x[:, 1:] + self.sample_pe(img, self.decoder_pos_embed) + # x = torch.cat((cls_token, x), dim=1) + + # apply Transformer blocks + + # if n_decoder_blocks == "all": + # for blk in self.decoder_blocks: + # x = blk(x) + # x = self.decoder_norm(x) + # else: + # for blk in self.decoder_blocks[:7]: + # x = blk(x) + + # # predictor projection + # x = self.decoder_pred(x) + + # # remove cls token + # x = x[:, 1:, :] + # + # return x + + return x[:, 1:, :].reshape(shape=(x.shape[0], H, W, -1)) \ + .permute(0, 3, 1, 2), x[:, 0, :] + + def forward_encoder(self, img, mask_ratio): + # embed patches + x = self.patch_embed(img) + + # add pos embed w/o cls token + x = x + self.sample_pe(img, self.pos_embed) + + + # masking: length -> length * mask_ratio + x, mask, ids_restore = self.random_masking(x, mask_ratio) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + + return x, mask, ids_restore + + def forward_decoder(self, x, ids_restore, img): + # embed tokens + x = self.decoder_embed(x) + + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) + x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token + x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle + x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token + + # # add pos embed + # x = x + self.decoder_pos_embed + + # add pos embed + cls_token = x[:, :1] + self.decoder_pos_embed[0, :1] + x = x[:, 1:] + self.sample_pe(img, self.decoder_pos_embed) + x = torch.cat((cls_token, x), dim=1) + print("foo") + + # apply Transformer blocks + for blk in self.decoder_blocks: + x = blk(x) + x = self.decoder_norm(x) + + # predictor projection + x = self.decoder_pred(x) + + # remove cls token + x = x[:, 1:, :] + + return x + + def forward_loss(self, imgs, pred, mask): + """ + imgs: [N, 3, H, W] + pred: [N, L, p*p*3] + mask: [N, L], 0 is keep, 1 is remove, + """ + target = self.patchify(imgs) + if self.norm_pix_loss: + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.e-6) ** .5 + + loss = (pred - target) ** 2 + loss = loss.mean(dim=-1) # [N, L], mean loss per patch + + loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches + return loss + + def forward(self, imgs, mask_ratio=0.75): + latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) + pred = self.forward_decoder(latent, ids_restore, imgs) # [N, L, p*p*3] + loss = self.forward_loss(imgs, pred, mask) + return loss, pred, mask + + +class MAEFeaturizer(nn.Module): + + def __init__(self, arch="mae_vit_large_patch16_gan"): + super().__init__() + # build model + shared_args = dict( + decoder_embed_dim=512, + decoder_depth=8, + decoder_num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6) + ) + if arch == "mae_vit_base_patch16": + self.model = MaskedAutoencoderViT( + patch_size=16, embed_dim=768, depth=12, num_heads=12, **shared_args) + chkpoint_dir = '../models/mae_visualize_vit_base.pth' + elif arch == "mae_vit_large_patch16": + self.model = MaskedAutoencoderViT( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, **shared_args) + chkpoint_dir = '../models/mae_visualize_vit_large.pth' + elif arch == "mae_vit_large_patch16_gan": + self.model = MaskedAutoencoderViT( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, **shared_args) + chkpoint_dir = '../models/mae_visualize_vit_large_ganloss.pth' + elif arch == "mae_vit_huge_patch14": + self.model = MaskedAutoencoderViT( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, **shared_args) + chkpoint_dir = '../models/mae_visualize_vit_huge.pth' + else: + raise ValueError("Unknown model arch {}".format(arch)) + + # load model + chkpoint_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), chkpoint_dir) + + checkpoint = torch.load(chkpoint_dir) + self.model.load_state_dict(checkpoint['model'], strict=False) + + def get_cls_token(self, img): + feats, cls_token = self.model.featurize(img) + return cls_token + + def forward(self, img): + feats, cls_token = self.model.featurize(img) + return feats + + +if __name__ == "__main__": + import torchvision.transforms as T + from PIL import Image + from shared import norm, crop_to_divisor + + device = "cuda" if torch.cuda.is_available() else "cpu" + + image = Image.open("../samples/lex1.jpg") + load_size = 224 # * 3 + transform = T.Compose([ + T.Resize(load_size, Image.BILINEAR), + # T.CenterCrop(load_size), + T.ToTensor(), + lambda x: crop_to_divisor(x, 16), + norm]) + + model = MAEFeaturizer().cuda() + + results = model(transform(image).cuda().unsqueeze(0)) + + print(results.shape) diff --git a/featup/featurizers/MIDAS.py b/featup/featurizers/MIDAS.py new file mode 100644 index 0000000000000000000000000000000000000000..f587fc132c1ffe067bbe10f863a537e29c711a22 --- /dev/null +++ b/featup/featurizers/MIDAS.py @@ -0,0 +1,569 @@ +import timm +import torch +import torch.nn as nn +import torchvision.transforms as T +from PIL import Image +from timm.models.layers import get_act_layer +import numpy as np +from torch.nn.functional import interpolate +import os + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) + + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, + use_vit_only=False, use_readout="ignore", in_features=[96, 256, 512, 1024]): + if backbone == "levit_384": + pretrained = _make_pretrained_levit_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [384, 512, 768], features, groups=groups, expand=expand + ) # LeViT 384 (backbone) + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups = 1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups = 1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size = size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate( + output, **modifier, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + + +def forward_levit(pretrained, x): + pretrained.model.forward_features(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + + layer_1 = pretrained.act_postprocess1(layer_1) + layer_2 = pretrained.act_postprocess2(layer_2) + layer_3 = pretrained.act_postprocess3(layer_3) + + return layer_1, layer_2, layer_3 + + +def _make_levit_backbone( + model, + hooks=[3, 11, 21], + patch_grid=[14, 14] +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + + pretrained.activations = activations + + patch_grid_size = np.array(patch_grid, dtype=int) + + pretrained.act_postprocess1 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) + ) + pretrained.act_postprocess2 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist())) + ) + pretrained.act_postprocess3 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist())) + ) + + return pretrained + + +class ConvTransposeNorm(nn.Sequential): + """ + Modification of + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm + such that ConvTranspose2d is used instead of Conv2d. + """ + + def __init__( + self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1, + groups=1, bn_weight_init=1): + super().__init__() + self.add_module('c', + nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False)) + self.add_module('bn', nn.BatchNorm2d(out_chs)) + + nn.init.constant_(self.bn.weight, bn_weight_init) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = nn.ConvTranspose2d( + w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, + padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +def stem_b4_transpose(in_chs, out_chs, activation): + """ + Modification of + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16 + such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half. + """ + return nn.Sequential( + ConvTransposeNorm(in_chs, out_chs, 3, 2, 1), + activation(), + ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1), + activation()) + + +def _make_pretrained_levit_384(pretrained, hooks=None): + model = timm.create_model("levit_384", pretrained=pretrained) + + hooks = [3, 11, 21] if hooks == None else hooks + return _make_levit_backbone( + model, + hooks=hooks + ) + + +def _make_fusion_block(features, use_bn, size=None): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + **kwargs + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + # For the Swin, Swin 2, LeViT and Next-ViT Transformers, the hierarchical architectures prevent setting the + # hooks freely. Instead, the hooks have to be chosen according to the ranges specified in the comments. + hooks = { + "beitl16_512": [5, 11, 17, 23], + "beitl16_384": [5, 11, 17, 23], + "beitb16_384": [2, 5, 8, 11], + "swin2l24_384": [1, 1, 17, 1], # Allowed ranges: [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "swin2b24_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "swin2t16_256": [1, 1, 5, 1], # [0, 1], [0, 1], [ 0, 5], [ 0, 1] + "swinl12_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "next_vit_large_6m": [2, 6, 36, 39], # [0, 2], [3, 6], [ 7, 36], [37, 39] + "levit_384": [3, 11, 21], # [0, 3], [6, 11], [14, 21] + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + }[backbone] + + if "next_vit" in backbone: + in_features = { + "next_vit_large_6m": [96, 256, 512, 1024], + }[backbone] + else: + in_features = None + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks, + use_readout=readout, + in_features=in_features, + ) + + self.number_layers = len(hooks) if hooks is not None else 4 + self.scratch.stem_transpose = None + + self.forward_transformer = forward_levit + size_refinenet3 = 7 + self.scratch.stem_transpose = stem_b4_transpose(256, 128, get_act_layer("hard_swish")) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn, size_refinenet3) + if self.number_layers >= 4: + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + def forward_features(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layers = self.forward_transformer(self.pretrained, x) + if self.number_layers == 3: + layer_1, layer_2, layer_3 = layers + else: + layer_1, layer_2, layer_3, layer_4 = layers + + all_feats = [] + target_size = layer_1.shape[2:] + + def prep(l): + if target_size != l.shape[2:]: + l = interpolate(l, size=target_size, mode="bilinear") + return l + + all_feats.append(prep(self.scratch.layer1_rn(layer_1))) + all_feats.append(prep(self.scratch.layer2_rn(layer_2))) + all_feats.append(prep(self.scratch.layer3_rn(layer_3))) + if self.number_layers >= 4: + all_feats.append(prep(self.scratch.layer4_rn(layer_4))) + return torch.cat([f for f in all_feats], dim=1) + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layers = self.forward_transformer(self.pretrained, x) + if self.number_layers == 3: + layer_1, layer_2, layer_3 = layers + else: + layer_1, layer_2, layer_3, layer_4 = layers + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + if self.number_layers >= 4: + layer_4_rn = self.scratch.layer4_rn(layer_4) + + if self.number_layers == 3: + path_3 = self.scratch.refinenet3(layer_3_rn, size=layer_2_rn.shape[2:]) + else: + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + if self.scratch.stem_transpose is not None: + path_1 = self.scratch.stem_transpose(path_1) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + head_features_1 = kwargs["head_features_1"] if "head_features_1" in kwargs else features + head_features_2 = kwargs["head_features_2"] if "head_features_2" in kwargs else 32 + kwargs.pop("head_features_1", None) + kwargs.pop("head_features_2", None) + + head = nn.Sequential( + nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) + + def forward_features(self, x): + return super().forward_features(x).squeeze(dim=1) + + +class MIDASFeaturizer(nn.Module): + + def __init__(self, output_root): + super().__init__() + self.model = DPTDepthModel( + path=os.path.join(output_root, 'models/dpt_levit_224.pt'), + backbone="levit_384", + non_negative=True, + head_features_1=64, + head_features_2=8, + ) + + def get_cls_token(self, img): + return None + + def forward(self, img): + feats = self.model.forward_features(img) + return feats + + +if __name__ == "__main__": + DPTDepthModel( + path='../../models/dpt_levit_224.pt', + backbone="levit_384", + non_negative=True, + head_features_1=64, + head_features_2=8, + ).cuda() + + image = Image.open("../../sample-images/car.jpg").convert("RGB") + + input_size = 224 + + transform = T.Compose([ + T.Resize(input_size), + T.CenterCrop(input_size), + T.ToTensor(), + T.Normalize([0.5] * 3, [0.5] * 3) + ]) + + t_img = transform(image).unsqueeze(0).cuda() + + with torch.no_grad(): + prediction = model.forward(t_img) + + import matplotlib.pyplot as plt + + plt.imshow(prediction.squeeze().cpu()) + plt.show() + print("here") diff --git a/featup/featurizers/MaskCLIP.py b/featup/featurizers/MaskCLIP.py new file mode 100644 index 0000000000000000000000000000000000000000..d72dba578dbb3121c695ec98d30fd3d879670fd3 --- /dev/null +++ b/featup/featurizers/MaskCLIP.py @@ -0,0 +1,47 @@ +import torch +from torch import nn +import os + +from featup.featurizers.maskclip import clip + + +class MaskCLIPFeaturizer(nn.Module): + + def __init__(self): + super().__init__() + self.model, self.preprocess = clip.load( + "ViT-B/16", + download_root=os.getenv('TORCH_HOME', os.path.join(os.path.expanduser('~'), '.cache', 'torch')) + ) + self.model.eval() + self.patch_size = self.model.visual.patch_size + + def forward(self, img): + b, _, input_size_h, input_size_w = img.shape + patch_h = input_size_h // self.patch_size + patch_w = input_size_w // self.patch_size + features = self.model.get_patch_encodings(img).to(torch.float32) + return features.reshape(b, patch_h, patch_w, -1).permute(0, 3, 1, 2) + + +if __name__ == "__main__": + import torchvision.transforms as T + from PIL import Image + from featup.util import norm, unnorm, crop_to_divisor + + device = "cuda" if torch.cuda.is_available() else "cpu" + + image = Image.open("../samples/lex1.jpg") + load_size = 224 # * 3 + transform = T.Compose([ + T.Resize(load_size, Image.BILINEAR), + # T.CenterCrop(load_size), + T.ToTensor(), + lambda x: crop_to_divisor(x, 16), + norm]) + + model = MaskCLIPFeaturizer().cuda() + + results = model(transform(image).cuda().unsqueeze(0)) + + print(clip.available_models()) diff --git a/featup/featurizers/ResNet.py b/featup/featurizers/ResNet.py new file mode 100644 index 0000000000000000000000000000000000000000..5b84f8a694d18d3191a559c6f776d6c6bd076f77 --- /dev/null +++ b/featup/featurizers/ResNet.py @@ -0,0 +1,16 @@ +from torch import nn + + +class ResNetFeaturizer(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def get_cls_token(self, img): + return self.model.forward(img) + + def get_layer(self, img, layer_num): + return self.model.get_layer(img, layer_num) + + def forward(self, img, layer_num=-1): + return self.model.get_layer(img, layer_num) diff --git a/featup/featurizers/__init__.py b/featup/featurizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/featup/featurizers/dinov2/__init__.py b/featup/featurizers/dinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/featup/featurizers/dinov2/layers/__init__.py b/featup/featurizers/dinov2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05a0b61868e43abb821ca05a813bab2b8b43629e --- /dev/null +++ b/featup/featurizers/dinov2/layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/featup/featurizers/dinov2/layers/attention.py b/featup/featurizers/dinov2/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..0fb76ef2816164729a58cceb18d0f000cfb18777 --- /dev/null +++ b/featup/featurizers/dinov2/layers/attention.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Attention)") + else: + warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/featup/featurizers/dinov2/layers/block.py b/featup/featurizers/dinov2/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..930787b262faac4f2264797496faff75ac56b7cc --- /dev/null +++ b/featup/featurizers/dinov2/layers/block.py @@ -0,0 +1,260 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Block)") + else: + warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + + warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/featup/featurizers/dinov2/layers/dino_head.py b/featup/featurizers/dinov2/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0ace8ffd6297a1dd480b19db407b662a6ea0f565 --- /dev/null +++ b/featup/featurizers/dinov2/layers/dino_head.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/featup/featurizers/dinov2/layers/drop_path.py b/featup/featurizers/dinov2/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5 --- /dev/null +++ b/featup/featurizers/dinov2/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/featup/featurizers/dinov2/layers/layer_scale.py b/featup/featurizers/dinov2/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386 --- /dev/null +++ b/featup/featurizers/dinov2/layers/layer_scale.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/featup/featurizers/dinov2/layers/mlp.py b/featup/featurizers/dinov2/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e --- /dev/null +++ b/featup/featurizers/dinov2/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/featup/featurizers/dinov2/layers/patch_embed.py b/featup/featurizers/dinov2/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339 --- /dev/null +++ b/featup/featurizers/dinov2/layers/patch_embed.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/featup/featurizers/dinov2/layers/swiglu_ffn.py b/featup/featurizers/dinov2/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..5e9dafa4592a408f6874d54853e8f60db5c41f74 --- /dev/null +++ b/featup/featurizers/dinov2/layers/swiglu_ffn.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (SwiGLU)") + else: + warnings.warn("xFormers is disabled (SwiGLU)") + raise ImportError +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/featup/featurizers/maskclip/README.md b/featup/featurizers/maskclip/README.md new file mode 100644 index 0000000000000000000000000000000000000000..95bf7ab481a75b7e5f2ffd52dc52f49b39d9bda6 --- /dev/null +++ b/featup/featurizers/maskclip/README.md @@ -0,0 +1,3 @@ +# CLIP +Modified version of [CLIP](https://github.com/openai/CLIP) with support for dense patch-level feature extraction +(based on [MaskCLIP](https://arxiv.org/abs/2112.01071) parametrization) and interpolation of the positional encoding. diff --git a/featup/featurizers/maskclip/__init__.py b/featup/featurizers/maskclip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bf5a8f20a224a45ae315f2b3e33b2df5c7090223 --- /dev/null +++ b/featup/featurizers/maskclip/__init__.py @@ -0,0 +1,5 @@ +from .clip import * + +""" +Modified from https://github.com/openai/CLIP +""" diff --git a/featup/featurizers/maskclip/bpe_simple_vocab_16e6.txt.gz b/featup/featurizers/maskclip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/featup/featurizers/maskclip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/featup/featurizers/maskclip/clip.py b/featup/featurizers/maskclip/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..1e291a45262bb9eddb4bba9a793160ec8c6f6c27 --- /dev/null +++ b/featup/featurizers/maskclip/clip.py @@ -0,0 +1,247 @@ +import hashlib +import os +import urllib +import warnings +from typing import Any, Union, List +from pkg_resources import packaging + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm + +from .model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + +if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + print(f"Downloading CLIP model from {url}") + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, + unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +TORCH_HUB_ROOT = os.path.expandvars(os.getenv("$TORCH_HUB_ROOT", "$HOME/.torch_hub")) + + +def load( + name: str, + device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", + jit: bool = False, + download_root: str = None +): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.torch_hub/clip" + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or TORCH_HUB_ROOT) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, 'rb') as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[ + torch.IntTensor, torch.LongTensor]: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/featup/featurizers/maskclip/interpolate.py b/featup/featurizers/maskclip/interpolate.py new file mode 100644 index 0000000000000000000000000000000000000000..aaaf7d8f563b085b60115c978b5925b651160a6f --- /dev/null +++ b/featup/featurizers/maskclip/interpolate.py @@ -0,0 +1,54 @@ +import numpy as np +import torch + + +def interpolate_positional_embedding( + positional_embedding: torch.Tensor, x: torch.Tensor, patch_size: int, w: int, h: int +): + """ + Interpolate the positional encoding for CLIP to the number of patches in the image given width and height. + Modified from DINO ViT `interpolate_pos_encoding` method. + https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L174 + """ + assert positional_embedding.ndim == 2, "pos_encoding must be 2D" + + # Number of patches in input + num_patches = x.shape[1] - 1 + # Original number of patches for square images + num_og_patches = positional_embedding.shape[0] - 1 + + if num_patches == num_og_patches and w == h: + # No interpolation needed + return positional_embedding.to(x.dtype) + + dim = x.shape[-1] + class_pos_embed = positional_embedding[:1] # (1, dim) + patch_pos_embed = positional_embedding[1:] # (num_og_patches, dim) + + # Compute number of tokens + w0 = w // patch_size + h0 = h // patch_size + assert w0 * h0 == num_patches, "Number of patches does not match" + + # Add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + + # Interpolate + patch_per_ax = int(np.sqrt(num_og_patches)) + patch_pos_embed_interp = torch.nn.functional.interpolate( + patch_pos_embed.reshape(1, patch_per_ax, patch_per_ax, dim).permute(0, 3, 1, 2), + # (1, dim, patch_per_ax, patch_per_ax) + scale_factor=(w0 / patch_per_ax, h0 / patch_per_ax), + mode="bicubic", + align_corners=False, + recompute_scale_factor=False, + ) # (1, dim, w0, h0) + assert ( + int(w0) == patch_pos_embed_interp.shape[-2] and int(h0) == patch_pos_embed_interp.shape[-1] + ), "Interpolation error." + + patch_pos_embed_interp = patch_pos_embed_interp.permute(0, 2, 3, 1).reshape(-1, dim) # (w0 * h0, dim) + # Concat class token embedding and interpolated patch embeddings + pos_embed_interp = torch.cat([class_pos_embed, patch_pos_embed_interp], dim=0) # (w0 * h0 + 1, dim) + return pos_embed_interp.to(x.dtype) diff --git a/featup/featurizers/maskclip/model.py b/featup/featurizers/maskclip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..a7aff7251977399677438b9a55c3e21831f6b9e2 --- /dev/null +++ b/featup/featurizers/maskclip/model.py @@ -0,0 +1,506 @@ +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from .interpolate import interpolate_positional_embedding + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + self.spacial_dim = spacial_dim + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) + + def forward_v(self, x: torch.Tensor): + """ + Forward function for computing the value features for dense prediction (i.e., features for every image patch). + """ + _, _, w, h = x.shape + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + + # Interpolate positional embedding to match the size of the input + interpolated_pe = interpolate_positional_embedding(self.positional_embedding, x.permute(1, 0, 2), patch_size=1, w=w, h=h) + x = x + interpolated_pe[:, None, :] # (HW+1)NC + + v_in = F.linear(x, self.v_proj.weight, self.v_proj.bias) + v_out = F.linear(v_in, self.c_proj.weight, self.c_proj.bias) + v_out = v_out.permute(1, 0, 2) # (HW+1)NC -> N(HW+1)C + return v_out + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x, patch_output: bool = False): + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + if patch_output: + x = self.attnpool.forward_v(x) + x = x[:, 1:, :] # remove the cls token + else: + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward_v(self, x: torch.Tensor): + """ + Forward function for computing the value features for dense prediction (i.e., features for every image patch). + """ + # Get the weights and biases for the value projection, multihead attention uses 3 * embed_dim for the input projection + v_in_proj_weight = self.attn.in_proj_weight[-self.attn.embed_dim:] + v_in_proj_bias = self.attn.in_proj_bias[-self.attn.embed_dim:] + + v_in = F.linear(self.ln_1(x), v_in_proj_weight, v_in_proj_bias) + v_out = F.linear(v_in, self.attn.out_proj.weight, self.attn.out_proj.bias) + + # Using the value features works the best. Adding this to 'x' or feeding 'v' to the LayerNorm then MLP degrades the performance + return v_out + + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + self.patch_size = patch_size + + def forward(self, x: torch.Tensor, patch_output: bool = False): + _, _, w, h = x.shape + + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + interpolate_positional_embedding(self.positional_embedding, x, patch_size=self.patch_size, w=w, h=h) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + + if patch_output: + *layers, last_resblock = self.transformer.resblocks + penultimate = nn.Sequential(*layers) + + x = penultimate(x) + x = last_resblock.forward_v(x) + x = x.permute(1, 0, 2) # LND -> NLD + + # Extract the patch tokens, not the class token + x = x[:, 1:, :] + x = self.ln_post(x) + if self.proj is not None: + # This is equivalent to conv1d + x = x @ self.proj + return x + + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def get_patch_encodings(self, image) -> torch.Tensor: + """ Get the encodings for each patch in the image """ + return self.visual(image.type(self.dtype), patch_output=True) + + def get_image_encoder_projection(self) -> nn.Parameter: + """ Get vision transformer projection matrix.""" + assert isinstance(self.visual, VisionTransformer) + return self.visual.proj + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() diff --git a/featup/featurizers/maskclip/simple_tokenizer.py b/featup/featurizers/maskclip/simple_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..542b79d4fd44d54ba0933ab349e70ed4f782cea4 --- /dev/null +++ b/featup/featurizers/maskclip/simple_tokenizer.py @@ -0,0 +1,138 @@ +import gzip +import html +import os +from collections.abc import Sequence +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + # note: pretty hacky but it is okay! + # ge: bad.this is used by the cli_multi_label.py script + if not isinstance(text, str): + text = ', '.join(text) + + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text diff --git a/featup/featurizers/modules/__init__.py b/featup/featurizers/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/featup/featurizers/modules/layers.py b/featup/featurizers/modules/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..b94813a0dcdc971ee34c50d7f9f322803c3d41e6 --- /dev/null +++ b/featup/featurizers/modules/layers.py @@ -0,0 +1,309 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +import math + +__all__ = ['forward_hook', 'AdaptiveAvgPool2d', 'Add', 'AvgPool2d', 'BatchNorm2d', 'Clone', 'Conv2d', 'ConvTranspose2d', + 'Dropout', 'Identity', 'LeakyReLU', 'Linear', 'MaxPool2d', 'Multiply', 'ReLU', 'Sequential', 'safe_divide', + 'ZeroPad2d', 'LayerNorm', 'GELU', 'einsum', 'Softmax'] + + +def safe_divide(a, b): + return a / (b + b.eq(0).type(b.type()) * 1e-9) * b.ne(0).type(b.type()) + + +def forward_hook(self, input, output): + if type(input[0]) in (list, tuple): + self.X = [] + for i in input[0]: + x = i.detach() + x.requires_grad = True + self.X.append(x) + else: + self.X = input[0].detach() + self.X.requires_grad = True + + self.Y = output + + +class RelProp(nn.Module): + def __init__(self): + super(RelProp, self).__init__() + # if not self.training: + self.register_forward_hook(forward_hook) + + def gradprop(self, Z, X, S): + C = torch.autograd.grad(Z, X, S, retain_graph=True) + return C + + def relprop(self, R, alpha=1): + return R + + +class RelPropSimple(RelProp): + def relprop(self, R, alpha=1): + Z = self.forward(self.X) + S = safe_divide(R, Z) + C = self.gradprop(Z, self.X, S) + + if torch.is_tensor(self.X) == False: + outputs = [] + outputs.append(self.X[0] * C[0]) + outputs.append(self.X[1] * C[1]) + else: + outputs = self.X * C[0] + return outputs + + +class Identity(nn.Identity, RelProp): + pass + + +class ReLU(nn.ReLU, RelProp): + pass + + +class GELU(nn.GELU, RelProp): + pass + +class LeakyReLU(nn.LeakyReLU, RelProp): + pass + +class Softmax(nn.Softmax, RelProp): + pass + +class einsum(RelPropSimple): + def __init__(self, equation): + super().__init__() + self.equation = equation + def forward(self, *operands): + return torch.einsum(self.equation, *operands) + +class Dropout(nn.Dropout, RelProp): + pass + + +class MaxPool2d(nn.MaxPool2d, RelPropSimple): + pass + +class LayerNorm(nn.LayerNorm, RelProp): + pass + +class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelProp): + def relprop(self, R, alpha=1): + px = torch.clamp(self.X, min=0) + + def f(x1): + Z1 = F.adaptive_avg_pool2d(x1, self.output_size) + S1 = safe_divide(R, Z1) + C1 = x1 * self.gradprop(Z1, x1, S1)[0] + return C1 + + activator_relevances = f(px) + out = activator_relevances + return out + + +class ZeroPad2d(nn.ZeroPad2d, RelPropSimple): + def relprop(self, R, alpha=1): + Z = self.forward(self.X) + S = safe_divide(R, Z) + C = self.gradprop(Z, self.X, S) + outputs = self.X * C[0] + return outputs + + +class AvgPool2d(nn.AvgPool2d, RelPropSimple): + pass + + +class Add(RelPropSimple): + def forward(self, inputs): + return torch.add(*inputs) + + def relprop(self, R, alpha): + Z = self.forward(self.X) + S = safe_divide(R, Z) + C = self.gradprop(Z, self.X, S) + + a = self.X[0] * C[0] + b = self.X[1] * C[1] + + a_sum = a.sum() + b_sum = b.sum() + + a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum() + b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum() + + a = a * safe_divide(a_fact, a.sum()) + b = b * safe_divide(b_fact, b.sum()) + + outputs = [a, b] + + return outputs + + +class Clone(RelProp): + def forward(self, input, num): + self.__setattr__('num', num) + outputs = [] + for _ in range(num): + outputs.append(input) + + return outputs + + def relprop(self, R, alpha = 1): + Z = [] + for _ in range(self.num): + Z.append(self.X) + S = [safe_divide(r, z) for r, z in zip(R, Z)] + C = self.gradprop(Z, self.X, S)[0] + + R = self.X * C + + return R + + +class Multiply(RelPropSimple): + def forward(self, inputs): + return torch.mul(*inputs) + + def relprop(self, R, alpha=1): + x0 = torch.clamp(self.X[0], min=0) + x1 = torch.clamp(self.X[1], min=0) + x = [x0, x1] + Z = self.forward(x) + S = safe_divide(R, Z) + C = self.gradprop(Z, x, S) + outputs = [] + outputs.append(x[0] * C[0]) + outputs.append(x[1] * C[1]) + return outputs + +class Sequential(nn.Sequential): + def relprop(self, R, alpha=1): + for m in reversed(self._modules.values()): + R = m.relprop(R, alpha) + return R + + + +class BatchNorm2d(nn.BatchNorm2d, RelProp): + def relprop(self, R, alpha=1): + X = self.X + beta = 1 - alpha + weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / ( + (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5)) + Z = X * weight + 1e-9 + S = R / Z + Ca = S * weight + R = self.X * (Ca) + return R + + +class Linear(nn.Linear, RelProp): + def relprop(self, R, alpha=1): + beta = alpha - 1 + pw = torch.clamp(self.weight, min=0) + nw = torch.clamp(self.weight, max=0) + px = torch.clamp(self.X, min=0) + nx = torch.clamp(self.X, max=0) + + # def f(w1, w2, x1, x2): + # Z1 = F.linear(x1, w1) + # Z2 = F.linear(x2, w2) + # S1 = safe_divide(R, Z1) + # S2 = safe_divide(R, Z2) + # C1 = x1 * self.gradprop(Z1, x1, S1)[0] + # C2 = x2 * self.gradprop(Z2, x2, S2)[0] + # return C1 #+ C2 + + def f(w1, w2, x1, x2): + Z1 = F.linear(x1, w1) + Z2 = F.linear(x2, w2) + Z = Z1 + Z2 + S = safe_divide(R, Z) + C1 = x1 * self.gradprop(Z1, x1, S)[0] + C2 = x2 * self.gradprop(Z2, x2, S)[0] + return C1 + C2 + + activator_relevances = f(pw, nw, px, nx) + inhibitor_relevances = f(nw, pw, px, nx) + + out = alpha * activator_relevances - beta * inhibitor_relevances + + return out + + + +class Conv2d(nn.Conv2d, RelProp): + + def relprop(self, R, alpha=1): + if self.X.shape[1] == 3: + pw = torch.clamp(self.weight, min=0) + nw = torch.clamp(self.weight, max=0) + X = self.X + L = self.X * 0 + \ + torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, + keepdim=True)[0] + H = self.X * 0 + \ + torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, + keepdim=True)[0] + Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \ + torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \ + torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9 + + S = R / Za + C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw) + R = C + else: + beta = alpha - 1 + pw = torch.clamp(self.weight, min=0) + nw = torch.clamp(self.weight, max=0) + px = torch.clamp(self.X, min=0) + nx = torch.clamp(self.X, max=0) + + def f(w1, w2, x1, x2): + Z1 = F.conv2d(x1, w1, bias=self.bias, stride=self.stride, padding=self.padding, groups=self.groups) + Z2 = F.conv2d(x2, w2, bias=self.bias, stride=self.stride, padding=self.padding, groups=self.groups) + Z = Z1 + Z2 + S = safe_divide(R, Z) + C1 = x1 * self.gradprop(Z1, x1, S)[0] + C2 = x2 * self.gradprop(Z2, x2, S)[0] + return C1 + C2 + + activator_relevances = f(pw, nw, px, nx) + inhibitor_relevances = f(nw, pw, px, nx) + + R = alpha * activator_relevances - beta * inhibitor_relevances + return R + + + +class ConvTranspose2d(nn.ConvTranspose2d, RelProp): + def relprop(self, R, alpha=1): + pw = torch.clamp(self.weight, min=0) + px = torch.clamp(self.X, min=0) + + def f(w1, x1): + Z1 = F.conv_transpose2d(x1, w1, bias=None, stride=self.stride, padding=self.padding, + output_padding=self.output_padding) + S1 = safe_divide(R, Z1) + C1 = x1 * self.gradprop(Z1, x1, S1)[0] + return C1 + + activator_relevances = f(pw, px) + R = activator_relevances + return R + + + +if __name__ == '__main__': + convt = ConvTranspose2d(100, 50, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False).cuda() + + rand = torch.rand((1, 100, 224, 224)).cuda() + out = convt(rand) + rel = convt.relprop(out) + + print(out.shape) diff --git a/featup/featurizers/modules/resnet.py b/featup/featurizers/modules/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..57af947a23fe6b37dcb73e6b90d4511a46216edd --- /dev/null +++ b/featup/featurizers/modules/resnet.py @@ -0,0 +1,339 @@ +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo + +from featup.featurizers.modules.layers import * +import torch + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152'] + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.clone = Clone() + + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = BatchNorm2d(planes) + self.conv2 = conv3x3(planes, planes) + self.bn2 = BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + self.relu1 = ReLU(inplace=True) + self.relu2 = ReLU(inplace=True) + + self.add = Add() + + self.register_forward_hook(forward_hook) + + def forward(self, x): + x1, x2 = self.clone(x, 2) + + out = self.conv1(x1) + out = self.bn1(out) + out = self.relu1(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + x2 = self.downsample(x2) + + out = self.add([out, x2]) + out = self.relu2(out) + + return out + + def relprop(self, R, alpha): + out = self.relu2.relprop(R, alpha) + out, x2 = self.add.relprop(out, alpha) + + if self.downsample is not None: + x2 = self.downsample.relprop(x2, alpha) + + out = self.bn2.relprop(out, alpha) + out = self.conv2.relprop(out, alpha) + + out = self.relu1.relprop(out, alpha) + out = self.bn1.relprop(out, alpha) + x1 = self.conv1.relprop(out, alpha) + + return self.clone.relprop([x1, x2], alpha) + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + + self.conv1 = conv1x1(inplanes, planes) + self.bn1 = BatchNorm2d(planes) + self.conv2 = conv3x3(planes, planes, stride) + self.bn2 = BatchNorm2d(planes) + self.conv3 = conv1x1(planes, planes * self.expansion) + self.bn3 = BatchNorm2d(planes * self.expansion) + self.downsample = downsample + self.stride = stride + + self.relu1 = ReLU(inplace=True) + self.relu2 = ReLU(inplace=True) + self.relu3 = ReLU(inplace=True) + + self.add = Add() + + self.register_forward_hook(forward_hook) + + def forward(self, x): + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu2(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + x = self.downsample(x) + + out = self.add([out, x]) + out = self.relu3(out) + + return out + + def relprop(self, R, alpha): + out = self.relu3.relprop(R, alpha) + + out, x = self.add.relprop(out, alpha) + + if self.downsample is not None: + x = self.downsample.relprop(x, alpha) + + out = self.bn3.relprop(out, alpha) + out = self.conv3.relprop(out, alpha) + + out = self.relu2.relprop(out, alpha) + out = self.bn2.relprop(out, alpha) + out = self.conv2.relprop(out, alpha) + + out = self.relu1.relprop(out, alpha) + out = self.bn1.relprop(out, alpha) + x1 = self.conv1.relprop(out, alpha) + + return x1 + x + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, long=False, zero_init_residual=False): + super(ResNet, self).__init__() + self.inplanes = 64 + self.conv1 = Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = BatchNorm2d(64) + self.relu = ReLU(inplace=True) + self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = AdaptiveAvgPool2d((1, 1)) + self.fc = Linear(512 * block.expansion, num_classes) + self.long = long + self.num_classes = num_classes + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return Sequential(*layers) + + def CLRP(self, x): + maxindex = torch.argmax(x, dim=1) + R = torch.ones(x.shape, device=x.device) + R /= -self.num_classes + for i in range(R.size(0)): + R[i, maxindex[i]] = 1 + return R + + def forward(self, img): + x = self.conv1(img) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + layer1 = self.layer1(x) + layer2 = self.layer2(layer1) + layer3 = self.layer3(layer2) + layer4 = self.layer4(layer3) + + x = self.avgpool(layer4) + x = x.view(x.size(0), -1) + return self.fc(x) + + def get_layer(self, img, layer_num): + x = self.conv1(img) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + layer1 = self.layer1(x) + if layer_num == 1: + return layer1 + layer2 = self.layer2(layer1) + if layer_num == 2: + return layer2 + layer3 = self.layer3(layer2) + if layer_num == 3: + return layer3 + layer4 = self.layer4(layer3) + if layer_num == 4 or layer_num == -1: + return layer4 + if isinstance(layer_num, tuple): + return [[layer1, layer2, layer3, layer4][i-1] for i in layer_num] + + raise ValueError(f"Unknown layer num: {layer_num}") + + def relevance_cam(self, large_img, layer_num, upsampler): + small_img = F.interpolate(large_img, size=(224, 224), mode='bilinear') + layer1, layer2, layer3, layer4 = self.get_layer(small_img, (1, 2, 3, 4)) + x = self.avgpool(layer4) + x = x.view(x.size(0), -1) + z = self.fc(x) + + R = self.CLRP(z) + R = self.fc.relprop(R, 1) + R = R.reshape_as(self.avgpool.Y) + R4 = self.avgpool.relprop(R, 1) + + if layer_num == 4: + r_weight4 = torch.mean(R4, dim=(2, 3), keepdim=True) + r_cam4 = upsampler(large_img, source=layer4) * r_weight4 + r_cam4 = torch.sum(r_cam4, dim=(1), keepdim=True) + return r_cam4 + elif layer_num == 3: + R3 = self.layer4.relprop(R4, 1) + r_weight3 = torch.mean(R3, dim=(2, 3), keepdim=True) + r_cam3 = upsampler(large_img, source=layer3) * r_weight3 + r_cam3 = torch.sum(r_cam3, dim=(1), keepdim=True) + return r_cam3 + elif layer_num == 2: + R3 = self.layer4.relprop(R4, 1) + R2 = self.layer3.relprop(R3, 1) + r_weight2 = torch.mean(R2, dim=(2, 3), keepdim=True) + r_cam2 = upsampler(large_img, source=layer2) * r_weight2 + r_cam2 = torch.sum(r_cam2, dim=(1), keepdim=True) + return r_cam2 + else: + raise ValueError(f"Unknown layer_num: {layer_num}") + + +def resnet18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) + return model + + +def resnet34(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) + return model + + +def resnet50(pretrained=False, long=False, **kwargs): + """Constructs a ResNet-50 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], long=long, **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) + return model + + +def resnet101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) + return model + + +def resnet152(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) + return model diff --git a/featup/featurizers/modules/vgg.py b/featup/featurizers/modules/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..98f157b4d1ff992068f8b1cc4d7abfcbc5e45422 --- /dev/null +++ b/featup/featurizers/modules/vgg.py @@ -0,0 +1,366 @@ +import copy +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo +import torch +from featup.featurizers.modules.layers import * + +__all__ = [ + 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', + 'vgg19_bn', 'vgg19', +] + + +model_urls = { + 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', + 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', + 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', + 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', + 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', + 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', + 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', + 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', +} + +class VGG_spread(nn.Module): + + def __init__(self, features, num_classes=1000, init_weights=True): + super(VGG_spread, self).__init__() + self.features = features + self.avgpool = AdaptiveAvgPool2d((7, 7)) + self.classifier = Sequential( + Linear(512 * 7 * 7, 4096), + ReLU(True), + Dropout(), + Linear(4096, 4096), + ReLU(True), + Dropout(), + Linear(4096, num_classes), + ) + if init_weights: + self._initialize_weights() + + def forward(self, x): + for layer in self.features: + x = layer(x) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + + def relprop(self, R, alpha): + x = self.classifier.relprop(R, alpha) + x = x.reshape_as(next(reversed(self.features._modules.values())).Y) + x = self.avgpool.relprop(x, alpha) + x = self.features.relprop(x, alpha) + return x + + def m_relprop(self, R, pred, alpha): + x = self.classifier.m_relprop(R, pred, alpha) + if torch.is_tensor(x) == False: + for i in range(len(x)): + x[i] = x[i].reshape_as(next(reversed(self.features._modules.values())).Y) + else: + x = x.reshape_as(next(reversed(self.features._modules.values())).Y) + x = self.avgpool.m_relprop(x, pred, alpha) + x = self.features.m_relprop(x, pred, alpha) + return x + + def RAP_relprop(self, R): + x1 = self.classifier.RAP_relprop(R) + if torch.is_tensor(x1) == False: + for i in range(len(x1)): + x1[i] = x1[i].reshape_as(next(reversed(self.features._modules.values())).Y) + else: + x1 = x1.reshape_as(next(reversed(self.features._modules.values())).Y) + x1 = self.avgpool.RAP_relprop(x1) + x1 = self.features.RAP_relprop(x1) + return x1 + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + +class VGG(nn.Module): + + def __init__(self, features, num_classes=1000, init_weights=True): + super(VGG, self).__init__() + self.features = features + self.avgpool = AdaptiveAvgPool2d((7, 7)) + self.classifier = Sequential( + Linear(512 * 7 * 7, 4096), + ReLU(True), + Dropout(), + Linear(4096, 4096), + ReLU(True), + Dropout(), + Linear(4096, num_classes), + ) + self.num_classes = num_classes + if init_weights: + self._initialize_weights() + + def CLRP(self, x, maxindex = [None]): + if maxindex == [None]: + maxindex = torch.argmax(x, dim=1) + R = torch.ones(x.shape, x.device) + R /= -self.num_classes + for i in range(R.size(0)): + R[i, maxindex[i]] = 1 + return R + + def upsample(self, source, guidance_unscaled, upsampler, scale): + _, _, H, W = source.shape + guidance = F.interpolate(guidance_unscaled, size=(H * scale, W * scale), mode='bilinear') + return upsampler(source, guidance) + + def forward(self, x,mode='output', target_class = [None], upsampler=None, scale=1): + inp = copy.deepcopy(x) + for i, layer in enumerate(self.features): + x = layer(x) + if mode.lstrip('-').isnumeric(): + if int(mode) == i: + target_layer = x + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + + if mode == 'output': + return x + + R = self.CLRP(x, target_class) + R = self.classifier.relprop(R) + R = R.reshape_as(next(reversed(self.features._modules.values())).Y) + R = self.avgpool.relprop(R) + + for i in range(len(self.features)-1, int(mode), -1): + R = self.features[i].relprop(R) + + if upsampler is not None: + target_layer = self.upsample(target_layer, inp, upsampler, scale) + + r_weight = torch.mean(R, dim=(2, 3), keepdim=True) + r_cam = target_layer * r_weight + r_cam = torch.sum(r_cam, dim=(1), keepdim=True) + return r_cam, x + + + + def relprop(self, R, alpha, flag=-1): + x = self.classifier.relprop(R, alpha) + x = x.reshape_as(next(reversed(self.features._modules.values())).Y) + x = self.avgpool.relprop(x, alpha) + # x = self.features.relprop(x, alpha) + for i in range(43, flag, -1): + x = self.features[i].relprop(x, alpha) + return x + + def m_relprop(self, R, pred, alpha): + x = self.classifier.m_relprop(R, pred, alpha) + if torch.is_tensor(x) == False: + for i in range(len(x)): + x[i] = x[i].reshape_as(next(reversed(self.features._modules.values())).Y) + else: + x = x.reshape_as(next(reversed(self.features._modules.values())).Y) + x = self.avgpool.m_relprop(x, pred, alpha) + x = self.features.m_relprop(x, pred, alpha) + return x + + def RAP_relprop(self, R): + x1 = self.classifier.RAP_relprop(R) + if torch.is_tensor(x1) == False: + for i in range(len(x1)): + x1[i] = x1[i].reshape_as(next(reversed(self.features._modules.values())).Y) + else: + x1 = x1.reshape_as(next(reversed(self.features._modules.values())).Y) + x1 = self.avgpool.RAP_relprop(x1) + x1 = self.features.RAP_relprop(x1) + + return x1 + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + +def make_layers(cfg, batch_norm=False): + layers = [] + in_channels = 3 + + for v in cfg: + if v == 'M': + layers += [MaxPool2d(kernel_size=2, stride=2)] + else: + conv2d = Conv2d(in_channels, v, kernel_size=3, padding=1) + if batch_norm: + layers += [conv2d, BatchNorm2d(v), ReLU(inplace=True)] + else: + layers += [conv2d, ReLU(inplace=True)] + in_channels = v + + return Sequential(*layers) + +def make_layers_list(cfg, batch_norm=False): + layers = [] + in_channels = 3 + for v in cfg: + if v == 'M': + layers += [MaxPool2d(kernel_size=2, stride=2)] + else: + conv2d = Conv2d(in_channels, v, kernel_size=3, padding=1) + if batch_norm: + layers += [conv2d, BatchNorm2d(v), ReLU(inplace=True)] + else: + layers += [conv2d, ReLU(inplace=True)] + in_channels = v + return layers + + +cfg = { + 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} + + +def vgg11(pretrained=False, **kwargs): + """VGG 11-layer model (configuration "A") + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['A']), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) + return model + + +def vgg11_bn(pretrained=False, **kwargs): + """VGG 11-layer model (configuration "A") with batch normalization + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) + return model + + +def vgg13(pretrained=False, **kwargs): + """VGG 13-layer model (configuration "B") + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['B']), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) + return model + + +def vgg13_bn(pretrained=False, **kwargs): + """VGG 13-layer model (configuration "B") with batch normalization + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) + return model + + +def vgg16(pretrained=False, **kwargs): + """VGG 16-layer model (configuration "D") + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['D']), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) + return model + +def vgg16_spread(pretrained=False, **kwargs): + """VGG 16-layer model (configuration "D") + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG_spread(make_layers_list(cfg['D']), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) + return model + +def vgg16_bn(pretrained=False, **kwargs): + """VGG 16-layer model (configuration "D") with batch normalization + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) + return model + + +def vgg19(pretrained=False, **kwargs): + """VGG 19-layer model (configuration "E") + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['E']), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) + return model + + +def vgg19_bn(pretrained=False, **kwargs): + """VGG 19-layer model (configuration 'E') with batch normalization + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn'])) + return model diff --git a/featup/featurizers/util.py b/featup/featurizers/util.py new file mode 100644 index 0000000000000000000000000000000000000000..4d9e4184fcf780731b3e0ed058f9a44c72ed7f3e --- /dev/null +++ b/featup/featurizers/util.py @@ -0,0 +1,73 @@ +import torch + +def get_featurizer(name, activation_type="key", **kwargs): + name = name.lower() + if name == "vit": + from .DINO import DINOFeaturizer + patch_size = 16 + model = DINOFeaturizer("vit_small_patch16_224", patch_size, activation_type) + dim = 384 + elif name == "midas": + from .MIDAS import MIDASFeaturizer + patch_size = 16 + model = MIDASFeaturizer(output_root=kwargs["output_root"]) + dim = 768 + elif name == "dino16": + from .DINO import DINOFeaturizer + patch_size = 16 + model = DINOFeaturizer("dino_vits16", patch_size, activation_type) + dim = 384 + elif name == "dino8": + from .DINO import DINOFeaturizer + patch_size = 8 + model = DINOFeaturizer("dino_vits8", patch_size, activation_type) + dim = 384 + elif name == "dinov2": + from .DINOv2 import DINOv2Featurizer + patch_size = 14 + model = DINOv2Featurizer("dinov2_vits14", patch_size, activation_type) + dim = 384 + elif name == "clip": + from .CLIP import CLIPFeaturizer + patch_size = 16 + model = CLIPFeaturizer() + dim = 512 + elif name == "maskclip": + from .MaskCLIP import MaskCLIPFeaturizer + patch_size = 16 + model = MaskCLIPFeaturizer() + dim = 512 + elif name == "mae": + from .MAE import MAEFeaturizer + patch_size = 16 + model = MAEFeaturizer(**kwargs) + dim = 1024 + elif name == "mocov3": + from .MOCOv3 import MOCOv3Featurizer + patch_size = 16 + model = MOCOv3Featurizer() + dim = 384 + elif name == "msn": + from .MSN import MSNFeaturizer + patch_size = 16 + model = MSNFeaturizer() + dim = 384 + elif name == "pixels": + patch_size = 1 + model = lambda x: x + dim = 3 + elif name == "resnet50": + from .modules.resnet import resnet50 + from .ResNet import ResNetFeaturizer + model = ResNetFeaturizer(resnet50(pretrained=True)) + patch_size = 1 + dim = 2048 + elif name == "deeplab": + from .DeepLabV3 import DeepLabV3Featurizer + model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet50', pretrained=True) + model = DeepLabV3Featurizer(model) + patch_size = 1 + dim = 2048 + else: + raise ValueError("unknown model: {}".format(name)) + return model, patch_size, dim diff --git a/featup/layers.py b/featup/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..782603d958d652ad68ddc1fba89def44134ca3bd --- /dev/null +++ b/featup/layers.py @@ -0,0 +1,89 @@ +import torch + + +def id_conv(dim, strength=.9): + conv = torch.nn.Conv2d(dim, dim, 1, padding="same") + start_w = conv.weight.data + conv.weight.data = torch.nn.Parameter( + torch.eye(dim, device=start_w.device).unsqueeze(-1).unsqueeze(-1) * strength + start_w * (1 - strength)) + conv.bias.data = torch.nn.Parameter(conv.bias.data * (1 - strength)) + return conv + + +class ImplicitFeaturizer(torch.nn.Module): + + def __init__(self, color_feats=True, n_freqs=10, learn_bias=False, time_feats=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.color_feats = color_feats + self.time_feats = time_feats + self.n_freqs = n_freqs + self.learn_bias = learn_bias + + self.dim_multiplier = 2 + + if self.color_feats: + self.dim_multiplier += 3 + + if self.time_feats: + self.dim_multiplier += 1 + + if self.learn_bias: + self.biases = torch.nn.Parameter(torch.randn(2, self.dim_multiplier, n_freqs).to(torch.float32)) + + def forward(self, original_image): + b, c, h, w = original_image.shape + grid_h = torch.linspace(-1, 1, h, device=original_image.device) + grid_w = torch.linspace(-1, 1, w, device=original_image.device) + feats = torch.cat([t.unsqueeze(0) for t in torch.meshgrid([grid_h, grid_w])]).unsqueeze(0) + feats = torch.broadcast_to(feats, (b, feats.shape[1], h, w)) + + if self.color_feats: + feat_list = [feats, original_image] + else: + feat_list = [feats] + + feats = torch.cat(feat_list, dim=1).unsqueeze(1) + freqs = torch.exp(torch.linspace(-2, 10, self.n_freqs, device=original_image.device)) \ + .reshape(1, self.n_freqs, 1, 1, 1) + feats = (feats * freqs) + + if self.learn_bias: + sin_feats = feats + self.biases[0].reshape(1, self.n_freqs, self.dim_multiplier, 1, 1) + cos_feats = feats + self.biases[1].reshape(1, self.n_freqs, self.dim_multiplier, 1, 1) + else: + sin_feats = feats + cos_feats = feats + + sin_feats = sin_feats.reshape(b, self.n_freqs * self.dim_multiplier, h, w) + cos_feats = cos_feats.reshape(b, self.n_freqs * self.dim_multiplier, h, w) + + if self.color_feats: + all_feats = [torch.sin(sin_feats), torch.cos(cos_feats), original_image] + else: + all_feats = [torch.sin(sin_feats), torch.cos(cos_feats)] + + return torch.cat(all_feats, dim=1) + + +class MinMaxScaler(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + c = x.shape[1] + flat_x = x.permute(1, 0, 2, 3).reshape(c, -1) + flat_x_min = flat_x.min(dim=-1).values.reshape(1, c, 1, 1) + flat_x_scale = flat_x.max(dim=-1).values.reshape(1, c, 1, 1) - flat_x_min + return ((x - flat_x_min) / flat_x_scale.clamp_min(0.0001)) - .5 + + +class ChannelNorm(torch.nn.Module): + + def __init__(self, dim, *args, **kwargs): + super().__init__(*args, **kwargs) + self.norm = torch.nn.LayerNorm(dim) + + def forward(self, x): + new_x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return new_x diff --git a/featup/losses.py b/featup/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..46c259e1f6f5e579946c6aa4921a05bd69f8e6f7 --- /dev/null +++ b/featup/losses.py @@ -0,0 +1,198 @@ +import torch +import torch.nn as nn + + +def entropy(t): + return -(t * torch.log(t.clamp_min(.0000001))).sum(dim=[-1, -2, -3]).mean() + + +def total_variation(img): + b, c, h, w = img.size() + return ((img[:, :, 1:, :] - img[:, :, :-1, :]).square().sum() + + (img[:, :, :, 1:] - img[:, :, :, :-1]).square().sum()) / (b * c * h * w) + + +class SampledCRFLoss(torch.nn.Module): + + def __init__(self, n_samples, alpha, beta, gamma, w1, w2, shift): + super(SampledCRFLoss, self).__init__() + self.alpha = alpha + self.beta = beta + self.gamma = gamma + self.w1 = w1 + self.w2 = w2 + self.n_samples = n_samples + self.shift = shift + + def forward(self, guidance, features): + device = features.device + assert (guidance.shape[0] == features.shape[0]) + assert (guidance.shape[2:] == features.shape[2:]) + h = guidance.shape[2] + w = guidance.shape[3] + + coords = torch.cat([ + torch.randint(0, h, size=[1, self.n_samples], device=device), + torch.randint(0, w, size=[1, self.n_samples], device=device)], 0) + norm_coords = coords / torch.tensor([h, w], device=guidance.device).unsqueeze(-1) + + selected_guidance = guidance[:, :, coords[0, :], coords[1, :]] + + coord_diff = (norm_coords.unsqueeze(-1) - norm_coords.unsqueeze(-2)).square().sum(0).unsqueeze(0) + guidance_diff = (selected_guidance.unsqueeze(-1) - selected_guidance.unsqueeze(-2)).square().sum(1) + + sim_kernel = self.w1 * torch.exp(- coord_diff / (2 * self.alpha) - guidance_diff / (2 * self.beta)) + \ + self.w2 * torch.exp(- coord_diff / (2 * self.gamma)) - self.shift + + # selected_clusters = F.normalize(features[:, :, coords[0, :], coords[1, :]], dim=1) + # cluster_sims = torch.einsum("bcn,bcm->bnm", selected_clusters, selected_clusters) + selected_feats = features[:, :, coords[0, :], coords[1, :]] + feat_diff = (selected_feats.unsqueeze(-1) - selected_feats.unsqueeze(-2)).square().sum(1) + + return (feat_diff * sim_kernel).mean() + + +class TVLoss(torch.nn.Module): + + def __init__(self): + super(TVLoss, self).__init__() + + def forward(self, img): + b, c, h, w = img.size() + return ((img[:, :, 1:, :] - img[:, :, :-1, :]).square().sum() + + (img[:, :, :, 1:] - img[:, :, :, :-1]).square().sum()) / (b * c * h * w) + + +def compute_scale_and_shift(prediction, target, mask): + # system matrix: A = [[a_00, a_01], [a_10, a_11]] + a_00 = torch.sum(mask * prediction * prediction, (1, 2)) + a_01 = torch.sum(mask * prediction, (1, 2)) + a_11 = torch.sum(mask, (1, 2)) + + # right hand side: b = [b_0, b_1] + b_0 = torch.sum(mask * prediction * target, (1, 2)) + b_1 = torch.sum(mask * target, (1, 2)) + + # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b + x_0 = torch.zeros_like(b_0) + x_1 = torch.zeros_like(b_1) + + det = a_00 * a_11 - a_01 * a_01 + valid = det.nonzero() + + x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] + x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] + + return x_0, x_1 + + +def reduction_batch_based(image_loss, M): + # average of all valid pixels of the batch + + # avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0) + divisor = torch.sum(M) + + if divisor == 0: + return 0 + else: + return torch.sum(image_loss) / divisor + + +def reduction_image_based(image_loss, M): + # mean of average of valid pixels of an image + + # avoid division by 0 (if M = sum(mask) = 0: image_loss = 0) + valid = M.nonzero() + + image_loss[valid] = image_loss[valid] / M[valid] + + return torch.mean(image_loss) + + +def mse_loss(prediction, target, mask, reduction=reduction_batch_based): + M = torch.sum(mask, (1, 2)) + res = prediction - target + image_loss = torch.sum(mask * res * res, (1, 2)) + + return reduction(image_loss, 2 * M) + + +def gradient_loss(prediction, target, mask, reduction=reduction_batch_based): + M = torch.sum(mask, (1, 2)) + + diff = prediction - target + diff = torch.mul(mask, diff) + + grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) + mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) + grad_x = torch.mul(mask_x, grad_x) + + grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) + mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) + grad_y = torch.mul(mask_y, grad_y) + + image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2)) + + return reduction(image_loss, M) + + +class MSELoss(nn.Module): + def __init__(self, reduction='batch-based'): + super().__init__() + + if reduction == 'batch-based': + self.__reduction = reduction_batch_based + else: + self.__reduction = reduction_image_based + + def forward(self, prediction, target, mask): + return mse_loss(prediction, target, mask, reduction=self.__reduction) + + +class GradientLoss(nn.Module): + def __init__(self, scales=4, reduction='batch-based'): + super().__init__() + + if reduction == 'batch-based': + self.__reduction = reduction_batch_based + else: + self.__reduction = reduction_image_based + + self.__scales = scales + + def forward(self, prediction, target, mask): + total = 0 + + for scale in range(self.__scales): + step = pow(2, scale) + + total += gradient_loss(prediction[:, ::step, ::step], target[:, ::step, ::step], + mask[:, ::step, ::step], reduction=self.__reduction) + + return total + + +class ScaleAndShiftInvariantLoss(nn.Module): + def __init__(self, alpha=0.5, scales=4, reduction='batch-based'): + super().__init__() + + self.__data_loss = MSELoss(reduction=reduction) + self.__regularization_loss = GradientLoss(scales=scales, reduction=reduction) + self.__alpha = alpha + + self.__prediction_ssi = None + + def forward(self, prediction, target, mask): + scale, shift = compute_scale_and_shift(prediction, target, mask) + self.__prediction_ssi = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1) + + total = self.__data_loss(self.__prediction_ssi, target, mask) + if self.__alpha > 0: + total += self.__alpha * self.__regularization_loss(self.__prediction_ssi, target, mask) + + return total + + def __get_prediction_ssi(self): + return self.__prediction_ssi + + prediction_ssi = property(__get_prediction_ssi) diff --git a/featup/plotting.py b/featup/plotting.py new file mode 100644 index 0000000000000000000000000000000000000000..84eaced174d96f03aedcb76710b512d665394e1f --- /dev/null +++ b/featup/plotting.py @@ -0,0 +1,54 @@ +import matplotlib.pyplot as plt +from featup.util import pca, remove_axes +from featup.featurizers.maskclip.clip import tokenize +from pytorch_lightning import seed_everything +import torch +import torch.nn.functional as F + + +@torch.no_grad() +def plot_feats(image, lr, hr): + assert len(image.shape) == len(lr.shape) == len(hr.shape) == 3 + seed_everything(0) + [lr_feats_pca, hr_feats_pca], _ = pca([lr.unsqueeze(0), hr.unsqueeze(0)]) + fig, ax = plt.subplots(1, 3, figsize=(15, 5)) + ax[0].imshow(image.permute(1, 2, 0).detach().cpu()) + ax[0].set_title("Image") + ax[1].imshow(lr_feats_pca[0].permute(1, 2, 0).detach().cpu()) + ax[1].set_title("Original Features") + ax[2].imshow(hr_feats_pca[0].permute(1, 2, 0).detach().cpu()) + ax[2].set_title("Upsampled Features") + remove_axes(ax) + plt.show() + + +@torch.no_grad() +def plot_lang_heatmaps(model, image, lr_feats, hr_feats, text_query): + assert len(image.shape) == len(lr_feats.shape) == len(hr_feats.shape) == 3 + fig, ax = plt.subplots(1, 3, figsize=(15, 5)) + cmap = plt.get_cmap("turbo") + + # encode query + text = tokenize(text_query).to(lr_feats.device) + text_feats = model.model.encode_text(text).squeeze().to(torch.float32) + assert len(text_feats.shape) == 1 + + lr_sims = torch.einsum( + "chw,c->hw", F.normalize(lr_feats.to(torch.float32), dim=0), F.normalize(text_feats, dim=0)) + hr_sims = torch.einsum( + "chw,c->hw", F.normalize(hr_feats.to(torch.float32), dim=0), F.normalize(text_feats, dim=0)) + + lr_sims_norm = (lr_sims - lr_sims.min()) / (lr_sims.max() - lr_sims.min()) + hr_sims_norm = (hr_sims - hr_sims.min()) / (hr_sims.max() - hr_sims.min()) + lr_heatmap = cmap(lr_sims_norm.cpu().numpy()) + hr_heatmap = cmap(hr_sims_norm.cpu().numpy()) + + ax[0].imshow(image.permute(1, 2, 0).detach().cpu()) + ax[0].set_title("Image") + ax[1].imshow(lr_heatmap) + ax[1].set_title(f"Original Similarity to \"{text_query}\"") + ax[2].imshow(hr_heatmap) + ax[2].set_title(f"Upsampled Similarity to \"{text_query}\"") + remove_axes(ax) + + return plt.show() diff --git a/featup/train_implicit_upsampler.py b/featup/train_implicit_upsampler.py new file mode 100644 index 0000000000000000000000000000000000000000..b2b560aeff9b064f4d2edbf53c79f0955ac000d6 --- /dev/null +++ b/featup/train_implicit_upsampler.py @@ -0,0 +1,401 @@ +import math +import os +from collections import defaultdict +from datetime import datetime +from os.path import join, dirname + +import hydra +import matplotlib.pyplot as plt +import torch +import torch.nn.functional as F +import torchvision.transforms as T +from PIL import Image +from kornia.filters import gaussian_blur2d +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning import seed_everything +from torch.utils.data import DataLoader, Subset +from torch.utils.tensorboard import SummaryWriter +from torchmetrics.functional.regression import explained_variance +from tqdm import tqdm + +from featup.datasets.JitteredImage import JitteredImage, apply_jitter +from featup.datasets.util import get_dataset, SlicedDataset +from featup.downsamplers import SimpleDownsampler, AttentionDownsampler +from featup.featurizers.util import get_featurizer +from featup.layers import ImplicitFeaturizer, MinMaxScaler, ChannelNorm +from featup.losses import total_variation +from featup.util import (norm as reg_norm, unnorm as reg_unorm, generate_subset, + midas_norm, midas_unnorm, pca, PCAUnprojector, prep_image) + +torch.multiprocessing.set_sharing_strategy('file_system') + + +def mag(t): + return t.square().sum(1, keepdim=True).sqrt() + + +class ExplicitUpsampler(torch.nn.Module): + + def __init__(self, size, dim, *args, **kwargs): + super().__init__(*args, **kwargs) + self.size = size + self.dim = dim + self.feats = torch.nn.Parameter(F.normalize(torch.randn(1, dim, size, size), dim=1)) + + def forward(self, x): + return self.feats + + +def get_implicit_upsampler(start_dim, end_dim, color_feats, n_freqs): + return torch.nn.Sequential( + MinMaxScaler(), + ImplicitFeaturizer(color_feats, n_freqs=n_freqs, learn_bias=True), + ChannelNorm(start_dim), + torch.nn.Dropout2d(p=.2), + torch.nn.Conv2d(start_dim, end_dim, 1), + torch.nn.ReLU(), + torch.nn.Dropout2d(p=.2), + ChannelNorm(end_dim), + torch.nn.Conv2d(end_dim, end_dim, 1), + torch.nn.ReLU(), + torch.nn.Conv2d(end_dim, end_dim, 1), + ) + + +@hydra.main(config_path="configs", config_name="implicit_upsampler.yaml") +def my_app(cfg: DictConfig) -> None: + print(OmegaConf.to_yaml(cfg)) + print(cfg.output_root) + seed_everything(0) + + input_size_h = 224 + input_size_w = 224 + final_size = 14 + redo = False + + steps = cfg.steps + if cfg.model_type in {"dino16", "vit", "clip", "midas", "maskclip"}: + multiplier = 1 + featurize_batch_size = 64 + kernel_size = 29 + elif cfg.model_type == "dinov2": + multiplier = 1 + featurize_batch_size = 64 + kernel_size = 29 + final_size = 16 + elif cfg.model_type == "dino8": + multiplier = 1 + featurize_batch_size = 64 + kernel_size = 8 + final_size = 28 + elif cfg.model_type == "deeplab": + multiplier = 1 + featurize_batch_size = 16 + kernel_size = 35 + final_size = 28 + elif cfg.model_type == "resnet50": + multiplier = 2 + final_size = 14 + kernel_size = 35 + featurize_batch_size = 16 + steps = 500 + else: + raise ValueError(f"Unknown model type {cfg.model_type}") + + if cfg.downsampler_type == "attention": + batch_size = 10 + inner_batch = 10 + else: + batch_size = 10 + inner_batch = 10 + + feat_dir = join(cfg.output_root, "feats", cfg.experiment_name, cfg.dataset, cfg.split, cfg.model_type) + log_dir = join(cfg.output_root, "logs", cfg.experiment_name, cfg.dataset, cfg.split, cfg.model_type) + + model, _, dim = get_featurizer(cfg.model_type, activation_type=cfg.activation_type, output_root=cfg.output_root) + + if cfg.use_norm: + model = torch.nn.Sequential(model, ChannelNorm(dim)) + + model = model.cuda() + + if cfg.model_type == "midas": + norm = midas_norm + unnorm = midas_unnorm + else: + norm = reg_norm + unnorm = reg_unorm + + def project(imgs): + if multiplier > 1: + imgs = F.interpolate(imgs, scale_factor=multiplier, mode="bilinear") + return model(imgs) + + transform = T.Compose([ + T.Resize(input_size_h), + T.CenterCrop((input_size_h, input_size_w)), + T.ToTensor(), + norm + ]) + + full_dataset = get_dataset(dataroot=cfg.pytorch_data_dir, + name=cfg.dataset, + split=cfg.split, + transform=transform, + target_transform=None, + include_labels=False) + + if "sample" in cfg.dataset: + partition_size = 1 + dataset = full_dataset + else: + if cfg.split == "val": + full_dataset = full_dataset + elif cfg.split == "train": + full_dataset = Subset(full_dataset, generate_subset(len(full_dataset), 5000)) + else: + raise ValueError(f"Unknown dataset {cfg.dataset}") + + full_size = len(full_dataset) + partition_size = math.ceil(full_size / cfg.total_partitions) + dataset = SlicedDataset( + full_dataset, + int(cfg.partition * partition_size), + int((cfg.partition + 1) * partition_size)) + loader = DataLoader(dataset, shuffle=False) + + for img_num, batch in enumerate(loader): + original_image = batch["img"].cuda() + output_location = join(feat_dir, "/".join(batch["img_path"][0].split("/")[-1:]).replace(".jpg", ".pth")) + + os.makedirs(dirname(output_location), exist_ok=True) + if not redo and os.path.exists(output_location) and not cfg.dataset == "sample": + print(f"Found {output_location}, skipping") + continue + else: + print(f"Did not find {output_location}, computing") + + if cfg.summarize: + writer = SummaryWriter(join(log_dir, str(datetime.now()))) + + params = [] + dataset = JitteredImage(original_image, cfg.n_images, cfg.use_flips, cfg.max_zoom, cfg.max_pad) + loader = DataLoader(dataset, featurize_batch_size) + with torch.no_grad(): + transform_params = defaultdict(list) + lr_feats = project(original_image) + [red_lr_feats], fit_pca = pca([lr_feats], dim=9, use_torch_pca=True) + + jit_features = [] + for transformed_image, tp in tqdm(loader): + for k, v in tp.items(): + transform_params[k].append(v) + jit_features.append(project(transformed_image).cpu()) + jit_features = torch.cat(jit_features, dim=0) + transform_params = {k: torch.cat(v, dim=0) for k, v in transform_params.items()} + + unprojector = PCAUnprojector(jit_features[:cfg.pca_batch], cfg.proj_dim, lr_feats.device, + use_torch_pca=True) + jit_features = unprojector.project(jit_features) + lr_feats = unprojector.project(lr_feats) + + if cfg.param_type == "implicit": + end_dim = cfg.proj_dim + if cfg.color_feats: + start_dim = 5 * cfg.n_freqs * 2 + 3 + else: + start_dim = 2 * cfg.n_freqs * 2 + + upsampler = get_implicit_upsampler( + start_dim, end_dim, cfg.color_feats, cfg.n_freqs).cuda() + elif cfg.param_type == "explicit": + upsampler = ExplicitUpsampler(input_size_h, cfg.proj_dim).cuda() + else: + raise ValueError(f"Unknown param type {cfg.param_type}") + params.append({"params": upsampler.parameters()}) + + if cfg.downsampler_type == "simple": + downsampler = SimpleDownsampler(kernel_size, final_size) + else: + downsampler = AttentionDownsampler(cfg.proj_dim + 1, kernel_size, final_size, cfg.blur_attn).cuda() + + params.append({"params": downsampler.parameters()}) + + if cfg.outlier_detection: + with torch.no_grad(): + outlier_detector = torch.nn.Conv2d(cfg.proj_dim, 1, 1).cuda() + outlier_detector.weight.copy_(outlier_detector.weight * .1) + outlier_detector.bias.copy_(outlier_detector.bias * .1) + + params.append({"params": outlier_detector.parameters()}) + get_scale = lambda feats: torch.exp(outlier_detector(feats) + .1).clamp_min(.0001) + else: + get_scale = lambda feats: torch.ones(feats.shape[0], 1, feats.shape[2], feats.shape[2], + device=feats.device, + dtype=feats.dtype) + + optim = torch.optim.NAdam(params) + + for step in tqdm(range(steps), f"Image {img_num} of {partition_size}"): + for i in range(batch_size // inner_batch): + upsampler.train() + downsampler.train() + + hr_feats = upsampler(original_image) + hr_mag = mag(hr_feats) + hr_both = torch.cat([hr_mag, hr_feats], dim=1) + loss = 0.0 + + target = [] + hr_feats_transformed = [] + for j in range(inner_batch): + idx = torch.randint(cfg.n_images, size=()) + target.append(jit_features[idx].unsqueeze(0)) + selected_tp = {k: v[idx] for k, v in transform_params.items()} + hr_feats_transformed.append(apply_jitter(hr_both, cfg.max_pad, selected_tp)) + + target = torch.cat(target, dim=0).cuda(non_blocking=True) + hr_feats_transformed = torch.cat(hr_feats_transformed, dim=0) + + output_both = downsampler(hr_feats_transformed, None) + magnitude = output_both[:, 0:1, :, :] + output = output_both[:, 1:, :, :] + + scales = get_scale(target) + + rec_loss = ((1 / (2 * scales ** 2)) * (output - target).square() + scales.log()).mean() + + loss += rec_loss + + if cfg.mag_weight > 0.0: + mag_loss = (magnitude - mag(target)).square().mean() + mag_loss2 = (mag(output) - mag(target)).square().mean() + loss += mag_loss * cfg.mag_weight + + if cfg.mag_tv_weight > 0.0: + mag_tv = total_variation(hr_mag) + loss += cfg.mag_tv_weight * mag_tv + + if cfg.blur_pin > 0.0: + blur_pin_loss = (gaussian_blur2d(hr_feats, 5, (1.0, 1.0)) - hr_feats).square().mean() + loss += cfg.blur_pin * blur_pin_loss + + loss.backward() + + should_log = cfg.summarize and (i == (batch_size // inner_batch - 1)) + + if should_log and step % 10 == 0: + upsampler.eval() + downsampler.eval() + + writer.add_scalar("loss", loss, step) + mean_mae = (lr_feats.mean(dim=[2, 3]) - hr_feats.mean(dim=[2, 3])).abs().mean() + writer.add_scalar("mean_mae", mean_mae, step) + writer.add_scalar("rec loss", rec_loss, step) + writer.add_scalar("mean scale", scales.mean(), step) + + if cfg.mag_weight > 0.0: + writer.add_scalar("mag loss", mag_loss, step) + writer.add_scalar("mag loss2", mag_loss2, step) + + if cfg.mag_tv_weight > 0.0: + writer.add_scalar("mag tv", mag_tv, step) + + if cfg.blur_pin > 0.0: + writer.add_scalar("blur pin loss", blur_pin_loss, step) + + if should_log and step % 100 == 0: + with torch.no_grad(): + upsampler.eval() + downsampler.eval() + + hr_feats = upsampler(original_image) + hr_mag = mag(hr_feats) + hr_both = torch.cat([hr_mag, hr_feats], dim=1) + target = [] + hr_feats_transformed = [] + for j in range(inner_batch): + idx = torch.randint(cfg.n_images, size=()) + target.append(jit_features[idx].unsqueeze(0)) + selected_tp = {k: v[idx] for k, v in transform_params.items()} + hr_feats_transformed.append(apply_jitter(hr_both, cfg.max_pad, selected_tp)) + + target = torch.cat(target, dim=0).cuda(non_blocking=True) + hr_feats_transformed = torch.cat(hr_feats_transformed, dim=0) + + output_both = downsampler(hr_feats_transformed, None) + output = output_both[:, 1:, :, :] + scales = get_scale(target) + big_target = unprojector(target) + big_output = unprojector(output) + + ev = explained_variance( + big_output.flatten(), + big_target.flatten()) + writer.add_scalar("explained_variance", ev, step) + + [red_hr_feats, red_target, red_output], _ = pca([ + unprojector(hr_feats), big_target, big_output + ], fit_pca=fit_pca, dim=9) + + def up(x): + return F.interpolate(x.unsqueeze(0), hr_feats.shape[2:], mode="nearest").squeeze(0) + + writer.add_image("feats/1/hr", red_hr_feats[0, :3], step) + writer.add_image("feats/2/hr", red_hr_feats[0, 3:6], step) + writer.add_image("feats/3/hr", red_hr_feats[0, 6:9], step) + + np_arr = (red_lr_feats[0, :3].permute(1, 2, 0) * 255).clamp(0, 255).to(torch.uint8) + Image.fromarray(np_arr.detach().cpu().numpy()).save("../sample-images/low_res_feats.png") + + writer.add_image("feats/1/lr", up(red_lr_feats[0, :3]), step) + writer.add_image("feats/2/lr", up(red_lr_feats[0, 3:6]), step) + writer.add_image("feats/3/lr", up(red_lr_feats[0, 6:9]), step) + writer.add_image("feats/1/pred", red_target[0, :3], step) + writer.add_image("feats/1/true", red_output[0, :3], step) + writer.add_image("image/original", unnorm(original_image)[0], step) + writer.add_image("image/transformed", unnorm(transformed_image)[0], step) + + norm_scales = scales[0] + norm_scales /= scales.max() + writer.add_image("scales", norm_scales, step) + writer.add_histogram("scales hist", scales, step) + + hr_lr_feats = F.interpolate(lr_feats, size=(input_size_h, input_size_w)) + fig, axes = plt.subplots(1, 2, figsize=(10, 5)) + plt1 = axes[0].imshow(mag(hr_feats)[0, 0].detach().cpu()) + plt.colorbar(plt1) + plt2 = axes[1].imshow(mag(hr_lr_feats)[0, 0].detach().cpu()) + plt.colorbar(plt2) + writer.add_figure("magnitudes", fig, step) + + if isinstance(downsampler, SimpleDownsampler): + writer.add_image( + "down/filter", + prep_image(downsampler.get_kernel().squeeze(), subtract_min=False), + step) + + if isinstance(downsampler, AttentionDownsampler): + writer.add_image( + "down/att", + prep_image(downsampler.forward_attention(hr_both, None)[0]), + step) + writer.add_image( + "down/w", + prep_image(downsampler.w.clone().squeeze()), + step) + writer.add_image( + "down/b", + prep_image(downsampler.b.clone().squeeze()), + step) + + writer.flush() + + optim.step() + optim.zero_grad() + + torch.save({"model": upsampler.state_dict(), "unprojector": unprojector.state_dict()}, output_location) + + +if __name__ == "__main__": + my_app() diff --git a/featup/train_jbu_upsampler.py b/featup/train_jbu_upsampler.py new file mode 100644 index 0000000000000000000000000000000000000000..a4559059ac716c149725b23dc09962b3da5da9d5 --- /dev/null +++ b/featup/train_jbu_upsampler.py @@ -0,0 +1,381 @@ +import gc +import os + +import hydra +import pytorch_lightning as pl +import torch +import torchvision.transforms as T +from omegaconf import DictConfig +from omegaconf import OmegaConf +from pytorch_lightning import Trainer +from pytorch_lightning import seed_everything +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger +from torch.utils.data import DataLoader +from torchvision.transforms import InterpolationMode +from os.path import join + +from featup.datasets.JitteredImage import apply_jitter, sample_transform +from featup.datasets.util import get_dataset, SingleImageDataset +from featup.downsamplers import SimpleDownsampler, AttentionDownsampler +from featup.featurizers.util import get_featurizer +from featup.layers import ChannelNorm +from featup.losses import TVLoss, SampledCRFLoss, entropy +from featup.upsamplers import get_upsampler +from featup.util import pca, RollingAvg, unnorm, norm, prep_image + +torch.multiprocessing.set_sharing_strategy('file_system') + + +class ScaleNet(torch.nn.Module): + + def __init__(self, dim): + super().__init__() + self.dim = dim + self.net = torch.nn.Conv2d(dim, 1, 1) + with torch.no_grad(): + self.net.weight.copy_(self.net.weight * .1) + self.net.bias.copy_(self.net.bias * .1) + + def forward(self, x): + return torch.exp(self.net(x) + .1).clamp_min(.0001) + + +class JBUFeatUp(pl.LightningModule): + def __init__(self, + model_type, + activation_type, + n_jitters, + max_pad, + max_zoom, + kernel_size, + final_size, + lr, + random_projection, + predicted_uncertainty, + crf_weight, + filter_ent_weight, + tv_weight, + upsampler, + downsampler, + chkpt_dir, + ): + super().__init__() + self.model_type = model_type + self.activation_type = activation_type + self.n_jitters = n_jitters + self.max_pad = max_pad + self.max_zoom = max_zoom + self.kernel_size = kernel_size + self.final_size = final_size + self.lr = lr + self.random_projection = random_projection + self.predicted_uncertainty = predicted_uncertainty + self.crf_weight = crf_weight + self.filter_ent_weight = filter_ent_weight + self.tv_weight = tv_weight + self.chkpt_dir = chkpt_dir + + self.model, self.patch_size, self.dim = get_featurizer(model_type, activation_type, num_classes=1000) + for p in self.model.parameters(): + p.requires_grad = False + self.model = torch.nn.Sequential(self.model, ChannelNorm(self.dim)) + self.upsampler = get_upsampler(upsampler, self.dim) + + if downsampler == 'simple': + self.downsampler = SimpleDownsampler(self.kernel_size, self.final_size) + elif downsampler == 'attention': + self.downsampler = AttentionDownsampler(self.dim, self.kernel_size, self.final_size, blur_attn=True) + else: + raise ValueError(f"Unknown downsampler {downsampler}") + + if self.predicted_uncertainty: + self.scale_net = ScaleNet(self.dim) + + self.avg = RollingAvg(20) + + self.crf = SampledCRFLoss( + alpha=.1, + beta=.15, + gamma=.005, + w1=10.0, + w2=3.0, + shift=0.00, + n_samples=1000) + self.tv = TVLoss() + + self.automatic_optimization = False + + def forward(self, x): + return self.upsampler(self.model(x)) + + def project(self, feats, proj): + if proj is None: + return feats + else: + return torch.einsum("bchw,bcd->bdhw", feats, proj) + + def training_step(self, batch, batch_idx): + opt = self.optimizers() + opt.zero_grad() + + with torch.no_grad(): + if type(batch) == dict: + img = batch['img'] + else: + img, _ = batch + lr_feats = self.model(img) + + full_rec_loss = 0.0 + full_crf_loss = 0.0 + full_entropy_loss = 0.0 + full_tv_loss = 0.0 + full_total_loss = 0.0 + for i in range(self.n_jitters): + hr_feats = self.upsampler(lr_feats, img) + + if hr_feats.shape[2] != img.shape[2]: + hr_feats = torch.nn.functional.interpolate(hr_feats, img.shape[2:], mode="bilinear") + + with torch.no_grad(): + transform_params = sample_transform( + True, self.max_pad, self.max_zoom, img.shape[2], img.shape[3]) + jit_img = apply_jitter(img, self.max_pad, transform_params) + lr_jit_feats = self.model(jit_img) + + if self.random_projection is not None: + proj = torch.randn(lr_feats.shape[0], + lr_feats.shape[1], + self.random_projection, device=lr_feats.device) + proj /= proj.square().sum(1, keepdim=True).sqrt() + else: + proj = None + + hr_jit_feats = apply_jitter(hr_feats, self.max_pad, transform_params) + proj_hr_feats = self.project(hr_jit_feats, proj) + + down_jit_feats = self.project(self.downsampler(hr_jit_feats, jit_img), proj) + + if self.predicted_uncertainty: + scales = self.scale_net(lr_jit_feats) + scale_factor = (1 / (2 * scales ** 2)) + mse = (down_jit_feats - self.project(lr_jit_feats, proj)).square() + rec_loss = (scale_factor * mse + scales.log()).mean() / self.n_jitters + else: + rec_loss = (self.project(lr_jit_feats, proj) - down_jit_feats).square().mean() / self.n_jitters + + full_rec_loss += rec_loss.item() + + if self.crf_weight > 0 and i == 0: + crf_loss = self.crf(img, proj_hr_feats) + full_crf_loss += crf_loss.item() + else: + crf_loss = 0.0 + + if self.filter_ent_weight > 0.0: + entropy_loss = entropy(self.downsampler.get_kernel()) + full_entropy_loss += entropy_loss.item() + else: + entropy_loss = 0 + + if self.tv_weight > 0 and i == 0: + tv_loss = self.tv(proj_hr_feats.square().sum(1, keepdim=True)) + full_tv_loss += tv_loss.item() + else: + tv_loss = 0.0 + + loss = rec_loss + self.crf_weight * crf_loss + self.tv_weight * tv_loss - self.filter_ent_weight * entropy_loss + full_total_loss += loss.item() + self.manual_backward(loss) + + self.avg.add("loss/crf", full_crf_loss) + self.avg.add("loss/ent", full_entropy_loss) + self.avg.add("loss/tv", full_tv_loss) + self.avg.add("loss/rec", full_rec_loss) + self.avg.add("loss/total", full_total_loss) + + if self.global_step % 100 == 0: + self.trainer.save_checkpoint(self.chkpt_dir[:-5] + '/' + self.chkpt_dir[:-5] + f'_{self.global_step}.ckpt') + + self.avg.logall(self.log) + if self.global_step < 10: + self.clip_gradients(opt, gradient_clip_val=.0001, gradient_clip_algorithm="norm") + + opt.step() + + return None + + def validation_step(self, batch, batch_idx): + with torch.no_grad(): + if self.trainer.is_global_zero and batch_idx == 0: + + if type(batch) == dict: + img = batch['img'] + else: + img, _ = batch + lr_feats = self.model(img) + + hr_feats = self.upsampler(lr_feats, img) + + if hr_feats.shape[2] != img.shape[2]: + hr_feats = torch.nn.functional.interpolate(hr_feats, img.shape[2:], mode="bilinear") + + transform_params = sample_transform( + True, self.max_pad, self.max_zoom, img.shape[2], img.shape[3]) + jit_img = apply_jitter(img, self.max_pad, transform_params) + lr_jit_feats = self.model(jit_img) + + if self.random_projection is not None: + proj = torch.randn(lr_feats.shape[0], + lr_feats.shape[1], + self.random_projection, device=lr_feats.device) + proj /= proj.square().sum(1, keepdim=True).sqrt() + else: + proj = None + + scales = self.scale_net(lr_jit_feats) + + writer = self.logger.experiment + + hr_jit_feats = apply_jitter(hr_feats, self.max_pad, transform_params) + down_jit_feats = self.downsampler(hr_jit_feats, jit_img) + + [red_lr_feats], fit_pca = pca([lr_feats[0].unsqueeze(0)]) + [red_hr_feats], _ = pca([hr_feats[0].unsqueeze(0)], fit_pca=fit_pca) + [red_lr_jit_feats], _ = pca([lr_jit_feats[0].unsqueeze(0)], fit_pca=fit_pca) + [red_hr_jit_feats], _ = pca([hr_jit_feats[0].unsqueeze(0)], fit_pca=fit_pca) + [red_down_jit_feats], _ = pca([down_jit_feats[0].unsqueeze(0)], fit_pca=fit_pca) + + writer.add_image("viz/image", unnorm(img[0].unsqueeze(0))[0], self.global_step) + writer.add_image("viz/lr_feats", red_lr_feats[0], self.global_step) + writer.add_image("viz/hr_feats", red_hr_feats[0], self.global_step) + writer.add_image("jit_viz/jit_image", unnorm(jit_img[0].unsqueeze(0))[0], self.global_step) + writer.add_image("jit_viz/lr_jit_feats", red_lr_jit_feats[0], self.global_step) + writer.add_image("jit_viz/hr_jit_feats", red_hr_jit_feats[0], self.global_step) + writer.add_image("jit_viz/down_jit_feats", red_down_jit_feats[0], self.global_step) + + norm_scales = scales[0] + norm_scales /= scales.max() + writer.add_image("scales", norm_scales, self.global_step) + writer.add_histogram("scales hist", scales, self.global_step) + + if isinstance(self.downsampler, SimpleDownsampler): + writer.add_image( + "down/filter", + prep_image(self.downsampler.get_kernel().squeeze(), subtract_min=False), + self.global_step) + + if isinstance(self.downsampler, AttentionDownsampler): + writer.add_image( + "down/att", + prep_image(self.downsampler.forward_attention(hr_feats, None)[0]), + self.global_step) + writer.add_image( + "down/w", + prep_image(self.downsampler.w.clone().squeeze()), + self.global_step) + writer.add_image( + "down/b", + prep_image(self.downsampler.b.clone().squeeze()), + self.global_step) + + writer.flush() + + def configure_optimizers(self): + all_params = [] + all_params.extend(list(self.downsampler.parameters())) + all_params.extend(list(self.upsampler.parameters())) + + if self.predicted_uncertainty: + all_params.extend(list(self.scale_net.parameters())) + + return torch.optim.NAdam(all_params, lr=self.lr) + + +@hydra.main(config_path="configs", config_name="jbu_upsampler.yaml") +def my_app(cfg: DictConfig) -> None: + print(OmegaConf.to_yaml(cfg)) + print(cfg.output_root) + seed_everything(seed=0, workers=True) + + load_size = 224 + + if cfg.model_type == "dinov2": + final_size = 16 + kernel_size = 14 + else: + final_size = 14 + kernel_size = 16 + + name = (f"{cfg.model_type}_{cfg.upsampler_type}_" + f"{cfg.dataset}_{cfg.downsampler_type}_" + f"crf_{cfg.crf_weight}_tv_{cfg.tv_weight}" + f"_ent_{cfg.filter_ent_weight}") + + log_dir = join(cfg.output_root, f"logs/jbu/{name}") + chkpt_dir = join(cfg.output_root, f"checkpoints/jbu/{name}.ckpt") + os.makedirs(log_dir, exist_ok=True) + + model = JBUFeatUp( + model_type=cfg.model_type, + activation_type=cfg.activation_type, + n_jitters=cfg.n_jitters, + max_pad=cfg.max_pad, + max_zoom=cfg.max_zoom, + kernel_size=kernel_size, + final_size=final_size, + lr=cfg.lr, + random_projection=cfg.random_projection, + predicted_uncertainty=cfg.outlier_detection, + crf_weight=cfg.crf_weight, + filter_ent_weight=cfg.filter_ent_weight, + tv_weight=cfg.tv_weight, + upsampler=cfg.upsampler_type, + downsampler=cfg.downsampler_type, + chkpt_dir=chkpt_dir + ) + + transform = T.Compose([ + T.Resize(load_size, InterpolationMode.BILINEAR), + T.CenterCrop(load_size), + T.ToTensor(), + norm]) + + dataset = get_dataset( + cfg.pytorch_data_dir, + cfg.dataset, + "train", + transform=transform, + target_transform=None, + include_labels=False) + + loader = DataLoader( + dataset, cfg.batch_size, shuffle=True, num_workers=cfg.num_workers) + val_loader = DataLoader( + SingleImageDataset(0, dataset, 1), 1, shuffle=False, num_workers=cfg.num_workers) + + tb_logger = TensorBoardLogger(log_dir, default_hp_metric=False) + callbacks = [ModelCheckpoint(chkpt_dir[:-5], every_n_epochs=1)] + + trainer = Trainer( + accelerator='gpu', + strategy="ddp", + devices=cfg.num_gpus, + max_epochs=cfg.epochs, + logger=tb_logger, + val_check_interval=100, + log_every_n_steps=10, + callbacks=callbacks, + reload_dataloaders_every_n_epochs=1, + ) + + gc.collect() + torch.cuda.empty_cache() + gc.collect() + + trainer.fit(model, loader, val_loader) + trainer.save_checkpoint(chkpt_dir) + + +if __name__ == "__main__": + my_app() diff --git a/featup/train_probes.py b/featup/train_probes.py new file mode 100644 index 0000000000000000000000000000000000000000..e45f08d5bd146105f989a3852370be0b6e6cd744 --- /dev/null +++ b/featup/train_probes.py @@ -0,0 +1,234 @@ +from os.path import join + +import hydra +import matplotlib.pyplot as plt +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from omegaconf import DictConfig +from omegaconf import OmegaConf +from pytorch_lightning import Trainer +from pytorch_lightning import seed_everything +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.utilities.seed import seed_everything +from torch.utils.data import DataLoader +from torchmetrics.classification import Accuracy, JaccardIndex + +from featup.datasets.COCO import Coco +from featup.datasets.EmbeddingFile import EmbeddingFile +from featup.losses import ScaleAndShiftInvariantLoss +from featup.util import pca +from featup.util import remove_axes + + +def tensor_correlation(a, b): + return torch.einsum("nchw,ncij->nhwij", a, b) + + +def sample(t: torch.Tensor, coords: torch.Tensor): + return F.grid_sample(t, coords.permute(0, 2, 1, 3), padding_mode='border', align_corners=True) + + +class LitPrototypeEvaluator(pl.LightningModule): + def __init__(self, task, n_dim): + super().__init__() + self.task = task + self.n_dim = n_dim + + if self.task == 'seg': + n_classes = 27 + elif self.task == 'depth': + n_classes = 1 + + self.midas = torch.hub.load('intel-isl/MiDaS', 'MiDaS_small').cuda() + self.midas.eval() + self.midas_loss = ScaleAndShiftInvariantLoss() + + self.mse = 0 + self.ssil = 0 + self.steps = 0 + + self.prototypes_buff = self.register_buffer("prototypes", torch.zeros(n_classes, n_dim)) + self.classifier = torch.nn.Conv2d(n_dim, n_classes, 1) + + self.prot_acc_metric = Accuracy(num_classes=n_classes, task="multiclass") + self.prot_acc_buff = self.register_buffer("prot_acc", torch.tensor(0.0)) + self.prot_iou_metric = JaccardIndex(num_classes=n_classes, task="multiclass") + self.prot_iou_buff = self.register_buffer("prot_iou", torch.tensor(0.0)) + + self.linear_acc_metric = Accuracy(num_classes=n_classes, task="multiclass") + self.linear_acc_buff = self.register_buffer("linear_acc", torch.tensor(0.0)) + self.linear_iou_metric = JaccardIndex(num_classes=n_classes, task="multiclass") + self.linear_iou_buff = self.register_buffer("linear_iou", torch.tensor(0.0)) + + self.ce = torch.nn.CrossEntropyLoss() + + def get_prototypes(self, feats): + b, c, h, w = feats.shape + k = self.prototypes.shape[0] + matches = torch.einsum("kc,bchw->kbhw", F.normalize(self.prototypes, dim=1), F.normalize(feats, dim=1)) \ + .reshape(k, -1).argmax(0) + return self.prototypes[matches].reshape(b, h, w, c).permute(0, 3, 1, 2) + + def training_step(self, batch, batch_idx): + feats, label = batch + b, c, h, w = feats.shape + + small_labels = F.interpolate( + label.unsqueeze(1).to(torch.float32), + size=(feats.shape[2], feats.shape[3])).to(torch.int64) + + linear_preds = self.classifier(feats) + + if self.task == 'seg': + flat_labels = small_labels.permute(0, 2, 3, 1).reshape(b * h * w) + flat_linear_preds = linear_preds.permute(0, 2, 3, 1).reshape(b * h * w, -1) + + selected = flat_labels > -1 + linear_loss = self.ce( + flat_linear_preds[selected], + flat_labels[selected]) + loss = linear_loss + self.log("linear_loss", linear_loss) + self.log("loss", loss) + + for l in range(self.n_classes): + self.prototypes[l] += feats.permute(0, 2, 3, 1).reshape(b * h * w, -1)[flat_labels == l].sum(dim=0) + + if self.global_step % 10 == 1 and self.trainer.is_global_zero: + with torch.no_grad(): + prots = self.get_prototypes(feats) + prot_loss = -(F.normalize(feats, dim=1) * F.normalize(prots, dim=1)).sum(1).mean() + self.logger.experiment.add_scalar("prot_loss", prot_loss, self.global_step) + + elif self.task == 'depth': + loss = self.midas_loss(linear_preds.squeeze(), small_labels.squeeze(), + torch.ones_like(linear_preds.squeeze())) + self.log('loss', loss) + + if self.global_step % 200 == 0 and self.trainer.is_global_zero: + n_images = 5 + fig, axes = plt.subplots(4, n_images, figsize=(4 * n_images, 5 * 5)) + + prot_preds = torch.einsum("bchw,kc->bkhw", + F.normalize(feats, dim=1), + F.normalize(self.prototypes, dim=1, eps=1e-10)) + + colorize = Coco.colorize_label if self.task == 'seg' else lambda x: x.detach().cpu() + for i in range(n_images): + feats_pca = pca([feats])[0][0][i] + axes[0, i].imshow(feats_pca) + axes[1, i].imshow(colorize(label[i])) + if self.task == 'depth': + axes[2, i].imshow(colorize(linear_preds[i][0])) + axes[3, i].imshow(colorize(prot_preds[i][0])) + elif self.task == 'seg': + axes[2, i].imshow(colorize(linear_preds.argmax(1)[i])) + axes[3, i].imshow(colorize(prot_preds.argmax(1)[i])) + + plt.tight_layout() + remove_axes(axes) + self.logger.experiment.add_figure('predictions', fig, self.global_step) + + return loss + + def validation_step(self, batch, batch_idx): + with torch.no_grad(): + feats, label = batch + + if self.task == 'seg': + label = F.interpolate( + label.to(torch.float32).unsqueeze(1), size=(224, 224)).to(torch.int64).squeeze(1) + + prot_preds = torch.einsum( + "bchw,kc->bkhw", + F.normalize(feats, dim=1), + F.normalize(self.prototypes, dim=1, eps=1e-10)).argmax(1, keepdim=True) + linear_preds = self.classifier(feats).argmax(1, keepdim=True) + + b, h, w = label.shape + flat_labels = label.flatten() + selected = flat_labels > -1 + flat_labels = flat_labels[selected] + + flat_prot_preds = F.interpolate( + prot_preds.to(torch.float32), (h, w)).to(torch.int64).flatten()[selected] + self.prot_acc_metric.update(flat_prot_preds, flat_labels) + self.prot_iou_metric.update(flat_prot_preds, flat_labels) + + flat_linear_preds = F.interpolate( + linear_preds.to(torch.float32), (h, w)).to(torch.int64).flatten()[selected] + self.linear_acc_metric.update(flat_linear_preds, flat_labels) + self.linear_iou_metric.update(flat_linear_preds, flat_labels) + + elif self.task == 'depth': + linear_preds = self.classifier(feats) + small_labels = F.interpolate( + label.unsqueeze(1).to(torch.float32), + size=(feats.shape[2], feats.shape[3])).to(torch.int64) + mse = (small_labels - linear_preds).pow(2).mean() + midas_l = self.midas_loss(linear_preds.squeeze(), small_labels.squeeze(), + torch.ones_like(linear_preds.squeeze())) + self.mse += mse.item() + self.ssil += midas_l.item() + + self.steps += 1 + + return None + + def validation_epoch_end(self, outputs): + self.prot_acc = self.prot_acc_metric.compute() + self.prot_iou = self.prot_iou_metric.compute() + self.linear_acc = self.linear_acc_metric.compute() + self.linear_iou = self.linear_iou_metric.compute() + + def configure_optimizers(self): + return torch.optim.Adam(self.classifier.parameters(), lr=5e-3) + + +@hydra.main(config_path="configs", config_name="train_probe.yaml") +def my_app(cfg: DictConfig) -> None: + print(OmegaConf.to_yaml(cfg)) + print(cfg.output_root) + seed_everything(seed=0, workers=True) + + log_dir = f"../probes/{cfg.task}-probe" + chkpt_dir = f"../probes/{cfg.task}-probe-{cfg.model_type}.ckpt" + + emb_root = join(cfg.pytorch_data_dir, "cocostuff", "embedding", cfg.model_type) + + train_dataset = EmbeddingFile(join(emb_root, "train", f"embeddings_{cfg.activation_type}.npz")) + train_loader = DataLoader(train_dataset, cfg.batch_size, shuffle=True, num_workers=cfg.num_workers) + + val_dataset = EmbeddingFile(join(emb_root, "val", f"embeddings_{cfg.activation_type}.npz")) + val_loader = DataLoader(val_dataset, cfg.batch_size, shuffle=True, num_workers=cfg.num_workers) + + evaluator = LitPrototypeEvaluator(cfg.task, train_dataset.dim()) + tb_logger = TensorBoardLogger(log_dir, default_hp_metric=False) + + trainer = Trainer( + accelerator='gpu', + devices=1, + max_epochs=cfg.epochs, + logger=tb_logger, + log_every_n_steps=100, + reload_dataloaders_every_n_epochs=1, + check_val_every_n_epoch=10, + ) + + trainer.fit(evaluator, train_loader, val_loader) + + trainer.save_checkpoint(chkpt_dir) + + result = { + "Prototype Accuracy": float(evaluator.prot_acc), + "Prototype mIoU": float(evaluator.prot_iou), + "Linear Accuracy": float(evaluator.linear_acc), + "Linear mIoU": float(evaluator.linear_iou), + "Model": cfg.model_type + } + print(result) + + +if __name__ == "__main__": + my_app() diff --git a/featup/upsamplers.py b/featup/upsamplers.py new file mode 100644 index 0000000000000000000000000000000000000000..410c0e7afe4c9b416d0a6f9a4705de8ed7f1908c --- /dev/null +++ b/featup/upsamplers.py @@ -0,0 +1,303 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from featup.adaptive_conv_cuda.adaptive_conv import AdaptiveConv + + +class SimpleImplicitFeaturizer(torch.nn.Module): + + def __init__(self, n_freqs=20): + super().__init__() + self.n_freqs = n_freqs + self.dim_multiplier = 2 + + def forward(self, original_image): + b, c, h, w = original_image.shape + grid_h = torch.linspace(-1, 1, h, device=original_image.device) + grid_w = torch.linspace(-1, 1, w, device=original_image.device) + feats = torch.cat([t.unsqueeze(0) for t in torch.meshgrid([grid_h, grid_w])]).unsqueeze(0) + feats = torch.broadcast_to(feats, (b, feats.shape[1], h, w)) + + feat_list = [feats] + feats = torch.cat(feat_list, dim=1).unsqueeze(1) + freqs = torch.exp(torch.linspace(-2, 10, self.n_freqs, device=original_image.device)) \ + .reshape(1, self.n_freqs, 1, 1, 1) + feats = (feats * freqs) + + feats = feats.reshape(b, self.n_freqs * self.dim_multiplier, h, w) + + all_feats = [torch.sin(feats), torch.cos(feats), original_image] + + return torch.cat(all_feats, dim=1) + + +class IFA(torch.nn.Module): + + def __init__(self, feat_dim, num_scales=20): + super().__init__() + self.scales = 2 * torch.exp(torch.tensor(torch.arange(1, num_scales + 1))) + self.feat_dim = feat_dim + self.sin_feats = SimpleImplicitFeaturizer() + self.mlp = nn.Sequential( + nn.Conv2d(feat_dim + (num_scales * 4) + 2, feat_dim, 1), + nn.BatchNorm2d(feat_dim), + nn.LeakyReLU(), + nn.Conv2d(feat_dim, feat_dim, 1), + ) + + def forward(self, source, guidance): + b, c, h, w = source.shape + up_source = F.interpolate(source, (h * 2, w * 2), mode="nearest") + assert h == w + lr_cord = torch.linspace(0, h, steps=h, device=source.device) + hr_cord = torch.linspace(0, h, steps=2 * h, device=source.device) + lr_coords = torch.cat([x.unsqueeze(0) for x in torch.meshgrid(lr_cord, lr_cord)], dim=0).unsqueeze(0) + hr_coords = torch.cat([x.unsqueeze(0) for x in torch.meshgrid(hr_cord, hr_cord)], dim=0).unsqueeze(0) + up_lr_coords = F.interpolate(lr_coords, (h * 2, w * 2), mode="nearest") + coord_diff = up_lr_coords - hr_coords + coord_diff_feats = self.sin_feats(coord_diff) + c2 = coord_diff_feats.shape[1] + bcast_coord_feats = torch.broadcast_to(coord_diff_feats, (b, c2, h * 2, w * 2)) + return self.mlp(torch.cat([up_source, bcast_coord_feats], dim=1)) # + up_source + + +class SAPAModule(nn.Module): + def __init__(self, dim_y, dim_x=None, + up_factor=2, up_kernel_size=5, embedding_dim=64, + qkv_bias=True, norm=nn.LayerNorm): + super().__init__() + dim_x = dim_x if dim_x is not None else dim_y + + self.up_factor = up_factor + self.up_kernel_size = up_kernel_size + self.embedding_dim = embedding_dim + + self.norm_y = norm(dim_y) + self.norm_x = norm(dim_x) + + self.q = nn.Linear(dim_y, embedding_dim, bias=qkv_bias) + self.k = nn.Linear(dim_x, embedding_dim, bias=qkv_bias) + + self.apply(self._init_weights) + + def forward(self, y, x): + y = y.permute(0, 2, 3, 1).contiguous() + x = x.permute(0, 2, 3, 1).contiguous() + y = self.norm_y(y) + x_ = self.norm_x(x) + + q = self.q(y) + k = self.k(x_) + + return self.attention(q, k, x).permute(0, 3, 1, 2).contiguous() + + def attention(self, q, k, v): + from sapa import sim, atn + + attn = F.softmax(sim(q, k, self.up_kernel_size, self.up_factor), dim=-1) + return atn(attn, v, self.up_kernel_size, self.up_factor) + + def _init_weights(self, m): + from timm.models.layers import trunc_normal_ + + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + +class SAPAUpsampler(torch.nn.Module): + def __init__(self, dim_x, *args, **kwargs): + super().__init__(*args, **kwargs) + self.up1 = SAPAModule(dim_x=dim_x, dim_y=3) + self.up2 = SAPAModule(dim_x=dim_x, dim_y=3) + self.up3 = SAPAModule(dim_x=dim_x, dim_y=3) + self.up4 = SAPAModule(dim_x=dim_x, dim_y=3) + + def adapt_guidance(self, source, guidance): + _, _, h, w = source.shape + small_guidance = F.adaptive_avg_pool2d(guidance, (h * 2, w * 2)) + return small_guidance + + def forward(self, source, guidance): + source_2 = self.up1(self.adapt_guidance(source, guidance), source) + source_4 = self.up2(self.adapt_guidance(source_2, guidance), source_2) + source_8 = self.up3(self.adapt_guidance(source_4, guidance), source_4) + source_16 = self.up4(self.adapt_guidance(source_8, guidance), source_8) + return source_16 + + +class CarafeUpsampler(torch.nn.Module): + + def __init__(self, dim, kernel_size, *args, **kwargs): + super().__init__(*args, **kwargs) + from mmcv.ops import CARAFEPack + self.up1 = CARAFEPack(dim, up_kernel=3, up_group=1, scale_factor=2) + self.up2 = CARAFEPack(dim, up_kernel=3, up_group=1, scale_factor=2) + self.up3 = CARAFEPack(dim, up_kernel=3, up_group=1, scale_factor=2) + self.up4 = CARAFEPack(dim, up_kernel=3, up_group=1, scale_factor=2) + + def forward(self, source, guidance): + source_2 = self.up1(source) + source_4 = self.up2(source_2) + source_8 = self.up3(source_4) + source_16 = self.up4(source_8) + return source_16 + + +class LayeredResizeConv(torch.nn.Module): + + def __init__(self, dim, kernel_size, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv1 = torch.nn.Conv2d(dim + 3, dim, kernel_size, padding="same") + self.conv2 = torch.nn.Conv2d(dim + 3, dim, kernel_size, padding="same") + self.conv3 = torch.nn.Conv2d(dim + 3, dim, kernel_size, padding="same") + self.conv4 = torch.nn.Conv2d(dim + 3, dim, kernel_size, padding="same") + + def apply_conv(self, source, guidance, conv, activation): + big_source = F.interpolate(source, scale_factor=2, mode="bilinear") + _, _, h, w = big_source.shape + small_guidance = F.interpolate(guidance, (h, w), mode="bilinear") + output = activation(conv(torch.cat([big_source, small_guidance], dim=1))) + return big_source + output + + def forward(self, source, guidance): + source_2 = self.apply_conv(source, guidance, self.conv1, F.relu) + source_4 = self.apply_conv(source_2, guidance, self.conv2, F.relu) + source_8 = self.apply_conv(source_4, guidance, self.conv3, F.relu) + source_16 = self.apply_conv(source_8, guidance, self.conv4, lambda x: x) + return source_16 + + +class JBULearnedRange(torch.nn.Module): + + def __init__(self, guidance_dim, feat_dim, key_dim, scale=2, radius=3): + super().__init__() + self.scale = scale + self.radius = radius + self.diameter = self.radius * 2 + 1 + + self.guidance_dim = guidance_dim + self.key_dim = key_dim + self.feat_dim = feat_dim + + self.range_temp = nn.Parameter(torch.tensor(0.0)) + self.range_proj = torch.nn.Sequential( + torch.nn.Conv2d(guidance_dim, key_dim, 1, 1), + torch.nn.GELU(), + torch.nn.Dropout2d(.1), + torch.nn.Conv2d(key_dim, key_dim, 1, 1), + ) + + self.fixup_proj = torch.nn.Sequential( + torch.nn.Conv2d(guidance_dim + self.diameter ** 2, self.diameter ** 2, 1, 1), + torch.nn.GELU(), + torch.nn.Dropout2d(.1), + torch.nn.Conv2d(self.diameter ** 2, self.diameter ** 2, 1, 1), + ) + + self.sigma_spatial = nn.Parameter(torch.tensor(1.0)) + + def get_range_kernel(self, x): + GB, GC, GH, GW = x.shape + proj_x = self.range_proj(x) + proj_x_padded = F.pad(proj_x, pad=[self.radius] * 4, mode='reflect') + queries = torch.nn.Unfold(self.diameter)(proj_x_padded) \ + .reshape((GB, self.key_dim, self.diameter * self.diameter, GH, GW)) \ + .permute(0, 1, 3, 4, 2) + pos_temp = self.range_temp.exp().clamp_min(1e-4).clamp_max(1e4) + return F.softmax(pos_temp * torch.einsum("bchwp,bchw->bphw", queries, proj_x), dim=1) + + def get_spatial_kernel(self, device): + dist_range = torch.linspace(-1, 1, self.diameter, device=device) + x, y = torch.meshgrid(dist_range, dist_range) + patch = torch.cat([x.unsqueeze(0), y.unsqueeze(0)], dim=0) + return torch.exp(- patch.square().sum(0) / (2 * self.sigma_spatial ** 2)) \ + .reshape(1, self.diameter * self.diameter, 1, 1) + + def forward(self, source, guidance): + GB, GC, GH, GW = guidance.shape + SB, SC, SH, SQ = source.shape + assert (SB == GB) + + spatial_kernel = self.get_spatial_kernel(source.device) + range_kernel = self.get_range_kernel(guidance) + + combined_kernel = range_kernel * spatial_kernel + combined_kernel /= combined_kernel.sum(1, keepdim=True).clamp(1e-7) + + combined_kernel += .1 * self.fixup_proj(torch.cat([combined_kernel, guidance], dim=1)) + combined_kernel = combined_kernel.permute(0, 2, 3, 1) \ + .reshape(GB, GH, GW, self.diameter, self.diameter) + + hr_source = torch.nn.Upsample((GH, GW), mode='bicubic', align_corners=False)(source) + hr_source_padded = F.pad(hr_source, pad=[self.radius] * 4, mode='reflect') + + # (B C, H+Pad, W+Pad) x (B, H, W, KH, KW) -> BCHW + result = AdaptiveConv.apply(hr_source_padded, combined_kernel) + return result + + +class JBUStack(torch.nn.Module): + + def __init__(self, feat_dim, *args, **kwargs): + super().__init__(*args, **kwargs) + self.up1 = JBULearnedRange(3, feat_dim, 32, radius=3) + self.up2 = JBULearnedRange(3, feat_dim, 32, radius=3) + self.up3 = JBULearnedRange(3, feat_dim, 32, radius=3) + self.up4 = JBULearnedRange(3, feat_dim, 32, radius=3) + self.fixup_proj = torch.nn.Sequential( + torch.nn.Dropout2d(0.2), + torch.nn.Conv2d(feat_dim, feat_dim, kernel_size=1)) + + def upsample(self, source, guidance, up): + _, _, h, w = source.shape + small_guidance = F.adaptive_avg_pool2d(guidance, (h * 2, w * 2)) + upsampled = up(source, small_guidance) + return upsampled + + def forward(self, source, guidance): + source_2 = self.upsample(source, guidance, self.up1) + source_4 = self.upsample(source_2, guidance, self.up2) + source_8 = self.upsample(source_4, guidance, self.up3) + source_16 = self.upsample(source_8, guidance, self.up4) + return self.fixup_proj(source_16) * 0.1 + source_16 + + +class Bilinear(torch.nn.Module): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, feats, img): + _, _, h, w = img.shape + return F.interpolate(feats, (h, w), mode="bilinear") + + +def get_upsampler(upsampler, dim): + if upsampler == 'bilinear': + return Bilinear() + elif upsampler == 'jbu_stack': + return JBUStack(dim) + elif upsampler == 'resize_conv': + return LayeredResizeConv(dim, 1) + elif upsampler == 'carafe': + return CarafeUpsampler(dim, 1) + elif upsampler == 'sapa': + return SAPAUpsampler(dim_x=dim) + elif upsampler == 'ifa': + return IFA(dim) + else: + raise ValueError(f"Unknown upsampler {upsampler}") diff --git a/featup/util.py b/featup/util.py new file mode 100644 index 0000000000000000000000000000000000000000..e4c5058834a894180b89f00e33ecc91b4cdc4abe --- /dev/null +++ b/featup/util.py @@ -0,0 +1,275 @@ +import matplotlib.pyplot as plt +import torch +import torchvision.transforms as T +import numpy as np +from sklearn.decomposition import PCA +import torch.nn.functional as F +from collections import defaultdict, deque +import torch +import torch.nn as nn + + +class RollingAvg: + + def __init__(self, length): + self.length = length + self.metrics = defaultdict(lambda: deque(maxlen=self.length)) + + def add(self, name, metric): + self.metrics[name].append(metric) + + def get(self, name): + return torch.tensor(list(self.metrics[name])).mean() + + def logall(self, log_func): + for k in self.metrics.keys(): + log_func(k, self.get(k)) + + +def _remove_axes(ax): + ax.xaxis.set_major_formatter(plt.NullFormatter()) + ax.yaxis.set_major_formatter(plt.NullFormatter()) + ax.set_xticks([]) + ax.set_yticks([]) + + +def remove_axes(axes): + if len(axes.shape) == 2: + for ax1 in axes: + for ax in ax1: + _remove_axes(ax) + else: + for ax in axes: + _remove_axes(ax) + + +class UnNormalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, image): + image2 = torch.clone(image) + if len(image2.shape) == 4: + # batched + image2 = image2.permute(1, 0, 2, 3) + for t, m, s in zip(image2, self.mean, self.std): + t.mul_(s).add_(m) + return image2.permute(1, 0, 2, 3) + + +norm = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) +unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + +midas_norm = T.Normalize([0.5] * 3, [0.5] * 3) +midas_unnorm = UnNormalize([0.5] * 3, [0.5] * 3) + + +class ToTargetTensor(object): + def __call__(self, target): + return torch.as_tensor(np.array(target), dtype=torch.int64).unsqueeze(0) + + +def show_heatmap(ax, + image, + heatmap, + cmap="bwr", + color=False, + center=False, + show_negative=False, + cax=None, + vmax=None): + frame = [] + + if color: + frame.append(ax.imshow(image)) + else: + bw = np.dot(np.array(image)[..., :3] / 255, [0.2989, 0.5870, 0.1140]) + bw = np.ones_like(image) * np.expand_dims(bw, -1) + frame.append(ax.imshow(bw)) + + if center: + heatmap -= heatmap.mean() + + if not show_negative: + heatmap = heatmap.clamp_min(0) + + heatmap = F.interpolate(heatmap.unsqueeze(0).unsqueeze(0), (image.shape[0], image.shape[1])) \ + .squeeze(0).squeeze(0) + + if vmax is None: + vmax = np.abs(heatmap).max() + + hm = ax.imshow(heatmap, alpha=.5, cmap=cmap, vmax=vmax, vmin=-vmax) + if cax is not None: + plt.colorbar(hm, cax=cax, orientation='vertical') + + frame.extend([hm]) + return frame + + +def implicit_feats(original_image, input_size, color_feats): + n_freqs = 20 + grid = torch.linspace(-1, 1, input_size, device=original_image.device) + feats = torch.cat([t.unsqueeze(0) for t in torch.meshgrid([grid, grid])]).unsqueeze(0) + + if color_feats: + feat_list = [feats, original_image] + dim_multiplier = 5 + else: + feat_list = [feats] + dim_multiplier = 2 + + feats = torch.cat(feat_list, dim=1) + freqs = torch.exp(torch.linspace(-2, 10, n_freqs, device=original_image.device)) \ + .reshape(n_freqs, 1, 1, 1) + feats = (feats * freqs).reshape(1, n_freqs * dim_multiplier, input_size, input_size) + + if color_feats: + all_feats = [torch.sin(feats), torch.cos(feats), original_image] + else: + all_feats = [torch.sin(feats), torch.cos(feats)] + return torch.cat(all_feats, dim=1) + + +def load_hr_emb(original_image, model_path, color_feats=True): + model = torch.load(model_path, map_location="cpu") + hr_model = model["model"].cuda().eval() + unprojector = model["unprojector"].cuda().eval() + + with torch.no_grad(): + h, w = original_image.shape[2:] + assert h == w + feats = implicit_feats(original_image, h, color_feats).cuda() + hr_feats = hr_model(feats) + hr_feats = unprojector(hr_feats.detach().cpu()) + + return hr_feats + + +def generate_subset(n, batch): + np.random.seed(0) + return np.random.permutation(n)[:batch] + + +class TorchPCA(object): + + def __init__(self, n_components): + self.n_components = n_components + + def fit(self, X): + self.mean_ = X.mean(dim=0) + unbiased = X - self.mean_.unsqueeze(0) + U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=4) + self.components_ = V.T + self.singular_values_ = S + return self + + def transform(self, X): + t0 = X - self.mean_.unsqueeze(0) + projected = t0 @ self.components_.T + return projected + + +def pca(image_feats_list, dim=3, fit_pca=None, use_torch_pca=True, max_samples=None): + device = image_feats_list[0].device + + def flatten(tensor, target_size=None): + if target_size is not None and fit_pca is None: + tensor = F.interpolate(tensor, (target_size, target_size), mode="bilinear") + B, C, H, W = tensor.shape + return tensor.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu() + + if len(image_feats_list) > 1 and fit_pca is None: + target_size = image_feats_list[0].shape[2] + else: + target_size = None + + flattened_feats = [] + for feats in image_feats_list: + flattened_feats.append(flatten(feats, target_size)) + x = torch.cat(flattened_feats, dim=0) + + # Subsample the data if max_samples is set and the number of samples exceeds max_samples + if max_samples is not None and x.shape[0] > max_samples: + indices = torch.randperm(x.shape[0])[:max_samples] + x = x[indices] + + if fit_pca is None: + if use_torch_pca: + fit_pca = TorchPCA(n_components=dim).fit(x) + else: + fit_pca = PCA(n_components=dim).fit(x) + + reduced_feats = [] + for feats in image_feats_list: + x_red = fit_pca.transform(flatten(feats)) + if isinstance(x_red, np.ndarray): + x_red = torch.from_numpy(x_red) + x_red -= x_red.min(dim=0, keepdim=True).values + x_red /= x_red.max(dim=0, keepdim=True).values + B, C, H, W = feats.shape + reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device)) + + return reduced_feats, fit_pca + + +class PCAUnprojector(nn.Module): + + def __init__(self, feats, dim, device, use_torch_pca=False, **kwargs): + super().__init__() + self.dim = dim + + if feats is not None: + self.original_dim = feats.shape[1] + else: + self.original_dim = kwargs["original_dim"] + + if self.dim != self.original_dim: + if feats is not None: + sklearn_pca = pca([feats], dim=dim, use_torch_pca=use_torch_pca)[1] + + # Register tensors as buffers + self.register_buffer('components_', + torch.tensor(sklearn_pca.components_, device=device, dtype=feats.dtype)) + self.register_buffer('singular_values_', + torch.tensor(sklearn_pca.singular_values_, device=device, dtype=feats.dtype)) + self.register_buffer('mean_', torch.tensor(sklearn_pca.mean_, device=device, dtype=feats.dtype)) + else: + self.register_buffer('components_', kwargs["components_"].t()) + self.register_buffer('singular_values_', kwargs["singular_values_"]) + self.register_buffer('mean_', kwargs["mean_"]) + + else: + print("PCAUnprojector will not transform data") + + def forward(self, red_feats): + if self.dim == self.original_dim: + return red_feats + else: + b, c, h, w = red_feats.shape + red_feats_reshaped = red_feats.permute(0, 2, 3, 1).reshape(b * h * w, c) + unprojected = (red_feats_reshaped @ self.components_) + self.mean_.unsqueeze(0) + return unprojected.reshape(b, h, w, self.original_dim).permute(0, 3, 1, 2) + + def project(self, feats): + if self.dim == self.original_dim: + return feats + else: + b, c, h, w = feats.shape + feats_reshaped = feats.permute(0, 2, 3, 1).reshape(b * h * w, c) + t0 = feats_reshaped - self.mean_.unsqueeze(0).to(feats.device) + projected = t0 @ self.components_.t().to(feats.device) + return projected.reshape(b, h, w, self.dim).permute(0, 3, 1, 2) + + +def prep_image(t, subtract_min=True): + if subtract_min: + t -= t.min() + t /= t.max() + t = (t * 255).clamp(0, 255).to(torch.uint8) + + if len(t.shape) == 2: + t = t.unsqueeze(0) + + return t