bryandts commited on
Commit
98688ba
1 Parent(s): 0838938

Update discriminatorModel.py

Browse files
Files changed (1) hide show
  1. discriminatorModel.py +15 -15
discriminatorModel.py CHANGED
@@ -2,6 +2,21 @@ import torch
2
  import torch.nn as nn
3
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  # The Discriminator model
6
  class Discriminator(nn.Module):
7
  def __init__(self, channels, embed_dim=1024, embed_out_dim=128):
@@ -23,21 +38,6 @@ class Discriminator(nn.Module):
23
  nn.Conv2d(512 + self.embed_out_dim, 1, 4, 1, 0, bias=False), nn.Sigmoid()
24
  )
25
 
26
- class Embedding(nn.Module):
27
- def __init__(self, size_in, size_out):
28
- super(Embedding, self).__init__()
29
- self.text_embedding = nn.Sequential(
30
- nn.Linear(size_in, size_out),
31
- nn.BatchNorm1d(1),
32
- nn.LeakyReLU(0.2, inplace=True)
33
- )
34
-
35
- def forward(self, x, text):
36
- embed_out = self.text_embedding(text)
37
- embed_out_resize = embed_out.repeat(4, 1, 4, 1).permute(1, 3, 0, 2) # Resize to match the discriminator input size
38
- out = torch.cat([x, embed_out_resize], 1) # Concatenate text embedding with the input feature map
39
- return out
40
-
41
  def _create_layer(self, size_in, size_out, kernel_size=4, stride=2, padding=1, normalize=True):
42
  layers = [nn.Conv2d(size_in, size_out, kernel_size=kernel_size, stride=stride, padding=padding)]
43
  if normalize:
 
2
  import torch.nn as nn
3
 
4
 
5
+ class Embedding(nn.Module):
6
+ def __init__(self, size_in, size_out):
7
+ super(Embedding, self).__init__()
8
+ self.text_embedding = nn.Sequential(
9
+ nn.Linear(size_in, size_out),
10
+ nn.BatchNorm1d(1),
11
+ nn.LeakyReLU(0.2, inplace=True)
12
+ )
13
+
14
+ def forward(self, x, text):
15
+ embed_out = self.text_embedding(text)
16
+ embed_out_resize = embed_out.repeat(4, 1, 4, 1).permute(1, 3, 0, 2) # Resize to match the discriminator input size
17
+ out = torch.cat([x, embed_out_resize], 1) # Concatenate text embedding with the input feature map
18
+ return out
19
+
20
  # The Discriminator model
21
  class Discriminator(nn.Module):
22
  def __init__(self, channels, embed_dim=1024, embed_out_dim=128):
 
38
  nn.Conv2d(512 + self.embed_out_dim, 1, 4, 1, 0, bias=False), nn.Sigmoid()
39
  )
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def _create_layer(self, size_in, size_out, kernel_size=4, stride=2, padding=1, normalize=True):
42
  layers = [nn.Conv2d(size_in, size_out, kernel_size=kernel_size, stride=stride, padding=padding)]
43
  if normalize: