K00B404 commited on
Commit
9ba87f1
1 Parent(s): 9aba175

Update 256_model.py

Browse files
Files changed (1) hide show
  1. 256_model.py +26 -0
256_model.py CHANGED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class UNet(nn.Module):
2
+ def __init__(self):
3
+ super(UNet, self).__init__()
4
+ # Encoder
5
+ self.encoder = nn.Sequential(
6
+ nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
7
+ nn.ReLU(inplace=True),
8
+ nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
9
+ nn.ReLU(inplace=True),
10
+ nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
11
+ nn.ReLU(inplace=True),
12
+ )
13
+ # Decoder
14
+ self.decoder = nn.Sequential(
15
+ nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
16
+ nn.ReLU(inplace=True),
17
+ nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
18
+ nn.ReLU(inplace=True),
19
+ nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
20
+ nn.Tanh()
21
+ )
22
+
23
+ def forward(self, x):
24
+ enc = self.encoder(x)
25
+ dec = self.decoder(enc)
26
+ return dec