hachreak commited on
Commit
22a2f9f
1 Parent(s): 7c04404

swin2mose: runnable version

Browse files

Signed-off-by: Leonardo Rossi <[email protected]>

.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pyc
swin2_mose/libs.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ def window_reverse(windows, window_size, H, W):
5
+ """
6
+ Args:
7
+ windows: (num_windows*B, window_size, window_size, C)
8
+ window_size (int): Window size
9
+ H (int): Height of image
10
+ W (int): Width of image
11
+
12
+ Returns:
13
+ x: (B, H, W, C)
14
+ """
15
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
16
+ x = windows.view(B, H // window_size, W // window_size, window_size,
17
+ window_size, -1)
18
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
19
+ return x
20
+
21
+
22
+ class Mlp(nn.Module):
23
+ def __init__(self, in_features, hidden_features=None, out_features=None,
24
+ act_layer=nn.GELU, drop=0.):
25
+ super().__init__()
26
+ out_features = out_features or in_features
27
+ hidden_features = hidden_features or in_features
28
+ self.fc1 = nn.Linear(in_features, hidden_features)
29
+ self.act = act_layer()
30
+ self.fc2 = nn.Linear(hidden_features, out_features)
31
+ self.drop = nn.Dropout(drop)
32
+
33
+ def forward(self, x):
34
+ x = self.fc1(x)
35
+ x = self.act(x)
36
+ x = self.drop(x)
37
+ x = self.fc2(x)
38
+ x = self.drop(x)
39
+ return x
40
+
41
+
42
+ def window_partition(x, window_size):
43
+ """
44
+ Args:
45
+ x: (B, H, W, C)
46
+ window_size (int): window size
47
+
48
+ Returns:
49
+ windows: (num_windows*B, window_size, window_size, C)
50
+ """
51
+ B, H, W, C = x.shape
52
+ x = x.view(B, H // window_size, window_size,
53
+ W // window_size, window_size, C)
54
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(
55
+ -1, window_size, window_size, C)
56
+ return windows
swin2_mose/model.py CHANGED
@@ -1,10 +1,9 @@
1
  #
2
- # Source code: https://github.com/mv-lab/swin2sr
3
  #
4
- # -----------------------------------------------------------------------------------
5
- # Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/2209.11345
6
- # Written by Conde and Choi et al.
7
- # -----------------------------------------------------------------------------------
8
 
9
  import math
10
  import numpy as np
@@ -14,7 +13,7 @@ import torch.nn.functional as F
14
  import torch.utils.checkpoint as checkpoint
15
  from timm.models.layers import DropPath, to_2tuple, trunc_normal_
16
 
17
- from utils import window_reverse, Mlp, window_partition
18
  from moe import MoE
19
 
20
 
@@ -746,9 +745,8 @@ class UpsampleOneStep(nn.Sequential):
746
 
747
 
748
 
749
- class Swin2SR(nn.Module):
750
- r""" Swin2SR
751
- A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
752
 
753
  Args:
754
  img_size (int | tuple(int)): Input image size. Default 64
@@ -784,8 +782,7 @@ class Swin2SR(nn.Module):
784
  MoE_config=None,
785
  use_rpe_bias=False,
786
  **kwargs):
787
- super(Swin2SR, self).__init__()
788
- print('==== SWIN 2SR')
789
  num_in_ch = in_chans
790
  num_out_ch = in_chans
791
  num_feat = 64
@@ -1154,4 +1151,4 @@ class Swin2SR(nn.Module):
1154
  flops += layer.flops()
1155
  flops += H * W * 3 * self.embed_dim * self.embed_dim
1156
  flops += self.upsample.flops()
1157
- return flops
 
1
  #
2
+ # Source code: https://github.com/IMPLabUniPr/swin2-mose
3
  #
4
+ # ----------------------------------------------------------------------------
5
+ # https://arxiv.org/abs/2404.18924
6
+ # ----------------------------------------------------------------------------
 
7
 
8
  import math
9
  import numpy as np
 
13
  import torch.utils.checkpoint as checkpoint
14
  from timm.models.layers import DropPath, to_2tuple, trunc_normal_
15
 
16
+ from libs import window_reverse, Mlp, window_partition
17
  from moe import MoE
18
 
19
 
 
745
 
746
 
747
 
748
+ class Swin2MoSE(nn.Module):
749
+ r""" Swin2-MoSE
 
750
 
751
  Args:
752
  img_size (int | tuple(int)): Input image size. Default 64
 
782
  MoE_config=None,
783
  use_rpe_bias=False,
784
  **kwargs):
785
+ super(Swin2MoSE, self).__init__()
 
786
  num_in_ch = in_chans
787
  num_out_ch = in_chans
788
  num_feat = 64
 
1151
  flops += layer.flops()
1152
  flops += H * W * 3 * self.embed_dim * self.embed_dim
1153
  flops += self.upsample.flops()
1154
+ return flops
swin2_mose/moe.py CHANGED
@@ -18,7 +18,8 @@ from torch.distributions.normal import Normal
18
  from copy import deepcopy
19
  import numpy as np
