Hugo Flores Garcia commited on
Commit
ac059f4
1 Parent(s): bcc3305
conf/{interface.yml → interface-jazzpop-exp.yml} RENAMED
@@ -2,4 +2,8 @@ Interface.coarse_ckpt: /runs/jazzpop-coarse-1m-steps.pth
2
  Interface.coarse2fine_ckpt: /runs/jazzpop-c2f.pth
3
  Interface.codec_ckpt: /runs/codec-ckpt/codec.pth
4
  Interface.coarse_chunk_size_s: 5
5
- Interface.coarse2fine_chunk_size_s: 3
 
 
 
 
 
2
  Interface.coarse2fine_ckpt: /runs/jazzpop-c2f.pth
3
  Interface.codec_ckpt: /runs/codec-ckpt/codec.pth
4
  Interface.coarse_chunk_size_s: 5
5
+ Interface.coarse2fine_chunk_size_s: 3
6
+
7
+ AudioLoader.sources:
8
+ - /data/spotdl/audio/val
9
+ - /data/spotdl/audio/test
demo.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Tuple
3
+ import yaml
4
+
5
+ import numpy as np
6
+ import audiotools as at
7
+ import argbind
8
+
9
+ import gradio as gr
10
+ from vampnet.interface import Interface
11
+
12
+ conf = yaml.safe_load(Path("conf/interface-jazzpop-exp.yml").read_text())
13
+
14
+ Interface = argbind.bind(Interface)
15
+ AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
16
+ with argbind.scope(conf):
17
+ interface = Interface()
18
+ loader = AudioLoader()
19
+
20
+ dataset = at.data.datasets.AudioDataset(
21
+ loader,
22
+ sample_rate=interface.codec.sample_rate,
23
+ duration=interface.coarse.chunk_size_s,
24
+ n_examples=5000,
25
+ without_replacement=True,
26
+ )
27
+
28
+
29
+ def load_audio(file):
30
+ print(file)
31
+ filepath = file.name
32
+ sig = at.AudioSignal.salient_excerpt(
33
+ filepath,
34
+ duration=interface.coarse.chunk_size_s
35
+ )
36
+ sig = interface.preprocess(sig)
37
+
38
+ audio = sig.samples.numpy()[0]
39
+ sr = sig.sample_rate
40
+ return sr, audio.T
41
+
42
+ def load_random_audio():
43
+ index = np.random.randint(0, len(dataset))
44
+ sig = dataset[index]["signal"]
45
+ sig = interface.preprocess(sig)
46
+
47
+ audio = sig.samples.numpy()[0]
48
+ sr = sig.sample_rate
49
+ return sr, audio.T
50
+
51
+ def mask_audio(
52
+ prefix_s, suffix_s, rand_mask_intensity,
53
+ mask_periodic_amt, beat_unmask_dur,
54
+ mask_dwn_chk, dwn_factor,
55
+ mask_up_chk, up_factor
56
+ ):
57
+ pass
58
+
59
+ def vamp(
60
+ input_audio, prefix_s, suffix_s, rand_mask_intensity,
61
+ mask_periodic_amt, beat_unmask_dur,
62
+ mask_dwn_chk, dwn_factor,
63
+ mask_up_chk, up_factor
64
+ ):
65
+ print(input_audio)
66
+
67
+
68
+ with gr.Blocks() as demo:
69
+
70
+ gr.Markdown('# Vampnet')
71
+
72
+ with gr.Row():
73
+ # input audio
74
+ with gr.Column():
75
+ gr.Markdown("## Input Audio")
76
+
77
+ manual_audio_upload = gr.File(
78
+ label=f"upload some audio (will be randomly trimmed to max of {interface.coarse.chunk_size_s:.2f}s)",
79
+ file_types=["audio"]
80
+ )
81
+ load_random_audio_button = gr.Button("or load random audio")
82
+
83
+ input_audio = gr.Audio(
84
+ label="input audio",
85
+ interactive=False,
86
+ )
87
+ input_audio_viz = gr.HTML(
88
+ label="input audio",
89
+ )
90
+
91
+ # connect widgets
92
+ load_random_audio_button.click(
93
+ fn=load_random_audio,
94
+ inputs=[],
95
+ outputs=[ input_audio]
96
+ )
97
+
98
+ manual_audio_upload.change(
99
+ fn=load_audio,
100
+ inputs=[manual_audio_upload],
101
+ outputs=[ input_audio]
102
+ )
103
+
104
+
105
+ # mask settings
106
+ with gr.Column():
107
+ gr.Markdown("## Mask Settings")
108
+ prefix_s = gr.Slider(
109
+ label="prefix length (seconds)",
110
+ minimum=0.0,
111
+ maximum=10.0,
112
+ value=0.0
113
+ )
114
+ suffix_s = gr.Slider(
115
+ label="suffix length (seconds)",
116
+ minimum=0.0,
117
+ maximum=10.0,
118
+ value=0.0
119
+ )
120
+
121
+ rand_mask_intensity = gr.Slider(
122
+ label="random mask intensity (lower means more freedom)",
123
+ minimum=0.0,
124
+ maximum=1.0,
125
+ value=1.0
126
+ )
127
+
128
+ mask_periodic_amt = gr.Slider(
129
+ label="periodic unmasking factor (higher means more freedom)",
130
+ minimum=0,
131
+ maximum=32,
132
+ step=1,
133
+ value=2,
134
+ )
135
+ compute_mask_button = gr.Button("compute mask")
136
+ mask_output = gr.Audio(
137
+ label="masked audio",
138
+ interactive=False,
139
+ visible=False
140
+ )
141
+ mask_output_viz = gr.Video(
142
+ label="masked audio",
143
+ interactive=False
144
+ )
145
+
146
+ with gr.Column():
147
+ gr.Markdown("## Beat Unmasking")
148
+ with gr.Accordion(label="beat unmask"):
149
+ beat_unmask_dur = gr.Slider(
150
+ label="duration",
151
+ minimum=0.0,
152
+ maximum=3.0,
153
+ value=0.1
154
+ )
155
+ with gr.Accordion("downbeat settings"):
156
+ mask_dwn_chk = gr.Checkbox(
157
+ label="unmask downbeats",
158
+ value=True
159
+ )
160
+ dwn_factor = gr.Slider(
161
+ label="downbeat downsample factor (unmask every Nth downbeat)",
162
+ value=1,
163
+ minimum=1,
164
+ maximum=16,
165
+ step=1
166
+ )
167
+ with gr.Accordion("upbeat settings"):
168
+ mask_up_chk = gr.Checkbox(
169
+ label="unmask upbeats",
170
+ value=True
171
+ )
172
+ up_factor = gr.Slider(
173
+ label="upbeat downsample factor (unmask every Nth upbeat)",
174
+ value=1,
175
+ minimum=1,
176
+ maximum=16,
177
+ step=1
178
+ )
179
+
180
+ # process and output
181
+ with gr.Row():
182
+ with gr.Column():
183
+ vamp_button = gr.Button("vamp")
184
+
185
+ output_audio = gr.Audio(
186
+ label="output audio",
187
+ interactive=False,
188
+ visible=False
189
+ )
190
+ output_audio_viz = gr.Video(
191
+ label="output audio",
192
+ interactive=False
193
+ )
194
+
195
+ # connect widgets
196
+ compute_mask_button.click(
197
+ fn=mask_audio,
198
+ inputs=[
199
+ prefix_s, suffix_s, rand_mask_intensity,
200
+ mask_periodic_amt, beat_unmask_dur,
201
+ mask_dwn_chk, dwn_factor,
202
+ mask_up_chk, up_factor
203
+ ],
204
+ outputs=[mask_output, mask_output_viz]
205
+ )
206
+
207
+ # connect widgets
208
+ vamp_button.click(
209
+ fn=vamp,
210
+ inputs=[input_audio,
211
+ prefix_s, suffix_s, rand_mask_intensity,
212
+ mask_periodic_amt, beat_unmask_dur,
213
+ mask_dwn_chk, dwn_factor,
214
+ mask_up_chk, up_factor
215
+ ],
216
+ outputs=[output_audio, output_audio_viz]
217
+ )
218
+
219
+
220
+ demo.launch(share=True)
scripts/exp/eval.py CHANGED
@@ -57,30 +57,31 @@ def eval(
57
  cond_files = cond_files[:num_files]
58
  assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}"
