Paolo-Fraccaro commited on
Commit
842ee21
1 Parent(s): e3f7268

fix channels order and no data

Browse files
Files changed (1) hide show
  1. app.py +15 -16
app.py CHANGED
@@ -137,37 +137,36 @@ def inference_segmentor(model, imgs, custom_test_pipeline=None):
137
  def inference_on_file(target_image, model, custom_test_pipeline):
138
 
139
  target_image = target_image.name
140
- # print(type(target_image))
141
 
142
- # output_image = target_image.replace('.tif', '_pred.tif')
143
  time_taken=-1
144
 
145
  st = time.time()
146
  print('Running inference...')
147
- result = inference_segmentor(model, target_image, custom_test_pipeline)
 
 
 
 
 
 
148
  print("Output has shape: " + str(result[0].shape))
149
 
150
- ##### get metadata mask
151
  mask = open_tiff(target_image)
152
  rgb = stretch_rgb((mask[[3, 2, 1], :, :].transpose((1,2,0))/10000*255).astype(np.uint8))
153
  meta = get_meta(target_image)
154
  mask = np.where(mask == meta['nodata'], 1, 0)
155
  mask = np.max(mask, axis=0)[None]
156
-
157
- result[0] = np.where(mask == 1, -1, result[0])
158
-
159
- ##### Save file to disk
160
- meta["count"] = 1
161
- meta["dtype"] = "int16"
162
- meta["compress"] = "lzw"
163
- meta["nodata"] = -1
164
- print('Saving output...')
165
- # write_tiff(result[0], output_image, meta)
166
  et = time.time()
167
  time_taken = np.round(et - st, 1)
168
  print(f'Inference completed in {str(time_taken)} seconds')
169
-
170
- return rgb, result[0][0]*255
171
 
172
  def process_test_pipeline(custom_test_pipeline, bands=None):
173
 
 
137
  def inference_on_file(target_image, model, custom_test_pipeline):
138
 
139
  target_image = target_image.name
 
140
 
 
141
  time_taken=-1
142
 
143
  st = time.time()
144
  print('Running inference...')
145
+ try:
146
+ result = inference_segmentor(model, target_image, custom_test_pipeline)
147
+ except:
148
+ print('Error: Try different channels order.')
149
+ model.cfg.data.test.pipeline[0]['channels_last'] = True
150
+ result = inference_segmentor(model, target_image, custom_test_pipeline)
151
+
152
  print("Output has shape: " + str(result[0].shape))
153
 
154
+ ##### prep outputs
155
  mask = open_tiff(target_image)
156
  rgb = stretch_rgb((mask[[3, 2, 1], :, :].transpose((1,2,0))/10000*255).astype(np.uint8))
157
  meta = get_meta(target_image)
158
  mask = np.where(mask == meta['nodata'], 1, 0)
159
  mask = np.max(mask, axis=0)[None]
160
+ rgb = np.where(mask.transpose((1,2,0)) == 1, 0, rgb)
161
+ rgb = np.where(rgb < 0, 0, rgb)
162
+ rgb = np.where(rgb > 255, 255, rgb)
163
+
164
+ prediction = np.where(mask == 1, 0, result[0]*255)
 
 
 
 
 
165
  et = time.time()
166
  time_taken = np.round(et - st, 1)
167
  print(f'Inference completed in {str(time_taken)} seconds')
168
+
169
+ return rgb, prediction[0]
170
 
171
  def process_test_pipeline(custom_test_pipeline, bands=None):
172