# Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # DUSt3R model class # -------------------------------------------------------- from copy import deepcopy import torch import os from packaging import version import huggingface_hub from .utils.misc import ( fill_default_args, freeze_all_params, is_symmetrized, interleave, transpose_to_landscape, ) from .heads import head_factory from mini_dust3r.patch_embed import get_patch_embed from mini_dust3r.croco.croco import CroCoNet inf = float("inf") hf_version_number = huggingface_hub.__version__ assert version.parse(hf_version_number) >= version.parse( "0.22.0" ), "Outdated huggingface_hub version, please reinstall requirements.txt" def load_model(model_path, device, verbose=True): if verbose: print("... loading model from", model_path) ckpt = torch.load(model_path, map_location="cpu") args = ckpt["args"].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") if "landscape_only" not in args: args = args[:-1] + ", landscape_only=False)" else: args = args.replace(" ", "").replace( "landscape_only=True", "landscape_only=False" ) assert "landscape_only=False" in args if verbose: print(f"instantiating : {args}") net = eval(args) s = net.load_state_dict(ckpt["model"], strict=False) if verbose: print(s) return net.to(device) class AsymmetricCroCo3DStereo( CroCoNet, huggingface_hub.PyTorchModelHubMixin, library_name="dust3r", repo_url="https://github.com/naver/dust3r", tags=["image-to-3d"], ): """Two siamese encoders, followed by two decoders. The goal is to output 3d points directly, both images in view1's frame (hence the asymmetry). """ def __init__( self, output_mode="pts3d", head_type="linear", depth_mode=("exp", -inf, inf), conf_mode=("exp", 1, inf), freeze="none", landscape_only=True, patch_embed_cls="PatchEmbedDust3R", # PatchEmbedDust3R or ManyAR_PatchEmbed **croco_kwargs, ): self.patch_embed_cls = patch_embed_cls self.croco_args = fill_default_args(croco_kwargs, super().__init__) super().__init__(**croco_kwargs) # dust3r specific initialization self.dec_blocks2 = deepcopy(self.dec_blocks) self.set_downstream_head( output_mode, head_type, landscape_only, depth_mode, conf_mode, **croco_kwargs, ) self.set_freeze(freeze) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kw): if os.path.isfile(pretrained_model_name_or_path): return load_model(pretrained_model_name_or_path, device="cpu") else: return super(AsymmetricCroCo3DStereo, cls).from_pretrained( pretrained_model_name_or_path, **kw ) def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): self.patch_embed = get_patch_embed( self.patch_embed_cls, img_size, patch_size, enc_embed_dim ) def load_state_dict(self, ckpt, **kw): # duplicate all weights for the second decoder if not present new_ckpt = dict(ckpt) if not any(k.startswith("dec_blocks2") for k in ckpt): for key, value in ckpt.items(): if key.startswith("dec_blocks"): new_ckpt[key.replace("dec_blocks", "dec_blocks2")] = value return super().load_state_dict(new_ckpt, **kw) def set_freeze(self, freeze): # this is for use by downstream models self.freeze = freeze to_be_frozen = { "none": [], "mask": [self.mask_token], "encoder": [self.mask_token, self.patch_embed, self.enc_blocks], } freeze_all_params(to_be_frozen[freeze]) def _set_prediction_head(self, *args, **kwargs): """No prediction head""" return def set_downstream_head( self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw, ): assert ( img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0 ), f"{img_size=} must be multiple of {patch_size=}" self.output_mode = output_mode self.head_type = head_type self.depth_mode = depth_mode self.conf_mode = conf_mode # allocate heads self.downstream_head1 = head_factory( head_type, output_mode, self, has_conf=bool(conf_mode) ) self.downstream_head2 = head_factory( head_type, output_mode, self, has_conf=bool(conf_mode) ) # magic wrapper self.head1 = transpose_to_landscape( self.downstream_head1, activate=landscape_only ) self.head2 = transpose_to_landscape( self.downstream_head2, activate=landscape_only ) def _encode_image(self, image, true_shape): # embed the image into patches (x has size B x Npatches x C) x, pos = self.patch_embed(image, true_shape=true_shape) # add positional embedding without cls token assert self.enc_pos_embed is None # now apply the transformer encoder and normalization for blk in self.enc_blocks: x = blk(x, pos) x = self.enc_norm(x) return x, pos, None def _encode_image_pairs(self, img1, img2, true_shape1, true_shape2): if img1.shape[-2:] == img2.shape[-2:]: out, pos, _ = self._encode_image( torch.cat((img1, img2), dim=0), torch.cat((true_shape1, true_shape2), dim=0), ) out, out2 = out.chunk(2, dim=0) pos, pos2 = pos.chunk(2, dim=0) else: out, pos, _ = self._encode_image(img1, true_shape1) out2, pos2, _ = self._encode_image(img2, true_shape2) return out, out2, pos, pos2 def _encode_symmetrized(self, view1, view2): img1 = view1["img"] img2 = view2["img"] B = img1.shape[0] # Recover true_shape when available, otherwise assume that the img shape is the true one shape1 = view1.get( "true_shape", torch.tensor(img1.shape[-2:])[None].repeat(B, 1) ) shape2 = view2.get( "true_shape", torch.tensor(img2.shape[-2:])[None].repeat(B, 1) ) # warning! maybe the images have different portrait/landscape orientations if is_symmetrized(view1, view2): # computing half of forward pass!' feat1, feat2, pos1, pos2 = self._encode_image_pairs( img1[::2], img2[::2], shape1[::2], shape2[::2] ) feat1, feat2 = interleave(feat1, feat2) pos1, pos2 = interleave(pos1, pos2) else: feat1, feat2, pos1, pos2 = self._encode_image_pairs( img1, img2, shape1, shape2 ) return (shape1, shape2), (feat1, feat2), (pos1, pos2) def _decoder(self, f1, pos1, f2, pos2): final_output = [(f1, f2)] # before projection # project to decoder dim f1 = self.decoder_embed(f1) f2 = self.decoder_embed(f2) final_output.append((f1, f2)) for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2): # img1 side f1, _ = blk1(*final_output[-1][::+1], pos1, pos2) # img2 side f2, _ = blk2(*final_output[-1][::-1], pos2, pos1) # store the result final_output.append((f1, f2)) # normalize last output del final_output[1] # duplicate with final_output[0] final_output[-1] = tuple(map(self.dec_norm, final_output[-1])) return zip(*final_output) def _downstream_head(self, head_num, decout, img_shape): B, S, D = decout[-1].shape # img_shape = tuple(map(int, img_shape)) head = getattr(self, f"head{head_num}") return head(decout, img_shape) def forward(self, view1, view2): # encode the two images --> B,S,D (shape1, shape2), (feat1, feat2), (pos1, pos2) = self._encode_symmetrized( view1, view2 ) # combine all ref images into object-centric representation dec1, dec2 = self._decoder(feat1, pos1, feat2, pos2) with torch.cuda.amp.autocast(enabled=False): res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1) res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2) res2["pts3d_in_other_view"] = res2.pop( "pts3d" ) # predict view2's pts3d in view1's frame return res1, res2