20
 
21
- from utils import Mlp as MLP
 
22
 
23
  class SparseDispatcher(object):
24
  """Helper for implementing a mixture of experts.
@@ -320,4 +321,4 @@ class MoE(nn.Module):
320
  expert_outputs = [self.experts[i](expert_inputs[i])
321
  for i in range(self.num_experts)]
322
  y = dispatcher.combine(expert_outputs, cnn_combine=self.cnn_combine)
323
- return y, loss
 
18
  from copy import deepcopy
19
  import numpy as np
20
 
21
+ from libs import Mlp as MLP
22
+
23
 
24
  class SparseDispatcher(object):
25
  """Helper for implementing a mixture of experts.
 
321
  expert_outputs = [self.experts[i](expert_inputs[i])
322
  for i in range(self.num_experts)]
323
  y = dispatcher.combine(expert_outputs, cnn_combine=self.cnn_combine)
324
+ return y, loss
swin2_mose/run.py CHANGED
@@ -1,20 +1,36 @@
1
- import torch
2
- from model import Swin2SR
3
-
4
- model_weights = "model-70.pt"
5
- model_params = {
6
- "upscale": 2,
7
- "in_chans": 4,
8
- "img_size": 64,
9
- "window_size": 16,
10
- "img_range": 1.,
11
- "depths": [6, 6, 6, 6],
12
- "embed_dim": 90,
13
- "num_heads": [6, 6, 6, 6],
14
- "mlp_ratio": 2,
15
- "upsampler": "pixelshuffledirect",
16
- "resi_connection": "1conv"
17
- }
18
-
19
- sr_model = Swin2SR(**model_params)
20
- sr_model.load_state_dict(torch.load(model_weights))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import benchmark
2
+ import matplotlib.pyplot as plt
3
+ import opensr_test
4
+
5
+ from utils import load_swin2_mose, load_config, run_swin2_mose
6
+
7
+
8
+ path = 'swin2_mose/weights/config-70.yml'
9
+ model_weights = "swin2_mose/weights/model-70.pt"
10
+ index = 2
11
+
12
+ # load config
13
+ cfg = load_config(path)
14
+ # load model
15
+ model = load_swin2_mose(model_weights, cfg)
16
+
17
+ # load the dataset
18
+ dataset = opensr_test.load("venus")
19
+ lr_dataset, hr_dataset = dataset["L2A"], dataset["HRharm"]
20
+
21
+ results = run_swin2_mose(model, lr_dataset[index], hr_dataset[index])
22
+
23
+ # Display the results
24
+ fig, ax = plt.subplots(1, 3, figsize=(10, 5))
25
+ ax[0].imshow(results['lr'].numpy().transpose(1, 2, 0)/3000)
26
+ ax[0].set_title("LR")
27
+ ax[0].axis("off")
28
+ ax[1].imshow(results["sr"].detach().numpy().transpose(1, 2, 0)/3000)
29
+ ax[1].set_title("SR")
30
+ ax[1].axis("off")
31
+ ax[2].imshow(results['hr'].numpy().transpose(1, 2, 0) / 3000)
32
+ ax[2].set_title("HR")
33
+ # plt.show()
34
+
35
+ # Run the experiment
36
+ benchmark.create_geotiff(model, run_swin2_mose, "all", "swin2mose/")
swin2_mose/utils.py CHANGED
@@ -1,56 +1,77 @@
1
- from torch import nn
2
-
3
-
4
- def window_reverse(windows, window_size, H, W):
5
- """
6
- Args:
7
- windows: (num_windows*B, window_size, window_size, C)
8
- window_size (int): Window size
9
- H (int): Height of image
10
- W (int): Width of image
11
-
12
- Returns:
13
- x: (B, H, W, C)
14
- """
15
- B = int(windows.shape[0] / (H * W / window_size / window_size))
16
- x = windows.view(B, H // window_size, W // window_size, window_size,
17
- window_size, -1)
18
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
19
- return x
20
-
21
-
22
- class Mlp(nn.Module):
23
- def __init__(self, in_features, hidden_features=None, out_features=None,
24
- act_layer=nn.GELU, drop=0.):
25
- super().__init__()
26
- out_features = out_features or in_features
27
- hidden_features = hidden_features or in_features
28
- self.fc1 = nn.Linear(in_features, hidden_features)
29
- self.act = act_layer()
30
- self.fc2 = nn.Linear(hidden_features, out_features)
31
- self.drop = nn.Dropout(drop)
32
-
33
- def forward(self, x):
34
- x = self.fc1(x)
35
- x = self.act(x)
36
- x = self.drop(x)
37
- x = self.fc2(x)
38
- x = self.drop(x)
39
- return x
40
-
41
-
42
- def window_partition(x, window_size):
43
- """
44
- Args:
45
- x: (B, H, W, C)
46
- window_size (int): window size
47
-
48
- Returns:
49
- windows: (num_windows*B, window_size, window_size, C)
50
- """
51
- B, H, W, C = x.shape
52
- x = x.view(B, H // window_size, window_size,
53
- W // window_size, window_size, C)
54
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(
55
- -1, window_size, window_size, C)
56
- return windows
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import yaml
3
+
4
+ from model import Swin2MoSE
5
+
6
+
7
+ def to_shape(t1, t2):
8
+ t1 = t1[None].repeat(t2.shape[0], 1)
9
+ t1 = t1.view((t2.shape[:2] + (1, 1)))
10
+ return t1
11
+
12
+
13
+ def norm(tensor, mean, std):
14
+ # get stats
15
+ mean = torch.tensor(mean).to(tensor.device)
16
+ std = torch.tensor(std).to(tensor.device)
17
+ # denorm
18
+ return (tensor - to_shape(mean, tensor)) / to_shape(std, tensor)
19
+
20
+
21
+ def denorm(tensor, mean, std):
22
+ # get stats
23
+ mean = torch.tensor(mean).to(tensor.device)
24
+ std = torch.tensor(std).to(tensor.device)
25
+ # denorm
26
+ return (tensor * to_shape(std, tensor)) + to_shape(mean, tensor)
27
+
28
+
29
+ def load_config(path):
30
+ # load config
31
+ with open(path, 'r') as f:
32
+ cfg = yaml.safe_load(f)
33
+ return cfg
34
+
35
+
36
+ def load_swin2_mose(model_weights, cfg):
37
+ # load checkpoint
38
+ checkpoint = torch.load(model_weights)
39
+
40
+ # build model
41
+ sr_model = Swin2MoSE(**cfg['super_res']['model'])
42
+ sr_model.load_state_dict(
43
+ checkpoint['model_state_dict'])
44
+
45
+ sr_model.cfg = cfg
46
+
47
+ return sr_model
48
+
49
+
50
+ def run_swin2_mose(model, lr, hr):
51
+ cfg = model.cfg
52
+
53
+ # norm fun
54
+ hr_stats = cfg['dataset']['stats']['tensor_05m_b2b3b4b8']
55
+ lr_stats = cfg['dataset']['stats']['tensor_10m_b2b3b4b8']
56
+
57
+ # select 10m lr bands: B02, B03, B04, B08 and hr bands
58
+ lr_orig = torch.tensor(lr)[None].float()[:, [3, 2, 1, 7]]
59
+ hr_orig = torch.tensor(hr)[None].float()
60
+
61
+ # normalize data
62
+ lr = norm(lr_orig, mean=lr_stats['mean'], std=lr_stats['std'])
63
+ hr = norm(hr_orig, mean=hr_stats['mean'], std=hr_stats['std'])
64
+
65
+ # predict a image
66
+ sr = model(lr)
67
+ if not torch.is_tensor(sr):
68
+ sr, _ = sr
69
+
70
+ # denorm sr
71
+ sr = denorm(sr, mean=hr_stats['mean'], std=hr_stats['std'])
72
+
73
+ return {
74
+ "lr": lr_orig[0],
75
+ "sr": sr[0],
76
+ "hr": hr_orig[0],
77
+ }
swin2_mose/weights/config-70.yml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ root_path: data/sen2venus
3
+ stats:
4
+ use_minmax: true
5
+ tensor_05m_b2b3b4b8: {
6
+ mean: [444.21923828125, 715.9031372070312, 813.4345703125, 2604.867919921875],
7
+ std: [279.85552978515625, 385.3569641113281, 648.458984375, 796.9918212890625],
8
+ min: [-1025.0, -3112.0, -5122.0, -3851.0],
9
+ max: [14748.0, 14960.0, 16472.0, 16109.0]
10
+ }
11
+ tensor_10m_b2b3b4b8: {
12
+ mean: [443.78643798828125, 715.4202270507812, 813.0512084960938, 2602.813232421875],
13
+ std: [283.89276123046875, 389.26361083984375, 651.094970703125, 811.5682373046875],
14
+ min: [-848.0, -902.0, -946.0, -323.0],
15
+ max: [19684.0, 17982.0, 17064.0, 15958.0]
16
+ }
17
+ hr_name: tensor_05m_b2b3b4b8
18
+ lr_name: tensor_10m_b2b3b4b8
19
+ collate_fn: mods.v3.collate_fn
20
+ denorm: mods.v3.uncollate_fn
21
+ printable: mods.v3.printable
22
+ super_res: {
23
+ version: 'v2',
24
+ model: {
25
+ upscale: 2,
26
+ use_lepe: true,
27
+ use_cpb_bias: false,
28
+ use_rpe_bias: true,
29
+ mlp_ratio: 1,
30
+ MoE_config: {
31
+ k: 2,
32
+ num_experts: 8,
33
+ with_noise: false,
34
+ with_smart_merger: v1,
35
+ },
36
+ depths: [6, 6, 6, 6],
37
+ embed_dim: 90,
38
+ img_range: 1.,
39
+ img_size: 64,
40
+ in_chans: 4,
41
+ num_heads: [6, 6, 6, 6],
42
+ resi_connection: 1conv,
43
+ upsampler: pixelshuffledirect,
44
+ window_size: 16,
45
+ }
46
+ }