59
 
60
- pbar = tqdm(zip(baseline_files, cond_files), total=len(baseline_files))
61
- for baseline_file, cond_file in pbar:
62
  # make sure the files match (same name)
63
  assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match"
64
- pbar.set_description(baseline_file.stem)
65
 
66
  # load the files
67
  baseline_sig = AudioSignal(str(baseline_file))
68
  cond_sig = AudioSignal(str(cond_file))
69
 
70
  # compute the metrics
71
- try:
72
- vsq = visqol(baseline_sig, cond_sig)
73
- except:
74
- vsq = 0.0
75
- metrics.append({
76
  "sisdr": -sisdr_loss(baseline_sig, cond_sig).item(),
77
  "stft": stft_loss(baseline_sig, cond_sig).item(),
78
  "mel": mel_loss(baseline_sig, cond_sig).item(),
79
  "frechet": frechet_score,
80
- "visqol": vsq,
81
  "condition": condition,
82
  "file": baseline_file.stem,
83
- })
 
 
 
84
 
85
  metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]
86
 
 
57
  cond_files = cond_files[:num_files]
58
  assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}"
59
 
60
+ def process(baseline_file, cond_file):
 
61
  # make sure the files match (same name)
62
  assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match"
 
63
 
64
  # load the files
65
  baseline_sig = AudioSignal(str(baseline_file))
