Spaces:
Runtime error
Runtime error
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
class ScaledDotProductAttention(nn.Module): | |
"""Scaled Dot-Product Attention""" | |
def __init__(self, temperature): | |
super().__init__() | |
self.temperature = temperature | |
self.softmax = nn.Softmax(dim=2) | |
def forward(self, q, k, v, mask=None): | |
attn = torch.bmm(q, k.transpose(1, 2)) | |
attn = attn / self.temperature | |
if mask is not None: | |
attn = attn.masked_fill(mask, -np.inf) | |
attn = self.softmax(attn) | |
output = torch.bmm(attn, v) | |
return output, attn | |