bryandts commited on
Commit
35de619
1 Parent(s): 7436807

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -68,17 +68,20 @@ weights_gen = torch.load(gen_weight, map_location=torch.device(device))
68
 
69
  # Apply the weights to your model
70
  generator.load_state_dict(weights_gen)
71
-
72
 
73
- # Load your model and other components here
74
  model = SentenceTransformer('sentence-transformers/all-MiniLM-L12-v2')
75
- with open(os.path.join("descriptions.json"), 'r') as file:
76
- dataset = json.load(file)
77
 
78
- classes = [e["text"] for e in dataset]
79
- embeddings_list = {cls: model.encode(cls, convert_to_tensor=True) for cls in classes}
 
 
 
 
 
 
80
 
81
  def generate_image(caption):
 
82
  noise_dim = 16
83
  results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes]
84
  sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5]
 
68
 
69
  # Apply the weights to your model
70
  generator.load_state_dict(weights_gen)
 
71
 
 
72
  model = SentenceTransformer('sentence-transformers/all-MiniLM-L12-v2')
 
 
73
 
74
+ def load_embedding(model):
75
+ # Load your model and other components here
76
+ with open(os.path.join("descriptions.json"), 'r') as file:
77
+ dataset = json.load(file)
78
+
79
+ classes = [e["text"] for e in dataset]
80
+ embeddings_list = {cls: model.encode(cls, convert_to_tensor=True) for cls in classes}
81
+ return embeddings_list
82
 
83
  def generate_image(caption):
84
+ embeddings = load_embedding(model)
85
  noise_dim = 16
86
  results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes]
87
  sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5]