How to Use SSAST Model Weights in the HuggingFace Ecosystem?

Community Article Published August 27, 2024

image/png Fig. 1: The Self-Supervised Audio Spectrogram Architecture from the original paper. (edited by author)

The Self-Supervised Audio Spectrogram Transformer (SSAST) model provides state-of-the-art audio classification capabilities [1, 2]. Self-supervised learning allows the use of unlabeled data, improving model performance and feature learning.
Unlike the supervised AST model, SSAST training is only available through the original implementation in the research repository, so working with the model can be cumbersome. However, after pre-training, the weights can be easily loaded into the HuggingFace Transformers AST implementation for fine-tuning on a downstream task while leveraging the HuggingFace ecosystem.

This tutorial will guide you through the process of integrating the SSAST model weights into the HuggingFace ecosystem, allowing for easier fine-tuning and deployment, and making it accessible to a wider audience.

Why Use SSAST Weights?

By loading SSAST weights into the HuggingFace Transformers AST implementation, you can:

  • Benefit from self-supervised learning on unlabeled data: Enhance model performance on downstream tasks by pretraining the model with the original SSAST implementation and fine-tuning it in the HuggingFace ecosystem.
  • Utilize HuggingFace’s powerful and user-friendly tools: Take advantage of HuggingFace’s comprehensive suite of tools for model training, evaluation, and deployment.
  • Escape the “fragile” research repository: Avoid compatibility and dependency issues by integrating the model into the robust and well-supported HuggingFace platform.

Let’s get started with the step-by-step guide to load the weights.

Step-by-Step Guide to Load SSAST Weights

Install all required packages with pip:

pip install 'transformers[torch]'

1. Configure the Architecture

First, configure the architecture of the SSAST model of which we want to load the weights using the ASTConfig class from the HuggingFace transformers library:

from transformers import ASTConfig, ASTModel

config = ASTConfig(
    architectures=["ASTModel"],
    frequency_stride=16,
    time_stride=16,
    hidden_size=768,
    max_length=1024,
    num_attention_heads=12,
    num_hidden_layers=12,
    num_mel_bins=128,
    qkv_bias=True
)

In the code snippet above, I have configured the 16–16 patch base model. The weights are already available as a download link in the SSAST repository or here. If you have pretrained your own model with a custom architecture, you will need to configure it accordingly.

Have a look at the other pretrained models in the SSAST repository.

2. Instantiate the AST Model

Next, create an instance of the ASTModel with the specified configuration:

model = ASTModel(config=config)

If you have not yet trained a model with the original SSAST repository, you can simply download the weights of any of the pretrained models available in the repository.

To load the weights into the transformers AST implementation, you load the weights from the state_dict .

import torch

model.load_state_dict(torch.load("./SSAST-Base-Frame-400.pth"))

Upon loading, you see messages indicating that some weights were not used. This is expected when initializing an ASTModel from a checkpoint trained on another task or with a different architecture:

