File size: 3,077 Bytes
82d77a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, AutoModelForImageClassification
from .configuration_moe import MoEConfig


def subgate(num_classes):
    layers = nn.Sequential(
        nn.Flatten(),
        nn.Linear(224 * 224 * 3, 1024),
        nn.ReLU(),
        nn.Linear(1024, 512),
        nn.ReLU(),
        nn.Linear(512, num_classes * 2),
    )
    return layers


class MoEModelForImageClassification(PreTrainedModel):
    config_class = MoEConfig

    def __init__(self, config):
        super().__init__(config)
        self.num_classes = config.num_classes
        self.switch_gate_model = AutoModelForImageClassification.from_pretrained(
            config.switch_gate
        )
        self.base_model1 = AutoModelForImageClassification.from_pretrained(
            config.base_model
        )
        self.base_model2 = AutoModelForImageClassification.from_pretrained(
            config.base_model
        )
        self.expert_model_1 = AutoModelForImageClassification.from_pretrained(
            config.experts[0]
        )
        self.expert_model_2 = AutoModelForImageClassification.from_pretrained(
            config.experts[1]
        )

        self.subgate1 = subgate(config.num_classes)
        self.subgate2 = subgate(config.num_classes)

        # Freeze all params
        for module in [
            self.base_model1,
            self.base_model2,
            self.expert_model_1,
            self.expert_model_2,
        ]:
            for param in module.parameters():
                param.requires_grad = False

    def forward(self, pixel_values, labels=None):
        switch_gate_result = self.switch_gate_model(pixel_values).logits
        base_model1_result = self.base_model1(pixel_values).logits
        base_model2_result = self.base_model2(pixel_values).logits

        expert1_result = self.expert_model_1(pixel_values).logits
        expert2_result = self.expert_model_2(pixel_values).logits

        subgate1_result = self.subgate1(pixel_values)
        subgate1_result = torch.reshape(subgate1_result, (2, -1, self.num_classes))

        subgate2_result = self.subgate2(pixel_values)
        subgate2_result = torch.reshape(subgate2_result, (2, -1, self.num_classes))

        expert1_and_base_res = (
            expert1_result * subgate1_result[0, :, :]
            + base_model1_result * subgate1_result[1, :, :]
        )
        expert2_and_base_res = (
            expert2_result * subgate2_result[0, :, :]
            + base_model2_result * subgate2_result[1, :, :]
        )

        # Gating Network
        expert1_and_base_res = expert1_and_base_res * switch_gate_result[
            :, 0
        ].unsqueeze(1)
        expert2_and_base_res = expert2_and_base_res * switch_gate_result[
            :, 1
        ].unsqueeze(1)

        logits = expert1_and_base_res + expert2_and_base_res
        if labels is not None:
            loss = F.cross_entropy(logits, labels)
            return {"loss": loss, "logits": logits}
        return {"logits": logits}