yvokeller commited on
Commit
1912458
1 Parent(s): 5b24075

ignore data

Browse files
Files changed (2) hide show
  1. .gitignore +3 -1
  2. inference.py +18 -12
.gitignore CHANGED
@@ -1,2 +1,4 @@
1
  hf_cache
2
- __pycache__
 
 
 
1
  hf_cache
2
+ __pycache__
3
+ .DS_Store
4
+ data/
inference.py CHANGED
@@ -66,6 +66,9 @@ class InferenceDataLoader:
66
  with rasterio.open(path) as src:
67
  # Transform the coordinates from WGS84 to UTM (EPSG:32632)
68
  utm_x, utm_y = self.transformer.transform(lon, lat)
 
 
 
69
 
70
  try:
71
  px, py = rowcol(src.transform, utm_x, utm_y)
@@ -77,20 +80,23 @@ class InferenceDataLoader:
77
 
78
  half_window_size = self.window_size // 2
79
 
80
- col_off = px - half_window_size
81
- row_off = py - half_window_size
82
 
83
- if col_off < 0:
84
- col_off = 0
85
  if row_off < 0:
86
  row_off = 0
87
- if col_off + self.window_size > src.width:
88
- col_off = src.width - self.window_size
89
- if row_off + self.window_size > src.height:
90
- row_off = src.height - self.window_size
 
 
91
 
92
  window = Window(col_off, row_off, self.window_size, self.window_size)
93
  window_transform = src.window_transform(window)
 
 
 
94
  crs = src.crs
95
 
96
  return window, window_transform, crs
@@ -193,10 +199,10 @@ def crop_predictions_to_gdf(field_ids, targets, predictions, transform, crs, cla
193
  return gdf
194
 
195
  def perform_inference(lon, lat, model, config, debug=False):
196
- features_path = "./data/stacked_features.tif"
197
- labels_path = "./data/labels.tif"
198
- field_ids_path = "./data/field_ids.tif"
199
- stats_path = "./data/chips_stats.yaml"
200
 
201
  loader = InferenceDataLoader(features_path, labels_path, field_ids_path, stats_path, n_timesteps=9, fold_indices=[0], debug=True)
202
 
 
66
  with rasterio.open(path) as src:
67
  # Transform the coordinates from WGS84 to UTM (EPSG:32632)
68
  utm_x, utm_y = self.transformer.transform(lon, lat)
69
+ if self.debug:
70
+ print("Source Transform", src.transform)
71
+ print(f"UTM X: {utm_x}, UTM Y: {utm_y}")
72
 
73
  try:
74
  px, py = rowcol(src.transform, utm_x, utm_y)
 
80
 
81
  half_window_size = self.window_size // 2
82
 
83
+ row_off = px - half_window_size
84
+ col_off = py - half_window_size
85
 
 
 
86
  if row_off < 0:
87
  row_off = 0
88
+ if col_off < 0:
89
+ col_off = 0
90
+ if row_off + self.window_size > src.width:
91
+ row_off = src.width - self.window_size
92
+ if col_off + self.window_size > src.height:
93
+ col_off = src.height - self.window_size
94
 
95
  window = Window(col_off, row_off, self.window_size, self.window_size)
96
  window_transform = src.window_transform(window)
97
+ if self.debug:
98
+ print(f"Window: {window}")
99
+ print(f"Window Transform: {window_transform}")
100
  crs = src.crs
101
 
102
  return window, window_transform, crs
 
199
  return gdf
200
 
201
  def perform_inference(lon, lat, model, config, debug=False):
202
+ features_path = "../data/stacked_features.tif"
203
+ labels_path = "../data/labels.tif"
204
+ field_ids_path = "../data/field_ids.tif"
205
+ stats_path = "../data/chips_stats.yaml"
206
 
207
  loader = InferenceDataLoader(features_path, labels_path, field_ids_path, stats_path, n_timesteps=9, fold_indices=[0], debug=True)
208