bryandts commited on
Commit
8f5a57a
1 Parent(s): 1574dd5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from sentence_transformers import SentenceTransformer, util
5
+ import json
6
+ import os
7
+ import matplotlib.pyplot as plt
8
+ import random
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ # The Generator model
14
+ class Generator(nn.Module):
15
+ def __init__(self, channels, noise_dim=100, embed_dim=1024, embed_out_dim=128):
16
+ super(Generator, self).__init__()
17
+ self.channels = channels
18
+ self.noise_dim = noise_dim
19
+ self.embed_dim = embed_dim
20
+ self.embed_out_dim = embed_out_dim
21
+
22
+ # Text embedding layers
23
+ self.text_embedding = nn.Sequential(
24
+ nn.Linear(self.embed_dim, self.embed_out_dim),
25
+ nn.BatchNorm1d(1),
26
+ nn.LeakyReLU(0.2, inplace=True)
27
+ )
28
+
29
+ # Generator architecture
30
+ model = []
31
+ model += self._create_layer(self.noise_dim + self.embed_out_dim, 512, 4, stride=1, padding=0)
32
+ model += self._create_layer(512, 256, 4, stride=2, padding=1)
33
+ model += self._create_layer(256, 128, 4, stride=2, padding=1)
34
+ model += self._create_layer(128, 64, 4, stride=2, padding=1)
35
+ model += self._create_layer(64, 32, 4, stride=2, padding=1)
36
+ model += self._create_layer(32, self.channels, 4, stride=2, padding=1, output=True)
37
+
38
+ self.model = nn.Sequential(*model)
39
+
40
+ def _create_layer(self, size_in, size_out, kernel_size=4, stride=2, padding=1, output=False):
41
+ layers = [nn.ConvTranspose2d(size_in, size_out, kernel_size, stride=stride, padding=padding, bias=False)]
42
+ if output:
43
+ layers.append(nn.Tanh()) # Tanh activation for the output layer
44
+ else:
45
+ layers += [nn.BatchNorm2d(size_out), nn.ReLU(True)] # Batch normalization and ReLU for other layers
46
+ return layers
47
+
48
+ def forward(self, noise, text):
49
+ # Apply text embedding to the input text
50
+ text = self.text_embedding(text)
51
+ text = text.view(text.shape[0], text.shape[2], 1, 1) # Reshape to match the generator input size
52
+ z = torch.cat([text, noise], 1) # Concatenate text embedding with noise
53
+ return self.model(z)
54
+
55
+ noise_dim = 16
56
+ embed_dim = 384
57
+ embed_out_dim = 256
58
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
59
+
60
+ # Path to your .pth file
61
+ gen_weight = 'generator_20240421_3.pth'
62
+
63
+ # Load the weights
64
+ weights_gen = torch.load(gen_weight, map_location=torch.device(device))
65
+
66
+ # Apply the weights to your model
67
+ generator.load_state_dict(weights_gen)
68
+
69
+
70
+ # Load your model and other components here
71
+ model = SentenceTransformer('sentence-transformers/all-MiniLM-L12-v2')
72
+ with open(os.path.join("descriptions.json"), 'r') as file:
73
+ dataset = json.load(file)
74
+
75
+ classes = [e["text"] for e in dataset]
76
+ embeddings_list = {cls: model.encode(cls, convert_to_tensor=True) for cls in classes}
77
+
78
+ def generate_image(caption):
79
+ noise_dim = 16
80
+ results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes]
81
+ sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5]
82
+
83
+ threshold = 0.40
84
+ coeff = 0.89
85
+
86
+ if sorted_results[0][0] <= threshold:
87
+ caption = sorted_results[0][1]
88
+ results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes]
89
+ sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5]
90
+
91
+ if sorted_results[0][0] >= 0.99:
92
+ coeff = 0.85
93
+
94
+ last_score = sorted_results[0][0]
95
+ filtered_results = []
96
+ for score, cls in sorted_results:
97
+ if score >= last_score * coeff:
98
+ filtered_results.append((score, cls))
99
+ last_score = score
100
+ else:
101
+ break
102
+
103
+ items = [cls for score, cls in filtered_results]
104
+ probabilities = [score for score, cls in filtered_results]
105
+ sampled_item = random.choices(items, weights=probabilities, k=1)[0]
106
+
107
+ noise = torch.randn(1, noise_dim, 1, 1, device=device) # Adjust noise_dim if different
108
+ fake_images = generator(noise, embeddings[sampled_item].unsqueeze(0).unsqueeze(0))
109
+ img = fake_images.squeeze(0).permute(1, 2, 0).cpu().detach().numpy()
110
+ img = (img - img.min()) / (img.max() - img.min())
111
+
112
+ return img
113
+
114
+ iface = gr.Interface(fn=generate_image,
115
+ inputs=gr.Textbox(lines=2, placeholder="Enter Caption Here..."),
116
+ outputs=gr.Image(type="numpy"),
117
+ title="Text-to-Image Generation",
118
+ description="Enter a caption to generate an image.")
119
+
120
+ iface.launch()