File size: 8,910 Bytes
e909f79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
diff --git a/src/gan_control/inference/controller.py b/src/gan_control/inference/controller.py
index ee464ba..d1907dd 100644
--- a/src/gan_control/inference/controller.py
+++ b/src/gan_control/inference/controller.py
@@ -13,9 +13,9 @@ _log = get_logger(__name__)
 
 
 class Controller(Inference):
-    def __init__(self, controller_dir):
+    def __init__(self, controller_dir, device):
         _log.info('Init Controller class...')
-        super(Controller, self).__init__(os.path.join(controller_dir, 'generator'))
+        super(Controller, self).__init__(os.path.join(controller_dir, 'generator'), device)
         self.fc_controls = {}
         self.config_controls = {}
         for sub_group_name in self.batch_utils.sub_group_names:
@@ -29,21 +29,21 @@ class Controller(Inference):
     @torch.no_grad()
     def gen_batch_by_controls(self, batch_size=1, latent=None, normalize=True, input_is_latent=False, static_noise=True, **kwargs):
         if latent is None:
-            latent = torch.randn(batch_size, self.config.model_config['latent_size'], device='cuda')
+            latent = torch.randn(batch_size, self.config.model_config['latent_size'], device=self.device)
         latent = latent.clone()
         if input_is_latent:
             latent_w = latent
         else:
             if isinstance(self.model, torch.nn.DataParallel):
-                latent_w = self.model.module.style(latent.cuda())
+                latent_w = self.model.module.style(latent.to(self.device))
             else:
-                latent_w = self.model.style(latent.cuda())
+                latent_w = self.model.style(latent.to(self.device))
         for group_key in kwargs.keys():
             if self.check_if_group_has_control(group_key):
                 if group_key == 'expression' and kwargs[group_key].shape[1] == 8:
-                    group_w_latent = self.fc_controls['expression_q'](kwargs[group_key].cuda().float())
+                    group_w_latent = self.fc_controls['expression_q'](kwargs[group_key].to(self.device).float())
                 else:
-                    group_w_latent = self.fc_controls[group_key](kwargs[group_key].cuda().float())
+                    group_w_latent = self.fc_controls[group_key](kwargs[group_key].to(self.device).float())
                 latent_w = self.insert_group_w_latent(latent_w, group_w_latent, group_key)
         injection_noise = None
         if static_noise:
@@ -101,12 +101,12 @@ class Controller(Inference):
         ckpt_path = ckpt_list[-1]
         ckpt_iter = ckpt_path.split('.')[0]
         config = read_json(config_path, return_obj=True)
-        ckpt = torch.load(os.path.join(checkpoints_path, ckpt_path))
+        ckpt = torch.load(os.path.join(checkpoints_path, ckpt_path), map_location=self.device)
         group_chunk = self.batch_utils.place_in_latent_dict[sub_group_name if sub_group_name is not 'expression_q' else 'expression']
         group_latent_size = group_chunk[1] - group_chunk[0]
 
         _log.info('Init %s Controller...' % sub_group_name)
-        controller = FcStack(config.model_config['lr_mlp'], config.model_config['n_mlp'], config.model_config['in_dim'], config.model_config['mid_dim'], group_latent_size).cuda()
+        controller = FcStack(config.model_config['lr_mlp'], config.model_config['n_mlp'], config.model_config['in_dim'], config.model_config['mid_dim'], group_latent_size).to(self.device)
         controller.print()
 
         _log.info('Loading Controller: %s, ckpt iter %s' % (controller_dir_path, ckpt_iter))
diff --git a/src/gan_control/inference/inference.py b/src/gan_control/inference/inference.py
index e6ccedb..4393bb7 100644
--- a/src/gan_control/inference/inference.py
+++ b/src/gan_control/inference/inference.py
@@ -15,10 +15,11 @@ _log = get_logger(__name__)
 
 
 class Inference():
-    def __init__(self, model_dir):
+    def __init__(self, model_dir, device):
         _log.info('Init inference class...')
         self.model_dir = model_dir
-        self.model, self.batch_utils, self.config, self.ckpt_iter = self.retrieve_model(model_dir)
+        self.device = device
+        self.model, self.batch_utils, self.config, self.ckpt_iter = self.retrieve_model(model_dir, device)
         self.noise = None
         self.reset_noise()
         self.mean_w_latent = None
@@ -28,7 +29,7 @@ class Inference():
         _log.info('Calc mean_w_latents...')
         mean_latent_w_list = []
         for i in range(100):
