File size: 1,141 Bytes
924cafd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
from transformers import PretrainedConfig
from typing import Dict, List
MOE_MODEL_TYPE = "moe"
DEFAULT_HUGGINGFACE_ACCOUNT = "louislu9911"
model_checkpoint = "convnextv2-base-1k-224"
EXPERTS = [
f"{DEFAULT_HUGGINGFACE_ACCOUNT}/Expert1-leaf-disease-{model_checkpoint}-0_4",
f"{DEFAULT_HUGGINGFACE_ACCOUNT}/Expert2-leaf-disease-{model_checkpoint}-1_2_3",
]
SWITCH_GATE = (
f"{DEFAULT_HUGGINGFACE_ACCOUNT}/switch_gate-leaf-disease-{model_checkpoint}"
)
BASE_MODEL = (
f"{DEFAULT_HUGGINGFACE_ACCOUNT}/BaseModel-leaf-disease-{model_checkpoint}-0_1_2_3_4"
)
class MoEConfig(PretrainedConfig):
model_type = MOE_MODEL_TYPE
def __init__(
self,
experts: List[str] = EXPERTS,
switch_gate: str = SWITCH_GATE,
base_model: str = BASE_MODEL,
num_classes: int = 5,
expert_class_mapping: Dict[int, List[int]] = None,
**kwargs,
):
self.experts = experts
self.switch_gate = switch_gate
self.base_model = base_model
self.num_classes = num_classes
self.expert_class_mapping = expert_class_mapping
super().__init__(**kwargs)
|