File size: 3,121 Bytes
01e655b
 
 
 
 
 
 
 
 
 
84cfd61
 
01e655b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29536f1
01e655b
 
29536f1
49bce5c
 
01e655b
29536f1
01e655b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84cfd61
 
01e655b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import numpy as np
import torch

from modules.speaker import Speaker
from modules.utils.SeedContext import SeedContext

from modules import models, config

import logging

from modules import devices

logger = logging.getLogger(__name__)


@torch.inference_mode()
def generate_audio(
    text: str,
    temperature: float = 0.3,
    top_P: float = 0.7,
    top_K: float = 20,
    spk: int | Speaker = -1,
    infer_seed: int = -1,
    use_decoder: bool = True,
    prompt1: str = "",
    prompt2: str = "",
    prefix: str = "",
):
    (sample_rate, wav) = generate_audio_batch(
        [text],
        temperature=temperature,
        top_P=top_P,
        top_K=top_K,
        spk=spk,
        infer_seed=infer_seed,
        use_decoder=use_decoder,
        prompt1=prompt1,
        prompt2=prompt2,
        prefix=prefix,
    )[0]

    return (sample_rate, wav)


@torch.inference_mode()
def generate_audio_batch(
    texts: list[str],
    temperature: float = 0.3,
    top_P: float = 0.7,
    top_K: float = 20,
    spk: int | Speaker = -1,
    infer_seed: int = -1,
    use_decoder: bool = True,
    prompt1: str = "",
    prompt2: str = "",
    prefix: str = "",
):
    chat_tts = models.load_chat_tts()
    params_infer_code = {
        "spk_emb": None,
        "temperature": temperature,
        "top_P": top_P,
        "top_K": top_K,
        "prompt1": prompt1 or "",
        "prompt2": prompt2 or "",
        "prefix": prefix or "",
        "repetition_penalty": 1.0,
        "disable_tqdm": config.disable_tqdm,
    }

    if isinstance(spk, int):
        with SeedContext(spk):
            params_infer_code["spk_emb"] = chat_tts.sample_random_speaker()
        logger.info(("spk", spk))
    elif isinstance(spk, Speaker):
        params_infer_code["spk_emb"] = spk.emb
        logger.info(("spk", spk.name))
    else:
        raise ValueError("spk must be int or Speaker")

    logger.info(
        {
            "text": texts,
            "infer_seed": infer_seed,
            "temperature": temperature,
            "top_P": top_P,
            "top_K": top_K,
            "prompt1": prompt1 or "",
            "prompt2": prompt2 or "",
            "prefix": prefix or "",
        }
    )

    with SeedContext(infer_seed):
        wavs = chat_tts.generate_audio(
            texts, params_infer_code, use_decoder=use_decoder
        )

    sample_rate = 24000

    devices.torch_gc()

    return [(sample_rate, np.array(wav).flatten().astype(np.float32)) for wav in wavs]


if __name__ == "__main__":
    import soundfile as sf

    # 测试batch生成
    inputs = ["你好[lbreak]", "再见[lbreak]", "长度不同的文本片段[lbreak]"]
    outputs = generate_audio_batch(inputs, spk=5, infer_seed=42)

    for i, (sample_rate, wav) in enumerate(outputs):
        print(i, sample_rate, wav.shape)

        sf.write(f"batch_{i}.wav", wav, sample_rate, format="wav")

    # 单独生成
    for i, text in enumerate(inputs):
        sample_rate, wav = generate_audio(text, spk=5, infer_seed=42)
        print(i, sample_rate, wav.shape)

        sf.write(f"one_{i}.wav", wav, sample_rate, format="wav")