Denys Rozumnyi commited on
Commit
fc034ff
1 Parent(s): c336685

Public release

Browse files
dataset.py DELETED
@@ -1,88 +0,0 @@
1
-
2
-
3
- class ShapeNetDataset(data.Dataset):
4
- def __init__(self,
5
- root,
6
- npoints=2500,
7
- classification=False,
8
- class_choice=None,
9
- split='train',
10
- data_augmentation=True):
11
- self.npoints = npoints
12
- self.root = root
13
- self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
14
- self.cat = {}
15
- self.data_augmentation = data_augmentation
16
- self.classification = classification
17
- self.seg_classes = {}
18
-
19
- with open(self.catfile, 'r') as f:
20
- for line in f:
21
- ls = line.strip().split()
22
- self.cat[ls[0]] = ls[1]
23
- #print(self.cat)
24
- if not class_choice is None:
25
- self.cat = {k: v for k, v in self.cat.items() if k in class_choice}
26
-
27
- self.id2cat = {v: k for k, v in self.cat.items()}
28
-
29
- self.meta = {}
30
- splitfile = os.path.join(self.root, 'train_test_split', 'shuffled_{}_file_list.json'.format(split))
31
- #from IPython import embed; embed()
32
- filelist = json.load(open(splitfile, 'r'))
33
- for item in self.cat:
34
- self.meta[item] = []
35
-
36
- for file in filelist:
37
- _, category, uuid = file.split('/')
38
- if category in self.cat.values():
39
- self.meta[self.id2cat[category]].append((os.path.join(self.root, category, 'points', uuid+'.pts'),
40
- os.path.join(self.root, category, 'points_label', uuid+'.seg')))
41
-
42
- self.datapath = []
43
- for item in self.cat:
44
- for fn in self.meta[item]:
45
- self.datapath.append((item, fn[0], fn[1]))
46
-
47
- self.classes = dict(zip(sorted(self.cat), range(len(self.cat))))
48
- print(self.classes)
49
- with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'r') as f:
50
- for line in f:
51
- ls = line.strip().split()
52
- self.seg_classes[ls[0]] = int(ls[1])
53
- self.num_seg_classes = self.seg_classes[list(self.cat.keys())[0]]
54
- print(self.seg_classes, self.num_seg_classes)
55
-
56
- def __getitem__(self, index):
57
- fn = self.datapath[index]
58
- cls = self.classes[self.datapath[index][0]]
59
- point_set = np.loadtxt(fn[1]).astype(np.float32)
60
- seg = np.loadtxt(fn[2]).astype(np.int64)
61
- #print(point_set.shape, seg.shape)
62
-
63
- choice = np.random.choice(len(seg), self.npoints, replace=True)
64
- #resample
65
- point_set = point_set[choice, :]
66
-
67
- point_set = point_set - np.expand_dims(np.mean(point_set, axis = 0), 0) # center
68
- dist = np.max(np.sqrt(np.sum(point_set ** 2, axis = 1)),0)
69
- point_set = point_set / dist #scale
70
-
71
- if self.data_augmentation:
72
- theta = np.random.uniform(0,np.pi*2)
73
- rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]])
74
- point_set[:,[0,2]] = point_set[:,[0,2]].dot(rotation_matrix) # random rotation
75
- point_set += np.random.normal(0, 0.02, size=point_set.shape) # random jitter
76
-
77
- seg = seg[choice]
78
- point_set = torch.from_numpy(point_set)
79
- seg = torch.from_numpy(seg)
80
- cls = torch.from_numpy(np.array([cls]).astype(np.int64))
81
-
82
- if self.classification:
83
- return point_set, cls
84
- else:
85
- return point_set, seg
86
-
87
- def __len__(self):
88
- return len(self.datapath)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geom_solver.py CHANGED
@@ -1,7 +1,6 @@
1
  import numpy as np
2
  from pytorch3d.ops import ball_query
3
  from helpers import *
4
- from handcrafted_solution import convert_entry_to_human_readable
5
  import cv2
6
  import hoho
7
  import itertools
@@ -10,13 +9,14 @@ from pytorch3d.renderer import PerspectiveCameras
10
  from hoho.color_mappings import gestalt_color_mapping
11
  from PIL import Image
12
 
 
13
  def my_empty_solution():
14
- return np.zeros((19,3)), [(0, 0)]
15
 
16
 
17
- def cheat_the_metric_solution(vertices=None):
18
  if vertices is None:
19
- nverts = 19
20
  vertices_new = np.zeros((nverts,3))
21
  else:
22
  nverts = vertices.shape[0]
@@ -31,7 +31,7 @@ class GeomSolver(object):
31
 
32
  def __init__(self):
33
  self.min_vertices = 10
34
- self.mean_vertices = 19
35
  self.max_vertices = 30
36
  self.kmeans_th = 200
37
  self.point_dist_th = 50
@@ -41,7 +41,7 @@ class GeomSolver(object):
41
  self.return_edges = False
42
  self.mean_fixed = False
43
  self.repeat_predicted = True
44
- self.cheat_metric = True
45
 
46
  def cluster_points(self, point_types):
47
  point_colors = []
@@ -65,9 +65,6 @@ class GeomSolver(object):
65
  vert_mask = (vert_mask > 0).astype(np.uint8)
66
 
67
  dist = cv2.distanceTransform(1-vert_mask, cv2.DIST_L2, 3)
68
- # dist[dist > 100] = 100
69
- # ndist = np.zeros_like(dist)
70
- # ndist = cv2.normalize(dist, ndist, 0, 1.0, cv2.NORM_MINMAX)
71
 
72
  in_this_image = np.array([cki in p.image_ids for p in self.points3D.values()])
73
  uv = torch.round(self.pyt_cameras[ki].transform_points(self.verts)[:, :2]).cpu().numpy().astype(int)
@@ -130,16 +127,12 @@ class GeomSolver(object):
130
  human_entry = self.human_entry
131
 
132
  col_cams = [hoho.Rt_to_eye_target(Image.new('RGB', (human_entry['cameras'][colmap_img.camera_id].width, human_entry['cameras'][colmap_img.camera_id].height)), to_K(*human_entry['cameras'][colmap_img.camera_id].params), quaternion_to_rotation_matrix(colmap_img.qvec), colmap_img.tvec) for colmap_img in human_entry['images'].values()]
