Spaces:
Running
on
T4
Running
on
T4
Hugo Flores Garcia
commited on
Commit
•
4d0cbfe
1
Parent(s):
85e8a86
tiny sampling refactor
Browse files- vampnet/modules/transformer.py +55 -54
vampnet/modules/transformer.py
CHANGED
@@ -724,7 +724,7 @@ class VampNet(at.ml.BaseModel):
|
|
724 |
|
725 |
logits = torch.log(probs)
|
726 |
|
727 |
-
z_inferred =
|
728 |
logits=logits,
|
729 |
top_k=top_k,
|
730 |
temperature=tmpt,
|
@@ -742,61 +742,60 @@ class VampNet(at.ml.BaseModel):
|
|
742 |
else:
|
743 |
return z
|
744 |
|
745 |
-
|
746 |
-
|
747 |
-
|
748 |
-
|
749 |
-
|
750 |
-
|
751 |
-
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
)
|
776 |
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
|
785 |
-
|
786 |
-
|
787 |
-
|
788 |
-
|
789 |
-
if sample == "multinomial":
|
790 |
-
probs = torch.softmax(logits, dim=-1)
|
791 |
-
inferred = torch.stack([pr.multinomial(1).squeeze(-1) for pr in probs])
|
792 |
-
elif sample == "argmax":
|
793 |
-
inferred = torch.softmax(logits, dim=-1).argmax(dim=-1)
|
794 |
-
elif sample == "gumbel":
|
795 |
-
inferred = gumbel_sample(logits, dim=-1)
|
796 |
-
else:
|
797 |
-
raise ValueError(f"invalid sampling method: {sample}")
|
798 |
|
799 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
800 |
|
801 |
|
802 |
|
@@ -833,3 +832,5 @@ if __name__ == "__main__":
|
|
833 |
args = argbind.parse_args()
|
834 |
with argbind.scope(args):
|
835 |
try_model()
|
|
|
|
|
|
724 |
|
725 |
logits = torch.log(probs)
|
726 |
|
727 |
+
z_inferred = sample_from_logits(
|
728 |
logits=logits,
|
729 |
top_k=top_k,
|
730 |
temperature=tmpt,
|
|
|
742 |
else:
|
743 |
return z
|
744 |
|
745 |
+
def sample_from_logits(
|
746 |
+
logits,
|
747 |
+
top_k: int = None,
|
748 |
+
temperature: float = 1.0,
|
749 |
+
sample: str = "multinomial",
|
750 |
+
typical_filtering=False,
|
751 |
+
typical_mass=0.2,
|
752 |
+
typical_min_tokens=1,
|
753 |
+
):
|
754 |
+
# add temperature
|
755 |
+
logits = logits / temperature
|
756 |
+
|
757 |
+
# add topk
|
758 |
+
if top_k is not None and typical_filtering == False:
|
759 |
+
v, topk_idx = logits.topk(top_k)
|
760 |
+
logits[logits < v[..., [-1]]] = -float("inf")
|
761 |
+
|
762 |
+
if typical_filtering:
|
763 |
+
assert top_k is None
|
764 |
+
nb, nt, _ = logits.shape
|
765 |
+
x_flat = rearrange(logits, "b t l -> (b t ) l")
|
766 |
+
x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
|
767 |
+
x_flat_norm_p = torch.exp(x_flat_norm)
|
768 |
+
entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)
|
769 |
+
|
770 |
+
c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
|
771 |
+
c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
|
772 |
+
x_flat_cumsum = (
|
773 |
+
x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
|
774 |
+
)
|
|
|
775 |
|
776 |
+
last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
|
777 |
+
sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(
|
778 |
+
1, last_ind.view(-1, 1)
|
779 |
+
)
|
780 |
+
if typical_min_tokens > 1:
|
781 |
+
sorted_indices_to_remove[..., :typical_min_tokens] = 0
|
782 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
783 |
+
1, x_flat_indices, sorted_indices_to_remove
|
784 |
+
)
|
785 |
+
x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
|
786 |
+
logits = rearrange(x_flat, "(b t) l -> b t l", t=nt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
787 |
|
788 |
+
if sample == "multinomial":
|
789 |
+
probs = torch.softmax(logits, dim=-1)
|
790 |
+
inferred = torch.stack([pr.multinomial(1).squeeze(-1) for pr in probs])
|
791 |
+
elif sample == "argmax":
|
792 |
+
inferred = torch.softmax(logits, dim=-1).argmax(dim=-1)
|
793 |
+
elif sample == "gumbel":
|
794 |
+
inferred = gumbel_sample(logits, dim=-1)
|
795 |
+
else:
|
796 |
+
raise ValueError(f"invalid sampling method: {sample}")
|
797 |
+
|
798 |
+
return inferred
|
799 |
|
800 |
|
801 |
|
|
|
832 |
args = argbind.parse_args()
|
833 |
with argbind.scope(args):
|
834 |
try_model()
|
835 |
+
|
836 |
+
|