66
  cond_sig = AudioSignal(str(cond_file))
67
 
68
  # compute the metrics
69
+ # try:
70
+ # vsq = visqol(baseline_sig, cond_sig)
71
+ # except:
72
+ # vsq = 0.0
73
+ return {
74
  "sisdr": -sisdr_loss(baseline_sig, cond_sig).item(),
75
  "stft": stft_loss(baseline_sig, cond_sig).item(),
76
  "mel": mel_loss(baseline_sig, cond_sig).item(),
77
  "frechet": frechet_score,
78
+ # "visqol": vsq,
79
  "condition": condition,
80
  "file": baseline_file.stem,
81
+ }
82
+
83
+ print(f"processing {len(baseline_files)} files in {baseline_dir} and {cond_dir}")
84
+ metrics.extend(tqdm(map(process, baseline_files, cond_files), total=len(baseline_files)))
85
 
86
  metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]
87
 
scripts/utils/vamp_folder.py CHANGED
@@ -1,8 +1,10 @@
1
  from pathlib import Path
 
 
2
 
3
  import argbind
4
  from tqdm import tqdm
5
- import torch
6
 
7
  from vampnet.interface import Interface
8
  import audiotools as at
@@ -12,9 +14,9 @@ Interface = argbind.bind(Interface)
12
  # condition wrapper for printing
13
  def condition(cond):
14
  def wrapper(sig, interface):
15
- print(f"Condition: {cond.__name__}")
16
  sig = cond(sig, interface)
17
- print(f"Condition: {cond.__name__} (done)\n")
18
  return sig
19
  return wrapper
20
 
@@ -49,48 +51,27 @@ def coarse2fine_argmax(sig, interface):
49
 
50
  @condition
51
  def one_codebook(sig, interface):
52
- z = interface.encode(sig)
53
-
54
- nb, _, nt = z.shape
55
- nc = interface.coarse.n_codebooks
56
- mask = torch.zeros(nb, nc, nt).to(interface.device)
57
- mask[:, 1:, :] = 1
58
-
59
  zv = interface.coarse_vamp_v2(
60
- sig, ext_mask=mask,
61
  )
62
  zv = interface.coarse_to_fine(zv)
63
 
64
  return interface.to_signal(zv)
65
 
66
- @condition
67
- def four_codebooks_downsampled_4x(sig, interface):
68
- zv = interface.coarse_vamp_v2(
69
- sig, downsample_factor=4
70
- )
71
- zv = interface.coarse_to_fine(zv)
72
- return interface.to_signal(zv)
73
-
74
  @condition
