File size: 4,597 Bytes
8f5a57a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
import torch
import torchvision.transforms as transforms
from sentence_transformers import SentenceTransformer, util
import json
import os
import matplotlib.pyplot as plt
import random

import torch
import torch.nn as nn

# The Generator model
class Generator(nn.Module):
    def __init__(self, channels, noise_dim=100, embed_dim=1024, embed_out_dim=128):
        super(Generator, self).__init__()
        self.channels = channels
        self.noise_dim = noise_dim
        self.embed_dim = embed_dim
        self.embed_out_dim = embed_out_dim

        # Text embedding layers
        self.text_embedding = nn.Sequential(
            nn.Linear(self.embed_dim, self.embed_out_dim),
            nn.BatchNorm1d(1),
            nn.LeakyReLU(0.2, inplace=True)
        )

        # Generator architecture
        model = []
        model += self._create_layer(self.noise_dim + self.embed_out_dim, 512, 4, stride=1, padding=0)
        model += self._create_layer(512, 256, 4, stride=2, padding=1)
        model += self._create_layer(256, 128, 4, stride=2, padding=1)
        model += self._create_layer(128, 64, 4, stride=2, padding=1)
        model += self._create_layer(64, 32, 4, stride=2, padding=1)
        model += self._create_layer(32, self.channels, 4, stride=2, padding=1, output=True)

        self.model = nn.Sequential(*model)

    def _create_layer(self, size_in, size_out, kernel_size=4, stride=2, padding=1, output=False):
        layers = [nn.ConvTranspose2d(size_in, size_out, kernel_size, stride=stride, padding=padding, bias=False)]
        if output:
            layers.append(nn.Tanh())  # Tanh activation for the output layer
        else:
            layers += [nn.BatchNorm2d(size_out), nn.ReLU(True)]  # Batch normalization and ReLU for other layers
        return layers

    def forward(self, noise, text):
        # Apply text embedding to the input text
        text = self.text_embedding(text)
        text = text.view(text.shape[0], text.shape[2], 1, 1)  # Reshape to match the generator input size
        z = torch.cat([text, noise], 1)  # Concatenate text embedding with noise
        return self.model(z)

noise_dim = 16
embed_dim = 384
embed_out_dim = 256
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Path to your .pth file
gen_weight = 'generator_20240421_3.pth'

# Load the weights
weights_gen = torch.load(gen_weight, map_location=torch.device(device))

# Apply the weights to your model
generator.load_state_dict(weights_gen)
        

# Load your model and other components here
model = SentenceTransformer('sentence-transformers/all-MiniLM-L12-v2')
with open(os.path.join("descriptions.json"), 'r') as file:
    dataset = json.load(file)

classes = [e["text"] for e in dataset]
embeddings_list = {cls: model.encode(cls, convert_to_tensor=True) for cls in classes}

def generate_image(caption):
    noise_dim = 16
    results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes]
    sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5]

    threshold = 0.40
    coeff = 0.89

    if sorted_results[0][0] <= threshold:
        caption = sorted_results[0][1]
        results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes]
        sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5]

    if sorted_results[0][0] >= 0.99:
        coeff = 0.85

    last_score = sorted_results[0][0]
    filtered_results = []
    for score, cls in sorted_results:
        if score >= last_score * coeff:
            filtered_results.append((score, cls))
            last_score = score
        else:
            break

    items = [cls for score, cls in filtered_results]
    probabilities = [score for score, cls in filtered_results]
    sampled_item = random.choices(items, weights=probabilities, k=1)[0]

    noise = torch.randn(1, noise_dim, 1, 1, device=device)  # Adjust noise_dim if different
    fake_images = generator(noise, embeddings[sampled_item].unsqueeze(0).unsqueeze(0))
    img = fake_images.squeeze(0).permute(1, 2, 0).cpu().detach().numpy()
    img = (img - img.min()) / (img.max() - img.min())

    return img

iface = gr.Interface(fn=generate_image,
                     inputs=gr.Textbox(lines=2, placeholder="Enter Caption Here..."),
                     outputs=gr.Image(type="numpy"),
                     title="Text-to-Image Generation",
                     description="Enter a caption to generate an image.")

iface.launch()