欧卫
'add_app_files'
58627fa
raw
history blame
1.06 kB
import torch
from contextlib import contextmanager
from colbert.utils.utils import NullContextManager
class MixedPrecisionManager():
def __init__(self, activated):
self.activated = activated
if self.activated:
self.scaler = torch.cuda.amp.GradScaler()
def context(self):
return torch.cuda.amp.autocast() if self.activated else NullContextManager()
def backward(self, loss):
if self.activated:
self.scaler.scale(loss).backward()
else:
loss.backward()
def step(self, colbert, optimizer, scheduler=None):
if self.activated:
self.scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0, error_if_nonfinite=False)
self.scaler.step(optimizer)
self.scaler.update()
else:
torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0)
optimizer.step()
if scheduler is not None:
scheduler.step()
optimizer.zero_grad()