Some weights of the model checkpoint at ./SSAST-Base-Patch-400.pth were not used when initializing ASTModel: ['module.v.blocks.3.mlp.fc2.bias', 'module.gpredlayer.2.bias', 'module.v.blocks.10.attn.qkv.bias', 'module.v.blocks.11.mlp.fc1.bias', 'module.v.blocks.1.mlp.fc1.weight', ...
- This IS expected if you are initializing ASTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ASTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ASTModel were not initialized from the model checkpoint at ./SSAST-Base-Patch-400.pth and are newly initialized: ['encoder.layer.4.layernorm_before.weight', 'encoder.layer.1.attention.attention.value.weight', 'encoder.layer.10.attention.attention.query.weight', ...
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

The key difference lies in the naming conventions of the layers. HuggingFace’s ASTModel uses a different naming scheme compared to the original SSAST model. In the HuggingFace implementation the “encoder.layer.[0–12]” correspond to “module.v.blocks.[0–12]”.

In the next step, we will resolve this issue.

3. Convert SSAST State Dictionary to HuggingFace Format

To resolve this, you can map the SSAST layer names to the corresponding HuggingFace layer names. Below is a function to perform this conversion:

import torch

def convert_ssast_state_dict_to_astmodel(pretrained_dict, layers: int = 12):
    conversion_dict = {
        'module.v.cls_token': 'embeddings.cls_token',
        'module.v.dist_token': 'embeddings.distillation_token',
        'module.v.pos_embed': 'embeddings.position_embeddings',
        'module.v.patch_embed.proj.weight': 'embeddings.patch_embeddings.projection.weight',
        'module.v.patch_embed.proj.bias': 'embeddings.patch_embeddings.projection.bias',
        'module.v.norm.weight': 'layernorm.weight',
        'module.v.norm.bias': 'layernorm.bias',
    }

    for i in range(layers):
        conversion_dict[
            f'module.v.blocks.{i}.norm1.weight'] = f'encoder.layer.{i}.layernorm_before.weight'
        conversion_dict[
            f'module.v.blocks.{i}.norm1.bias'] = f'encoder.layer.{i}.layernorm_before.bias'
        conversion_dict[f'module.v.blocks.{i}.attn.qkv.weight'] = [
            f'encoder.layer.{i}.attention.attention.query.weight',
            f'encoder.layer.{i}.attention.attention.key.weight',
            f'encoder.layer.{i}.attention.attention.value.weight'
        ]
        conversion_dict[f'module.v.blocks.{i}.attn.qkv.bias'] = [
            f'encoder.layer.{i}.attention.attention.query.bias',
            f'encoder.layer.{i}.attention.attention.key.bias',
            f'encoder.layer.{i}.attention.attention.value.bias'
        ]
        conversion_dict[
            f'module.v.blocks.{i}.attn.proj.weight'] = f'encoder.layer.{i}.attention.output.dense.weight'
        conversion_dict[
            f'module.v.blocks.{i}.attn.proj.bias'] = f'encoder.layer.{i}.attention.output.dense.bias'
        conversion_dict[
            f'module.v.blocks.{i}.norm2.weight'] = f'encoder.layer.{i}.layernorm_after.weight'
        conversion_dict[
            f'module.v.blocks.{i}.norm2.bias'] = f'encoder.layer.{i}.layernorm_after.bias'
        conversion_dict[
            f'module.v.blocks.{i}.mlp.fc1.weight'] = f'encoder.layer.{i}.intermediate.dense.weight'
        conversion_dict[
            f'module.v.blocks.{i}.mlp.fc1.bias'] = f'encoder.layer.{i}.intermediate.dense.bias'
        conversion_dict[
            f'module.v.blocks.{i}.mlp.fc2.weight'] = f'encoder.layer.{i}.output.dense.weight'
        conversion_dict[
            f'module.v.blocks.{i}.mlp.fc2.bias'] = f'encoder.layer.{i}.output.dense.bias'
    }

    converted_dict = {}
    for key, value in pretrained_dict.items():
        if key in conversion_dict:
            mapped_key = conversion_dict[key]
            if isinstance(mapped_key, list):
                # Assuming value is split equally among q, k, v if it's a concatenated tensor
                split_size = value.shape[0] // 3
                converted_dict[mapped_key[0]] = value[:split_size]
                converted_dict[mapped_key[1]] = value[split_size:2 * split_size]
                converted_dict[mapped_key[2]] = value[2 * split_size:]
            else:
                converted_dict[mapped_key] = value

    return converted_dict

4. Load the Converted State Dictionary

Load the SSAST checkpoint, convert the state_dict, and initialize the ASTModel:

ssast_state_dict = torch.load("./SSAST-Base-Patch-400.pth")
converted = convert_ssast_state_dict_to_astmodel(ssast_state_dict)

model.load_state_dict(converted)

If the conversion is successful, you should see:

“Out[1]: <All keys matched successfully>”

You are now able to use the ASTModel with SSAST pretrained weights for any task you like, such as creating embeddings or integrating it into your custom training pipeline.

Using the SSAST Model Weights for Audio Classification

If you want to use the weights to initialize an audio classifier, you must make some minor adjustments.

To instantiate an ASTForAudioClassification model with the SSAST weights, add “audio_spectrogram_transformer.” to the encoder and embedding layer names to match them correctly. For example:

'module.v.blocks.0.norm1.weight' --> 'audio_spectrogram_transformer.encoder.layer.0.layernorm_before.weight'
'module.v.cls_token' --> 'audio_spectrogram_transformer.embeddings.cls_token'
'module.v.norm.weight'--> 'audio_spectrogram_transformer.layernorm.weight'

Since the classification head will be initialized with zeros, be sure to call model.initialize() afterward.

from transformers import ASTForAudioClassification

model = ASTForAudioClassification(config=config)
model.load_state_dict(converted, strict=False)
model.initialize()

Now, your ASTForAudioClassification model is ready for fine-tuning on an audio classification task.

Learn how to fine-tune the AST in this article.

Conclusion

This guide demonstrates how easy it is to load the weights of SSAST models pretrained with the original implementation into the HuggingFace ASTModel class. Integrating SSAST model weights into the HuggingFace ecosystem can unlock the powerful capabilities of self-supervised learning for your AST training or fine-tuning pipeline in the HuggingFace ecosystem.

Outlook

I've been working with the AST model for the last 1.5 years and started a series of articles about how to train the model in general and make adaptations regarding specific problems in the audio domain. This is only the first part of the series. If you're interested in expanding your knowledge of machine learning applied to audio, be sure to check out my audio articles list on medium.

The second article is about How to Fine-Tune the Audio Spectrogram Transformer (AST) with Hugging Face Transformers and has already been published by Towards Data Science.

Happy modeling!

And thanks for reading! My name is Marius Steger, I’m a Machine Learning Engineer @Renumics — We have developed Spotlight, an Open Source tool for interactive data exploration and visualization that integrates with Hugging Face datasets. If you want to learn more about the tool have a look at this Community Article from my colleague Markus.

References

[1] Leaderboard on Papers With Code: Audio Classification on AudioSet

[2] Yuan Gong, Cheng-I Jeff Lai, Yu-An Chung, James Glass: SSAST: Self-SSAST: Self-Supervised Audio Spectrogram Transformer. (2021), arxiv