File size: 3,444 Bytes
b9e43f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from typing import Dict, Optional, Tuple, Type

import numpy as np
import torch
import torch.nn as nn
from numpy.typing import NDArray
from torch.distributions import Distribution, constraints

from rl_algo_impls.shared.actor.actor import Actor, PiForward
from rl_algo_impls.shared.actor.categorical import MaskedCategorical
from rl_algo_impls.shared.encoder import EncoderOutDim
from rl_algo_impls.shared.module.module import mlp


class MultiCategorical(Distribution):
    def __init__(
        self,
        nvec: NDArray[np.int64],
        probs=None,
        logits=None,
        validate_args=None,
        masks: Optional[torch.Tensor] = None,
    ):
        # Either probs or logits should be set
        assert (probs is None) != (logits is None)
        masks_split = (
            torch.split(masks, nvec.tolist(), dim=1)
            if masks is not None
            else [None] * len(nvec)
        )
        if probs:
            self.dists = [
                MaskedCategorical(probs=p, validate_args=validate_args, mask=m)
                for p, m in zip(torch.split(probs, nvec.tolist(), dim=1), masks_split)
            ]
            param = probs
        else:
            assert logits is not None
            self.dists = [
                MaskedCategorical(logits=lg, validate_args=validate_args, mask=m)
                for lg, m in zip(torch.split(logits, nvec.tolist(), dim=1), masks_split)
            ]
            param = logits
        batch_shape = param.size()[:-1] if param.ndimension() > 1 else torch.Size()
        super().__init__(batch_shape=batch_shape, validate_args=validate_args)

    def log_prob(self, action: torch.Tensor) -> torch.Tensor:
        prob_stack = torch.stack(
            [c.log_prob(a) for a, c in zip(action.T, self.dists)], dim=-1
        )
        return prob_stack.sum(dim=-1)

    def entropy(self) -> torch.Tensor:
        return torch.stack([c.entropy() for c in self.dists], dim=-1).sum(dim=-1)

    def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
        return torch.stack([c.sample(sample_shape) for c in self.dists], dim=-1)

    @property
    def mode(self) -> torch.Tensor:
        return torch.stack([c.mode for c in self.dists], dim=-1)

    @property
    def arg_constraints(self) -> Dict[str, constraints.Constraint]:
        # Constraints handled by child distributions in dist
        return {}


class MultiDiscreteActorHead(Actor):
    def __init__(
        self,
        nvec: NDArray[np.int64],
        in_dim: EncoderOutDim,
        hidden_sizes: Tuple[int, ...] = (32,),
        activation: Type[nn.Module] = nn.ReLU,
        init_layers_orthogonal: bool = True,
    ) -> None:
        super().__init__()
        self.nvec = nvec
        assert isinstance(in_dim, int)
        layer_sizes = (in_dim,) + hidden_sizes + (nvec.sum(),)
        self._fc = mlp(
            layer_sizes,
            activation,
            init_layers_orthogonal=init_layers_orthogonal,
            final_layer_gain=0.01,
        )

    def forward(
        self,
        obs: torch.Tensor,
        actions: Optional[torch.Tensor] = None,
        action_masks: Optional[torch.Tensor] = None,
    ) -> PiForward:
        logits = self._fc(obs)
        pi = MultiCategorical(self.nvec, logits=logits, masks=action_masks)
        return self.pi_forward(pi, actions)

    @property
    def action_shape(self) -> Tuple[int, ...]:
        return (len(self.nvec),)