75
  def two_codebooks_downsampled_4x(sig, interface):
76
- z = interface.encode(sig)
77
-
78
- nb, _, nt = z.shape
79
- nc = interface.coarse.n_codebooks
80
- mask = torch.zeros(nb, nc, nt).to(interface.device)
81
- mask[:, 2:, :] = 1
82
-
83
  zv = interface.coarse_vamp_v2(
84
- sig, ext_mask=mask, downsample_factor=4
 
85
  )
86
  zv = interface.coarse_to_fine(zv)
87
 
88
  return interface.to_signal(zv)
89
 
90
- @condition
91
- def four_codebooks_downsampled_8x(sig, interface):
92
  zv = interface.coarse_vamp_v2(
93
- sig, downsample_factor=8
94
  )
95
  zv = interface.coarse_to_fine(zv)
96
  return interface.to_signal(zv)
@@ -101,9 +82,13 @@ COARSE_SAMPLE_CONDS ={
101
  "reconstructed": reconstructed,
102
  "coarse2fine": coarse2fine,
103
  "one_codebook": one_codebook,
104
- "four_codebooks_downsampled_4x": four_codebooks_downsampled_4x,
105
  "two_codebooks_downsampled_4x": two_codebooks_downsampled_4x,
106
- "four_codebooks_downsampled_8x": four_codebooks_downsampled_8x,
 
 
 
 
 
107
  }
108
 