133
- # eye, target, up, fov = col_cams[0]
134
 
135
  cameras, images, self.points3D = human_entry['cameras'], human_entry['images'], human_entry['points3d']
136
  colmap_cameras_tf = list(human_entry['images'].keys())
137
  self.xyz = np.stack([p.xyz for p in self.points3D.values()])
138
  color = np.stack([p.rgb for p in self.points3D.values()])
139
  self.gests = [np.array(gest0) for gest0 in human_entry['gestalt']]
140
- # for ki in range(1, len(self.gests)):
141
- # if self.gests[ki].shape != self.gests[0].shape:
142
- # self.gests[ki] = self.gests[ki].transpose(1,0,2)
143
 
144
  to_camera_ids = np.array([colmap_img.camera_id for colmap_img in human_entry['images'].values()])
145
 
@@ -183,16 +176,11 @@ class GeomSolver(object):
183
 
184
  self.vertices = centers
185
  nvert = centers.shape[0]
186
- # desired_vertices = (self.xyz[:,-1] > z_th).sum() // 300
187
  desired_vertices = int(2.2*nvert)
188
- # desired_vertices = self.mean_vertices
189
  if desired_vertices < self.min_vertices:
190
  desired_vertices = self.mean_vertices
191
  if desired_vertices > self.max_vertices:
192
  desired_vertices = self.mean_vertices
193
- # if self.broken_cams.any():
194
- # vertices = centers
195
- # print("There are broken cams.")
196
  if nvert >= desired_vertices:
197
  vertices = centers[:desired_vertices]
198
  print("Enough vertices.")
@@ -248,8 +236,8 @@ class GeomSolver(object):
248
  uvs.append(uv)
249
 
250
  edges = []
251
- # thresholds_min_mean = {0 : [5, 7], 1 : [9, 25], 2: [30, 1000]}
252
- thresholds_min_mean = {0 : [1, 7], 1 : [3, 25], 2: [3, 1000]}
253
  for i in range(pyt_centers.shape[0]):
254
  for j in range(i+1, pyt_centers.shape[0]):
255
  etype = (self.is_apex[i] + self.is_apex[j])
@@ -298,12 +286,10 @@ class GeomSolver(object):
298
  else:
299
  edges = [(0, 0)]
300
 
301
- if self.cheat_metric:
302
- dumb_vertices = np.zeros((vertices.shape[0],3))
303
- # dumb_vertices = self.wf_center[None].repeat(vertices.shape[0], axis=0)
304
- vertices, edges = cheat_the_metric_solution(dumb_vertices)
305
- # vertices_new, edges = cheat_the_metric_solution(np.zeros((vertices.shape[0] // 2,3)))
306
- # vertices = np.concatenate((vertices_new, vertices[:vertices_new.shape[0]]))
307
 
308
  if visualize:
309
  from hoho.viz3d import plot_estimate_and_gt
 
1
  import numpy as np
2
  from pytorch3d.ops import ball_query
3
  from helpers import *
 
4
  import cv2
5
  import hoho
6
  import itertools
 
9
  from hoho.color_mappings import gestalt_color_mapping
10
  from PIL import Image
11
 
12
+
13
  def my_empty_solution():
14
+ return np.zeros((20,3)), [(0, 0)]
15
 
16
 
17
+ def fully_connected_solution(vertices=None):
18
  if vertices is None:
19
+ nverts = 20
20
  vertices_new = np.zeros((nverts,3))
21
  else:
22
  nverts = vertices.shape[0]
 
31
 
32
  def __init__(self):
33
  self.min_vertices = 10
34
+ self.mean_vertices = 20
35
  self.max_vertices = 30
36
  self.kmeans_th = 200
37
  self.point_dist_th = 50
 
41
  self.return_edges = False
42
  self.mean_fixed = False
43
  self.repeat_predicted = True
44
+ self.return_fully_connected = True
45
 
46
  def cluster_points(self, point_types):
47
  point_colors = []
 
65
  vert_mask = (vert_mask > 0).astype(np.uint8)
66
 
67
  dist = cv2.distanceTransform(1-vert_mask, cv2.DIST_L2, 3)
 
 
 
68
 
69
  in_this_image = np.array([cki in p.image_ids for p in self.points3D.values()])
70
  uv = torch.round(self.pyt_cameras[ki].transform_points(self.verts)[:, :2]).cpu().numpy().astype(int)
 
127
  human_entry = self.human_entry
128
 
129
  col_cams = [hoho.Rt_to_eye_target(Image.new('RGB', (human_entry['cameras'][colmap_img.camera_id].width, human_entry['cameras'][colmap_img.camera_id].height)), to_K(*human_entry['cameras'][colmap_img.camera_id].params), quaternion_to_rotation_matrix(colmap_img.qvec), colmap_img.tvec) for colmap_img in human_entry['images'].values()]
 
130
 
131
  cameras, images, self.points3D = human_entry['cameras'], human_entry['images'], human_entry['points3d']
132
  colmap_cameras_tf = list(human_entry['images'].keys())
133
  self.xyz = np.stack([p.xyz for p in self.points3D.values()])
134
  color = np.stack([p.rgb for p in self.points3D.values()])
135
  self.gests = [np.array(gest0) for gest0 in human_entry['gestalt']]
 
 
 
136
 
137
  to_camera_ids = np.array([colmap_img.camera_id for colmap_img in human_entry['images'].values()])
138
 
 
176
 
177
  self.vertices = centers
178
  nvert = centers.shape[0]
 
179
  desired_vertices = int(2.2*nvert)
 
180
  if desired_vertices < self.min_vertices:
181
  desired_vertices = self.mean_vertices
182
  if desired_vertices > self.max_vertices:
183
  desired_vertices = self.mean_vertices
 
 
 
184
  if nvert >= desired_vertices:
185
  vertices = centers[:desired_vertices]
186
  print("Enough vertices.")
 
236
  uvs.append(uv)
237
 
238
  edges = []
239
+ thresholds_min_mean = {0 : [5, 7], 1 : [9, 25], 2: [30, 1000]}
240
+ # thresholds_min_mean = {0 : [1, 7], 1 : [3, 25], 2: [3, 1000]}
241
  for i in range(pyt_centers.shape[0]):
242
  for j in range(i+1, pyt_centers.shape[0]):
