0308-022448-Synchronize_GitHub_update_improve_inference_speed
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- AR/__pycache__/__init__.cpython-310.pyc +0 -0
- AR/data/bucket_sampler.py +2 -1
- AR/data/data_module.py +4 -2
- AR/data/dataset.py +2 -1
- AR/models/__pycache__/__init__.cpython-310.pyc +0 -0
- AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc +0 -0
- AR/models/__pycache__/t2s_model.cpython-310.pyc +0 -0
- AR/models/__pycache__/utils.cpython-310.pyc +0 -0
- AR/models/t2s_lightning_module.py +4 -3
- AR/models/t2s_lightning_module_onnx.py +2 -1
- AR/models/t2s_model.py +165 -44
- AR/models/t2s_model_onnx.py +2 -1
- AR/models/utils.py +72 -3
- AR/modules/__pycache__/__init__.cpython-310.pyc +0 -0
- AR/modules/__pycache__/activation.cpython-310.pyc +0 -0
- AR/modules/__pycache__/embedding.cpython-310.pyc +0 -0
- AR/modules/__pycache__/lr_schedulers.cpython-310.pyc +0 -0
- AR/modules/__pycache__/optim.cpython-310.pyc +0 -0
- AR/modules/__pycache__/patched_mha_with_cache.cpython-310.pyc +0 -0
- AR/modules/__pycache__/scaling.cpython-310.pyc +0 -0
- AR/modules/__pycache__/transformer.cpython-310.pyc +0 -0
- AR/modules/lr_schedulers.py +2 -1
- AR/modules/patched_mha_with_cache.py +4 -2
- AR/modules/scaling.py +1 -1
- AR/text_processing/phonemizer.py +2 -1
- AR/text_processing/symbols.py +2 -1
- MODELS/21/1.mp3 +0 -0
- MODELS/21/11.mp3 +0 -0
- MODELS/21/191.mp3 +0 -0
- MODELS/21/21.ckpt +0 -3
- MODELS/21/21.pth +0 -3
- MODELS/21/s1.mp3 +0 -0
- MODELS/21/s2.mp3 +0 -0
- MODELS/21/s3.mp3 +0 -0
- MODELS/22/22.ckpt +0 -3
- MODELS/22/22.pth +0 -3
- MODELS/22/passion.mp3 +0 -0
- MODELS/22/s1.mp3 +0 -0
- MODELS/22/s2.mp3 +0 -0
- MODELS/22/s3.mp3 +0 -0
- MODELS/22/slow_calm.mp3 +0 -0
- MODELS/22/speed.mp3 +0 -0
- MODELS/31/1.mp3 +0 -0
- MODELS/31/148.mp3 +0 -0
- MODELS/31/31.ckpt +0 -3
- MODELS/31/31.pth +0 -3
- MODELS/31/96.mp3 +0 -0
- MODELS/31/s1.mp3 +0 -0
- MODELS/31/s2.mp3 +0 -0
- MODELS/31/s3.mp3 +0 -0
AR/__pycache__/__init__.cpython-310.pyc
CHANGED
Binary files a/AR/__pycache__/__init__.cpython-310.pyc and b/AR/__pycache__/__init__.cpython-310.pyc differ
|
|
AR/data/bucket_sampler.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
# modified from https://github.com/
|
|
|
2 |
import itertools
|
3 |
import math
|
4 |
import random
|
|
|
1 |
+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/bucket_sampler.py
|
2 |
+
# reference: https://github.com/lifeiteng/vall-e
|
3 |
import itertools
|
4 |
import math
|
5 |
import random
|
AR/data/data_module.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
# modified from https://github.com/
|
|
|
2 |
from pytorch_lightning import LightningDataModule
|
3 |
from AR.data.bucket_sampler import DistributedBucketSampler
|
4 |
from AR.data.dataset import Text2SemanticDataset
|
@@ -41,7 +42,8 @@ class Text2SemanticDataModule(LightningDataModule):
|
|
41 |
# pad_val=self.config['data']['pad_val'])
|
42 |
|
43 |
def train_dataloader(self):
|
44 |
-
batch_size
|
|
|
45 |
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
|
46 |
return DataLoader(
|
47 |
self._train_dataset,
|
|
|
1 |
+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py
|
2 |
+
# reference: https://github.com/lifeiteng/vall-e
|
3 |
from pytorch_lightning import LightningDataModule
|
4 |
from AR.data.bucket_sampler import DistributedBucketSampler
|
5 |
from AR.data.dataset import Text2SemanticDataset
|
|
|
42 |
# pad_val=self.config['data']['pad_val'])
|
43 |
|
44 |
def train_dataloader(self):
|
45 |
+
batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"]
|
46 |
+
batch_size = max(min(batch_size,len(self._train_dataset)//4),1)#防止不保存
|
47 |
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
|
48 |
return DataLoader(
|
49 |
self._train_dataset,
|
AR/data/dataset.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
# modified from https://github.com/
|
|
|
2 |
import pdb
|
3 |
import sys
|
4 |
|
|
|
1 |
+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/dataset.py
|
2 |
+
# reference: https://github.com/lifeiteng/vall-e
|
3 |
import pdb
|
4 |
import sys
|
5 |
|
AR/models/__pycache__/__init__.cpython-310.pyc
CHANGED
Binary files a/AR/models/__pycache__/__init__.cpython-310.pyc and b/AR/models/__pycache__/__init__.cpython-310.pyc differ
|
|
AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc
CHANGED
Binary files a/AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc and b/AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc differ
|
|
AR/models/__pycache__/t2s_model.cpython-310.pyc
CHANGED
Binary files a/AR/models/__pycache__/t2s_model.cpython-310.pyc and b/AR/models/__pycache__/t2s_model.cpython-310.pyc differ
|
|
AR/models/__pycache__/utils.cpython-310.pyc
CHANGED
Binary files a/AR/models/__pycache__/utils.cpython-310.pyc and b/AR/models/__pycache__/utils.cpython-310.pyc differ
|
|
AR/models/t2s_lightning_module.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
# modified from https://github.com/
|
|
|
2 |
import os, sys
|
3 |
|
4 |
now_dir = os.getcwd()
|
@@ -11,7 +12,6 @@ from AR.models.t2s_model import Text2SemanticDecoder
|
|
11 |
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
12 |
from AR.modules.optim import ScaledAdam
|
13 |
|
14 |
-
|
15 |
class Text2SemanticLightningModule(LightningModule):
|
16 |
def __init__(self, config, output_dir, is_train=True):
|
17 |
super().__init__()
|
@@ -35,7 +35,8 @@ class Text2SemanticLightningModule(LightningModule):
|
|
35 |
def training_step(self, batch: Dict, batch_idx: int):
|
36 |
opt = self.optimizers()
|
37 |
scheduler = self.lr_schedulers()
|
38 |
-
|
|
|
39 |
batch["phoneme_ids"],
|
40 |
batch["phoneme_ids_len"],
|
41 |
batch["semantic_ids"],
|
|
|
1 |
+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
|
2 |
+
# reference: https://github.com/lifeiteng/vall-e
|
3 |
import os, sys
|
4 |
|
5 |
now_dir = os.getcwd()
|
|
|
12 |
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
13 |
from AR.modules.optim import ScaledAdam
|
14 |
|
|
|
15 |
class Text2SemanticLightningModule(LightningModule):
|
16 |
def __init__(self, config, output_dir, is_train=True):
|
17 |
super().__init__()
|
|
|
35 |
def training_step(self, batch: Dict, batch_idx: int):
|
36 |
opt = self.optimizers()
|
37 |
scheduler = self.lr_schedulers()
|
38 |
+
forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old
|
39 |
+
loss, acc = forward(
|
40 |
batch["phoneme_ids"],
|
41 |
batch["phoneme_ids_len"],
|
42 |
batch["semantic_ids"],
|
AR/models/t2s_lightning_module_onnx.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
# modified from https://github.com/
|
|
|
2 |
import os, sys
|
3 |
|
4 |
now_dir = os.getcwd()
|
|
|
1 |
+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
|
2 |
+
# reference: https://github.com/lifeiteng/vall-e
|
3 |
import os, sys
|
4 |
|
5 |
now_dir = os.getcwd()
|
AR/models/t2s_model.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
# modified from https://github.com/
|
|
|
2 |
import torch
|
3 |
from tqdm import tqdm
|
4 |
|
@@ -8,6 +9,9 @@ from AR.models.utils import (
|
|
8 |
sample,
|
9 |
logits_to_probs,
|
10 |
multinomial_sample_one_no_sync,
|
|
|
|
|
|
|
11 |
)
|
12 |
from AR.modules.embedding import SinePositionalEmbedding
|
13 |
from AR.modules.embedding import TokenEmbedding
|
@@ -85,11 +89,104 @@ class Text2SemanticDecoder(nn.Module):
|
|
85 |
ignore_index=self.EOS,
|
86 |
)
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
def forward(self, x, x_lens, y, y_lens, bert_feature):
|
89 |
"""
|
90 |
x: phoneme_ids
|
91 |
y: semantic_ids
|
92 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
x = self.ar_text_embedding(x)
|
94 |
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
95 |
x = self.ar_text_position(x)
|
@@ -231,6 +328,7 @@ class Text2SemanticDecoder(nn.Module):
|
|
231 |
prompts, ####参考音频token
|
232 |
bert_feature,
|
233 |
top_k: int = -100,
|
|
|
234 |
early_stop_num: int = -1,
|
235 |
temperature: float = 1.0,
|
236 |
):
|
@@ -240,7 +338,7 @@ class Text2SemanticDecoder(nn.Module):
|
|
240 |
|
241 |
# AR Decoder
|
242 |
y = prompts
|
243 |
-
|
244 |
x_len = x.shape[1]
|
245 |
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
246 |
stop = False
|
@@ -256,47 +354,41 @@ class Text2SemanticDecoder(nn.Module):
|
|
256 |
"first_infer": 1,
|
257 |
"stage": 0,
|
258 |
}
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
|
265 |
-
)
|
266 |
-
cache["y_emb"] = y_emb
|
267 |
y_pos = self.ar_audio_position(y_emb)
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
y_len =
|
274 |
-
|
275 |
-
|
276 |
-
|
|
|
|
|
|
|
|
|
277 |
x_attn_mask,
|
278 |
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
279 |
value=True,
|
280 |
)
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
###最下面一行(是对的)
|
294 |
-
xy_attn_mask = torch.zeros(
|
295 |
-
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
|
296 |
-
)
|
297 |
-
# pdb.set_trace()
|
298 |
-
###缓存重头戏
|
299 |
-
# print(1111,xy_pos.shape,xy_attn_mask.shape,x_len,y_len)
|
300 |
xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
|
301 |
logits = self.ar_predict_layer(
|
302 |
xy_dec[:, -1]
|
@@ -305,8 +397,12 @@ class Text2SemanticDecoder(nn.Module):
|
|
305 |
if(idx==0):###第一次跑不能EOS否则没有了
|
306 |
logits = logits[:, :-1] ###刨除1024终止符号的概率
|
307 |
samples = sample(
|
308 |
-
logits[0], y, top_k=top_k, top_p=
|
309 |
)[0].unsqueeze(0)
|
|
|
|
|
|
|
|
|
310 |
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
311 |
print("use early stop num:", early_stop_num)
|
312 |
stop = True
|
@@ -315,13 +411,38 @@ class Text2SemanticDecoder(nn.Module):
|
|
315 |
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
|
316 |
stop = True
|
317 |
if stop:
|
318 |
-
if prompts.shape[1] == y.shape[1]:
|
|
|
|
|
|
|
319 |
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
320 |
print("bad zero prediction")
|
321 |
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
322 |
break
|
323 |
-
|
324 |
-
|
325 |
-
y = torch.concat([y, samples], dim=1)
|
326 |
cache["first_infer"] = 0
|
327 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
2 |
+
# reference: https://github.com/lifeiteng/vall-e
|
3 |
import torch
|
4 |
from tqdm import tqdm
|
5 |
|
|
|
9 |
sample,
|
10 |
logits_to_probs,
|
11 |
multinomial_sample_one_no_sync,
|
12 |
+
dpo_loss,
|
13 |
+
make_reject_y,
|
14 |
+
get_batch_logps
|
15 |
)
|
16 |
from AR.modules.embedding import SinePositionalEmbedding
|
17 |
from AR.modules.embedding import TokenEmbedding
|
|
|
89 |
ignore_index=self.EOS,
|
90 |
)
|
91 |
|
92 |
+
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
|
93 |
+
x = self.ar_text_embedding(x)
|
94 |
+
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
95 |
+
x = self.ar_text_position(x)
|
96 |
+
x_mask = make_pad_mask(x_lens)
|
97 |
+
|
98 |
+
y_mask = make_pad_mask(y_lens)
|
99 |
+
y_mask_int = y_mask.type(torch.int64)
|
100 |
+
codes = y.type(torch.int64) * (1 - y_mask_int)
|
101 |
+
|
102 |
+
# Training
|
103 |
+
# AR Decoder
|
104 |
+
y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
|
105 |
+
x_len = x_lens.max()
|
106 |
+
y_len = y_lens.max()
|
107 |
+
y_emb = self.ar_audio_embedding(y)
|
108 |
+
y_pos = self.ar_audio_position(y_emb)
|
109 |
+
|
110 |
+
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
|
111 |
+
|
112 |
+
ar_xy_padding_mask = xy_padding_mask
|
113 |
+
|
114 |
+
x_attn_mask = F.pad(
|
115 |
+
torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
|
116 |
+
(0, y_len),
|
117 |
+
value=True,
|
118 |
+
)
|
119 |
+
|
120 |
+
y_attn_mask = F.pad(
|
121 |
+
torch.triu(
|
122 |
+
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
|
123 |
+
diagonal=1,
|
124 |
+
),
|
125 |
+
(x_len, 0),
|
126 |
+
value=False,
|
127 |
+
)
|
128 |
+
|
129 |
+
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
|
130 |
+
bsz, src_len = x.shape[0], x_len + y_len
|
131 |
+
_xy_padding_mask = (
|
132 |
+
ar_xy_padding_mask.view(bsz, 1, 1, src_len)
|
133 |
+
.expand(-1, self.num_head, -1, -1)
|
134 |
+
.reshape(bsz * self.num_head, 1, src_len)
|
135 |
+
)
|
136 |
+
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
|
137 |
+
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
|
138 |
+
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
|
139 |
+
xy_attn_mask = new_attn_mask
|
140 |
+
# x 和完整的 y 一次性输入模型
|
141 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
142 |
+
|
143 |
+
return xy_pos, xy_attn_mask, targets
|
144 |
+
|
145 |
def forward(self, x, x_lens, y, y_lens, bert_feature):
|
146 |
"""
|
147 |
x: phoneme_ids
|
148 |
y: semantic_ids
|
149 |
"""
|
150 |
+
|
151 |
+
reject_y, reject_y_lens = make_reject_y(y, y_lens)
|
152 |
+
|
153 |
+
xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature)
|
154 |
+
|
155 |
+
xy_dec, _ = self.h(
|
156 |
+
(xy_pos, None),
|
157 |
+
mask=xy_attn_mask,
|
158 |
+
)
|
159 |
+
x_len = x_lens.max()
|
160 |
+
logits = self.ar_predict_layer(xy_dec[:, x_len:])
|
161 |
+
|
162 |
+
###### DPO #############
|
163 |
+
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
|
164 |
+
|
165 |
+
reject_xy_dec, _ = self.h(
|
166 |
+
(reject_xy_pos, None),
|
167 |
+
mask=reject_xy_attn_mask,
|
168 |
+
)
|
169 |
+
x_len = x_lens.max()
|
170 |
+
reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len:])
|
171 |
+
|
172 |
+
# loss
|
173 |
+
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
|
174 |
+
|
175 |
+
loss_1 = F.cross_entropy(logits.permute(0, 2, 1), targets, reduction="sum")
|
176 |
+
acc = self.ar_accuracy_metric(logits.permute(0, 2, 1).detach(), targets).item()
|
177 |
+
|
178 |
+
A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
|
179 |
+
loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
|
180 |
+
|
181 |
+
loss = loss_1 + loss_2
|
182 |
+
|
183 |
+
return loss, acc
|
184 |
+
|
185 |
+
def forward_old(self, x, x_lens, y, y_lens, bert_feature):
|
186 |
+
"""
|
187 |
+
x: phoneme_ids
|
188 |
+
y: semantic_ids
|
189 |
+
"""
|
190 |
x = self.ar_text_embedding(x)
|
191 |
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
192 |
x = self.ar_text_position(x)
|
|
|
328 |
prompts, ####参考音频token
|
329 |
bert_feature,
|
330 |
top_k: int = -100,
|
331 |
+
top_p: int = 100,
|
332 |
early_stop_num: int = -1,
|
333 |
temperature: float = 1.0,
|
334 |
):
|
|
|
338 |
|
339 |
# AR Decoder
|
340 |
y = prompts
|
341 |
+
|
342 |
x_len = x.shape[1]
|
343 |
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
344 |
stop = False
|
|
|
354 |
"first_infer": 1,
|
355 |
"stage": 0,
|
356 |
}
|
357 |
+
################### first step ##########################
|
358 |
+
if y is not None:
|
359 |
+
y_emb = self.ar_audio_embedding(y)
|
360 |
+
y_len = y_emb.shape[1]
|
361 |
+
prefix_len = y.shape[1]
|
|
|
|
|
|
|
362 |
y_pos = self.ar_audio_position(y_emb)
|
363 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
364 |
+
cache["y_emb"] = y_emb
|
365 |
+
ref_free = False
|
366 |
+
else:
|
367 |
+
y_emb = None
|
368 |
+
y_len = 0
|
369 |
+
prefix_len = 0
|
370 |
+
y_pos = None
|
371 |
+
xy_pos = x
|
372 |
+
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
|
373 |
+
ref_free = True
|
374 |
+
|
375 |
+
x_attn_mask_pad = F.pad(
|
376 |
x_attn_mask,
|
377 |
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
378 |
value=True,
|
379 |
)
|
380 |
+
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
381 |
+
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
382 |
+
(x_len, 0),
|
383 |
+
value=False,
|
384 |
+
)
|
385 |
+
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
|
386 |
+
x.device
|
387 |
+
)
|
388 |
+
|
389 |
+
|
390 |
+
for idx in tqdm(range(1500)):
|
391 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
|
393 |
logits = self.ar_predict_layer(
|
394 |
xy_dec[:, -1]
|
|
|
397 |
if(idx==0):###第一次跑不能EOS否则没有了
|
398 |
logits = logits[:, :-1] ###刨除1024终止符号的概率
|
399 |
samples = sample(
|
400 |
+
logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
|
401 |
)[0].unsqueeze(0)
|
402 |
+
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
|
403 |
+
# print(samples.shape)#[1,1]#第一个1是bs
|
404 |
+
y = torch.concat([y, samples], dim=1)
|
405 |
+
|
406 |
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
407 |
print("use early stop num:", early_stop_num)
|
408 |
stop = True
|
|
|
411 |
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
|
412 |
stop = True
|
413 |
if stop:
|
414 |
+
# if prompts.shape[1] == y.shape[1]:
|
415 |
+
# y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
416 |
+
# print("bad zero prediction")
|
417 |
+
if y.shape[1]==0:
|
418 |
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
419 |
print("bad zero prediction")
|
420 |
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
421 |
break
|
422 |
+
|
423 |
+
####################### update next step ###################################
|
|
|
424 |
cache["first_infer"] = 0
|
425 |
+
if cache["y_emb"] is not None:
|
426 |
+
y_emb = torch.cat(
|
427 |
+
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim = 1
|
428 |
+
)
|
429 |
+
cache["y_emb"] = y_emb
|
430 |
+
y_pos = self.ar_audio_position(y_emb)
|
431 |
+
xy_pos = y_pos[:, -1:]
|
432 |
+
else:
|
433 |
+
y_emb = self.ar_audio_embedding(y[:, -1:])
|
434 |
+
cache["y_emb"] = y_emb
|
435 |
+
y_pos = self.ar_audio_position(y_emb)
|
436 |
+
xy_pos = y_pos
|
437 |
+
y_len = y_pos.shape[1]
|
438 |
+
|
439 |
+
###最右边一列(是错的)
|
440 |
+
# xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
|
441 |
+
# xy_attn_mask[:,-1]=False
|
442 |
+
###最下面一行(是对的)
|
443 |
+
xy_attn_mask = torch.zeros(
|
444 |
+
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
|
445 |
+
)
|
446 |
+
if ref_free:
|
447 |
+
return y[:, :-1], 0
|
448 |
+
return y[:, :-1], idx-1
|
AR/models/t2s_model_onnx.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
# modified from https://github.com/
|
|
|
2 |
import torch
|
3 |
from tqdm import tqdm
|
4 |
|
|
|
1 |
+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
2 |
+
# reference: https://github.com/lifeiteng/vall-e
|
3 |
import torch
|
4 |
from tqdm import tqdm
|
5 |
|
AR/models/utils.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
-
# modified from https://github.com/
|
|
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
4 |
-
|
5 |
|
6 |
def sequence_mask(length, max_length=None):
|
7 |
if max_length is None:
|
@@ -114,7 +115,8 @@ def logits_to_probs(
|
|
114 |
top_p: Optional[int] = None,
|
115 |
repetition_penalty: float = 1.0,
|
116 |
):
|
117 |
-
previous_tokens
|
|
|
118 |
# print(logits.shape,previous_tokens.shape)
|
119 |
# pdb.set_trace()
|
120 |
if previous_tokens is not None and repetition_penalty != 1.0:
|
@@ -158,3 +160,70 @@ def sample(
|
|
158 |
)
|
159 |
idx_next = multinomial_sample_one_no_sync(probs)
|
160 |
return idx_next, probs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/utils.py
|
2 |
+
# reference: https://github.com/lifeiteng/vall-e
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
5 |
+
from typing import Tuple
|
6 |
|
7 |
def sequence_mask(length, max_length=None):
|
8 |
if max_length is None:
|
|
|
115 |
top_p: Optional[int] = None,
|
116 |
repetition_penalty: float = 1.0,
|
117 |
):
|
118 |
+
if previous_tokens is not None:
|
119 |
+
previous_tokens = previous_tokens.squeeze()
|
120 |
# print(logits.shape,previous_tokens.shape)
|
121 |
# pdb.set_trace()
|
122 |
if previous_tokens is not None and repetition_penalty != 1.0:
|
|
|
160 |
)
|
161 |
idx_next = multinomial_sample_one_no_sync(probs)
|
162 |
return idx_next, probs
|
163 |
+
|
164 |
+
def dpo_loss(policy_chosen_logps: torch.FloatTensor,
|
165 |
+
policy_rejected_logps: torch.FloatTensor,
|
166 |
+
reference_chosen_logps: torch.FloatTensor,
|
167 |
+
reference_rejected_logps: torch.FloatTensor,
|
168 |
+
beta: float,
|
169 |
+
reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
170 |
+
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
171 |
+
ref_logratios = reference_chosen_logps - reference_rejected_logps
|
172 |
+
|
173 |
+
if reference_free:
|
174 |
+
ref_logratios = 0
|
175 |
+
|
176 |
+
logits = pi_logratios - ref_logratios
|
177 |
+
|
178 |
+
losses = -F.logsigmoid(beta * logits)
|
179 |
+
chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach()
|
180 |
+
rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach()
|
181 |
+
|
182 |
+
return losses.mean(), chosen_rewards, rejected_rewards
|
183 |
+
|
184 |
+
def get_batch_logps(logits_target: torch.FloatTensor, logits_reject: torch.FloatTensor, labels_target: torch.LongTensor, labels_reject: torch.LongTensor, average_log_prob: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
185 |
+
|
186 |
+
# dummy token; we'll ignore the losses on these tokens later
|
187 |
+
|
188 |
+
per_token_logps_target = torch.gather(logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)).squeeze(2)
|
189 |
+
per_token_logps_reject = torch.gather(logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)).squeeze(2)
|
190 |
+
|
191 |
+
return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1)
|
192 |
+
|
193 |
+
def make_reject_y(y_o, y_lens):
|
194 |
+
def repeat_P(y):
|
195 |
+
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
|
196 |
+
pre = y[:range_idx[0]]
|
197 |
+
shf = y[range_idx[1]:]
|
198 |
+
range_text = y[range_idx[0]:range_idx[1]]
|
199 |
+
new_y = torch.cat([pre, range_text, range_text, shf])
|
200 |
+
return new_y
|
201 |
+
def lost_P(y):
|
202 |
+
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
|
203 |
+
pre = y[:range_idx[0]]
|
204 |
+
shf = y[range_idx[1]:]
|
205 |
+
range_text = y[range_idx[0]:range_idx[1]]
|
206 |
+
new_y = torch.cat([pre, shf])
|
207 |
+
return new_y
|
208 |
+
bs = len(y_lens)
|
209 |
+
reject_y = []
|
210 |
+
reject_y_lens = []
|
211 |
+
for b in range(bs):
|
212 |
+
process_item_idx = torch.randint(0, 1, size=(1, ))[0]
|
213 |
+
if process_item_idx == 0:
|
214 |
+
new_y = repeat_P(y_o[b])
|
215 |
+
reject_y.append(new_y)
|
216 |
+
reject_y_lens.append(len(new_y))
|
217 |
+
elif process_item_idx==1:
|
218 |
+
new_y = lost_P(y_o[b])
|
219 |
+
reject_y.append(new_y)
|
220 |
+
reject_y_lens.append(len(new_y))
|
221 |
+
max_length = max(reject_y_lens)
|
222 |
+
for b in range(bs):
|
223 |
+
pad_length = max_length - reject_y_lens[b]
|
224 |
+
reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0)
|
225 |
+
|
226 |
+
reject_y = torch.stack(reject_y, dim = 0)
|
227 |
+
reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device)
|
228 |
+
|
229 |
+
return reject_y, reject_y_lens
|
AR/modules/__pycache__/__init__.cpython-310.pyc
CHANGED
Binary files a/AR/modules/__pycache__/__init__.cpython-310.pyc and b/AR/modules/__pycache__/__init__.cpython-310.pyc differ
|
|
AR/modules/__pycache__/activation.cpython-310.pyc
CHANGED
Binary files a/AR/modules/__pycache__/activation.cpython-310.pyc and b/AR/modules/__pycache__/activation.cpython-310.pyc differ
|
|
AR/modules/__pycache__/embedding.cpython-310.pyc
CHANGED
Binary files a/AR/modules/__pycache__/embedding.cpython-310.pyc and b/AR/modules/__pycache__/embedding.cpython-310.pyc differ
|
|
AR/modules/__pycache__/lr_schedulers.cpython-310.pyc
CHANGED
Binary files a/AR/modules/__pycache__/lr_schedulers.cpython-310.pyc and b/AR/modules/__pycache__/lr_schedulers.cpython-310.pyc differ
|
|
AR/modules/__pycache__/optim.cpython-310.pyc
CHANGED
Binary files a/AR/modules/__pycache__/optim.cpython-310.pyc and b/AR/modules/__pycache__/optim.cpython-310.pyc differ
|
|
AR/modules/__pycache__/patched_mha_with_cache.cpython-310.pyc
CHANGED
Binary files a/AR/modules/__pycache__/patched_mha_with_cache.cpython-310.pyc and b/AR/modules/__pycache__/patched_mha_with_cache.cpython-310.pyc differ
|
|
AR/modules/__pycache__/scaling.cpython-310.pyc
CHANGED
Binary files a/AR/modules/__pycache__/scaling.cpython-310.pyc and b/AR/modules/__pycache__/scaling.cpython-310.pyc differ
|
|
AR/modules/__pycache__/transformer.cpython-310.pyc
CHANGED
Binary files a/AR/modules/__pycache__/transformer.cpython-310.pyc and b/AR/modules/__pycache__/transformer.cpython-310.pyc differ
|
|
AR/modules/lr_schedulers.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
# modified from https://github.com/
|
|
|
2 |
import math
|
3 |
|
4 |
import torch
|
|
|
1 |
+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/modules/lr_schedulers.py
|
2 |
+
# reference: https://github.com/lifeiteng/vall-e
|
3 |
import math
|
4 |
|
5 |
import torch
|
AR/modules/patched_mha_with_cache.py
CHANGED
@@ -5,8 +5,8 @@ from torch.nn.functional import (
|
|
5 |
_none_or_dtype,
|
6 |
_in_projection_packed,
|
7 |
)
|
8 |
-
|
9 |
-
|
10 |
# Tensor = torch.Tensor
|
11 |
# from typing import Callable, List, Optional, Tuple, Union
|
12 |
|
@@ -448,9 +448,11 @@ def multi_head_attention_forward_patched(
|
|
448 |
k = k.view(bsz, num_heads, src_len, head_dim)
|
449 |
v = v.view(bsz, num_heads, src_len, head_dim)
|
450 |
|
|
|
451 |
attn_output = scaled_dot_product_attention(
|
452 |
q, k, v, attn_mask, dropout_p, is_causal
|
453 |
)
|
|
|
454 |
attn_output = (
|
455 |
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
456 |
)
|
|
|
5 |
_none_or_dtype,
|
6 |
_in_projection_packed,
|
7 |
)
|
8 |
+
from torch.nn import functional as F
|
9 |
+
import torch
|
10 |
# Tensor = torch.Tensor
|
11 |
# from typing import Callable, List, Optional, Tuple, Union
|
12 |
|
|
|
448 |
k = k.view(bsz, num_heads, src_len, head_dim)
|
449 |
v = v.view(bsz, num_heads, src_len, head_dim)
|
450 |
|
451 |
+
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
452 |
attn_output = scaled_dot_product_attention(
|
453 |
q, k, v, attn_mask, dropout_p, is_causal
|
454 |
)
|
455 |
+
|
456 |
attn_output = (
|
457 |
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
458 |
)
|
AR/modules/scaling.py
CHANGED
@@ -13,7 +13,7 @@
|
|
13 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
# See the License for the specific language governing permissions and
|
15 |
# limitations under the License.
|
16 |
-
|
17 |
import math
|
18 |
import random
|
19 |
from typing import Optional
|
|
|
13 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
# See the License for the specific language governing permissions and
|
15 |
# limitations under the License.
|
16 |
+
import logging
|
17 |
import math
|
18 |
import random
|
19 |
from typing import Optional
|
AR/text_processing/phonemizer.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
# modified from https://github.com/
|
|
|
2 |
import itertools
|
3 |
import re
|
4 |
from typing import Dict
|
|
|
1 |
+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/phonemizer.py
|
2 |
+
# reference: https://github.com/lifeiteng/vall-e
|
3 |
import itertools
|
4 |
import re
|
5 |
from typing import Dict
|
AR/text_processing/symbols.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
# modified from https://github.com/
|
|
|
2 |
PAD = "_"
|
3 |
PUNCTUATION = ';:,.!?¡¿—…"«»“” '
|
4 |
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
|
|
1 |
+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/symbols.py
|
2 |
+
# reference: https://github.com/lifeiteng/vall-e
|
3 |
PAD = "_"
|
4 |
PUNCTUATION = ';:,.!?¡¿—…"«»“” '
|
5 |
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
MODELS/21/1.mp3
DELETED
Binary file (30.9 kB)
|
|
MODELS/21/11.mp3
DELETED
Binary file (28 kB)
|
|
MODELS/21/191.mp3
DELETED
Binary file (29.5 kB)
|
|
MODELS/21/21.ckpt
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:c4b29bb398a9dbed95c50489a2633f90a01c0c4ae1e4432f5d37d388401f9887
|
3 |
-
size 155077753
|
|
|
|
|
|
|
|
MODELS/21/21.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:bfb359648e858765e9c1e3f7d51869aec9f607d18efd90d059cb83f1a7988141
|
3 |
-
size 84927748
|
|
|
|
|
|
|
|
MODELS/21/s1.mp3
DELETED
Binary file (29 kB)
|
|
MODELS/21/s2.mp3
DELETED
Binary file (29 kB)
|
|
MODELS/21/s3.mp3
DELETED
Binary file (28.5 kB)
|
|
MODELS/22/22.ckpt
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:c3632e3d1876f7a8e86850f346338c5e2390d09f382891277acf77a4e1a65a25
|
3 |
-
size 155083315
|
|
|
|
|
|
|
|
MODELS/22/22.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:3dfe7fe2765b179db75d8e12bd2b32e1f8d624dcee9a3fecdecbc94904757c29
|
3 |
-
size 84927982
|
|
|
|
|
|
|
|
MODELS/22/passion.mp3
DELETED
Binary file (131 kB)
|
|
MODELS/22/s1.mp3
DELETED
Binary file (26.8 kB)
|
|
MODELS/22/s2.mp3
DELETED
Binary file (33.1 kB)
|
|
MODELS/22/s3.mp3
DELETED
Binary file (30.2 kB)
|
|
MODELS/22/slow_calm.mp3
DELETED
Binary file (79.2 kB)
|
|
MODELS/22/speed.mp3
DELETED
Binary file (122 kB)
|
|
MODELS/31/1.mp3
DELETED
Binary file (111 kB)
|
|
MODELS/31/148.mp3
DELETED
Binary file (86.8 kB)
|
|
MODELS/31/31.ckpt
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:532d92b5b2a1550ed1151aa2d0a801a2fc390fc7b87a7d0278ca7af4cad50c7f
|
3 |
-
size 155084485
|
|
|
|
|
|
|
|
MODELS/31/31.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:e3d128cd00c3853ebe375dd5aeccd979c55a7e8d036cc41843507e2191ccd6d3
|
3 |
-
size 84929396
|
|
|
|
|
|
|
|
MODELS/31/96.mp3
DELETED
Binary file (83.4 kB)
|
|
MODELS/31/s1.mp3
DELETED
Binary file (32.2 kB)
|
|
MODELS/31/s2.mp3
DELETED
Binary file (43 kB)
|
|
MODELS/31/s3.mp3
DELETED
Binary file (39.1 kB)
|
|