Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,51 +6,9 @@ import json
|
|
6 |
import os
|
7 |
import matplotlib.pyplot as plt
|
8 |
import random
|
9 |
-
|
10 |
import torch
|
11 |
import torch.nn as nn
|
12 |
-
|
13 |
-
# The Generator model
|
14 |
-
class Generator(nn.Module):
|
15 |
-
def __init__(self, channels, noise_dim=100, embed_dim=1024, embed_out_dim=128):
|
16 |
-
super(Generator, self).__init__()
|
17 |
-
self.channels = channels
|
18 |
-
self.noise_dim = noise_dim
|
19 |
-
self.embed_dim = embed_dim
|
20 |
-
self.embed_out_dim = embed_out_dim
|
21 |
-
|
22 |
-
# Text embedding layers
|
23 |
-
self.text_embedding = nn.Sequential(
|
24 |
-
nn.Linear(self.embed_dim, self.embed_out_dim),
|
25 |
-
nn.BatchNorm1d(1),
|
26 |
-
nn.LeakyReLU(0.2, inplace=True)
|
27 |
-
)
|
28 |
-
|
29 |
-
# Generator architecture
|
30 |
-
model = []
|
31 |
-
model += self._create_layer(self.noise_dim + self.embed_out_dim, 512, 4, stride=1, padding=0)
|
32 |
-
model += self._create_layer(512, 256, 4, stride=2, padding=1)
|
33 |
-
model += self._create_layer(256, 128, 4, stride=2, padding=1)
|
34 |
-
model += self._create_layer(128, 64, 4, stride=2, padding=1)
|
35 |
-
model += self._create_layer(64, 32, 4, stride=2, padding=1)
|
36 |
-
model += self._create_layer(32, self.channels, 4, stride=2, padding=1, output=True)
|
37 |
-
|
38 |
-
self.model = nn.Sequential(*model)
|
39 |
-
|
40 |
-
def _create_layer(self, size_in, size_out, kernel_size=4, stride=2, padding=1, output=False):
|
41 |
-
layers = [nn.ConvTranspose2d(size_in, size_out, kernel_size, stride=stride, padding=padding, bias=False)]
|
42 |
-
if output:
|
43 |
-
layers.append(nn.Tanh()) # Tanh activation for the output layer
|
44 |
-
else:
|
45 |
-
layers += [nn.BatchNorm2d(size_out), nn.ReLU(True)] # Batch normalization and ReLU for other layers
|
46 |
-
return layers
|
47 |
-
|
48 |
-
def forward(self, noise, text):
|
49 |
-
# Apply text embedding to the input text
|
50 |
-
text = self.text_embedding(text)
|
51 |
-
text = text.view(text.shape[0], text.shape[2], 1, 1) # Reshape to match the generator input size
|
52 |
-
z = torch.cat([text, noise], 1) # Concatenate text embedding with noise
|
53 |
-
return self.model(z)
|
54 |
|
55 |
noise_dim = 16
|
56 |
embed_dim = 384
|
|
|
6 |
import os
|
7 |
import matplotlib.pyplot as plt
|
8 |
import random
|
|
|
9 |
import torch
|
10 |
import torch.nn as nn
|
11 |
+
from generator import Generator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
noise_dim = 16
|
14 |
embed_dim = 384
|