bryandts commited on
Commit
70ad629
1 Parent(s): a6111ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -43
app.py CHANGED
@@ -6,51 +6,9 @@ import json
6
  import os
7
  import matplotlib.pyplot as plt
8
  import random
9
-
10
  import torch
11
  import torch.nn as nn
12
-
13
- # The Generator model
14
- class Generator(nn.Module):
15
- def __init__(self, channels, noise_dim=100, embed_dim=1024, embed_out_dim=128):
16
- super(Generator, self).__init__()
17
- self.channels = channels
18
- self.noise_dim = noise_dim
19
- self.embed_dim = embed_dim
20
- self.embed_out_dim = embed_out_dim
21
-
22
- # Text embedding layers
23
- self.text_embedding = nn.Sequential(
24
- nn.Linear(self.embed_dim, self.embed_out_dim),
25
- nn.BatchNorm1d(1),
26
- nn.LeakyReLU(0.2, inplace=True)
27
- )
28
-
29
- # Generator architecture
30
- model = []
31
- model += self._create_layer(self.noise_dim + self.embed_out_dim, 512, 4, stride=1, padding=0)
32
- model += self._create_layer(512, 256, 4, stride=2, padding=1)
33
- model += self._create_layer(256, 128, 4, stride=2, padding=1)
34
- model += self._create_layer(128, 64, 4, stride=2, padding=1)
35
- model += self._create_layer(64, 32, 4, stride=2, padding=1)
36
- model += self._create_layer(32, self.channels, 4, stride=2, padding=1, output=True)
37
-
38
- self.model = nn.Sequential(*model)
39
-
40
- def _create_layer(self, size_in, size_out, kernel_size=4, stride=2, padding=1, output=False):
41
- layers = [nn.ConvTranspose2d(size_in, size_out, kernel_size, stride=stride, padding=padding, bias=False)]
42
- if output:
43
- layers.append(nn.Tanh()) # Tanh activation for the output layer
44
- else:
45
- layers += [nn.BatchNorm2d(size_out), nn.ReLU(True)] # Batch normalization and ReLU for other layers
46
- return layers
47
-
48
- def forward(self, noise, text):
49
- # Apply text embedding to the input text
50
- text = self.text_embedding(text)
51
- text = text.view(text.shape[0], text.shape[2], 1, 1) # Reshape to match the generator input size
52
- z = torch.cat([text, noise], 1) # Concatenate text embedding with noise
53
- return self.model(z)
54
 
55
  noise_dim = 16
56
  embed_dim = 384
 
6
  import os
7
  import matplotlib.pyplot as plt
8
  import random
 
9
  import torch
10
  import torch.nn as nn
11
+ from generator import Generator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  noise_dim = 16
14
  embed_dim = 384