File size: 854 Bytes
94f372a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from torch import nn
from hydra.utils import instantiate
from omegaconf import OmegaConf
from huggingface_hub import PyTorchModelHubMixin

class Geolocalizer(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config):
        super().__init__()
        self.config = OmegaConf.create(config)
        self.transform = instantiate(self.config.transform)
        self.model = instantiate(self.config.model)
        self.head = self.model.head
        self.mid = self.model.mid
        self.backbone = self.model.backbone

    def forward(self, img: torch.Tensor):
        output = self.head(self.mid(self.backbone({"img": img})), None)
        return output["gps"]

    def forward_tensor(self, img: torch.Tensor):
        output = self.head(self.mid(self.backbone(img)), None)
        return output["gps"]