TomatoCocotree
上传
6a62ffb
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import contextlib
from omegaconf import open_dict, OmegaConf
from fairseq.tasks import register_task
from fairseq.tasks.sentence_prediction import (
SentencePredictionTask,
SentencePredictionConfig,
)
logger = logging.getLogger(__name__)
@register_task("sentence_prediction_adapters", dataclass=SentencePredictionConfig)
class SentencePredictionAdapterTask(SentencePredictionTask):
def build_model(self, cfg):
from fairseq import models
with open_dict(cfg) if OmegaConf.is_config(cfg) else contextlib.ExitStack():
cfg.max_positions = self.cfg.max_positions
model = models.build_model(cfg, self)
model.register_classification_head(
self.cfg.classification_head_name,
num_classes=self.cfg.num_classes,
)
logger.info("Freezing Embedding Parameters")
for parameter in model.encoder.sentence_encoder.embed_positions.parameters():
parameter.requires_grad = False
for (
parameter
) in model.encoder.sentence_encoder.layernorm_embedding.parameters():
parameter.requires_grad = False
for parameter in model.encoder.sentence_encoder.embed_tokens.parameters():
parameter.requires_grad = False
logger.info("Freezing Adapters")
for k, v in model.encoder.sentence_encoder.layers._modules.items():
logger.info("Freezing Adapters in Layer " + str(k))
if hasattr(v, "adapter_layer_norm"):
logger.info("Freezing Adapter LN")
for parameter in v.adapter_layer_norm.parameters():
parameter.requires_grad = False
for parameter in v.adapter_modules.parameters():
parameter.requires_grad = False
return model