Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
from sentence_transformers import SentenceTransformer, util
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import random
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
# The Generator model
|
14 |
+
class Generator(nn.Module):
|
15 |
+
def __init__(self, channels, noise_dim=100, embed_dim=1024, embed_out_dim=128):
|
16 |
+
super(Generator, self).__init__()
|
17 |
+
self.channels = channels
|
18 |
+
self.noise_dim = noise_dim
|
19 |
+
self.embed_dim = embed_dim
|
20 |
+
self.embed_out_dim = embed_out_dim
|
21 |
+
|
22 |
+
# Text embedding layers
|
23 |
+
self.text_embedding = nn.Sequential(
|
24 |
+
nn.Linear(self.embed_dim, self.embed_out_dim),
|
25 |
+
nn.BatchNorm1d(1),
|
26 |
+
nn.LeakyReLU(0.2, inplace=True)
|
27 |
+
)
|
28 |
+
|
29 |
+
# Generator architecture
|
30 |
+
model = []
|
31 |
+
model += self._create_layer(self.noise_dim + self.embed_out_dim, 512, 4, stride=1, padding=0)
|
32 |
+
model += self._create_layer(512, 256, 4, stride=2, padding=1)
|
33 |
+
model += self._create_layer(256, 128, 4, stride=2, padding=1)
|
34 |
+
model += self._create_layer(128, 64, 4, stride=2, padding=1)
|
35 |
+
model += self._create_layer(64, 32, 4, stride=2, padding=1)
|
36 |
+
model += self._create_layer(32, self.channels, 4, stride=2, padding=1, output=True)
|
37 |
+
|
38 |
+
self.model = nn.Sequential(*model)
|
39 |
+
|
40 |
+
def _create_layer(self, size_in, size_out, kernel_size=4, stride=2, padding=1, output=False):
|
41 |
+
layers = [nn.ConvTranspose2d(size_in, size_out, kernel_size, stride=stride, padding=padding, bias=False)]
|
42 |
+
if output:
|
43 |
+
layers.append(nn.Tanh()) # Tanh activation for the output layer
|
44 |
+
else:
|
45 |
+
layers += [nn.BatchNorm2d(size_out), nn.ReLU(True)] # Batch normalization and ReLU for other layers
|
46 |
+
return layers
|
47 |
+
|
48 |
+
def forward(self, noise, text):
|
49 |
+
# Apply text embedding to the input text
|
50 |
+
text = self.text_embedding(text)
|
51 |
+
text = text.view(text.shape[0], text.shape[2], 1, 1) # Reshape to match the generator input size
|
52 |
+
z = torch.cat([text, noise], 1) # Concatenate text embedding with noise
|
53 |
+
return self.model(z)
|
54 |
+
|
55 |
+
noise_dim = 16
|
56 |
+
embed_dim = 384
|
57 |
+
embed_out_dim = 256
|
58 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
59 |
+
|
60 |
+
# Path to your .pth file
|
61 |
+
gen_weight = 'generator_20240421_3.pth'
|
62 |
+
|
63 |
+
# Load the weights
|
64 |
+
weights_gen = torch.load(gen_weight, map_location=torch.device(device))
|
65 |
+
|
66 |
+
# Apply the weights to your model
|
67 |
+
generator.load_state_dict(weights_gen)
|
68 |
+
|
69 |
+
|
70 |
+
# Load your model and other components here
|
71 |
+
model = SentenceTransformer('sentence-transformers/all-MiniLM-L12-v2')
|
72 |
+
with open(os.path.join("descriptions.json"), 'r') as file:
|
73 |
+
dataset = json.load(file)
|
74 |
+
|
75 |
+
classes = [e["text"] for e in dataset]
|
76 |
+
embeddings_list = {cls: model.encode(cls, convert_to_tensor=True) for cls in classes}
|
77 |
+
|
78 |
+
def generate_image(caption):
|
79 |
+
noise_dim = 16
|
80 |
+
results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes]
|
81 |
+
sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5]
|
82 |
+
|
83 |
+
threshold = 0.40
|
84 |
+
coeff = 0.89
|
85 |
+
|
86 |
+
if sorted_results[0][0] <= threshold:
|
87 |
+
caption = sorted_results[0][1]
|
88 |
+
results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes]
|
89 |
+
sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5]
|
90 |
+
|
91 |
+
if sorted_results[0][0] >= 0.99:
|
92 |
+
coeff = 0.85
|
93 |
+
|
94 |
+
last_score = sorted_results[0][0]
|
95 |
+
filtered_results = []
|
96 |
+
for score, cls in sorted_results:
|
97 |
+
if score >= last_score * coeff:
|
98 |
+
filtered_results.append((score, cls))
|
99 |
+
last_score = score
|
100 |
+
else:
|
101 |
+
break
|
102 |
+
|
103 |
+
items = [cls for score, cls in filtered_results]
|
104 |
+
probabilities = [score for score, cls in filtered_results]
|
105 |
+
sampled_item = random.choices(items, weights=probabilities, k=1)[0]
|
106 |
+
|
107 |
+
noise = torch.randn(1, noise_dim, 1, 1, device=device) # Adjust noise_dim if different
|
108 |
+
fake_images = generator(noise, embeddings[sampled_item].unsqueeze(0).unsqueeze(0))
|
109 |
+
img = fake_images.squeeze(0).permute(1, 2, 0).cpu().detach().numpy()
|
110 |
+
img = (img - img.min()) / (img.max() - img.min())
|
111 |
+
|
112 |
+
return img
|
113 |
+
|
114 |
+
iface = gr.Interface(fn=generate_image,
|
115 |
+
inputs=gr.Textbox(lines=2, placeholder="Enter Caption Here..."),
|
116 |
+
outputs=gr.Image(type="numpy"),
|
117 |
+
title="Text-to-Image Generation",
|
118 |
+
description="Enter a caption to generate an image.")
|
119 |
+
|
120 |
+
iface.launch()
|