kxhit
cuda reinit?
5ca3a35
raw
history blame
No virus
8.97 kB
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# DUSt3R model class
# --------------------------------------------------------
from copy import deepcopy
import torch
import os
from packaging import version
import huggingface_hub
from .utils.misc import (
fill_default_args,
freeze_all_params,
is_symmetrized,
interleave,
transpose_to_landscape,
)
from .heads import head_factory
from mini_dust3r.patch_embed import get_patch_embed
from mini_dust3r.croco.croco import CroCoNet
inf = float("inf")
hf_version_number = huggingface_hub.__version__
assert version.parse(hf_version_number) >= version.parse(
"0.22.0"
), "Outdated huggingface_hub version, please reinstall requirements.txt"
def load_model(model_path, device, verbose=True):
if verbose:
print("... loading model from", model_path)
ckpt = torch.load(model_path, map_location="cpu")
args = ckpt["args"].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R")
if "landscape_only" not in args:
args = args[:-1] + ", landscape_only=False)"
else:
args = args.replace(" ", "").replace(
"landscape_only=True", "landscape_only=False"
)
assert "landscape_only=False" in args
if verbose:
print(f"instantiating : {args}")
net = eval(args)
s = net.load_state_dict(ckpt["model"], strict=False)
if verbose:
print(s)
return net.to(device)
class AsymmetricCroCo3DStereo(
CroCoNet,
huggingface_hub.PyTorchModelHubMixin,
library_name="dust3r",
repo_url="https://github.com/naver/dust3r",
tags=["image-to-3d"],
):
"""Two siamese encoders, followed by two decoders.
The goal is to output 3d points directly, both images in view1's frame
(hence the asymmetry).
"""
def __init__(
self,
output_mode="pts3d",
head_type="linear",
depth_mode=("exp", -inf, inf),
conf_mode=("exp", 1, inf),
freeze="none",
landscape_only=True,
patch_embed_cls="PatchEmbedDust3R", # PatchEmbedDust3R or ManyAR_PatchEmbed
**croco_kwargs,
):
self.patch_embed_cls = patch_embed_cls
self.croco_args = fill_default_args(croco_kwargs, super().__init__)
super().__init__(**croco_kwargs)
# dust3r specific initialization
self.dec_blocks2 = deepcopy(self.dec_blocks)
self.set_downstream_head(
output_mode,
head_type,
landscape_only,
depth_mode,
conf_mode,
**croco_kwargs,
)
self.set_freeze(freeze)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kw):
if os.path.isfile(pretrained_model_name_or_path):
return load_model(pretrained_model_name_or_path, device="cpu")
else:
return super(AsymmetricCroCo3DStereo, cls).from_pretrained(
pretrained_model_name_or_path, **kw
)
def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
self.patch_embed = get_patch_embed(
self.patch_embed_cls, img_size, patch_size, enc_embed_dim
)
def load_state_dict(self, ckpt, **kw):
# duplicate all weights for the second decoder if not present
new_ckpt = dict(ckpt)
if not any(k.startswith("dec_blocks2") for k in ckpt):
for key, value in ckpt.items():
if key.startswith("dec_blocks"):
new_ckpt[key.replace("dec_blocks", "dec_blocks2")] = value
return super().load_state_dict(new_ckpt, **kw)
def set_freeze(self, freeze): # this is for use by downstream models
self.freeze = freeze
to_be_frozen = {
"none": [],
"mask": [self.mask_token],
"encoder": [self.mask_token, self.patch_embed, self.enc_blocks],
}
freeze_all_params(to_be_frozen[freeze])
def _set_prediction_head(self, *args, **kwargs):
"""No prediction head"""
return
def set_downstream_head(
self,
output_mode,
head_type,
landscape_only,
depth_mode,
conf_mode,
patch_size,
img_size,
**kw,
):
assert (
img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0
), f"{img_size=} must be multiple of {patch_size=}"
self.output_mode = output_mode
self.head_type = head_type
self.depth_mode = depth_mode
self.conf_mode = conf_mode
# allocate heads
self.downstream_head1 = head_factory(
head_type, output_mode, self, has_conf=bool(conf_mode)
)
self.downstream_head2 = head_factory(
head_type, output_mode, self, has_conf=bool(conf_mode)
)
# magic wrapper
self.head1 = transpose_to_landscape(
self.downstream_head1, activate=landscape_only
)
self.head2 = transpose_to_landscape(
self.downstream_head2, activate=landscape_only
)
def _encode_image(self, image, true_shape):
# embed the image into patches (x has size B x Npatches x C)
x, pos = self.patch_embed(image, true_shape=true_shape)
# add positional embedding without cls token
assert self.enc_pos_embed is None
# now apply the transformer encoder and normalization
for blk in self.enc_blocks:
x = blk(x, pos)
x = self.enc_norm(x)
return x, pos, None
def _encode_image_pairs(self, img1, img2, true_shape1, true_shape2):
if img1.shape[-2:] == img2.shape[-2:]:
out, pos, _ = self._encode_image(
torch.cat((img1, img2), dim=0),
torch.cat((true_shape1, true_shape2), dim=0),
)
out, out2 = out.chunk(2, dim=0)
pos, pos2 = pos.chunk(2, dim=0)
else:
out, pos, _ = self._encode_image(img1, true_shape1)
out2, pos2, _ = self._encode_image(img2, true_shape2)
return out, out2, pos, pos2
def _encode_symmetrized(self, view1, view2):
img1 = view1["img"]
img2 = view2["img"]
B = img1.shape[0]
# Recover true_shape when available, otherwise assume that the img shape is the true one
shape1 = view1.get(
"true_shape", torch.tensor(img1.shape[-2:])[None].repeat(B, 1)
)
shape2 = view2.get(
"true_shape", torch.tensor(img2.shape[-2:])[None].repeat(B, 1)
)
# warning! maybe the images have different portrait/landscape orientations
if is_symmetrized(view1, view2):
# computing half of forward pass!'
feat1, feat2, pos1, pos2 = self._encode_image_pairs(
img1[::2], img2[::2], shape1[::2], shape2[::2]
)
feat1, feat2 = interleave(feat1, feat2)
pos1, pos2 = interleave(pos1, pos2)
else:
feat1, feat2, pos1, pos2 = self._encode_image_pairs(
img1, img2, shape1, shape2
)
return (shape1, shape2), (feat1, feat2), (pos1, pos2)
def _decoder(self, f1, pos1, f2, pos2):
final_output = [(f1, f2)] # before projection
# project to decoder dim
f1 = self.decoder_embed(f1)
f2 = self.decoder_embed(f2)
final_output.append((f1, f2))
for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2):
# img1 side
f1, _ = blk1(*final_output[-1][::+1], pos1, pos2)
# img2 side
f2, _ = blk2(*final_output[-1][::-1], pos2, pos1)
# store the result
final_output.append((f1, f2))
# normalize last output
del final_output[1] # duplicate with final_output[0]
final_output[-1] = tuple(map(self.dec_norm, final_output[-1]))
return zip(*final_output)
def _downstream_head(self, head_num, decout, img_shape):
B, S, D = decout[-1].shape
# img_shape = tuple(map(int, img_shape))
head = getattr(self, f"head{head_num}")
return head(decout, img_shape)
def forward(self, view1, view2):
# encode the two images --> B,S,D
(shape1, shape2), (feat1, feat2), (pos1, pos2) = self._encode_symmetrized(
view1, view2
)
# combine all ref images into object-centric representation
dec1, dec2 = self._decoder(feat1, pos1, feat2, pos2)
with torch.cuda.amp.autocast(enabled=False):
res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1)
res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2)
res2["pts3d_in_other_view"] = res2.pop(
"pts3d"
) # predict view2's pts3d in view1's frame
return res1, res2