File size: 629 Bytes
94f372a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import numpy as np
from torch import nn


class NormGPS(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        """Normalize latitude longtitude radians to -1, 1."""  # not used currently
        return x / torch.Tensor([np.pi * 0.5, np.pi]).unsqueeze(0).to(x.device)


class UnormGPS(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        """Unormalize latitude longtitude radians to -1, 1."""
        x = torch.clamp(x, -1, 1)
        return x * torch.Tensor([np.pi * 0.5, np.pi]).unsqueeze(0).to(x.device)