bryandts commited on
Commit
e1e726a
1 Parent(s): c229261

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -36,10 +36,10 @@ def load_embedding(model):
36
 
37
  classes = [e["text"] for e in dataset]
38
  embeddings_list = {cls: model.encode(cls, convert_to_tensor=True) for cls in classes}
39
- return embeddings_list
40
 
41
  def generate_image(caption):
42
- embeddings = load_embedding(model)
43
  noise_dim = 16
44
  results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes]
45
  sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5]
 
36
 
37
  classes = [e["text"] for e in dataset]
38
  embeddings_list = {cls: model.encode(cls, convert_to_tensor=True) for cls in classes}
39
+ return embeddings_list, classes
40
 
41
  def generate_image(caption):
42
+ embeddings, classes = load_embedding(model)
43
  noise_dim = 16
44
  results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes]
45
  sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5]