Spaces:
Runtime error
Runtime error
dtypes weird stuff
Browse files
backend/disentangle_concepts.py
CHANGED
@@ -84,14 +84,16 @@ def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3
|
|
84 |
for _, lambda_ in enumerate(lambdas):
|
85 |
z_0 = z + lambda_ * decision_boundary
|
86 |
if latent_space == 'Z':
|
87 |
-
W_0 = G.mapping(z_0, label, truncation_psi=1)
|
88 |
-
W = G.mapping(z, label, truncation_psi=1)
|
|
|
89 |
else:
|
90 |
-
W_0 = z_0.expand((14, -1)).unsqueeze(0)
|
91 |
-
W = z.expand((14, -1)).unsqueeze(0)
|
|
|
92 |
|
93 |
if layers:
|
94 |
-
W_f = torch.empty_like(W).copy_(W)
|
95 |
W_f[:, layers, :] = W_0[:, layers, :]
|
96 |
img = G.synthesis(W_f, noise_mode='const')
|
97 |
else:
|
@@ -125,7 +127,7 @@ def generate_original_image(z, model, latent_space='Z'):
|
|
125 |
label = torch.zeros([1, G.c_dim], device=device)
|
126 |
if latent_space == 'Z':
|
127 |
z = torch.from_numpy(z.copy()).to(device)
|
128 |
-
img = G(z, label, truncation_psi=
|
129 |
else:
|
130 |
W = torch.from_numpy(np.repeat(z, 14, axis=0).reshape(1, 14, z.shape[1]).copy()).to(device)
|
131 |
print(W.shape)
|
|
|
84 |
for _, lambda_ in enumerate(lambdas):
|
85 |
z_0 = z + lambda_ * decision_boundary
|
86 |
if latent_space == 'Z':
|
87 |
+
W_0 = G.mapping(z_0, label, truncation_psi=1).to(torch.float32)
|
88 |
+
W = G.mapping(z, label, truncation_psi=1).to(torch.float32)
|
89 |
+
print(W.dtype)
|
90 |
else:
|
91 |
+
W_0 = z_0.expand((14, -1)).unsqueeze(0).to(torch.float32)
|
92 |
+
W = z.expand((14, -1)).unsqueeze(0).to(torch.float32)
|
93 |
+
print(W.dtype)
|
94 |
|
95 |
if layers:
|
96 |
+
W_f = torch.empty_like(W).copy_(W).to(torch.float32)
|
97 |
W_f[:, layers, :] = W_0[:, layers, :]
|
98 |
img = G.synthesis(W_f, noise_mode='const')
|
99 |
else:
|
|
|
127 |
label = torch.zeros([1, G.c_dim], device=device)
|
128 |
if latent_space == 'Z':
|
129 |
z = torch.from_numpy(z.copy()).to(device)
|
130 |
+
img = G(z, label, truncation_psi=1, noise_mode='const')
|
131 |
else:
|
132 |
W = torch.from_numpy(np.repeat(z, 14, axis=0).reshape(1, 14, z.shape[1]).copy()).to(device)
|
133 |
print(W.shape)
|