vshirasuna
commited on
Commit
•
2ca1091
1
Parent(s):
82d7dbd
Added Gradient Checkpointing and fix bugs
Browse files- smi-ted/training/trainer.py +36 -2
smi-ted/training/trainer.py
CHANGED
@@ -2,12 +2,16 @@
|
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
import torch.nn.functional as F
|
|
|
5 |
from torch.utils.data import DataLoader
|
6 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
7 |
|
8 |
# Standard library
|
9 |
from tqdm import tqdm
|
10 |
import pandas as pd
|
|
|
|
|
11 |
import os
|
12 |
|
13 |
|
@@ -41,6 +45,7 @@ class Trainer:
|
|
41 |
self.model = DDP(self.model, device_ids=[self.local_rank])
|
42 |
|
43 |
def _load_checkpoint(self, checkpoint_path):
|
|
|
44 |
loc = f"cuda:{self.local_rank}"
|
45 |
ckpt_dict = torch.load(checkpoint_path, map_location=loc)
|
46 |
if os.path.exists(os.path.join(self.save_checkpoint_path, 'OPTIMIZER_STATES.pt')):
|
@@ -262,6 +267,12 @@ class TrainerEncoderDecoder(Trainer):
|
|
262 |
if self.local_rank == 0:
|
263 |
loss_list.to_csv(os.path.join(self.config.save_checkpoint_path, f'training_loss_epoch{epoch}.csv'), index=False)
|
264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
def _run_batch(self, batch_idx, bucket_idx_masked, bucket_targets, bucket_idx_not_masked):
|
266 |
self.optimE.zero_grad(set_to_none=True)
|
267 |
self.optimD.zero_grad(set_to_none=True)
|
@@ -292,7 +303,13 @@ class TrainerEncoderDecoder(Trainer):
|
|
292 |
for param in self.model.module.decoder.parameters():
|
293 |
param.requires_grad = False
|
294 |
|
295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
logits = logits.view(-1, logits.size(-1))
|
297 |
targets = targets.view(-1)
|
298 |
errorE_tmp = self.criterionC(logits, targets) / len(bucket_idx_masked)
|
@@ -370,6 +387,12 @@ class TrainerDirectDecoder(Trainer):
|
|
370 |
self.criterionC = nn.CrossEntropyLoss(ignore_index=-100)
|
371 |
self.criterionR = nn.MSELoss()
|
372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
def _run_batch(self, bucket_idx_masked, bucket_targets, bucket_idx_not_masked):
|
374 |
padding_idx = 2
|
375 |
error = torch.zeros(1).to(self.local_rank)
|
@@ -385,7 +408,18 @@ class TrainerDirectDecoder(Trainer):
|
|
385 |
mask = (idx_masked != padding_idx)
|
386 |
|
387 |
# encoder forward
|
388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
389 |
|
390 |
# add padding
|
391 |
input_mask_expanded = mask.unsqueeze(-1).expand(true_cte.size()).float()
|
|
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
import torch.nn.functional as F
|
5 |
+
import torch.utils.checkpoint as checkpoint
|
6 |
from torch.utils.data import DataLoader
|
7 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
8 |
+
from fast_transformers.masking import LengthMask
|
9 |
|
10 |
# Standard library
|
11 |
from tqdm import tqdm
|
12 |
import pandas as pd
|
13 |
+
import numpy as np
|
14 |
+
import random
|
15 |
import os
|
16 |
|
17 |
|
|
|
45 |
self.model = DDP(self.model, device_ids=[self.local_rank])
|
46 |
|
47 |
def _load_checkpoint(self, checkpoint_path):
|
48 |
+
opt_dict = None
|
49 |
loc = f"cuda:{self.local_rank}"
|
50 |
ckpt_dict = torch.load(checkpoint_path, map_location=loc)
|
51 |
if os.path.exists(os.path.join(self.save_checkpoint_path, 'OPTIMIZER_STATES.pt')):
|
|
|
267 |
if self.local_rank == 0:
|
268 |
loss_list.to_csv(os.path.join(self.config.save_checkpoint_path, f'training_loss_epoch{epoch}.csv'), index=False)
|
269 |
|
270 |
+
def custom(self, module):
|
271 |
+
def custom_forward(*inputs):
|
272 |
+
inputs = module(inputs[0])
|
273 |
+
return inputs
|
274 |
+
return custom_forward
|
275 |
+
|
276 |
def _run_batch(self, batch_idx, bucket_idx_masked, bucket_targets, bucket_idx_not_masked):
|
277 |
self.optimE.zero_grad(set_to_none=True)
|
278 |
self.optimD.zero_grad(set_to_none=True)
|
|
|
303 |
for param in self.model.module.decoder.parameters():
|
304 |
param.requires_grad = False
|
305 |
|
306 |
+
# encoder forward
|
307 |
+
x = self.model.module.encoder.tok_emb(idx_masked)
|
308 |
+
x = self.model.module.encoder.drop(x)
|
309 |
+
x = checkpoint.checkpoint(self.custom(self.model.module.encoder.blocks), x)
|
310 |
+
logits = self.model.module.encoder.lang_model(x)
|
311 |
+
|
312 |
+
# loss function
|
313 |
logits = logits.view(-1, logits.size(-1))
|
314 |
targets = targets.view(-1)
|
315 |
errorE_tmp = self.criterionC(logits, targets) / len(bucket_idx_masked)
|
|
|
387 |
self.criterionC = nn.CrossEntropyLoss(ignore_index=-100)
|
388 |
self.criterionR = nn.MSELoss()
|
389 |
|
390 |
+
def custom(self, module):
|
391 |
+
def custom_forward(*inputs):
|
392 |
+
inputs = module(inputs[0], length_mask=inputs[1])
|
393 |
+
return inputs
|
394 |
+
return custom_forward
|
395 |
+
|
396 |
def _run_batch(self, bucket_idx_masked, bucket_targets, bucket_idx_not_masked):
|
397 |
padding_idx = 2
|
398 |
error = torch.zeros(1).to(self.local_rank)
|
|
|
408 |
mask = (idx_masked != padding_idx)
|
409 |
|
410 |
# encoder forward
|
411 |
+
x = self.model.module.encoder.tok_emb(idx_masked)
|
412 |
+
x = self.model.module.encoder.drop(x)
|
413 |
+
x = checkpoint.checkpoint(self.custom(self.model.module.encoder.blocks), x, LengthMask(mask.sum(-1), max_len=idx_masked.shape[1]))
|
414 |
+
|
415 |
+
# mean pooling
|
416 |
+
input_masked_expanded = mask.unsqueeze(-1).expand(x.size()).float()
|
417 |
+
sum_embeddings = torch.sum(x*input_masked_expanded, 1)
|
418 |
+
sum_mask = torch.clamp(input_masked_expanded.sum(1), min=1e-9)
|
419 |
+
true_set = sum_embeddings/sum_mask
|
420 |
+
true_cte = x
|
421 |
+
del x
|
422 |
+
torch.cuda.empty_cache()
|
423 |
|
424 |
# add padding
|
425 |
input_mask_expanded = mask.unsqueeze(-1).expand(true_cte.size()).float()
|