Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import contextlib | |
from omegaconf import open_dict, OmegaConf | |
from fairseq.tasks import register_task | |
from fairseq.tasks.sentence_prediction import ( | |
SentencePredictionTask, | |
SentencePredictionConfig, | |
) | |
logger = logging.getLogger(__name__) | |
class SentencePredictionAdapterTask(SentencePredictionTask): | |
def build_model(self, cfg): | |
from fairseq import models | |
with open_dict(cfg) if OmegaConf.is_config(cfg) else contextlib.ExitStack(): | |
cfg.max_positions = self.cfg.max_positions | |
model = models.build_model(cfg, self) | |
model.register_classification_head( | |
self.cfg.classification_head_name, | |
num_classes=self.cfg.num_classes, | |
) | |
logger.info("Freezing Embedding Parameters") | |
for parameter in model.encoder.sentence_encoder.embed_positions.parameters(): | |
parameter.requires_grad = False | |
for ( | |
parameter | |
) in model.encoder.sentence_encoder.layernorm_embedding.parameters(): | |
parameter.requires_grad = False | |
for parameter in model.encoder.sentence_encoder.embed_tokens.parameters(): | |
parameter.requires_grad = False | |
logger.info("Freezing Adapters") | |
for k, v in model.encoder.sentence_encoder.layers._modules.items(): | |
logger.info("Freezing Adapters in Layer " + str(k)) | |
if hasattr(v, "adapter_layer_norm"): | |
logger.info("Freezing Adapter LN") | |
for parameter in v.adapter_layer_norm.parameters(): | |
parameter.requires_grad = False | |
for parameter in v.adapter_modules.parameters(): | |
parameter.requires_grad = False | |
return model | |