Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -36,10 +36,10 @@ def load_embedding(model):
|
|
36 |
|
37 |
classes = [e["text"] for e in dataset]
|
38 |
embeddings_list = {cls: model.encode(cls, convert_to_tensor=True) for cls in classes}
|
39 |
-
return embeddings_list
|
40 |
|
41 |
def generate_image(caption):
|
42 |
-
embeddings = load_embedding(model)
|
43 |
noise_dim = 16
|
44 |
results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes]
|
45 |
sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5]
|
|
|
36 |
|
37 |
classes = [e["text"] for e in dataset]
|
38 |
embeddings_list = {cls: model.encode(cls, convert_to_tensor=True) for cls in classes}
|
39 |
+
return embeddings_list, classes
|
40 |
|
41 |
def generate_image(caption):
|
42 |
+
embeddings, classes = load_embedding(model)
|
43 |
noise_dim = 16
|
44 |
results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes]
|
45 |
sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5]
|