109
  C2F_SAMPLE_CONDS = {
@@ -131,7 +116,7 @@ def main(
131
 
132
  from audiotools.data.datasets import AudioLoader, AudioDataset
133
 
134
- loader = AudioLoader(sources=sources)
135
  dataset = AudioDataset(loader,
136
  sample_rate=interface.codec.sample_rate,
137
  duration=interface.coarse.chunk_size_s,
@@ -141,7 +126,18 @@ def main(
141
 
142
  SAMPLE_CONDS = COARSE_SAMPLE_CONDS if exp_type == "coarse" else C2F_SAMPLE_CONDS
143
 
144
- for i in tqdm(range(max_excerpts)):
 
 
 
 
 
 
 
 
 
 
 
145
  sig = dataset[i]["signal"]
146
 
147
  results = {
 
1
  from pathlib import Path
2
+ import random
3
+ from typing import List
4
 
5
  import argbind
6
  from tqdm import tqdm
7
+ import argbind
8
 
9
  from vampnet.interface import Interface
10
  import audiotools as at
 
14
  # condition wrapper for printing
15
  def condition(cond):
16
  def wrapper(sig, interface):
17
+ # print(f"Condition: {cond.__name__}")
18
  sig = cond(sig, interface)
19
+ # print(f"Condition: {cond.__name__} (done)\n")
20
  return sig
21
  return wrapper
22
 
 
51
 
52
  @condition
53
  def one_codebook(sig, interface):
 
 
 
 
 
 
 
54
  zv = interface.coarse_vamp_v2(
55
+ sig, n_conditioning_codebooks=1
56
  )
57
  zv = interface.coarse_to_fine(zv)
58
 
59
  return interface.to_signal(zv)
60
 
 
 
 
 
 
 
 
 
61
  @condition
62
  def two_codebooks_downsampled_4x(sig, interface):
 
 
 
 
 
 
 
63
  zv = interface.coarse_vamp_v2(
64
+ sig, n_conditioning_codebooks=2,
65
+ downsample_factor=4
66
  )
67
  zv = interface.coarse_to_fine(zv)
68
 
69
  return interface.to_signal(zv)
70
 
71
+
72
+ def four_codebooks_downsampled(sig, interface, x=12):
73
  zv = interface.coarse_vamp_v2(
74
+ sig, downsample_factor=12
75
  )
76
  zv = interface.coarse_to_fine(zv)
77
  return interface.to_signal(zv)
 
82
  "reconstructed": reconstructed,
83
  "coarse2fine": coarse2fine,
84
  "one_codebook": one_codebook,
 
85
  "two_codebooks_downsampled_4x": two_codebooks_downsampled_4x,
86
+ # four codebooks at different downsample factors
87
+ **{
88
+ f"four_codebooks_downsampled_{x}x": lambda sig, interface: four_codebooks_downsampled(sig, interface, x=x)
89
+ for x in [4, 8, 12, 16, 20, 24]
90
+ }
91
+
92
  }
93
 
94
  C2F_SAMPLE_CONDS = {
 
116
 
117
  from audiotools.data.datasets import AudioLoader, AudioDataset
118
 
119
+ loader = AudioLoader(sources=sources, shuffle_state=seed)
120
  dataset = AudioDataset(loader,
121
  sample_rate=interface.codec.sample_rate,
122
  duration=interface.coarse.chunk_size_s,
 
126
 
127
  SAMPLE_CONDS = COARSE_SAMPLE_CONDS if exp_type == "coarse" else C2F_SAMPLE_CONDS
128
 
129
+
130
+ indices = list(range(max_excerpts))
131
+ random.shuffle(indices)
132
+ for i in tqdm(indices):
133
+ # if all our files are already there, skip
134
+ # done = []
135
+ # for name in SAMPLE_CONDS:
136
+ # o_dir = Path(output_dir) / name
137
+ # done.append((o_dir / f"{i}.wav").exists())
138
+ # if all(done):
139
+ # continue
140
+
141
  sig = dataset[i]["signal"]
142
 
143
  results = {
setup.py CHANGED
@@ -26,16 +26,15 @@ setup(
26
  license="MIT",
27
  packages=find_packages(),
28
  install_requires=[
29
- "torch<=1.11.0",
30
  "argbind>=0.3.2",
31
  "pytorch-ignite",
32
  "rich",
33
- "audiotools @ git+https://github.com/descriptinc/lyrebird-audiotools.git@0.6.3",
34
- "lac @ git+https://github.com/descriptinc/lyrebird-audio-codec.git@main",
35
  "tqdm",
36
  "tensorboard",
37
  "google-cloud-logging==2.2.0",
38
- "torchmetrics>=0.7.3",
39
  "einops",
40
  "frechet_audio_distance"
41
  ],
 
26
  license="MIT",
27
  packages=find_packages(),
28
  install_requires=[
29
+ "torch",
30
  "argbind>=0.3.2",
31
  "pytorch-ignite",
32
  "rich",
33
+ "audiotools @ git+https://github.com/descriptinc/lyrebird-audiotools.git@hf/backup-info",
34
+ "lac @ git+https://github.com/descriptinc/lyrebird-audio-codec.git",
35
  "tqdm",
36
  "tensorboard",
37
  "google-cloud-logging==2.2.0",
 
38
  "einops",
39
  "frechet_audio_distance"
40
  ],
vampnet/gradio.py DELETED
@@ -1,4 +0,0 @@
1
-
2
- import gradio as gr
3
-
4
-
 
 
 
 
 
vampnet/interface.py CHANGED
@@ -315,6 +315,7 @@ class Interface:
315
  debug=False,
316
  swap_prefix_suffix=False,
317
  ext_mask=None,
 
318
  verbose=False,
319
  **kwargs
320
  ):
@@ -351,7 +352,8 @@ class Interface:
351
  n_suffix=n_suffix,
352
  downsample_factor=downsample_factor,
353
  mask=cz_mask,
354
- ext_mask=ext_mask
 
355
  )
356
  if debug:
357
  print("tokens to infer")
 
315
  debug=False,
316
  swap_prefix_suffix=False,
317
  ext_mask=None,
318
+ n_conditioning_codebooks=None,
319
  verbose=False,
320
  **kwargs
321
  ):
 
352
  n_suffix=n_suffix,
353
  downsample_factor=downsample_factor,
354
  mask=cz_mask,
355
+ ext_mask=ext_mask,
356
+ n_conditioning_codebooks=n_conditioning_codebooks
357
  )
358
  if debug:
359
  print("tokens to infer")
vampnet/modules/base.py CHANGED
@@ -41,6 +41,7 @@ class VampBase(at.ml.BaseModel):
41
  n_prefix: Optional[torch.Tensor] = None,
42
  n_suffix: Optional[torch.Tensor] = None,
43
  downsample_factor: Optional[int] = None,
 
44
  ) -> Tuple[torch.Tensor, torch.Tensor]:
45
  assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
46
 
@@ -79,7 +80,8 @@ class VampBase(at.ml.BaseModel):
79
  mask = mask.round().long()
80
 
81
  # if we have any conditioning codebooks, set their mask to 0
82
- mask[:, : self.n_conditioning_codebooks, :] = 0
 
83
  else:
84
  assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq)"
85
  assert mask.shape == x.shape, "mask must be same shape as x"
 
41
  n_prefix: Optional[torch.Tensor] = None,
42
  n_suffix: Optional[torch.Tensor] = None,
43
  downsample_factor: Optional[int] = None,
44
+ n_conditioning_codebooks: Optional[int] = None,
45
  ) -> Tuple[torch.Tensor, torch.Tensor]:
46
  assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
47
 
 
80
  mask = mask.round().long()
81
 
82
  # if we have any conditioning codebooks, set their mask to 0
83
+ n_conditioning_codebooks = n_conditioning_codebooks or self.n_conditioning_codebooks
84
+ mask[:, :n_conditioning_codebooks, :] = 0
85
  else:
86
  assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq)"
