|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import PreTrainedModel, AutoModelForImageClassification |
|
from .configuration_moe import MoEConfig |
|
|
|
|
|
def subgate(num_classes): |
|
layers = nn.Sequential( |
|
nn.Flatten(), |
|
nn.Linear(224 * 224 * 3, 1024), |
|
nn.ReLU(), |
|
nn.Linear(1024, 512), |
|
nn.ReLU(), |
|
nn.Linear(512, num_classes * 2), |
|
) |
|
return layers |
|
|
|
|
|
class MoEModelForImageClassification(PreTrainedModel): |
|
config_class = MoEConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_classes = config.num_classes |
|
self.switch_gate_model = AutoModelForImageClassification.from_pretrained( |
|
config.switch_gate |
|
) |
|
self.base_model1 = AutoModelForImageClassification.from_pretrained( |
|
config.base_model |
|
) |
|
self.base_model2 = AutoModelForImageClassification.from_pretrained( |
|
config.base_model |
|
) |
|
self.expert_model_1 = AutoModelForImageClassification.from_pretrained( |
|
config.experts[0] |
|
) |
|
self.expert_model_2 = AutoModelForImageClassification.from_pretrained( |
|
config.experts[1] |
|
) |
|
|
|
self.subgate1 = subgate(config.num_classes) |
|
self.subgate2 = subgate(config.num_classes) |
|
|
|
|
|
for module in [ |
|
self.base_model1, |
|
self.base_model2, |
|
self.expert_model_1, |
|
self.expert_model_2, |
|
]: |
|
for param in module.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, pixel_values, labels=None): |
|
switch_gate_result = self.switch_gate_model(pixel_values).logits |
|
base_model1_result = self.base_model1(pixel_values).logits |
|
base_model2_result = self.base_model2(pixel_values).logits |
|
|
|
expert1_result = self.expert_model_1(pixel_values).logits |
|
expert2_result = self.expert_model_2(pixel_values).logits |
|
|
|
subgate1_result = self.subgate1(pixel_values) |
|
subgate1_result = torch.reshape(subgate1_result, (2, -1, self.num_classes)) |
|
|
|
subgate2_result = self.subgate2(pixel_values) |
|
subgate2_result = torch.reshape(subgate2_result, (2, -1, self.num_classes)) |
|
|
|
expert1_and_base_res = ( |
|
expert1_result * subgate1_result[0, :, :] |
|
+ base_model1_result * subgate1_result[1, :, :] |
|
) |
|
expert2_and_base_res = ( |
|
expert2_result * subgate2_result[0, :, :] |
|
+ base_model2_result * subgate2_result[1, :, :] |
|
) |
|
|
|
|
|
expert1_and_base_res = expert1_and_base_res * switch_gate_result[ |
|
:, 0 |
|
].unsqueeze(1) |
|
expert2_and_base_res = expert2_and_base_res * switch_gate_result[ |
|
:, 1 |
|
].unsqueeze(1) |
|
|
|
logits = expert1_and_base_res + expert2_and_base_res |
|
if labels is not None: |
|
loss = F.cross_entropy(logits, labels) |
|
return {"loss": loss, "logits": logits} |
|
return {"logits": logits} |
|
|