bryandts commited on
Commit
cb60a7c
1 Parent(s): e1e726a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -21
app.py CHANGED
@@ -10,25 +10,6 @@ import torch
10
  import torch.nn as nn
11
  from generatorModel import Generator
12
 
13
- noise_dim = 16
14
- embed_dim = 384
15
- embed_out_dim = 256
16
- device = 'cpu'
17
-
18
- generator = Generator(channels=3, embed_dim=embed_dim, noise_dim=noise_dim, embed_out_dim=embed_out_dim).to(device)
19
-
20
-
21
- # Path to your .pth file
22
- gen_weight = 'generator_20240421_3.pth'
23
-
24
- # Load the weights
25
- weights_gen = torch.load(gen_weight, map_location=torch.device(device))
26
-
27
- # Apply the weights to your model
28
- generator.load_state_dict(weights_gen)
29
-
30
- model = SentenceTransformer('sentence-transformers/all-MiniLM-L12-v2')
31
-
32
  def load_embedding(model):
33
  # Load your model and other components here
34
  with open(os.path.join("descriptions.json"), 'r') as file:
@@ -39,7 +20,6 @@ def load_embedding(model):
39
  return embeddings_list, classes
40
 
41
  def generate_image(caption):
42
- embeddings, classes = load_embedding(model)
43
  noise_dim = 16
44
  results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes]
45
  sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5]
@@ -75,10 +55,27 @@ def generate_image(caption):
75
 
76
  return img
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  iface = gr.Interface(fn=generate_image,
79
  inputs=gr.Textbox(lines=2, placeholder="Enter Caption Here..."),
80
  outputs=gr.Image(type="numpy"),
81
  title="CUHK Shenzhen Building Text-to-Image Generation",
82
  description="Enter a caption of some specific building in CUHK-Shenzhen to generate an image..")
83
 
84
- iface.launch(share=True, debug=True)
 
10
  import torch.nn as nn
11
  from generatorModel import Generator
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def load_embedding(model):
14
  # Load your model and other components here
15
  with open(os.path.join("descriptions.json"), 'r') as file:
 
20
  return embeddings_list, classes
21
 
22
  def generate_image(caption):
 
23
  noise_dim = 16
24
  results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes]
25
  sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5]
 
55
 
56
  return img
57
 
58
+ noise_dim = 16
59
+ embed_dim = 384
60
+ embed_out_dim = 256
61
+ device = 'cpu'
62
+
63
+ generator = Generator(channels=3, embed_dim=embed_dim, noise_dim=noise_dim, embed_out_dim=embed_out_dim).to(device)
64
+
65
+ # Path to .pth file and load the weights
66
+ gen_weight = 'generator_20240421_3.pth'
67
+ weights_gen = torch.load(gen_weight, map_location=torch.device(device))
68
+ generator.load_state_dict(weights_gen)
69
+
70
+ model = SentenceTransformer('sentence-transformers/all-MiniLM-L12-v2')
71
+
72
+ embeddings, classes = load_embedding(model)
73
+
74
+
75
  iface = gr.Interface(fn=generate_image,
76
  inputs=gr.Textbox(lines=2, placeholder="Enter Caption Here..."),
77
  outputs=gr.Image(type="numpy"),
78
  title="CUHK Shenzhen Building Text-to-Image Generation",
79
  description="Enter a caption of some specific building in CUHK-Shenzhen to generate an image..")
80
 
81
+ iface.launch(debug=True)