87
  assert mask.shape == x.shape, "mask must be same shape as x"
vampnet/util.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+ # import pathos
3
+
4
+ def process_map(fn, *iterables, **tqdm_kwargs):
5
+ """
6
+ Equivalent of `list(map(fn, *iterables))`
7
+ driven by `concurrent.futures.ProcessPoolExecutor`.
8
+
9
+ Parameters
10
+ ----------
11
+ tqdm_class : optional
12
+ `tqdm` class to use for bars [default: tqdm.auto.tqdm].
13
+ max_workers : int, optional
14
+ Maximum number of workers to spawn; passed to
15
+ `concurrent.futures.ProcessPoolExecutor.__init__`.
16
+ [default: min(32, cpu_count() + 4)].
17
+ chunksize : int, optional
18
+ Size of chunks sent to worker processes; passed to
19
+ `concurrent.futures.ProcessPoolExecutor.map`. [default: 1].
20
+ lock_name : str, optional
21
+ Member of `tqdm_class.get_lock()` to use [default: mp_lock].
22
+ """
23
+ from concurrent.futures import ProcessPoolExecutor
24
+ if iterables and "chunksize" not in tqdm_kwargs:
25
+ # default `chunksize=1` has poor performance for large iterables
26
+ # (most time spent dispatching items to workers).
27
+ longest_iterable_len = max(map(length_hint, iterables))
28
+ if longest_iterable_len > 1000:
29
+ from warnings import warn
30
+ warn("Iterable length %d > 1000 but `chunksize` is not set."
31
+ " This may seriously degrade multiprocess performance."
32
+ " Set `chunksize=1` or more." % longest_iterable_len,
33
+ TqdmWarning, stacklevel=2)
34
+ if "lock_name" not in tqdm_kwargs:
35
+ tqdm_kwargs = tqdm_kwargs.copy()
36
+ tqdm_kwargs["lock_name"] = "mp_lock"
37
+ return _executor_map(ProcessPoolExecutor, fn, *iterables, **tqdm_kwargs)
38
+
39
+
40
+ def parallelize(
41
+ fn,
42
+ *iterables,
43
+ parallel: str = "thread_map",
44
+ **kwargs
45
+ ):
46
+ if parallel == "thread_map":
47
+ from tqdm.contrib.concurrent import thread_map
48
+ return thread_map(
49
+ fn,
50
+ *iterables,
51
+ **kwargs
52
+ )
53
+ elif parallel == "process_map":
54
+ from tqdm.contrib.concurrent import process_map
55
+ return process_map(
56
+ fn,
57
+ *iterables,
58
+ **kwargs
59
+ )
60
+ elif parallel == "single":
61
+ return [fn(x) for x in tqdm.tqdm(*iterables)]
62
+ else:
63
+ raise ValueError(f"parallel must be one of 'thread_map', 'process_map', 'single', but got {parallel}")