Spaces:
Running
Running
fix labels
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ import torch
|
|
4 |
from torchvision.transforms import InterpolationMode
|
5 |
|
6 |
BICUBIC = InterpolationMode.BICUBIC
|
7 |
-
from utils import setup, get_similarity_map,
|
8 |
from vpt.launch import default_argument_parser
|
9 |
from collections import OrderedDict
|
10 |
import numpy as np
|
@@ -81,8 +81,52 @@ def run(sketch, caption, threshold, seed):
|
|
81 |
pixel_similarity[pixel_similarity < threshold] = 0
|
82 |
pixel_similarity_array = pixel_similarity.cpu().numpy().transpose(2, 0, 1)
|
83 |
|
84 |
-
display_segmented_sketch(pixel_similarity_array, sketch2, classes, classes_colors, live=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
rgb_image = Image.open('output.png')
|
87 |
|
88 |
return rgb_image
|
|
|
4 |
from torchvision.transforms import InterpolationMode
|
5 |
|
6 |
BICUBIC = InterpolationMode.BICUBIC
|
7 |
+
from utils import setup, get_similarity_map,get_noun_phrase, rgb_to_hsv, hsv_to_rgb
|
8 |
from vpt.launch import default_argument_parser
|
9 |
from collections import OrderedDict
|
10 |
import numpy as np
|
|
|
81 |
pixel_similarity[pixel_similarity < threshold] = 0
|
82 |
pixel_similarity_array = pixel_similarity.cpu().numpy().transpose(2, 0, 1)
|
83 |
|
84 |
+
# display_segmented_sketch(pixel_similarity_array, sketch2, classes, classes_colors, live=True)
|
85 |
+
|
86 |
+
# Find the class index with the highest similarity for each pixel
|
87 |
+
class_indices = np.argmax(pixel_similarity_array, axis=0)
|
88 |
+
# Create an HSV image placeholder
|
89 |
+
hsv_image = np.zeros(class_indices.shape + (3,)) # Shape (512, 512, 3)
|
90 |
+
hsv_image[..., 2] = 1 # Set Value to 1 for a white base
|
91 |
+
|
92 |
+
# Set the hue and value channels
|
93 |
+
for i, color in enumerate(classes_colors):
|
94 |
+
rgb_color = np.array(color).reshape(1, 1, 3)
|
95 |
+
hsv_color = rgb_to_hsv(rgb_color)
|
96 |
+
mask = class_indices == i
|
97 |
+
if i < len(classes): # For the first N-2 classes, set color based on similarity
|
98 |
+
hsv_image[..., 0][mask] = hsv_color[0, 0, 0] # Hue
|
99 |
+
hsv_image[..., 1][mask] = pixel_similarity_array[i][mask] > 0 # Saturation
|
100 |
+
hsv_image[..., 2][mask] = pixel_similarity_array[i][mask] # Value
|
101 |
+
else: # For the last two classes, set pixels to black
|
102 |
+
hsv_image[..., 0][mask] = 0 # Hue doesn't matter for black
|
103 |
+
hsv_image[..., 1][mask] = 0 # Saturation set to 0
|
104 |
+
hsv_image[..., 2][mask] = 0 # Value set to 0, making it black
|
105 |
+
|
106 |
+
mask_tensor_org = sketch2[:,:,0]/255
|
107 |
+
hsv_image[mask_tensor_org==1] = [0,0,1]
|
108 |
|
109 |
+
# Convert the HSV image back to RGB to display and save
|
110 |
+
rgb_image = hsv_to_rgb(hsv_image)
|
111 |
+
|
112 |
+
|
113 |
+
if len(classes) > 1:
|
114 |
+
# Calculate centroids and render class names
|
115 |
+
for i, class_name in enumerate(classes):
|
116 |
+
mask = class_indices == i
|
117 |
+
if np.any(mask):
|
118 |
+
y, x = np.nonzero(mask)
|
119 |
+
centroid_x, centroid_y = np.mean(x), np.mean(y)
|
120 |
+
plt.text(centroid_x, centroid_y, class_name, color=classes_colors[i], ha='center', va='center',fontsize=10, # color=classes_colors[i]
|
121 |
+
bbox=dict(facecolor='lightgrey', edgecolor='none', boxstyle='round,pad=0.2', alpha=0.8))
|
122 |
+
|
123 |
+
# Display the image with class names
|
124 |
+
plt.imshow(rgb_image)
|
125 |
+
plt.axis('off')
|
126 |
+
plt.tight_layout()
|
127 |
+
plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
|
128 |
+
plt.close()
|
129 |
+
|
130 |
rgb_image = Image.open('output.png')
|
131 |
|
132 |
return rgb_image
|