Spaces:
Runtime error
Runtime error
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
# This code is modified from https://github.com/lifeiteng/vall-e | |
import torch | |
import torch.nn as nn | |
class TokenEmbedding(nn.Module): | |
def __init__(self, dim_model: int, vocab_size: int, dropout: float = 0.0): | |
super().__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
self.word_embeddings = nn.Embedding(vocab_size, dim_model) | |
def weight(self) -> torch.Tensor: | |
return self.word_embeddings.weight | |
def forward(self, x: torch.Tensor): | |
x = self.word_embeddings(x) | |
x = self.dropout(x) | |
return x | |