-            latent_z = torch.randn(1000, self.config.model_config['latent_size'], device='cuda')
+            latent_z = torch.randn(1000, self.config.model_config['latent_size'], device=self.device)
             if isinstance(self.model, torch.nn.DataParallel):
                 latent_w = self.model.module.style(latent_z).cpu()
             else:
@@ -41,9 +42,9 @@ class Inference():
 
     def reset_noise(self):
         if isinstance(self.model, torch.nn.DataParallel):
-            self.noise = self.model.module.make_noise(device='cuda')
+            self.noise = self.model.module.make_noise(device=self.device)
         else:
-            self.noise = self.model.make_noise(device='cuda')
+            self.noise = self.model.make_noise(device=self.device)
 
     @staticmethod
     def expend_noise(noise, batch_size):
@@ -56,14 +57,14 @@ class Inference():
             self.calc_mean_w_latents()
         injection_noise = None
         if latent is None:
-            latent = torch.randn(batch_size, self.config.model_config['latent_size'], device='cuda')
+            latent = torch.randn(batch_size, self.config.model_config['latent_size'], device=self.device)
         elif input_is_latent:
-            latent = latent.cuda()
+            latent = latent.to(self.device)
             for group_key in kwargs.keys():
                 if group_key not in self.batch_utils.sub_group_names:
                     raise ValueError('group_key: %s not in sub_group_names %s' % (group_key, str(self.batch_utils.sub_group_names)))
                 if isinstance(kwargs[group_key], str) and kwargs[group_key] == 'random':
-                    group_latent_w = self.model.style(torch.randn(latent.shape[0], self.config.model_config['latent_size'], device='cuda'))
+                    group_latent_w = self.model.style(torch.randn(latent.shape[0], self.config.model_config['latent_size'], device=self.device))
                     group_latent_w = group_latent_w[:, self.batch_utils.place_in_latent_dict[group_key][0], self.batch_utils.place_in_latent_dict[group_key][0]]
                     latent[:, self.batch_utils.place_in_latent_dict[group_key][0], self.batch_utils.place_in_latent_dict[group_key][0]] = group_latent_w
         if static_noise:
@@ -82,11 +83,11 @@ class Inference():
                 latent[:, place_in_latent[0]: place_in_latent[1]] = \
                     truncation * (latent[:, place_in_latent[0]: place_in_latent[1]] - torch.cat(
                         [self.mean_w_latents[key].clone().unsqueeze(0) for _ in range(latent.shape[0])], dim=0
-                    ).cuda()) + torch.cat(
+                    ).to(self.device)) + torch.cat(
                         [self.mean_w_latents[key].clone().unsqueeze(0) for _ in range(latent.shape[0])], dim=0
-                    ).cuda()
+                    ).to(self.device)
 
-        tensor, latent_w = self.model([latent.cuda()], return_latents=True, input_is_latent=input_is_latent, noise=injection_noise)
+        tensor, latent_w = self.model([latent.to(self.device)], return_latents=True, input_is_latent=input_is_latent, noise=injection_noise)
         if normalize:
             tensor = tensor.mul(0.5).add(0.5).clamp(min=0., max=1.).cpu()
         return tensor, latent, latent_w
@@ -107,7 +108,7 @@ class Inference():
         return grid_image
 
     @staticmethod
-    def retrieve_model(model_dir):
+    def retrieve_model(model_dir, device):
         config_path = os.path.join(model_dir, 'args.json')
 
         _log.info('Retrieve config from %s' % config_path)
@@ -117,7 +118,7 @@ class Inference():
         ckpt_path = ckpt_list[-1]
         ckpt_iter = ckpt_path.split('.')[0]
         config = read_json(config_path, return_obj=True)
-        ckpt = torch.load(os.path.join(checkpoints_path, ckpt_path))
+        ckpt = torch.load(os.path.join(checkpoints_path, ckpt_path), map_location=device)
 
         batch_utils = None
         if not config.model_config['vanilla']:
@@ -140,7 +141,7 @@ class Inference():
             fc_config=None if config.model_config['vanilla'] else batch_utils.get_fc_config(),
             conv_transpose=config.model_config['conv_transpose'],
             noise_mode=config.model_config['g_noise_mode']
-        ).cuda()
+        ).to(device)
         _log.info('Loading Model: %s, ckpt iter %s' % (model_dir, ckpt_iter))
         model.load_state_dict(ckpt['g_ema'])
         model = torch.nn.DataParallel(model)