Spaces:
Sleeping
Sleeping
File size: 3,540 Bytes
8f5a57a 67fe846 8f5a57a 7c5a10d 67fe846 8f5a57a 35de619 e1e726a 8f5a57a 67fe846 2c84baf 8f5a57a cb60a7c 8f5a57a 8f2b3c8 8f5a57a cb60a7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import gradio as gr
import numpy as np
import torch
import torchvision.transforms as transforms
from sentence_transformers import SentenceTransformer, util
import json
import os
import matplotlib.pyplot as plt
import random
import torch
import torch.nn as nn
from generatorModel import Generator
import cv2
def upscale_and_sharpen_image(input_array):
# Upscale the image to 256x256
upscaled_img = cv2.resize(input_array, (256, 256), interpolation=cv2.INTER_LANCZOS4)
# Define a sharpening kernel
sharpening_kernel = np.array([[-1, -1, -1],
[-1, 9, -1],
[-1, -1, -1]])
# Apply the sharpening kernel using filter2D
sharpened_img = cv2.filter2D(upscaled_img, -1, sharpening_kernel)
# Return the processed array
return sharpened_img
def load_embedding(model):
# Load your model and other components here
with open(os.path.join("descriptions.json"), 'r') as file:
dataset = json.load(file)
classes = [e["text"] for e in dataset]
embeddings_list = {cls: model.encode(cls, convert_to_tensor=True) for cls in classes}
return embeddings_list, classes
def generate_image(caption):
noise_dim = 16
results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes]
sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5]
threshold = 0.40
coeff = 0.89
if sorted_results[0][0] <= threshold:
caption = sorted_results[0][1]
results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes]
sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5]
if sorted_results[0][0] >= 0.99:
coeff = 0.85
last_score = sorted_results[0][0]
filtered_results = []
for score, cls in sorted_results:
if score >= last_score * coeff:
filtered_results.append((score, cls))
last_score = score
else:
break
items = [cls for score, cls in filtered_results]
probabilities = [score for score, cls in filtered_results]
sampled_item = random.choices(items, weights=probabilities, k=1)[0]
noise = torch.randn(1, noise_dim, 1, 1, device=device) # Adjust noise_dim if different
fake_images = generator(noise, embeddings[sampled_item].unsqueeze(0).unsqueeze(0))
img = fake_images.squeeze(0).permute(1, 2, 0).cpu().detach().numpy()
img = upscale_and_sharpen_image(img)
img = (img - img.min()) / (img.max() - img.min())
return img
noise_dim = 16
embed_dim = 384
embed_out_dim = 256
device = 'cpu'
generator = Generator(channels=3, embed_dim=embed_dim, noise_dim=noise_dim, embed_out_dim=embed_out_dim).to(device)
# Path to .pth file and load the weights
gen_weight = 'generator_20240421_3.pth'
weights_gen = torch.load(gen_weight, map_location=torch.device(device))
generator.load_state_dict(weights_gen)
model = SentenceTransformer('sentence-transformers/all-MiniLM-L12-v2')
embeddings, classes = load_embedding(model)
iface = gr.Interface(fn=generate_image,
inputs=gr.Textbox(lines=2, placeholder="Enter Caption Here..."),
outputs=gr.Image(type="numpy"),
title="CUHK Shenzhen Building Text-to-Image Generation",
description="Enter a caption of some specific building in CUHK-Shenzhen to generate an image..")
iface.launch(debug=True)
|