Spaces:
Runtime error
Runtime error
import torch | |
from contextlib import contextmanager | |
from colbert.utils.utils import NullContextManager | |
class MixedPrecisionManager(): | |
def __init__(self, activated): | |
self.activated = activated | |
if self.activated: | |
self.scaler = torch.cuda.amp.GradScaler() | |
def context(self): | |
return torch.cuda.amp.autocast() if self.activated else NullContextManager() | |
def backward(self, loss): | |
if self.activated: | |
self.scaler.scale(loss).backward() | |
else: | |
loss.backward() | |
def step(self, colbert, optimizer, scheduler=None): | |
if self.activated: | |
self.scaler.unscale_(optimizer) | |
torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0, error_if_nonfinite=False) | |
self.scaler.step(optimizer) | |
self.scaler.update() | |
else: | |
torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0) | |
optimizer.step() | |
if scheduler is not None: | |
scheduler.step() | |
optimizer.zero_grad() | |