Spaces:
Running
on
T4
Running
on
T4
liuyizhang
commited on
Commit
•
6e3e561
1
Parent(s):
1c6fec7
update app.py
Browse files
app.py
CHANGED
@@ -434,7 +434,9 @@ def concatenate_images_vertical(image1, image2):
|
|
434 |
|
435 |
return new_image
|
436 |
|
437 |
-
def relate_anything(
|
|
|
|
|
438 |
w, h = input_image.size
|
439 |
max_edge = 1500
|
440 |
if w > max_edge or h > max_edge:
|
@@ -442,12 +444,14 @@ def relate_anything(input_image, k):
|
|
442 |
new_size = (int(w / ratio), int(h / ratio))
|
443 |
input_image.thumbnail(new_size)
|
444 |
|
|
|
445 |
# load image
|
446 |
pil_image = input_image.convert('RGBA')
|
447 |
image = np.array(input_image)
|
448 |
sam_masks = sam_mask_generator.generate(image)
|
449 |
filtered_masks = sort_and_deduplicate(sam_masks)
|
450 |
|
|
|
451 |
feat_list = []
|
452 |
for fm in filtered_masks:
|
453 |
feat = torch.Tensor(fm['feat']).unsqueeze(0).unsqueeze(0).to(device)
|
@@ -455,6 +459,7 @@ def relate_anything(input_image, k):
|
|
455 |
feat = torch.cat(feat_list, dim=1).to(device)
|
456 |
matrix_output, rel_triplets = ram_model.predict(feat)
|
457 |
|
|
|
458 |
pil_image_list = []
|
459 |
for i, rel in enumerate(rel_triplets[:k]):
|
460 |
s,o,r = int(rel[0]),int(rel[1]),int(rel[2])
|
@@ -473,6 +478,7 @@ def relate_anything(input_image, k):
|
|
473 |
concate_pil_image = concatenate_images_vertical(current_pil_image, title_image)
|
474 |
pil_image_list.append(concate_pil_image)
|
475 |
|
|
|
476 |
yield pil_image_list
|
477 |
|
478 |
|
|
|
434 |
|
435 |
return new_image
|
436 |
|
437 |
+
def relate_anything(input_image_mask, k):
|
438 |
+
logger.info(f'relate_anything_1_')
|
439 |
+
input_image = input_image_mask['image']
|
440 |
w, h = input_image.size
|
441 |
max_edge = 1500
|
442 |
if w > max_edge or h > max_edge:
|
|
|
444 |
new_size = (int(w / ratio), int(h / ratio))
|
445 |
input_image.thumbnail(new_size)
|
446 |
|
447 |
+
logger.info(f'relate_anything_2_')
|
448 |
# load image
|
449 |
pil_image = input_image.convert('RGBA')
|
450 |
image = np.array(input_image)
|
451 |
sam_masks = sam_mask_generator.generate(image)
|
452 |
filtered_masks = sort_and_deduplicate(sam_masks)
|
453 |
|
454 |
+
logger.info(f'relate_anything_3_')
|
455 |
feat_list = []
|
456 |
for fm in filtered_masks:
|
457 |
feat = torch.Tensor(fm['feat']).unsqueeze(0).unsqueeze(0).to(device)
|
|
|
459 |
feat = torch.cat(feat_list, dim=1).to(device)
|
460 |
matrix_output, rel_triplets = ram_model.predict(feat)
|
461 |
|
462 |
+
logger.info(f'relate_anything_4_')
|
463 |
pil_image_list = []
|
464 |
for i, rel in enumerate(rel_triplets[:k]):
|
465 |
s,o,r = int(rel[0]),int(rel[1]),int(rel[2])
|
|
|
478 |
concate_pil_image = concatenate_images_vertical(current_pil_image, title_image)
|
479 |
pil_image_list.append(concate_pil_image)
|
480 |
|
481 |
+
logger.info(f'relate_anything_5_')
|
482 |
yield pil_image_list
|
483 |
|
484 |
|