|
from collections import OrderedDict |
|
from typing import List, Union, Dict |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch import Tensor |
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
HIDDEN_DIM = 8 |
|
|
|
class Model(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.model1 = nn.Linear(1, HIDDEN_DIM) |
|
self.model2 = nn.Linear(HIDDEN_DIM, HIDDEN_DIM) |
|
|
|
def forward(self, wavs, upstream_feature_selection="hidden_states"): |
|
|
|
hidden = self.model1(wavs) |
|
|
|
|
|
feature = self.model2(hidden) |
|
|
|
|
|
return [hidden, feature] |
|
|
|
class UpstreamExpert(nn.Module): |
|
def __init__( |
|
self, |
|
ckpt: str = "./model.pt", |
|
upstream_feature_selection: str = "hidden_states", |
|
**kwargs): |
|
""" |
|
Args: |
|
ckpt: |
|
The checkpoint path for loading your pretrained weights. |
|
Should be fixed as model.pt for SUPERB Challenge. |
|
upstream_feature_selection: |
|
The value could be |
|
'hidden_states', 'PR', 'SID', 'ER', 'ASR', 'QbE', 'ASV', 'SD', 'ST', 'SE', 'SS', 'secret', or others(new tasks). |
|
You can use it to control which task-specified pre- / post-processing to do. |
|
""" |
|
super().__init__() |
|
self.name = "[Example UpstreamExpert]" |
|
self.upstream_feature_selection = upstream_feature_selection |
|
|
|
|
|
ckpt = torch.load(ckpt, map_location="cpu") |
|
self.model = Model() |
|
self.model.load_state_dict(ckpt) |
|
|
|
def get_downsample_rates(self, key: str) -> int: |
|
""" |
|
Since we do not do any downsampling in this example upstream |
|
All keys' corresponding representations have downsample rate of 1 |
|
Eg. 10ms stride representation has the downsample rate 160 (input wavs are all in 16kHz) |
|
""" |
|
return 1 |
|
|
|
def forward(self, wavs: List[Tensor]) -> Dict[str, List[Tensor]]: |
|
""" |
|
When the returning Dict contains the List with more than one Tensor, |
|
those Tensors should be in the same shape to train a weighted-sum on them. |
|
""" |
|
|
|
wavs = pad_sequence(wavs, batch_first=True).unsqueeze(-1) |
|
|
|
|
|
hidden_states = self.model(wavs, upstream_feature_selection=self.upstream_feature_selection) |
|
|
|
|
|
|
|
|
|
|
|
return { |
|
"hidden_states": hidden_states, |
|
} |
|
|