ahmedbrs commited on
Commit
6fde90c
1 Parent(s): c27dea1

fix labels

Browse files
Files changed (1) hide show
  1. app.py +46 -2
app.py CHANGED
@@ -4,7 +4,7 @@ import torch
4
  from torchvision.transforms import InterpolationMode
5
 
6
  BICUBIC = InterpolationMode.BICUBIC
7
- from utils import setup, get_similarity_map, display_segmented_sketch,get_noun_phrase
8
  from vpt.launch import default_argument_parser
9
  from collections import OrderedDict
10
  import numpy as np
@@ -81,8 +81,52 @@ def run(sketch, caption, threshold, seed):
81
  pixel_similarity[pixel_similarity < threshold] = 0
82
  pixel_similarity_array = pixel_similarity.cpu().numpy().transpose(2, 0, 1)
83
 
84
- display_segmented_sketch(pixel_similarity_array, sketch2, classes, classes_colors, live=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  rgb_image = Image.open('output.png')
87
 
88
  return rgb_image
 
4
  from torchvision.transforms import InterpolationMode
5
 
6
  BICUBIC = InterpolationMode.BICUBIC
7
+ from utils import setup, get_similarity_map,get_noun_phrase, rgb_to_hsv, hsv_to_rgb
8
  from vpt.launch import default_argument_parser
9
  from collections import OrderedDict
10
  import numpy as np
 
81
  pixel_similarity[pixel_similarity < threshold] = 0
82
  pixel_similarity_array = pixel_similarity.cpu().numpy().transpose(2, 0, 1)
83
 
84
+ # display_segmented_sketch(pixel_similarity_array, sketch2, classes, classes_colors, live=True)
85
+
86
+ # Find the class index with the highest similarity for each pixel
87
+ class_indices = np.argmax(pixel_similarity_array, axis=0)
88
+ # Create an HSV image placeholder
89
+ hsv_image = np.zeros(class_indices.shape + (3,)) # Shape (512, 512, 3)
90
+ hsv_image[..., 2] = 1 # Set Value to 1 for a white base
91
+
92
+ # Set the hue and value channels
93
+ for i, color in enumerate(classes_colors):
94
+ rgb_color = np.array(color).reshape(1, 1, 3)
95
+ hsv_color = rgb_to_hsv(rgb_color)
96
+ mask = class_indices == i
97
+ if i < len(classes): # For the first N-2 classes, set color based on similarity
98
+ hsv_image[..., 0][mask] = hsv_color[0, 0, 0] # Hue
99
+ hsv_image[..., 1][mask] = pixel_similarity_array[i][mask] > 0 # Saturation
100
+ hsv_image[..., 2][mask] = pixel_similarity_array[i][mask] # Value
101
+ else: # For the last two classes, set pixels to black
102
+ hsv_image[..., 0][mask] = 0 # Hue doesn't matter for black
103
+ hsv_image[..., 1][mask] = 0 # Saturation set to 0
104
+ hsv_image[..., 2][mask] = 0 # Value set to 0, making it black
105
+
106
+ mask_tensor_org = sketch2[:,:,0]/255
107
+ hsv_image[mask_tensor_org==1] = [0,0,1]
108
 
109
+ # Convert the HSV image back to RGB to display and save
110
+ rgb_image = hsv_to_rgb(hsv_image)
111
+
112
+
113
+ if len(classes) > 1:
114
+ # Calculate centroids and render class names
115
+ for i, class_name in enumerate(classes):
116
+ mask = class_indices == i
117
+ if np.any(mask):
118
+ y, x = np.nonzero(mask)
119
+ centroid_x, centroid_y = np.mean(x), np.mean(y)
120
+ plt.text(centroid_x, centroid_y, class_name, color=classes_colors[i], ha='center', va='center',fontsize=10, # color=classes_colors[i]
121
+ bbox=dict(facecolor='lightgrey', edgecolor='none', boxstyle='round,pad=0.2', alpha=0.8))
122
+
123
+ # Display the image with class names
124
+ plt.imshow(rgb_image)
125
+ plt.axis('off')
126
+ plt.tight_layout()
127
+ plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
128
+ plt.close()
129
+
130
  rgb_image = Image.open('output.png')
131
 
132
  return rgb_image