Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -68,17 +68,20 @@ weights_gen = torch.load(gen_weight, map_location=torch.device(device))
|
|
68 |
|
69 |
# Apply the weights to your model
|
70 |
generator.load_state_dict(weights_gen)
|
71 |
-
|
72 |
|
73 |
-
# Load your model and other components here
|
74 |
model = SentenceTransformer('sentence-transformers/all-MiniLM-L12-v2')
|
75 |
-
with open(os.path.join("descriptions.json"), 'r') as file:
|
76 |
-
dataset = json.load(file)
|
77 |
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
def generate_image(caption):
|
|
|
82 |
noise_dim = 16
|
83 |
results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes]
|
84 |
sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5]
|
|
|
68 |
|
69 |
# Apply the weights to your model
|
70 |
generator.load_state_dict(weights_gen)
|
|
|
71 |
|
|
|
72 |
model = SentenceTransformer('sentence-transformers/all-MiniLM-L12-v2')
|
|
|
|
|
73 |
|
74 |
+
def load_embedding(model):
|
75 |
+
# Load your model and other components here
|
76 |
+
with open(os.path.join("descriptions.json"), 'r') as file:
|
77 |
+
dataset = json.load(file)
|
78 |
+
|
79 |
+
classes = [e["text"] for e in dataset]
|
80 |
+
embeddings_list = {cls: model.encode(cls, convert_to_tensor=True) for cls in classes}
|
81 |
+
return embeddings_list
|
82 |
|
83 |
def generate_image(caption):
|
84 |
+
embeddings = load_embedding(model)
|
85 |
noise_dim = 16
|
86 |
results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes]
|
87 |
sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5]
|