|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
_COTRACKER_URL = "https://huggingface.co/facebook/cotracker/resolve/main/cotracker2.pth" |
|
|
|
|
|
def _make_cotracker_predictor(*, pretrained: bool = True, online=False, **kwargs): |
|
if online: |
|
from cotracker.predictor import CoTrackerOnlinePredictor |
|
|
|
predictor = CoTrackerOnlinePredictor(checkpoint=None) |
|
else: |
|
from cotracker.predictor import CoTrackerPredictor |
|
|
|
predictor = CoTrackerPredictor(checkpoint=None) |
|
if pretrained: |
|
state_dict = torch.hub.load_state_dict_from_url(_COTRACKER_URL, map_location="cpu") |
|
predictor.model.load_state_dict(state_dict) |
|
return predictor |
|
|
|
|
|
def cotracker2(*, pretrained: bool = True, **kwargs): |
|
""" |
|
CoTracker2 with stride 4 and window length 8. Can track up to 265*265 points jointly. |
|
""" |
|
return _make_cotracker_predictor(pretrained=pretrained, online=False, **kwargs) |
|
|
|
|
|
def cotracker2_online(*, pretrained: bool = True, **kwargs): |
|
""" |
|
Online CoTracker2 with stride 4 and window length 8. Can track up to 265*265 points jointly. |
|
""" |
|
return _make_cotracker_predictor(pretrained=pretrained, online=True, **kwargs) |
|
|