keithhon commited on
Commit
e9e7667
1 Parent(s): 9e6195f

Upload encoder/train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. encoder/train.py +123 -0
encoder/train.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.visualizations import Visualizations
2
+ from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
3
+ from encoder.params_model import *
4
+ from encoder.model import SpeakerEncoder
5
+ from utils.profiler import Profiler
6
+ from pathlib import Path
7
+ import torch
8
+
9
+ def sync(device: torch.device):
10
+ # For correct profiling (cuda operations are async)
11
+ if device.type == "cuda":
12
+ torch.cuda.synchronize(device)
13
+
14
+
15
+ def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
16
+ backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
17
+ no_visdom: bool):
18
+ # Create a dataset and a dataloader
19
+ dataset = SpeakerVerificationDataset(clean_data_root)
20
+ loader = SpeakerVerificationDataLoader(
21
+ dataset,
22
+ speakers_per_batch,
23
+ utterances_per_speaker,
24
+ num_workers=8,
25
+ )
26
+
27
+ # Setup the device on which to run the forward pass and the loss. These can be different,
28
+ # because the forward pass is faster on the GPU whereas the loss is often (depending on your
29
+ # hyperparameters) faster on the CPU.
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ # FIXME: currently, the gradient is None if loss_device is cuda
32
+ loss_device = torch.device("cpu")
33
+
34
+ # Create the model and the optimizer
35
+ model = SpeakerEncoder(device, loss_device)
36
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
37
+ init_step = 1
38
+
39
+ # Configure file path for the model
40
+ state_fpath = models_dir.joinpath(run_id + ".pt")
41
+ backup_dir = models_dir.joinpath(run_id + "_backups")
42
+
43
+ # Load any existing model
44
+ if not force_restart:
45
+ if state_fpath.exists():
46
+ print("Found existing model \"%s\", loading it and resuming training." % run_id)
47
+ checkpoint = torch.load(state_fpath)
48
+ init_step = checkpoint["step"]
49
+ model.load_state_dict(checkpoint["model_state"])
50
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
51
+ optimizer.param_groups[0]["lr"] = learning_rate_init
52
+ else:
53
+ print("No model \"%s\" found, starting training from scratch." % run_id)
54
+ else:
55
+ print("Starting the training from scratch.")
56
+ model.train()
57
+
58
+ # Initialize the visualization environment
59
+ vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
60
+ vis.log_dataset(dataset)
61
+ vis.log_params()
62
+ device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
63
+ vis.log_implementation({"Device": device_name})
64
+
65
+ # Training loop
66
+ profiler = Profiler(summarize_every=10, disabled=False)
67
+ for step, speaker_batch in enumerate(loader, init_step):
68
+ profiler.tick("Blocking, waiting for batch (threaded)")
69
+
70
+ # Forward pass
71
+ inputs = torch.from_numpy(speaker_batch.data).to(device)
72
+ sync(device)
73
+ profiler.tick("Data to %s" % device)
74
+ embeds = model(inputs)
75
+ sync(device)
76
+ profiler.tick("Forward pass")
77
+ embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
78
+ loss, eer = model.loss(embeds_loss)
79
+ sync(loss_device)
80
+ profiler.tick("Loss")
81
+
82
+ # Backward pass
83
+ model.zero_grad()
84
+ loss.backward()
85
+ profiler.tick("Backward pass")
86
+ model.do_gradient_ops()
87
+ optimizer.step()
88
+ profiler.tick("Parameter update")
89
+
90
+ # Update visualizations
91
+ # learning_rate = optimizer.param_groups[0]["lr"]
92
+ vis.update(loss.item(), eer, step)
93
+
94
+ # Draw projections and save them to the backup folder
95
+ if umap_every != 0 and step % umap_every == 0:
96
+ print("Drawing and saving projections (step %d)" % step)
97
+ backup_dir.mkdir(exist_ok=True)
98
+ projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step))
99
+ embeds = embeds.detach().cpu().numpy()
100
+ vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
101
+ vis.save()
102
+
103
+ # Overwrite the latest version of the model
104
+ if save_every != 0 and step % save_every == 0:
105
+ print("Saving the model (step %d)" % step)
106
+ torch.save({
107
+ "step": step + 1,
108
+ "model_state": model.state_dict(),
109
+ "optimizer_state": optimizer.state_dict(),
110
+ }, state_fpath)
111
+
112
+ # Make a backup
113
+ if backup_every != 0 and step % backup_every == 0:
114
+ print("Making a backup (step %d)" % step)
115
+ backup_dir.mkdir(exist_ok=True)
116
+ backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step))
117
+ torch.save({
118
+ "step": step + 1,
119
+ "model_state": model.state_dict(),
120
+ "optimizer_state": optimizer.state_dict(),
121
+ }, backup_fpath)
122
+
123
+ profiler.tick("Extras (visualizations, saving)")