bryandts commited on
Commit
e26bc6e
1 Parent(s): e05c407

Update discriminatorModel.py

Browse files
Files changed (1) hide show
  1. discriminatorModel.py +15 -1
discriminatorModel.py CHANGED
@@ -1,7 +1,21 @@
1
 
2
  import torch
3
  import torch.nn as nn
4
- from discriminatorEmbedding import Embedding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  # The Discriminator model
7
  class Discriminator(nn.Module):
 
1
 
2
  import torch
3
  import torch.nn as nn
4
+
5
+ class Embedding(nn.Module):
6
+ def __init__(self, size_in, size_out):
7
+ super(Embedding, self).__init__()
8
+ self.text_embedding = nn.Sequential(
9
+ nn.Linear(size_in, size_out),
10
+ nn.BatchNorm1d(1),
11
+ nn.LeakyReLU(0.2, inplace=True)
12
+ )
13
+
14
+ def forward(self, x, text):
15
+ embed_out = self.text_embedding(text)
16
+ embed_out_resize = embed_out.repeat(4, 1, 4, 1).permute(1, 3, 0, 2) # Resize to match the discriminator input size
17
+ out = torch.cat([x, embed_out_resize], 1) # Concatenate text embedding with the input feature map
18
+ return out
19
 
20
  # The Discriminator model
21
  class Discriminator(nn.Module):