bryandts commited on
Commit
8610a1c
1 Parent(s): bd223fb

Create discriminatorEmbedding.py

Browse files
Files changed (1) hide show
  1. discriminatorEmbedding.py +17 -0
discriminatorEmbedding.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class Embedding(nn.Module):
5
+ def __init__(self, size_in, size_out):
6
+ super(Embedding, self).__init__()
7
+ self.text_embedding = nn.Sequential(
8
+ nn.Linear(size_in, size_out),
9
+ nn.BatchNorm1d(1),
10
+ nn.LeakyReLU(0.2, inplace=True)
11
+ )
12
+
13
+ def forward(self, x, text):
14
+ embed_out = self.text_embedding(text)
15
+ embed_out_resize = embed_out.repeat(4, 1, 4, 1).permute(1, 3, 0, 2) # Resize to match the discriminator input size
16
+ out = torch.cat([x, embed_out_resize], 1) # Concatenate text embedding with the input feature map
17
+ return out