Spaces:
Running
on
T4
Running
on
T4
Hugo Flores
commited on
Commit
•
326b5bb
1
Parent(s):
04c5b94
fix: sample prefix suffix
Browse files- scripts/exp/train.py +4 -3
scripts/exp/train.py
CHANGED
@@ -216,6 +216,7 @@ def accuracy(
|
|
216 |
return accuracy
|
217 |
|
218 |
def sample_prefix_suffix_amt(
|
|
|
219 |
n_batch,
|
220 |
prefix_amt,
|
221 |
suffix_amt,
|
@@ -362,7 +363,7 @@ def train(
|
|
362 |
n_batch = z.shape[0]
|
363 |
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
364 |
|
365 |
-
n_prefix, n_suffix = sample_prefix_suffix_amt(
|
366 |
n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
|
367 |
prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
|
368 |
rng=rng
|
@@ -448,7 +449,7 @@ def train(
|
|
448 |
n_batch = z.shape[0]
|
449 |
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
450 |
|
451 |
-
n_prefix, n_suffix = sample_prefix_suffix_amt(
|
452 |
n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
|
453 |
prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
|
454 |
rng=rng
|
@@ -606,7 +607,7 @@ def train(
|
|
606 |
|
607 |
n_batch = z.shape[0]
|
608 |
|
609 |
-
n_prefix, n_suffix = sample_prefix_suffix_amt(
|
610 |
n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
|
611 |
prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
|
612 |
rng=rng
|
|
|
216 |
return accuracy
|
217 |
|
218 |
def sample_prefix_suffix_amt(
|
219 |
+
z,
|
220 |
n_batch,
|
221 |
prefix_amt,
|
222 |
suffix_amt,
|
|
|
363 |
n_batch = z.shape[0]
|
364 |
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
365 |
|
366 |
+
n_prefix, n_suffix = sample_prefix_suffix_amt(z=z,
|
367 |
n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
|
368 |
prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
|
369 |
rng=rng
|
|
|
449 |
n_batch = z.shape[0]
|
450 |
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
451 |
|
452 |
+
n_prefix, n_suffix = sample_prefix_suffix_amt(z=z,
|
453 |
n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
|
454 |
prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
|
455 |
rng=rng
|
|
|
607 |
|
608 |
n_batch = z.shape[0]
|
609 |
|
610 |
+
n_prefix, n_suffix = sample_prefix_suffix_amt(z=z,
|
611 |
n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
|
612 |
prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
|
613 |
rng=rng
|