Spaces:
Build error
Build error
File size: 795 Bytes
62bee25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from .text_encoder import Text_Encoder
from .resunet_film import UNetRes_FiLM
class LASSNet(nn.Module):
def __init__(self, device='cuda'):
super(LASSNet, self).__init__()
self.text_embedder = Text_Encoder(device)
self.UNet = UNetRes_FiLM(channels=1, cond_embedding_dim=256)
def forward(self, x, caption):
# x: (Batch, 1, T, 128))
input_ids, attns_mask = self.text_embedder.tokenize(caption)
cond_vec = self.text_embedder(input_ids, attns_mask)[0]
dec_cond_vec = cond_vec
mask = self.UNet(x, cond_vec, dec_cond_vec)
mask = torch.sigmoid(mask)
return mask
def get_tokenizer(self):
return self.text_embedder.tokenizer
|