Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torchvision.transforms as transforms | |
from sentence_transformers import SentenceTransformer, util | |
import json | |
import os | |
import matplotlib.pyplot as plt | |
import random | |
import torch | |
import torch.nn as nn | |
# The Generator model | |
class Generator(nn.Module): | |
def __init__(self, channels, noise_dim=100, embed_dim=1024, embed_out_dim=128): | |
super(Generator, self).__init__() | |
self.channels = channels | |
self.noise_dim = noise_dim | |
self.embed_dim = embed_dim | |
self.embed_out_dim = embed_out_dim | |
# Text embedding layers | |
self.text_embedding = nn.Sequential( | |
nn.Linear(self.embed_dim, self.embed_out_dim), | |
nn.BatchNorm1d(1), | |
nn.LeakyReLU(0.2, inplace=True) | |
) | |
# Generator architecture | |
model = [] | |
model += self._create_layer(self.noise_dim + self.embed_out_dim, 512, 4, stride=1, padding=0) | |
model += self._create_layer(512, 256, 4, stride=2, padding=1) | |
model += self._create_layer(256, 128, 4, stride=2, padding=1) | |
model += self._create_layer(128, 64, 4, stride=2, padding=1) | |
model += self._create_layer(64, 32, 4, stride=2, padding=1) | |
model += self._create_layer(32, self.channels, 4, stride=2, padding=1, output=True) | |
self.model = nn.Sequential(*model) | |
def _create_layer(self, size_in, size_out, kernel_size=4, stride=2, padding=1, output=False): | |
layers = [nn.ConvTranspose2d(size_in, size_out, kernel_size, stride=stride, padding=padding, bias=False)] | |
if output: | |
layers.append(nn.Tanh()) # Tanh activation for the output layer | |
else: | |
layers += [nn.BatchNorm2d(size_out), nn.ReLU(True)] # Batch normalization and ReLU for other layers | |
return layers | |
def forward(self, noise, text): | |
# Apply text embedding to the input text | |
text = self.text_embedding(text) | |
text = text.view(text.shape[0], text.shape[2], 1, 1) # Reshape to match the generator input size | |
z = torch.cat([text, noise], 1) # Concatenate text embedding with noise | |
return self.model(z) | |
noise_dim = 16 | |
embed_dim = 384 | |
embed_out_dim = 256 | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Path to your .pth file | |
gen_weight = 'generator_20240421_3.pth' | |
# Load the weights | |
weights_gen = torch.load(gen_weight, map_location=torch.device(device)) | |
# Apply the weights to your model | |
generator.load_state_dict(weights_gen) | |
# Load your model and other components here | |
model = SentenceTransformer('sentence-transformers/all-MiniLM-L12-v2') | |
with open(os.path.join("descriptions.json"), 'r') as file: | |
dataset = json.load(file) | |
classes = [e["text"] for e in dataset] | |
embeddings_list = {cls: model.encode(cls, convert_to_tensor=True) for cls in classes} | |
def generate_image(caption): | |
noise_dim = 16 | |
results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes] | |
sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5] | |
threshold = 0.40 | |
coeff = 0.89 | |
if sorted_results[0][0] <= threshold: | |
caption = sorted_results[0][1] | |
results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes] | |
sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5] | |
if sorted_results[0][0] >= 0.99: | |
coeff = 0.85 | |
last_score = sorted_results[0][0] | |
filtered_results = [] | |
for score, cls in sorted_results: | |
if score >= last_score * coeff: | |
filtered_results.append((score, cls)) | |
last_score = score | |
else: | |
break | |
items = [cls for score, cls in filtered_results] | |
probabilities = [score for score, cls in filtered_results] | |
sampled_item = random.choices(items, weights=probabilities, k=1)[0] | |
noise = torch.randn(1, noise_dim, 1, 1, device=device) # Adjust noise_dim if different | |
fake_images = generator(noise, embeddings[sampled_item].unsqueeze(0).unsqueeze(0)) | |
img = fake_images.squeeze(0).permute(1, 2, 0).cpu().detach().numpy() | |
img = (img - img.min()) / (img.max() - img.min()) | |
return img | |
iface = gr.Interface(fn=generate_image, | |
inputs=gr.Textbox(lines=2, placeholder="Enter Caption Here..."), | |
outputs=gr.Image(type="numpy"), | |
title="Text-to-Image Generation", | |
description="Enter a caption to generate an image.") | |
iface.launch() | |