Spaces:
Running
Running
File size: 854 Bytes
94f372a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch
from torch import nn
from hydra.utils import instantiate
from omegaconf import OmegaConf
from huggingface_hub import PyTorchModelHubMixin
class Geolocalizer(nn.Module, PyTorchModelHubMixin):
def __init__(self, config):
super().__init__()
self.config = OmegaConf.create(config)
self.transform = instantiate(self.config.transform)
self.model = instantiate(self.config.model)
self.head = self.model.head
self.mid = self.model.mid
self.backbone = self.model.backbone
def forward(self, img: torch.Tensor):
output = self.head(self.mid(self.backbone({"img": img})), None)
return output["gps"]
def forward_tensor(self, img: torch.Tensor):
output = self.head(self.mid(self.backbone(img)), None)
return output["gps"]
|