243
  etype = (self.is_apex[i] + self.is_apex[j])
 
286
  else:
287
  edges = [(0, 0)]
288
 
289
+ if self.return_fully_connected:
290
+ zero_vertices = np.zeros((vertices.shape[0],3))
291
+ # zero_vertices = self.wf_center[None].repeat(vertices.shape[0], axis=0)
292
+ vertices, edges = fully_connected_solution(zero_vertices)
 
 
293
 
294
  if visualize:
295
  from hoho.viz3d import plot_estimate_and_gt
handcrafted_solution.py DELETED
@@ -1,245 +0,0 @@
1
- # Description: This file contains the handcrafted solution for the task of wireframe reconstruction
2
-
3
- import io
4
- from PIL import Image as PImage
5
- import numpy as np
6
- from collections import defaultdict
7
- import cv2
8
- from typing import Tuple, List
9
- from scipy.spatial.distance import cdist
10
-
11
- from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary
12
- from hoho.color_mappings import gestalt_color_mapping, ade20k_color_mapping
13
-
14
-
15
- def empty_solution():
16
- '''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
17
- return np.zeros((2,3)), [(0, 1)]
18
-
19
-
20
- def convert_entry_to_human_readable(entry):
21
- out = {}
22
- already_good = ['__key__', 'wf_vertices', 'wf_edges', 'edge_semantics', 'mesh_vertices', 'mesh_faces', 'face_semantics', 'K', 'R', 't']
23
- for k, v in entry.items():
24
- if k in already_good:
25
- out[k] = v
26
- continue
27
- if k == 'points3d':
28
- out[k] = read_points3D_binary(fid=io.BytesIO(v))
29
- if k == 'cameras':
30
- out[k] = read_cameras_binary(fid=io.BytesIO(v))
31
- if k == 'images':
32
- out[k] = read_images_binary(fid=io.BytesIO(v))
33
- if k in ['ade20k', 'gestalt']:
34
- out[k] = [PImage.open(io.BytesIO(x)).convert('RGB') for x in v]
35
- if k == 'depthcm':
36
- out[k] = [PImage.open(io.BytesIO(x)) for x in entry['depthcm']]
37
- return out
38
-
39
-
40
- def get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th = 50.0):
41
- '''Get the vertices and edges from the gestalt segmentation mask of the house'''
42
- vertices = []
43
- connections = []
44
- # Apex
45
- apex_color = np.array(gestalt_color_mapping['apex'])
46
- apex_mask = cv2.inRange(gest_seg_np, apex_color-0.5, apex_color+0.5)
47
- if apex_mask.sum() > 0:
48
- output = cv2.connectedComponentsWithStats(apex_mask, 8, cv2.CV_32S)
49
- (numLabels, labels, stats, centroids) = output
50
- stats, centroids = stats[1:], centroids[1:]
51
-
52
- for i in range(numLabels-1):
53
- vert = {"xy": centroids[i], "type": "apex"}
54
- vertices.append(vert)
55
-
56
- eave_end_color = np.array(gestalt_color_mapping['eave_end_point'])
57
- eave_end_mask = cv2.inRange(gest_seg_np, eave_end_color-0.5, eave_end_color+0.5)
58
- if eave_end_mask.sum() > 0:
59
- output = cv2.connectedComponentsWithStats(eave_end_mask, 8, cv2.CV_32S)
60
- (numLabels, labels, stats, centroids) = output
61
- stats, centroids = stats[1:], centroids[1:]
62
-
63
- for i in range(numLabels-1):
64
- vert = {"xy": centroids[i], "type": "eave_end_point"}
65
- vertices.append(vert)
66
- # Connectivity
67
- apex_pts = []
68
- apex_pts_idxs = []
69
- for j, v in enumerate(vertices):
70
- apex_pts.append(v['xy'])
71
- apex_pts_idxs.append(j)
72
- apex_pts = np.array(apex_pts)
73
-
74
- # Ridge connects two apex points
75
- for edge_class in ['eave', 'ridge', 'rake', 'valley']:
76
- edge_color = np.array(gestalt_color_mapping[edge_class])
77
- mask = cv2.morphologyEx(cv2.inRange(gest_seg_np,
78
- edge_color-0.5,
79
- edge_color+0.5),
80
- cv2.MORPH_DILATE, np.ones((11, 11)))
81
- line_img = np.copy(gest_seg_np) * 0
82
- if mask.sum() > 0:
83
- output = cv2.connectedComponentsWithStats(mask, 8, cv2.CV_32S)
84
- (numLabels, labels, stats, centroids) = output
85
- stats, centroids = stats[1:], centroids[1:]
86
- edges = []
87
- for i in range(1, numLabels):
88
- y,x = np.where(labels == i)
89
- xleft_idx = np.argmin(x)
90
- x_left = x[xleft_idx]
91
- y_left = y[xleft_idx]
92
- xright_idx = np.argmax(x)
93
- x_right = x[xright_idx]
94
- y_right = y[xright_idx]
95
- edges.append((x_left, y_left, x_right, y_right))
96
- cv2.line(line_img, (x_left, y_left), (x_right, y_right), (255, 255, 255), 2)
97
- edges = np.array(edges)
98
- if (len(apex_pts) < 2) or len(edges) <1:
99
- continue
100
- pts_to_edges_dist = np.minimum(cdist(apex_pts, edges[:,:2]), cdist(apex_pts, edges[:,2:]))
101
- connectivity_mask = pts_to_edges_dist <= edge_th
102
- edge_connects = connectivity_mask.sum(axis=0)
103
- for edge_idx, edgesum in enumerate(edge_connects):
104
- if edgesum>=2:
105
- connected_verts = np.where(connectivity_mask[:,edge_idx])[0]
106
- for a_i, a in enumerate(connected_verts):
107
- for b in connected_verts[a_i+1:]:
108
- connections.append((a, b))
109
- return vertices, connections
110
-
111
- def get_uv_depth(vertices, depth):
112
- '''Get the depth of the vertices from the depth image'''
113
- uv = []
114
- for v in vertices:
115
- uv.append(v['xy'])
116
- uv = np.array(uv)
117
- uv_int = uv.astype(np.int32)
118
- H, W = depth.shape[:2]
119
- uv_int[:, 0] = np.clip( uv_int[:, 0], 0, W-1)
120
- uv_int[:, 1] = np.clip( uv_int[:, 1], 0, H-1)
121
- vertex_depth = depth[(uv_int[:, 1] , uv_int[:, 0])]
122
- return uv, vertex_depth
123
-
124
-
125
- def merge_vertices_3d(vert_edge_per_image, th=0.1):
126
- '''Merge vertices that are close to each other in 3D space and are of same types'''
127
- all_3d_vertices = []
128
- connections_3d = []
129
- all_indexes = []
130
- cur_start = 0
131
- types = []
132
- for cimg_idx, (vertices, connections, vertices_3d) in vert_edge_per_image.items():
133
- types += [int(v['type']=='apex') for v in vertices]
134
- all_3d_vertices.append(vertices_3d)
135
- connections_3d+=[(x+cur_start,y+cur_start) for (x,y) in connections]
136
- cur_start+=len(vertices_3d)
137
- all_3d_vertices = np.concatenate(all_3d_vertices, axis=0)
138
- #print (connections_3d)
139
- distmat = cdist(all_3d_vertices, all_3d_vertices)
140
- types = np.array(types).reshape(-1,1)
141
- same_types = cdist(types, types)
142
- mask_to_merge = (distmat <= th) & (same_types==0)
143
- new_vertices = []
144
- new_connections = []
145
- to_merge = sorted(list(set([tuple(a.nonzero()[0].tolist()) for a in mask_to_merge])))
146
- to_merge_final = defaultdict(list)
147
- for i in range(len(all_3d_vertices)):
148
- for j in to_merge:
149
- if i in j:
150
- to_merge_final[i]+=j
151
- for k, v in to_merge_final.items():
152
- to_merge_final[k] = list(set(v))
153
- already_there = set()
154
- merged = []
155
- for k, v in to_merge_final.items():
156
- if k in already_there:
157
- continue
158
- merged.append(v)
159
- for vv in v:
160
- already_there.add(vv)
161
- old_idx_to_new = {}
162
- count=0
163
- for idxs in merged:
164
- new_vertices.append(all_3d_vertices[idxs].mean(axis=0))
165
- for idx in idxs:
166
- old_idx_to_new[idx] = count
167
- count +=1
168
- #print (connections_3d)
169
- new_vertices=np.array(new_vertices)
170
- #print (connections_3d)
171
- for conn in connections_3d:
172
- new_con = sorted((old_idx_to_new[conn[0]], old_idx_to_new[conn[1]]))
173
- if new_con[0] == new_con[1]:
174
- continue
175
- if new_con not in new_connections:
176
- new_connections.append(new_con)
177
- #print (f'{len(new_vertices)} left after merging {len(all_3d_vertices)} with {th=}')
178
- return new_vertices, new_connections
179
-
180
- def prune_not_connected(all_3d_vertices, connections_3d):
181
- '''Prune vertices that are not connected to any other vertex'''
182
- connected = defaultdict(list)
183
- for c in connections_3d:
184
- connected[c[0]].append(c)
185
- connected[c[1]].append(c)
186
- new_indexes = {}
187
- new_verts = []
188
- connected_out = []
189
- for k,v in connected.items():
190
- vert = all_3d_vertices[k]
191
- if tuple(vert) not in new_verts:
192
- new_verts.append(tuple(vert))
193
- new_indexes[k]=len(new_verts) -1
194
- for k,v in connected.items():
195
- for vv in v:
196
- connected_out.append((new_indexes[vv[0]],new_indexes[vv[1]]))
197
- connected_out=list(set(connected_out))
198
-
199
- return np.array(new_verts), connected_out
200
-
201
-
202
- def predict(entry, visualize=False) -> Tuple[np.ndarray, List[int]]:
203
- good_entry = convert_entry_to_human_readable(entry)
204
- vert_edge_per_image = {}
205
- for i, (gest, depth, K, R, t) in enumerate(zip(good_entry['gestalt'],
206
- good_entry['depthcm'],
207
- good_entry['K'],
208
- good_entry['R'],
209
- good_entry['t']
210
- )):
211
- gest_seg = gest.resize(depth.size)
212
- gest_seg_np = np.array(gest_seg).astype(np.uint8)
213
- # Metric3D
214
- depth_np = np.array(depth) / 2.5 # 2.5 is the scale estimation coefficient
215
- vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th = 20.)
216
- if (len(vertices) < 2) or (len(connections) < 1):
217
- print (f'Not enough vertices or connections in image {i}')
218
- vert_edge_per_image[i] = np.empty((0, 2)), [], np.empty((0, 3))
219
- continue
220
- uv, depth_vert = get_uv_depth(vertices, depth_np)
221
- # Normalize the uv to the camera intrinsics
222
- xy_local = np.ones((len(uv), 3))
223
- xy_local[:, 0] = (uv[:, 0] - K[0,2]) / K[0,0]
224
- xy_local[:, 1] = (uv[:, 1] - K[1,2]) / K[1,1]
225
- # Get the 3D vertices
226
- vertices_3d_local = depth_vert[...,None] * (xy_local/np.linalg.norm(xy_local, axis=1)[...,None])
227
- world_to_cam = np.eye(4)
228
- world_to_cam[:3, :3] = R
229
- world_to_cam[:3, 3] = t.reshape(-1)
230
- cam_to_world = np.linalg.inv(world_to_cam)
231
- vertices_3d = cv2.transform(cv2.convertPointsToHomogeneous(vertices_3d_local), cam_to_world)
232
- vertices_3d = cv2.convertPointsFromHomogeneous(vertices_3d).reshape(-1, 3)
233
- vert_edge_per_image[i] = vertices, connections, vertices_3d
234
- all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 3.0)
235
- all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d)
236
- if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1:
237
- print (f'Not enough vertices or connections in the 3D vertices')
238
- return (good_entry['__key__'], *empty_solution())
239
- if visualize:
240
- from hoho.viz3d import plot_estimate_and_gt
241
- plot_estimate_and_gt( all_3d_vertices_clean,
242
- connections_3d_clean,
243
- good_entry['wf_vertices'],
244
- good_entry['wf_edges'])
245
- return good_entry['__key__'], all_3d_vertices_clean, connections_3d_clean
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
helpers.py CHANGED
@@ -3,6 +3,27 @@ from PIL import Image as PImage
3
  import io
4
  from scipy.spatial.distance import cdist
5
  from scipy.optimize import linear_sum_assignment
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  def to_K(f, cx, cy):
 
3
  import io
4
  from scipy.spatial.distance import cdist
5
  from scipy.optimize import linear_sum_assignment
6
+ from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary
7
+
8
+
9
+ def convert_entry_to_human_readable(entry):
10
+ out = {}
11
+ already_good = ['__key__', 'wf_vertices', 'wf_edges', 'edge_semantics', 'mesh_vertices', 'mesh_faces', 'face_semantics', 'K', 'R', 't']
12
+ for k, v in entry.items():
13
+ if k in already_good:
14
+ out[k] = v
15
+ continue
16
+ if k == 'points3d':
17
+ out[k] = read_points3D_binary(fid=io.BytesIO(v))
18
+ if k == 'cameras':
19
+ out[k] = read_cameras_binary(fid=io.BytesIO(v))
20
+ if k == 'images':
21
+ out[k] = read_images_binary(fid=io.BytesIO(v))
22
+ if k in ['ade20k', 'gestalt']:
23
+ out[k] = [PImage.open(io.BytesIO(x)).convert('RGB') for x in v]
24
+ if k == 'depthcm':
25
+ out[k] = [PImage.open(io.BytesIO(x)) for x in entry['depthcm']]
26
+ return out
27
 
28
 
29
  def to_K(f, cx, cy):
testing.ipynb → main.ipynb RENAMED
The diff for this file is too large to render. See raw diff
 
my_solution.py CHANGED
@@ -9,43 +9,17 @@ from scipy.spatial.distance import cdist
9
  from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary
10
  from hoho.color_mappings import gestalt_color_mapping, ade20k_color_mapping
11
 
12
- from geom_solver import GeomSolver, my_empty_solution, cheat_the_metric_solution
13
-
14
-
15
- def convert_entry_to_human_readable(entry):
16
- out = {}
17
- already_good = ['__key__', 'wf_vertices', 'wf_edges', 'edge_semantics', 'mesh_vertices', 'mesh_faces', 'face_semantics', 'K', 'R', 't']
18
- for k, v in entry.items():
19
- if k in already_good:
20
- out[k] = v
21
- continue
22
- if k == 'points3d':
23
- out[k] = read_points3D_binary(fid=io.BytesIO(v))
24
- if k == 'cameras':
25
- out[k] = read_cameras_binary(fid=io.BytesIO(v))
26
- if k == 'images':
27
- out[k] = read_images_binary(fid=io.BytesIO(v))
28
- if k in ['ade20k', 'gestalt']:
29
- out[k] = [PImage.open(io.BytesIO(x)).convert('RGB') for x in v]
30
- if k == 'depthcm':
31
- out[k] = [PImage.open(io.BytesIO(x)) for x in entry['depthcm']]
32
- return out
33
 
34
 
35
  def predict(entry, visualize=False) -> Tuple[np.ndarray, List[int]]:
36
- # return (entry['__key__'], *my_empty_solution())
37
  vertices0, edges0 = my_empty_solution()
38
  try:
39
  vertices, edges = GeomSolver().solve(entry)
40
  except:
41
  print('ERROR')
42
- # vertices, edges = vertices0, edges0
43
- vertices, edges = cheat_the_metric_solution()
44
 
45
- # if vertices.shape[0] < vertices0.shape[0]:
46
- # verts_new = vertices0
47
- # verts_new[:vertices.shape[0]] = vertices
48
- # vertices = verts_new
49
 
50
  if (len(edges) < 1) and (len(vertices) >= 2):
51
  # print("Added only edges")
 
9
  from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary
10
  from hoho.color_mappings import gestalt_color_mapping, ade20k_color_mapping
11
 
12
+ from geom_solver import GeomSolver, my_empty_solution, fully_connected_solution
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  def predict(entry, visualize=False) -> Tuple[np.ndarray, List[int]]:
 
16
  vertices0, edges0 = my_empty_solution()
17
  try:
18
  vertices, edges = GeomSolver().solve(entry)
19
  except:
20
  print('ERROR')
21
+ vertices, edges = fully_connected_solution()
 
22
 
 
 
 
 
23
 
24
  if (len(edges) < 1) and (len(vertices) >= 2):
25
  # print("Added only edges")
pointnet.py DELETED
@@ -1,213 +0,0 @@
1
- from __future__ import print_function
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.parallel
5
- import torch.utils.data
6
- from torch.autograd import Variable
7
- import numpy as np
8
- import torch.nn.functional as F
9
-
10
-
11
- class STN3d(nn.Module):
12
- def __init__(self):
13
- super(STN3d, self).__init__()
14
- self.conv1 = torch.nn.Conv1d(3, 64, 1)
15
- self.conv2 = torch.nn.Conv1d(64, 128, 1)
16
- self.conv3 = torch.nn.Conv1d(128, 1024, 1)
17
- self.fc1 = nn.Linear(1024, 512)
18
- self.fc2 = nn.Linear(512, 256)
19
- self.fc3 = nn.Linear(256, 9)
20
- self.relu = nn.ReLU()
21
-
22
- self.bn1 = nn.BatchNorm1d(64)
23
- self.bn2 = nn.BatchNorm1d(128)
24
- self.bn3 = nn.BatchNorm1d(1024)
25
- self.bn4 = nn.BatchNorm1d(512)
26
- self.bn5 = nn.BatchNorm1d(256)
27
-
28
-
29
- def forward(self, x):
30
- batchsize = x.size()[0]
31
- x = F.relu(self.bn1(self.conv1(x)))
32
- x = F.relu(self.bn2(self.conv2(x)))
33
- x = F.relu(self.bn3(self.conv3(x)))
34
- x = torch.max(x, 2, keepdim=True)[0]
35
- x = x.view(-1, 1024)
36
-
37
- x = F.relu(self.bn4(self.fc1(x)))
38
- x = F.relu(self.bn5(self.fc2(x)))
39
- x = self.fc3(x)
40
-
41
- iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)
42
- if x.is_cuda:
43
- iden = iden.cuda()
44
- x = x + iden
45
- x = x.view(-1, 3, 3)
46
- return x
47
-
48
-
49
- class STNkd(nn.Module):
50
- def __init__(self, k=64):
51
- super(STNkd, self).__init__()
52
- self.conv1 = torch.nn.Conv1d(k, 64, 1)
53
- self.conv2 = torch.nn.Conv1d(64, 128, 1)
54
- self.conv3 = torch.nn.Conv1d(128, 1024, 1)
55
- self.fc1 = nn.Linear(1024, 512)
56
- self.fc2 = nn.Linear(512, 256)
57
- self.fc3 = nn.Linear(256, k*k)
58
- self.relu = nn.ReLU()
59
-
60
- self.bn1 = nn.BatchNorm1d(64)
61
- self.bn2 = nn.BatchNorm1d(128)
62
- self.bn3 = nn.BatchNorm1d(1024)
63
- self.bn4 = nn.BatchNorm1d(512)
64
- self.bn5 = nn.BatchNorm1d(256)
65
-
66
- self.k = k
67
-
68
- def forward(self, x):
69
- batchsize = x.size()[0]
70
- x = F.relu(self.bn1(self.conv1(x)))
71
- x = F.relu(self.bn2(self.conv2(x)))
72
- x = F.relu(self.bn3(self.conv3(x)))
73
- x = torch.max(x, 2, keepdim=True)[0]
74
- x = x.view(-1, 1024)
75
-
76
- x = F.relu(self.bn4(self.fc1(x)))
77
- x = F.relu(self.bn5(self.fc2(x)))
78
- x = self.fc3(x)
79
-
80
- iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)
81
- if x.is_cuda:
82
- iden = iden.cuda()
83
- x = x + iden
84
- x = x.view(-1, self.k, self.k)
85
- return x
86
-
87
- class PointNetfeat(nn.Module):
88
- def __init__(self, global_feat = True, feature_transform = False):
89
- super(PointNetfeat, self).__init__()
90
- self.stn = STN3d()
91
- self.conv1 = torch.nn.Conv1d(3, 64, 1)
92
- self.conv2 = torch.nn.Conv1d(64, 128, 1)
93
- self.conv3 = torch.nn.Conv1d(128, 1024, 1)
94
- self.bn1 = nn.BatchNorm1d(64)
95
- self.bn2 = nn.BatchNorm1d(128)
96
- self.bn3 = nn.BatchNorm1d(1024)
97
- self.global_feat = global_feat
98
- self.feature_transform = feature_transform
99
- if self.feature_transform:
100
- self.fstn = STNkd(k=64)
101
-
102
- def forward(self, x):
103
- n_pts = x.size()[2]
104
- trans = self.stn(x)
105
- x = x.transpose(2, 1)
106
- x = torch.bmm(x, trans)
107
- x = x.transpose(2, 1)
108
- x = F.relu(self.bn1(self.conv1(x)))
109
-
110
- if self.feature_transform:
111
- trans_feat = self.fstn(x)
112
- x = x.transpose(2,1)
113
- x = torch.bmm(x, trans_feat)
114
- x = x.transpose(2,1)
115
- else:
116
- trans_feat = None
117
-
118
- pointfeat = x
119
- x = F.relu(self.bn2(self.conv2(x)))
120
- x = self.bn3(self.conv3(x))
121
- x = torch.max(x, 2, keepdim=True)[0]
122
- x = x.view(-1, 1024)
123
- if self.global_feat:
124
- return x, trans, trans_feat
125
- else:
126
- x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
127
- return torch.cat([x, pointfeat], 1), trans, trans_feat
128
-
129
- class PointNetCls(nn.Module):
130
- def __init__(self, k=2, feature_transform=False):
131
- super(PointNetCls, self).__init__()
132
- self.feature_transform = feature_transform
133
- self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)
134
- self.fc1 = nn.Linear(1024, 512)
135
- self.fc2 = nn.Linear(512, 256)
136
- self.fc3 = nn.Linear(256, k)
137
- self.dropout = nn.Dropout(p=0.3)
138
- self.bn1 = nn.BatchNorm1d(512)
139
- self.bn2 = nn.BatchNorm1d(256)
140
- self.relu = nn.ReLU()
141
-
142
- def forward(self, x):
143
- x, trans, trans_feat = self.feat(x)
144
- x = F.relu(self.bn1(self.fc1(x)))
145
- x = F.relu(self.bn2(self.dropout(self.fc2(x))))
146
- x = self.fc3(x)
147
- return F.log_softmax(x, dim=1), trans, trans_feat
148
-
149
-
150
- class PointNetDenseCls(nn.Module):
151
- def __init__(self, k = 2, feature_transform=False):
152
- super(PointNetDenseCls, self).__init__()
153
- self.k = k
154
- self.feature_transform=feature_transform
155
- self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)
156
- self.conv1 = torch.nn.Conv1d(1088, 512, 1)
157
- self.conv2 = torch.nn.Conv1d(512, 256, 1)
158
- self.conv3 = torch.nn.Conv1d(256, 128, 1)
159
- self.conv4 = torch.nn.Conv1d(128, self.k, 1)
160
- self.bn1 = nn.BatchNorm1d(512)
161
- self.bn2 = nn.BatchNorm1d(256)
162
- self.bn3 = nn.BatchNorm1d(128)
163
-
164
- def forward(self, x):
165
- batchsize = x.size()[0]
166
- n_pts = x.size()[2]
167
- x, trans, trans_feat = self.feat(x)
168
- x = F.relu(self.bn1(self.conv1(x)))
169
- x = F.relu(self.bn2(self.conv2(x)))
170
- x = F.relu(self.bn3(self.conv3(x)))
171
- x = self.conv4(x)
172
- x = x.transpose(2,1).contiguous()
173
- x = F.log_softmax(x.view(-1,self.k), dim=-1)
174
- x = x.view(batchsize, n_pts, self.k)
175
- return x, trans, trans_feat
176
-
177
- def feature_transform_regularizer(trans):
178
- d = trans.size()[1]
179
- batchsize = trans.size()[0]
180
- I = torch.eye(d)[None, :, :]
181
- if trans.is_cuda:
182
- I = I.cuda()
183
- loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2)))
184
- return loss
185
-
186
- if __name__ == '__main__':
187
- sim_data = Variable(torch.rand(32,3,2500))
188
- trans = STN3d()
189
- out = trans(sim_data)
190
- print('stn', out.size())
191
- print('loss', feature_transform_regularizer(out))
192
-
193
- sim_data_64d = Variable(torch.rand(32, 64, 2500))
194
- trans = STNkd(k=64)
195
- out = trans(sim_data_64d)
196
- print('stn64d', out.size())
197
- print('loss', feature_transform_regularizer(out))
198
-
199
- pointfeat = PointNetfeat(global_feat=True)
200
- out, _, _ = pointfeat(sim_data)
201
- print('global feat', out.size())
202
-
203
- pointfeat = PointNetfeat(global_feat=False)
204
- out, _, _ = pointfeat(sim_data)
205
- print('point feat', out.size())
206
-
207
- cls = PointNetCls(k = 5)
208
- out, _, _ = cls(sim_data)
209
- print('class', out.size())
210
-
211
- seg = PointNetDenseCls(k = 3)
212
- out, _, _ = seg(sim_data)
213
- print('seg', out.size())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
script_cpus.py DELETED
@@ -1,145 +0,0 @@
1
- ### This is example of the script that will be run in the test environment.
2
- ### Some parts of the code are compulsory and you should NOT CHANGE THEM.
3
- ### They are between '''---compulsory---''' comments.
4
- ### You can change the rest of the code to define and test your solution.
5
- ### However, you should not change the signature of the provided function.
6
- ### The script would save "submission.parquet" file in the current directory.
7
- ### The actual logic of the solution is implemented in the `handcrafted_solution.py` file.
8
- ### The `handcrafted_solution.py` file is a placeholder for your solution.
9
- ### You should implement the logic of your solution in that file.
10
- ### You can use any additional files and subdirectories to organize your code.
11
-
12
- '''---compulsory---'''
13
- # import subprocess
14
- # from pathlib import Path
15
- # def install_package_from_local_file(package_name, folder='packages'):
16
- # """
17
- # Installs a package from a local .whl file or a directory containing .whl files using pip.
18
-
19
- # Parameters:
20
- # path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files.
21
- # """
22
- # try:
23
- # pth = str(Path(folder) / package_name)
24
- # subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install",
25
- # "--no-index", # Do not use package index
26
- # "--find-links", pth, # Look for packages in the specified directory or at the file
27
- # package_name]) # Specify the package to install
28
- # print(f"Package installed successfully from {pth}")
29
- # except subprocess.CalledProcessError as e:
30
- # print(f"Failed to install package from {pth}. Error: {e}")
31
-
32
- # install_package_from_local_file('hoho')
33
-
34
- import hoho; hoho.setup() # YOU MUST CALL hoho.setup() BEFORE ANYTHING ELSE
35
- # import subprocess
36
- # import importlib
37
- # from pathlib import Path
38
- # import subprocess
39
-
40
-
41
- # ### The function below is useful for installing additional python wheels.
42
- # def install_package_from_local_file(package_name, folder='packages'):
43
- # """
44
- # Installs a package from a local .whl file or a directory containing .whl files using pip.
45
-
46
- # Parameters:
47
- # path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files.
48
- # """
49
- # try:
50
- # pth = str(Path(folder) / package_name)
51
- # subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install",
52
- # "--no-index", # Do not use package index
53
- # "--find-links", pth, # Look for packages in the specified directory or at the file
54
- # package_name]) # Specify the package to install
55
- # print(f"Package installed successfully from {pth}")
56
- # except subprocess.CalledProcessError as e:
57
- # print(f"Failed to install package from {pth}. Error: {e}")
58
-
59
-
60
- # pip download webdataset -d packages/webdataset --platform manylinux1_x86_64 --python-version 38 --only-binary=:all:
61
- # install_package_from_local_file('webdataset')
62
- # install_package_from_local_file('tqdm')
63
-
64
- ### Here you can import any library or module you want.
65
- ### The code below is used to read and parse the input dataset.
66
- ### Please, do not modify it.
67
-
68
- import webdataset as wds
69
- from tqdm import tqdm
70
- from typing import Dict
71
- import pandas as pd
72
- from transformers import AutoTokenizer
73
- import os
74
- import time
75
- import io
76
- from PIL import Image as PImage
77
- import numpy as np
78
-
79
- from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary
80
- from hoho import proc, Sample
81
-
82
- def convert_entry_to_human_readable(entry):
83
- out = {}
84
- already_good = ['__key__', 'wf_vertices', 'wf_edges', 'edge_semantics', 'mesh_vertices', 'mesh_faces', 'face_semantics', 'K', 'R', 't']
85
- for k, v in entry.items():
86
- if k in already_good:
87
- out[k] = v
88
- continue
89
- if k == 'points3d':
90
- out[k] = read_points3D_binary(fid=io.BytesIO(v))
91
- if k == 'cameras':
92
- out[k] = read_cameras_binary(fid=io.BytesIO(v))
93
- if k == 'images':
94
- out[k] = read_images_binary(fid=io.BytesIO(v))
95
- if k in ['ade20k', 'gestalt']:
96
- out[k] = [PImage.open(io.BytesIO(x)).convert('RGB') for x in v]
97
- if k == 'depthcm':
98
- out[k] = [PImage.open(io.BytesIO(x)) for x in entry['depthcm']]
99
- return out
100
-
101
- '''---end of compulsory---'''
102
-
103
- ### The part below is used to define and test your solution.
104
-
105
- from pathlib import Path
106
- def save_submission(submission, path):
107
- """
108
- Saves the submission to a specified path.
109
-
110
- Parameters:
111
- submission (List[Dict[]]): The submission to save.
112
- path (str): The path to save the submission to.
113
- """
114
- sub = pd.DataFrame(submission, columns=["__key__", "wf_vertices", "wf_edges"])
115
- sub.to_parquet(path)
116
- print(f"Submission saved to {path}")
117
-
118
- if __name__ == "__main__":
119
- from my_solution import predict
120
- print ("------------ Loading dataset------------ ")
121
- params = hoho.get_params()
122
- dataset = hoho.get_dataset(decode=None, split='all', dataset_type='webdataset')
123
-
124
- print('------------ Now you can do your solution ---------------')
125
- solution = []
126
- from concurrent.futures import ProcessPoolExecutor
127
- with ProcessPoolExecutor(max_workers=8) as pool:
128
- results = []
129
- for i, sample in enumerate(tqdm(dataset)):
130
- results.append(pool.submit(predict, sample, visualize=False))
131
-
132
- for i, result in enumerate(tqdm(results)):
133
- key, pred_vertices, pred_edges = result.result()
134
- solution.append({
135
- '__key__': key,
136
- 'wf_vertices': pred_vertices.tolist(),
137
- 'wf_edges': pred_edges
138
- })
139
- if i % 100 == 0:
140
- # incrementally save the results in case we run out of time
141
- print(f"Processed {i} samples")
142
- # save_submission(solution, Path(params['output_path']) / "submission.parquet")
143
- print('------------ Saving results ---------------')
144
- save_submission(solution, Path(params['output_path']) / "submission.parquet")
145
- print("------------ Done ------------ ")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_pointnet.py DELETED
@@ -1,148 +0,0 @@
1
- from __future__ import print_function
2
- import argparse
3
- import os
4
- import random
5
- import torch
6
- import torch.nn.parallel
7
- import torch.optim as optim
8
- import torch.utils.data
9
- from pointnet.dataset import ShapeNetDataset, ModelNetDataset
10
- from pointnet import PointNetCls, feature_transform_regularizer
11
- import torch.nn.functional as F
12
- from tqdm import tqdm
13
-
14
-
15
- parser = argparse.ArgumentParser()
16
- parser.add_argument(
17
- '--batchSize', type=int, default=32, help='input batch size')
18
- parser.add_argument(
19
- '--num_points', type=int, default=2500, help='input batch size')
20
- parser.add_argument(
21
- '--workers', type=int, help='number of data loading workers', default=4)
22
- parser.add_argument(
23
- '--nepoch', type=int, default=250, help='number of epochs to train for')
24
- parser.add_argument('--outf', type=str, default='cls', help='output folder')
25
- parser.add_argument('--model', type=str, default='', help='model path')
26
- parser.add_argument('--dataset', type=str, required=True, help="dataset path")
27
- parser.add_argument('--dataset_type', type=str, default='shapenet', help="dataset type shapenet|modelnet40")
28
- parser.add_argument('--feature_transform', action='store_true', help="use feature transform")
29
-
30
- opt = parser.parse_args()
31
- print(opt)
32
-
33
- blue = lambda x: '\033[94m' + x + '\033[0m'
34
-
35
- opt.manualSeed = random.randint(1, 10000) # fix seed
36
- print("Random Seed: ", opt.manualSeed)
37
- random.seed(opt.manualSeed)
38
- torch.manual_seed(opt.manualSeed)
39
-
40
- if opt.dataset_type == 'shapenet':
41
- dataset = ShapeNetDataset(
42
- root=opt.dataset,
43
- classification=True,
44
- npoints=opt.num_points)
45
-
46
- test_dataset = ShapeNetDataset(
47
- root=opt.dataset,
48
- classification=True,
49
- split='test',
50
- npoints=opt.num_points,
51
- data_augmentation=False)
52
- elif opt.dataset_type == 'modelnet40':
53
- dataset = ModelNetDataset(
54
- root=opt.dataset,
55
- npoints=opt.num_points,
56
- split='trainval')
57
-
58
- test_dataset = ModelNetDataset(
59
- root=opt.dataset,
60
- split='test',
61
- npoints=opt.num_points,
62
- data_augmentation=False)
63
- else:
64
- exit('wrong dataset type')
65
-
66
-
67
- dataloader = torch.utils.data.DataLoader(
68
- dataset,
69
- batch_size=opt.batchSize,
70
- shuffle=True,
71
- num_workers=int(opt.workers))
72
-
73
- testdataloader = torch.utils.data.DataLoader(
74
- test_dataset,
75
- batch_size=opt.batchSize,
76
- shuffle=True,
77
- num_workers=int(opt.workers))
78
-
79
- print(len(dataset), len(test_dataset))
80
- num_classes = len(dataset.classes)
81
- print('classes', num_classes)
82
-
83
- try:
84
- os.makedirs(opt.outf)
85
- except OSError:
86
- pass
87
-
88
- classifier = PointNetCls(k=num_classes, feature_transform=opt.feature_transform)
89
-
90
- if opt.model != '':
91
- classifier.load_state_dict(torch.load(opt.model))
92
-
93
-
94
- optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999))
95
- scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
96
- classifier.cuda()
97
-
98
- num_batch = len(dataset) / opt.batchSize
99
-
100
- for epoch in range(opt.nepoch):
101
- scheduler.step()
102
- for i, data in enumerate(dataloader, 0):
103
- points, target = data
104
- target = target[:, 0]
105
- points = points.transpose(2, 1)
106
- points, target = points.cuda(), target.cuda()
107
- optimizer.zero_grad()
108
- classifier = classifier.train()
109
- pred, trans, trans_feat = classifier(points)
110
- loss = F.nll_loss(pred, target)
111
- if opt.feature_transform:
112
- loss += feature_transform_regularizer(trans_feat) * 0.001
113
- loss.backward()
114
- optimizer.step()
115
- pred_choice = pred.data.max(1)[1]
116
- correct = pred_choice.eq(target.data).cpu().sum()
117
- print('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch, loss.item(), correct.item() / float(opt.batchSize)))
118
-
119
- if i % 10 == 0:
120
- j, data = next(enumerate(testdataloader, 0))
121
- points, target = data
122
- target = target[:, 0]
123
- points = points.transpose(2, 1)
124
- points, target = points.cuda(), target.cuda()
125
- classifier = classifier.eval()
126
- pred, _, _ = classifier(points)
127
- loss = F.nll_loss(pred, target)
128
- pred_choice = pred.data.max(1)[1]
129
- correct = pred_choice.eq(target.data).cpu().sum()
130
- print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize)))
131
-
132
- torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch))
133
-
134
- total_correct = 0
135
- total_testset = 0
136
- for i,data in tqdm(enumerate(testdataloader, 0)):
137
- points, target = data
138
- target = target[:, 0]
139
- points = points.transpose(2, 1)
140
- points, target = points.cuda(), target.cuda()
141
- classifier = classifier.eval()
142
- pred, _, _ = classifier(points)
143
- pred_choice = pred.data.max(1)[1]
144
- correct = pred_choice.eq(target.data).cpu().sum()
145
- total_correct += correct.item()
146
- total_testset += points.size()[0]
147
-
148
- print("final accuracy {}".format(total_correct / float(total_testset)))