Spaces:
Runtime error
Runtime error
Upload encoder/train.py with huggingface_hub
Browse files- encoder/train.py +123 -0
encoder/train.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from encoder.visualizations import Visualizations
|
2 |
+
from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
|
3 |
+
from encoder.params_model import *
|
4 |
+
from encoder.model import SpeakerEncoder
|
5 |
+
from utils.profiler import Profiler
|
6 |
+
from pathlib import Path
|
7 |
+
import torch
|
8 |
+
|
9 |
+
def sync(device: torch.device):
|
10 |
+
# For correct profiling (cuda operations are async)
|
11 |
+
if device.type == "cuda":
|
12 |
+
torch.cuda.synchronize(device)
|
13 |
+
|
14 |
+
|
15 |
+
def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
|
16 |
+
backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
|
17 |
+
no_visdom: bool):
|
18 |
+
# Create a dataset and a dataloader
|
19 |
+
dataset = SpeakerVerificationDataset(clean_data_root)
|
20 |
+
loader = SpeakerVerificationDataLoader(
|
21 |
+
dataset,
|
22 |
+
speakers_per_batch,
|
23 |
+
utterances_per_speaker,
|
24 |
+
num_workers=8,
|
25 |
+
)
|
26 |
+
|
27 |
+
# Setup the device on which to run the forward pass and the loss. These can be different,
|
28 |
+
# because the forward pass is faster on the GPU whereas the loss is often (depending on your
|
29 |
+
# hyperparameters) faster on the CPU.
|
30 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
31 |
+
# FIXME: currently, the gradient is None if loss_device is cuda
|
32 |
+
loss_device = torch.device("cpu")
|
33 |
+
|
34 |
+
# Create the model and the optimizer
|
35 |
+
model = SpeakerEncoder(device, loss_device)
|
36 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
|
37 |
+
init_step = 1
|
38 |
+
|
39 |
+
# Configure file path for the model
|
40 |
+
state_fpath = models_dir.joinpath(run_id + ".pt")
|
41 |
+
backup_dir = models_dir.joinpath(run_id + "_backups")
|
42 |
+
|
43 |
+
# Load any existing model
|
44 |
+
if not force_restart:
|
45 |
+
if state_fpath.exists():
|
46 |
+
print("Found existing model \"%s\", loading it and resuming training." % run_id)
|
47 |
+
checkpoint = torch.load(state_fpath)
|
48 |
+
init_step = checkpoint["step"]
|
49 |
+
model.load_state_dict(checkpoint["model_state"])
|
50 |
+
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
51 |
+
optimizer.param_groups[0]["lr"] = learning_rate_init
|
52 |
+
else:
|
53 |
+
print("No model \"%s\" found, starting training from scratch." % run_id)
|
54 |
+
else:
|
55 |
+
print("Starting the training from scratch.")
|
56 |
+
model.train()
|
57 |
+
|
58 |
+
# Initialize the visualization environment
|
59 |
+
vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
|
60 |
+
vis.log_dataset(dataset)
|
61 |
+
vis.log_params()
|
62 |
+
device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
|
63 |
+
vis.log_implementation({"Device": device_name})
|
64 |
+
|
65 |
+
# Training loop
|
66 |
+
profiler = Profiler(summarize_every=10, disabled=False)
|
67 |
+
for step, speaker_batch in enumerate(loader, init_step):
|
68 |
+
profiler.tick("Blocking, waiting for batch (threaded)")
|
69 |
+
|
70 |
+
# Forward pass
|
71 |
+
inputs = torch.from_numpy(speaker_batch.data).to(device)
|
72 |
+
sync(device)
|
73 |
+
profiler.tick("Data to %s" % device)
|
74 |
+
embeds = model(inputs)
|
75 |
+
sync(device)
|
76 |
+
profiler.tick("Forward pass")
|
77 |
+
embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
|
78 |
+
loss, eer = model.loss(embeds_loss)
|
79 |
+
sync(loss_device)
|
80 |
+
profiler.tick("Loss")
|
81 |
+
|
82 |
+
# Backward pass
|
83 |
+
model.zero_grad()
|
84 |
+
loss.backward()
|
85 |
+
profiler.tick("Backward pass")
|
86 |
+
model.do_gradient_ops()
|
87 |
+
optimizer.step()
|
88 |
+
profiler.tick("Parameter update")
|
89 |
+
|
90 |
+
# Update visualizations
|
91 |
+
# learning_rate = optimizer.param_groups[0]["lr"]
|
92 |
+
vis.update(loss.item(), eer, step)
|
93 |
+
|
94 |
+
# Draw projections and save them to the backup folder
|
95 |
+
if umap_every != 0 and step % umap_every == 0:
|
96 |
+
print("Drawing and saving projections (step %d)" % step)
|
97 |
+
backup_dir.mkdir(exist_ok=True)
|
98 |
+
projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step))
|
99 |
+
embeds = embeds.detach().cpu().numpy()
|
100 |
+
vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
|
101 |
+
vis.save()
|
102 |
+
|
103 |
+
# Overwrite the latest version of the model
|
104 |
+
if save_every != 0 and step % save_every == 0:
|
105 |
+
print("Saving the model (step %d)" % step)
|
106 |
+
torch.save({
|
107 |
+
"step": step + 1,
|
108 |
+
"model_state": model.state_dict(),
|
109 |
+
"optimizer_state": optimizer.state_dict(),
|
110 |
+
}, state_fpath)
|
111 |
+
|
112 |
+
# Make a backup
|
113 |
+
if backup_every != 0 and step % backup_every == 0:
|
114 |
+
print("Making a backup (step %d)" % step)
|
115 |
+
backup_dir.mkdir(exist_ok=True)
|
116 |
+
backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step))
|
117 |
+
torch.save({
|
118 |
+
"step": step + 1,
|
119 |
+
"model_state": model.state_dict(),
|
120 |
+
"optimizer_state": optimizer.state_dict(),
|
121 |
+
}, backup_fpath)
|
122 |
+
|
123 |
+
profiler.tick("Extras (visualizations, saving)")
|