Spaces:
Runtime error
Runtime error
File size: 1,058 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import torch
from contextlib import contextmanager
from colbert.utils.utils import NullContextManager
class MixedPrecisionManager():
def __init__(self, activated):
self.activated = activated
if self.activated:
self.scaler = torch.cuda.amp.GradScaler()
def context(self):
return torch.cuda.amp.autocast() if self.activated else NullContextManager()
def backward(self, loss):
if self.activated:
self.scaler.scale(loss).backward()
else:
loss.backward()
def step(self, colbert, optimizer, scheduler=None):
if self.activated:
self.scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0, error_if_nonfinite=False)
self.scaler.step(optimizer)
self.scaler.update()
else:
torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0)
optimizer.step()
if scheduler is not None:
scheduler.step()
optimizer.zero_grad()
|