Spaces:
Running
Running
File size: 5,219 Bytes
254fdf2 37b5ba0 254fdf2 37b5ba0 254fdf2 37b5ba0 254fdf2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import torch
import numpy as np
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
from vpt.src.configs.config import get_cfg
import os
from time import sleep
from random import randint
from vpt.src.utils.file_io import PathManager
import matplotlib.pyplot as plt
from matplotlib.colors import rgb_to_hsv, hsv_to_rgb
import warnings
import nltk
warnings.filterwarnings("ignore")
def get_noun_phrase(tokenized):
# Taken from Su Nam Kim Paper...
grammar = r"""
NBAR:
{<NN.*|JJ>*<NN.*>} # Nouns and Adjectives, terminated with Nouns
NP:
{<NBAR>}
{<NBAR><IN><NBAR>} # Above, connected with in/of/etc...
"""
chunker = nltk.RegexpParser(grammar)
chunked = chunker.parse(nltk.pos_tag(tokenized))
continuous_chunk = []
current_chunk = []
for subtree in chunked:
if isinstance(subtree, nltk.Tree):
current_chunk.append(' '.join([token for token, pos in subtree.leaves()]))
elif current_chunk:
named_entity = ' '.join(current_chunk)
if named_entity not in continuous_chunk:
continuous_chunk.append(named_entity)
current_chunk = []
else:
continue
return continuous_chunk
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
output_dir = cfg.OUTPUT_DIR
lr = cfg.SOLVER.BASE_LR
wd = cfg.SOLVER.WEIGHT_DECAY
output_folder = os.path.join(
cfg.DATA.NAME, cfg.DATA.FEATURE, f"lr{lr}_wd{wd}")
# train cfg.RUN_N_TIMES times
count = 1
while count <= cfg.RUN_N_TIMES:
output_path = os.path.join(output_dir, output_folder, f"run{count}")
# pause for a random time, so concurrent process with same setting won't interfere with each other. # noqa
sleep(randint(3, 30))
if not PathManager.exists(output_path):
PathManager.mkdirs(output_path)
cfg.OUTPUT_DIR = output_path
break
else:
count += 1
cfg.freeze()
return cfg
def get_similarity_map(sm, shape):
# sm: torch.Size([1, 196, 1])
# min-max norm
sm = (sm - sm.min(1, keepdim=True)[0]) / (sm.max(1, keepdim=True)[0] - sm.min(1, keepdim=True)[0]) # torch.Size([1, 196, 1])
# reshape
side = int(sm.shape[1] ** 0.5) # square output, side = 14
sm = sm.reshape(sm.shape[0], side, side, -1).permute(0, 3, 1, 2)
# interpolate
sm = torch.nn.functional.interpolate(sm, shape, mode='bilinear')
sm = sm.permute(0, 2, 3, 1)
return sm.squeeze(0)
def display_segmented_sketch(pixel_similarity_array,binary_sketch,classes,classes_colors,save_path=None,live=False):
# Find the class index with the highest similarity for each pixel
class_indices = np.argmax(pixel_similarity_array, axis=0)
# Create an HSV image placeholder
hsv_image = np.zeros(class_indices.shape + (3,)) # Shape (512, 512, 3)
hsv_image[..., 2] = 1 # Set Value to 1 for a white base
# Set the hue and value channels
for i, color in enumerate(classes_colors):
rgb_color = np.array(color).reshape(1, 1, 3)
hsv_color = rgb_to_hsv(rgb_color)
mask = class_indices == i
if i < len(classes): # For the first N-2 classes, set color based on similarity
hsv_image[..., 0][mask] = hsv_color[0, 0, 0] # Hue
hsv_image[..., 1][mask] = pixel_similarity_array[i][mask] > 0 # Saturation
hsv_image[..., 2][mask] = pixel_similarity_array[i][mask] # Value
else: # For the last two classes, set pixels to black
hsv_image[..., 0][mask] = 0 # Hue doesn't matter for black
hsv_image[..., 1][mask] = 0 # Saturation set to 0
hsv_image[..., 2][mask] = 0 # Value set to 0, making it black
mask_tensor_org = binary_sketch[:,:,0]/255
hsv_image[mask_tensor_org==1] = [0,0,1]
# Convert the HSV image back to RGB to display and save
rgb_image = hsv_to_rgb(hsv_image)
if len(classes) > 1:
# Calculate centroids and render class names
for i, class_name in enumerate(classes):
mask = class_indices == i
if np.any(mask):
y, x = np.nonzero(mask)
centroid_x, centroid_y = np.mean(x), np.mean(y)
plt.text(centroid_x, centroid_y, class_name, color=classes_colors[i], ha='center', va='center',fontsize=10, # color=classes_colors[i]
bbox=dict(facecolor='lightgrey', edgecolor='none', boxstyle='round,pad=0.2', alpha=0.8))
# Display the image with class names
plt.imshow(rgb_image)
plt.axis('off')
plt.tight_layout()
if live:
plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
else:
save_dir = "/".join(save_path.split("/")[:-1])
if save_dir !='':
if not os.path.exists(save_dir):
os.makedirs(save_dir)
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
else:
plt.show()
|