Spaces:
Sleeping
Sleeping
Hugo Flores
commited on
Commit
•
9439b64
1
Parent(s):
b862275
fix sampling logic for paella
Browse files- vampnet/modules/base.py +13 -18
vampnet/modules/base.py
CHANGED
@@ -153,7 +153,6 @@ class VampBase(at.ml.BaseModel):
|
|
153 |
sampling_steps: int = 12,
|
154 |
start_tokens: Optional[torch.Tensor] = None,
|
155 |
mask: Optional[torch.Tensor] = None,
|
156 |
-
device: str = "cpu",
|
157 |
temperature: Union[float, Tuple[float, float]] = 1.0,
|
158 |
top_k: int = None,
|
159 |
sample: str = "gumbel",
|
@@ -164,7 +163,8 @@ class VampBase(at.ml.BaseModel):
|
|
164 |
typical_min_tokens=1,
|
165 |
return_signal=True,
|
166 |
):
|
167 |
-
|
|
|
168 |
if renoise_steps == None:
|
169 |
renoise_steps = sampling_steps - 1
|
170 |
|
@@ -186,7 +186,7 @@ class VampBase(at.ml.BaseModel):
|
|
186 |
if self.noise_mode == "noise":
|
187 |
z = torch.randint(
|
188 |
0, self.vocab_size, size=(1, self.n_codebooks, time_steps)
|
189 |
-
).to(device)
|
190 |
elif self.noise_mode == "mask":
|
191 |
z = torch.full((1, self.n_codebooks, time_steps), self.mask_token)
|
192 |
else:
|
@@ -197,19 +197,14 @@ class VampBase(at.ml.BaseModel):
|
|
197 |
assert z.shape[0] == 1, f"batch size must be 1"
|
198 |
|
199 |
if mask is None:
|
200 |
-
mask = torch.ones(z.shape[0], z.shape[-1]).to(device).int()
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
z.shape[0],
|
205 |
-
z.shape[-1],
|
206 |
-
), f"mask must be shape (batch, seq_len), got {mask.shape}"
|
207 |
-
mask = mask[:, None, :]
|
208 |
-
mask = mask.repeat(1, z.shape[1], 1)
|
209 |
mask[:, : self.n_conditioning_codebooks, :] = 0.0
|
210 |
|
211 |
-
|
212 |
-
|
213 |
|
214 |
z, mask = self.add_noise(z, r=r[0], random_x=None, mask=mask)
|
215 |
z_init = z.clone()
|
@@ -228,8 +223,8 @@ class VampBase(at.ml.BaseModel):
|
|
228 |
|
229 |
z = self.sample_from_logits(
|
230 |
logits,
|
231 |
-
|
232 |
-
|
233 |
sample=sample,
|
234 |
typical_filtering=typical_filtering,
|
235 |
typical_mass=typical_mass,
|
@@ -323,7 +318,7 @@ class VampBase(at.ml.BaseModel):
|
|
323 |
# how many codebooks are we inferring vs conditioning on?
|
324 |
n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
|
325 |
|
326 |
-
for i in
|
327 |
# our current temperature
|
328 |
tmpt = temperature[i]
|
329 |
|
@@ -450,7 +445,7 @@ class VampBase(at.ml.BaseModel):
|
|
450 |
probs = torch.softmax(logits, dim=-1)
|
451 |
inferred = torch.stack([pr.multinomial(1).squeeze(-1) for pr in probs])
|
452 |
elif sample == "argmax":
|
453 |
-
inferred = torch.softmax(
|
454 |
elif sample == "gumbel":
|
455 |
inferred = gumbel_sample(logits, dim=-1)
|
456 |
else:
|
|
|
153 |
sampling_steps: int = 12,
|
154 |
start_tokens: Optional[torch.Tensor] = None,
|
155 |
mask: Optional[torch.Tensor] = None,
|
|
|
156 |
temperature: Union[float, Tuple[float, float]] = 1.0,
|
157 |
top_k: int = None,
|
158 |
sample: str = "gumbel",
|
|
|
163 |
typical_min_tokens=1,
|
164 |
return_signal=True,
|
165 |
):
|
166 |
+
|
167 |
+
r = torch.linspace(0, 1, sampling_steps + 1)[:-1][:, None].to(self.device)
|
168 |
if renoise_steps == None:
|
169 |
renoise_steps = sampling_steps - 1
|
170 |
|
|
|
186 |
if self.noise_mode == "noise":
|
187 |
z = torch.randint(
|
188 |
0, self.vocab_size, size=(1, self.n_codebooks, time_steps)
|
189 |
+
).to(self.device)
|
190 |
elif self.noise_mode == "mask":
|
191 |
z = torch.full((1, self.n_codebooks, time_steps), self.mask_token)
|
192 |
else:
|
|
|
197 |
assert z.shape[0] == 1, f"batch size must be 1"
|
198 |
|
199 |
if mask is None:
|
200 |
+
mask = torch.ones(z.shape[0], z.shape[-1]).to(self.device).int()
|
201 |
+
mask = mask[:, None, :]
|
202 |
+
mask = mask.repeat(1, z.shape[1], 1)
|
203 |
+
|
|
|
|
|
|
|
|
|
|
|
204 |
mask[:, : self.n_conditioning_codebooks, :] = 0.0
|
205 |
|
206 |
+
|
207 |
+
z_true = z.clone()
|
208 |
|
209 |
z, mask = self.add_noise(z, r=r[0], random_x=None, mask=mask)
|
210 |
z_init = z.clone()
|
|
|
223 |
|
224 |
z = self.sample_from_logits(
|
225 |
logits,
|
226 |
+
top_k=top_k,
|
227 |
+
temperature=tmpt,
|
228 |
sample=sample,
|
229 |
typical_filtering=typical_filtering,
|
230 |
typical_mass=typical_mass,
|
|
|
318 |
# how many codebooks are we inferring vs conditioning on?
|
319 |
n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
|
320 |
|
321 |
+
for i in range(sampling_steps):
|
322 |
# our current temperature
|
323 |
tmpt = temperature[i]
|
324 |
|
|
|
445 |
probs = torch.softmax(logits, dim=-1)
|
446 |
inferred = torch.stack([pr.multinomial(1).squeeze(-1) for pr in probs])
|
447 |
elif sample == "argmax":
|
448 |
+
inferred = torch.softmax(logits, dim=-1).argmax(dim=-1)
|
449 |
elif sample == "gumbel":
|
450 |
inferred = gumbel_sample(logits, dim=-1)
|
451 |
else:
|