kevinwang676
commited on
Commit
•
0fc73ad
1
Parent(s):
d7de8f4
Update bark/generation.py
Browse files- bark/generation.py +9 -11
bark/generation.py
CHANGED
@@ -494,8 +494,7 @@ def generate_text_semantic(
|
|
494 |
)
|
495 |
if top_p is not None:
|
496 |
# faster to convert to numpy
|
497 |
-
|
498 |
-
logits_dtype = relevant_logits.type()
|
499 |
relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
|
500 |
sorted_indices = np.argsort(relevant_logits)[::-1]
|
501 |
sorted_logits = relevant_logits[sorted_indices]
|
@@ -505,7 +504,7 @@ def generate_text_semantic(
|
|
505 |
sorted_indices_to_remove[0] = False
|
506 |
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
|
507 |
relevant_logits = torch.from_numpy(relevant_logits)
|
508 |
-
relevant_logits = relevant_logits.to(
|
509 |
if top_k is not None:
|
510 |
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
511 |
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
@@ -599,10 +598,10 @@ def generate_coarse(
|
|
599 |
and x_coarse_history.shape[-1] >= 0
|
600 |
and x_coarse_history.min() >= 0
|
601 |
and x_coarse_history.max() <= CODEBOOK_SIZE - 1
|
602 |
-
and (
|
603 |
-
|
604 |
-
|
605 |
-
)
|
606 |
)
|
607 |
x_coarse_history = _flatten_codebooks(x_coarse_history) + SEMANTIC_VOCAB_SIZE
|
608 |
# trim histories correctly
|
@@ -685,8 +684,7 @@ def generate_coarse(
|
|
685 |
relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx]
|
686 |
if top_p is not None:
|
687 |
# faster to convert to numpy
|
688 |
-
|
689 |
-
logits_dtype = relevant_logits.type()
|
690 |
relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
|
691 |
sorted_indices = np.argsort(relevant_logits)[::-1]
|
692 |
sorted_logits = relevant_logits[sorted_indices]
|
@@ -696,7 +694,7 @@ def generate_coarse(
|
|
696 |
sorted_indices_to_remove[0] = False
|
697 |
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
|
698 |
relevant_logits = torch.from_numpy(relevant_logits)
|
699 |
-
relevant_logits = relevant_logits.to(
|
700 |
if top_k is not None:
|
701 |
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
702 |
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
@@ -862,4 +860,4 @@ def codec_decode(fine_tokens):
|
|
862 |
del arr, emb, out
|
863 |
if OFFLOAD_CPU:
|
864 |
model.to("cpu")
|
865 |
-
return audio_arr
|
|
|
494 |
)
|
495 |
if top_p is not None:
|
496 |
# faster to convert to numpy
|
497 |
+
original_device = relevant_logits.device
|
|
|
498 |
relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
|
499 |
sorted_indices = np.argsort(relevant_logits)[::-1]
|
500 |
sorted_logits = relevant_logits[sorted_indices]
|
|
|
504 |
sorted_indices_to_remove[0] = False
|
505 |
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
|
506 |
relevant_logits = torch.from_numpy(relevant_logits)
|
507 |
+
relevant_logits = relevant_logits.to(original_device)
|
508 |
if top_k is not None:
|
509 |
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
510 |
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
|
|
598 |
and x_coarse_history.shape[-1] >= 0
|
599 |
and x_coarse_history.min() >= 0
|
600 |
and x_coarse_history.max() <= CODEBOOK_SIZE - 1
|
601 |
+
#and (
|
602 |
+
# round(x_coarse_history.shape[-1] / len(x_semantic_history), 1)
|
603 |
+
# == round(semantic_to_coarse_ratio / N_COARSE_CODEBOOKS, 1)
|
604 |
+
#)
|
605 |
)
|
606 |
x_coarse_history = _flatten_codebooks(x_coarse_history) + SEMANTIC_VOCAB_SIZE
|
607 |
# trim histories correctly
|
|
|
684 |
relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx]
|
685 |
if top_p is not None:
|
686 |
# faster to convert to numpy
|
687 |
+
original_device = relevant_logits.device
|
|
|
688 |
relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
|
689 |
sorted_indices = np.argsort(relevant_logits)[::-1]
|
690 |
sorted_logits = relevant_logits[sorted_indices]
|
|
|
694 |
sorted_indices_to_remove[0] = False
|
695 |
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
|
696 |
relevant_logits = torch.from_numpy(relevant_logits)
|
697 |
+
relevant_logits = relevant_logits.to(original_device)
|
698 |
if top_k is not None:
|
699 |
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
700 |
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
|
|
860 |
del arr, emb, out
|
861 |
if OFFLOAD_CPU:
|
862 |
model.to("cpu")
|
863 |
+
return audio_arr
|