louislu9911's picture
Upload model
82d77a8 verified
raw
history blame
3.08 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, AutoModelForImageClassification
from .configuration_moe import MoEConfig
def subgate(num_classes):
layers = nn.Sequential(
nn.Flatten(),
nn.Linear(224 * 224 * 3, 1024),
nn.ReLU(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, num_classes * 2),
)
return layers
class MoEModelForImageClassification(PreTrainedModel):
config_class = MoEConfig
def __init__(self, config):
super().__init__(config)
self.num_classes = config.num_classes
self.switch_gate_model = AutoModelForImageClassification.from_pretrained(
config.switch_gate
)
self.base_model1 = AutoModelForImageClassification.from_pretrained(
config.base_model
)
self.base_model2 = AutoModelForImageClassification.from_pretrained(
config.base_model
)
self.expert_model_1 = AutoModelForImageClassification.from_pretrained(
config.experts[0]
)
self.expert_model_2 = AutoModelForImageClassification.from_pretrained(
config.experts[1]
)
self.subgate1 = subgate(config.num_classes)
self.subgate2 = subgate(config.num_classes)
# Freeze all params
for module in [
self.base_model1,
self.base_model2,
self.expert_model_1,
self.expert_model_2,
]:
for param in module.parameters():
param.requires_grad = False
def forward(self, pixel_values, labels=None):
switch_gate_result = self.switch_gate_model(pixel_values).logits
base_model1_result = self.base_model1(pixel_values).logits
base_model2_result = self.base_model2(pixel_values).logits
expert1_result = self.expert_model_1(pixel_values).logits
expert2_result = self.expert_model_2(pixel_values).logits
subgate1_result = self.subgate1(pixel_values)
subgate1_result = torch.reshape(subgate1_result, (2, -1, self.num_classes))
subgate2_result = self.subgate2(pixel_values)
subgate2_result = torch.reshape(subgate2_result, (2, -1, self.num_classes))
expert1_and_base_res = (
expert1_result * subgate1_result[0, :, :]
+ base_model1_result * subgate1_result[1, :, :]
)
expert2_and_base_res = (
expert2_result * subgate2_result[0, :, :]
+ base_model2_result * subgate2_result[1, :, :]
)
# Gating Network
expert1_and_base_res = expert1_and_base_res * switch_gate_result[
:, 0
].unsqueeze(1)
expert2_and_base_res = expert2_and_base_res * switch_gate_result[
:, 1
].unsqueeze(1)
logits = expert1_and_base_res + expert2_and_base_res
if labels is not None:
loss = F.cross_entropy(logits, labels)
return {"loss": loss, "logits": logits}
return {"logits": logits}