bryandts commited on
Commit
0838938
1 Parent(s): e26bc6e

Update discriminatorModel.py

Browse files
Files changed (1) hide show
  1. discriminatorModel.py +16 -16
discriminatorModel.py CHANGED
@@ -1,21 +1,6 @@
1
-
2
  import torch
3
  import torch.nn as nn
4
 
5
- class Embedding(nn.Module):
6
- def __init__(self, size_in, size_out):
7
- super(Embedding, self).__init__()
8
- self.text_embedding = nn.Sequential(
9
- nn.Linear(size_in, size_out),
10
- nn.BatchNorm1d(1),
11
- nn.LeakyReLU(0.2, inplace=True)
12
- )
13
-
14
- def forward(self, x, text):
15
- embed_out = self.text_embedding(text)
16
- embed_out_resize = embed_out.repeat(4, 1, 4, 1).permute(1, 3, 0, 2) # Resize to match the discriminator input size
17
- out = torch.cat([x, embed_out_resize], 1) # Concatenate text embedding with the input feature map
18
- return out
19
 
20
  # The Discriminator model
21
  class Discriminator(nn.Module):
@@ -37,7 +22,22 @@ class Discriminator(nn.Module):
37
  self.output = nn.Sequential(
38
  nn.Conv2d(512 + self.embed_out_dim, 1, 4, 1, 0, bias=False), nn.Sigmoid()
39
  )
40
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def _create_layer(self, size_in, size_out, kernel_size=4, stride=2, padding=1, normalize=True):
42
  layers = [nn.Conv2d(size_in, size_out, kernel_size=kernel_size, stride=stride, padding=padding)]
43
  if normalize:
 
 
1
  import torch
2
  import torch.nn as nn
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  # The Discriminator model
6
  class Discriminator(nn.Module):
 
22
  self.output = nn.Sequential(
23
  nn.Conv2d(512 + self.embed_out_dim, 1, 4, 1, 0, bias=False), nn.Sigmoid()
24
  )
25
+
26
+ class Embedding(nn.Module):
27
+ def __init__(self, size_in, size_out):
28
+ super(Embedding, self).__init__()
29
+ self.text_embedding = nn.Sequential(
30
+ nn.Linear(size_in, size_out),
31
+ nn.BatchNorm1d(1),
32
+ nn.LeakyReLU(0.2, inplace=True)
33
+ )
34
+
35
+ def forward(self, x, text):
36
+ embed_out = self.text_embedding(text)
37
+ embed_out_resize = embed_out.repeat(4, 1, 4, 1).permute(1, 3, 0, 2) # Resize to match the discriminator input size
38
+ out = torch.cat([x, embed_out_resize], 1) # Concatenate text embedding with the input feature map
39
+ return out
40
+
41
  def _create_layer(self, size_in, size_out, kernel_size=4, stride=2, padding=1, normalize=True):
42
  layers = [nn.Conv2d(size_in, size_out, kernel_size=kernel_size, stride=stride, padding=padding)]
43
  if normalize: