Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -10,25 +10,6 @@ import torch
|
|
10 |
import torch.nn as nn
|
11 |
from generatorModel import Generator
|
12 |
|
13 |
-
noise_dim = 16
|
14 |
-
embed_dim = 384
|
15 |
-
embed_out_dim = 256
|
16 |
-
device = 'cpu'
|
17 |
-
|
18 |
-
generator = Generator(channels=3, embed_dim=embed_dim, noise_dim=noise_dim, embed_out_dim=embed_out_dim).to(device)
|
19 |
-
|
20 |
-
|
21 |
-
# Path to your .pth file
|
22 |
-
gen_weight = 'generator_20240421_3.pth'
|
23 |
-
|
24 |
-
# Load the weights
|
25 |
-
weights_gen = torch.load(gen_weight, map_location=torch.device(device))
|
26 |
-
|
27 |
-
# Apply the weights to your model
|
28 |
-
generator.load_state_dict(weights_gen)
|
29 |
-
|
30 |
-
model = SentenceTransformer('sentence-transformers/all-MiniLM-L12-v2')
|
31 |
-
|
32 |
def load_embedding(model):
|
33 |
# Load your model and other components here
|
34 |
with open(os.path.join("descriptions.json"), 'r') as file:
|
@@ -39,7 +20,6 @@ def load_embedding(model):
|
|
39 |
return embeddings_list, classes
|
40 |
|
41 |
def generate_image(caption):
|
42 |
-
embeddings, classes = load_embedding(model)
|
43 |
noise_dim = 16
|
44 |
results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes]
|
45 |
sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5]
|
@@ -75,10 +55,27 @@ def generate_image(caption):
|
|
75 |
|
76 |
return img
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
iface = gr.Interface(fn=generate_image,
|
79 |
inputs=gr.Textbox(lines=2, placeholder="Enter Caption Here..."),
|
80 |
outputs=gr.Image(type="numpy"),
|
81 |
title="CUHK Shenzhen Building Text-to-Image Generation",
|
82 |
description="Enter a caption of some specific building in CUHK-Shenzhen to generate an image..")
|
83 |
|
84 |
-
iface.launch(
|
|
|
10 |
import torch.nn as nn
|
11 |
from generatorModel import Generator
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
def load_embedding(model):
|
14 |
# Load your model and other components here
|
15 |
with open(os.path.join("descriptions.json"), 'r') as file:
|
|
|
20 |
return embeddings_list, classes
|
21 |
|
22 |
def generate_image(caption):
|
|
|
23 |
noise_dim = 16
|
24 |
results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes]
|
25 |
sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5]
|
|
|
55 |
|
56 |
return img
|
57 |
|
58 |
+
noise_dim = 16
|
59 |
+
embed_dim = 384
|
60 |
+
embed_out_dim = 256
|
61 |
+
device = 'cpu'
|
62 |
+
|
63 |
+
generator = Generator(channels=3, embed_dim=embed_dim, noise_dim=noise_dim, embed_out_dim=embed_out_dim).to(device)
|
64 |
+
|
65 |
+
# Path to .pth file and load the weights
|
66 |
+
gen_weight = 'generator_20240421_3.pth'
|
67 |
+
weights_gen = torch.load(gen_weight, map_location=torch.device(device))
|
68 |
+
generator.load_state_dict(weights_gen)
|
69 |
+
|
70 |
+
model = SentenceTransformer('sentence-transformers/all-MiniLM-L12-v2')
|
71 |
+
|
72 |
+
embeddings, classes = load_embedding(model)
|
73 |
+
|
74 |
+
|
75 |
iface = gr.Interface(fn=generate_image,
|
76 |
inputs=gr.Textbox(lines=2, placeholder="Enter Caption Here..."),
|
77 |
outputs=gr.Image(type="numpy"),
|
78 |
title="CUHK Shenzhen Building Text-to-Image Generation",
|
79 |
description="Enter a caption of some specific building in CUHK-Shenzhen to generate an image..")
|
80 |
|
81 |
+
iface.launch(debug=True)
|