File size: 3,088 Bytes
fa90792
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/python3
import os
import torch
import logging
from audiosr import super_resolution, build_model, save_wave, get_time, read_list
import argparse

os.environ["TOKENIZERS_PARALLELISM"] = "true"
matplotlib_logger = logging.getLogger('matplotlib')
matplotlib_logger.setLevel(logging.WARNING)

parser = argparse.ArgumentParser()

parser.add_argument(
    "-i",
    "--input_audio_file",
    type=str,
    required=False,
    help="Input audio file for audio super resolution",
)

parser.add_argument(
    "-il",
    "--input_file_list",
    type=str,
    required=False,
    default="",
    help="A file that contains all audio files that need to perform audio super resolution",
)

parser.add_argument(
    "-s",
    "--save_path",
    type=str,
    required=False,
    help="The path to save model output",
    default="./output",
)

parser.add_argument(
    "--model_name",
    type=str,
    required=False,
    help="The checkpoint you gonna use",
    default="basic",
    choices=["basic","speech"]
)

parser.add_argument(
    "-d",
    "--device",
    type=str,
    required=False,
    help="The device for computation. If not specified, the script will automatically choose the device based on your environment.",
    default="auto",
)

parser.add_argument(
    "--ddim_steps",
    type=int,
    required=False,
    default=50,
    help="The sampling step for DDIM",
)

parser.add_argument(
    "-gs",
    "--guidance_scale",
    type=float,
    required=False,
    default=3.5,
    help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
)

parser.add_argument(
    "--seed",
    type=int,
    required=False,
    default=42,
    help="Changing this value (any integer number) will lead to a different generation result.",
)

parser.add_argument(
    "--suffix",
    type=str,
    required=False,
    help="Suffix for the output file",
    default="_AudioSR_Processed_48K",
)

args = parser.parse_args()
torch.set_float32_matmul_precision("high")
save_path = os.path.join(args.save_path, get_time())

assert args.input_file_list is not None or args.input_audio_file is not None,"Please provide either a list of audio files or a single audio file"

input_file = args.input_audio_file
random_seed = args.seed
sample_rate=48000
latent_t_per_second=12.8
guidance_scale = args.guidance_scale

os.makedirs(save_path, exist_ok=True)
audiosr = build_model(model_name=args.model_name, device=args.device)

if(args.input_file_list):
    print("Generate audio based on the text prompts in %s" % args.input_file_list)
    files_todo = read_list(args.input_file_list)
else: 
    files_todo = [input_file]
    
for input_file in files_todo:
    name = os.path.splitext(os.path.basename(input_file))[0] + args.suffix
    
    waveform = super_resolution(
        audiosr,
        input_file,
        seed=random_seed,
        guidance_scale=guidance_scale,
        ddim_steps=args.ddim_steps,
        latent_t_per_second=latent_t_per_second
    )
    save_wave(waveform, inputpath=input_file, savepath=save_path, name=name, samplerate=sample_rate)