Spaces:
Runtime error
Runtime error
import time | |
import torch | |
import random | |
import torch.nn as nn | |
import numpy as np | |
from transformers import AdamW, get_linear_schedule_with_warmup | |
from colbert.infra import ColBERTConfig | |
from colbert.training.rerank_batcher import RerankBatcher | |
from colbert.utils.amp import MixedPrecisionManager | |
from colbert.training.lazy_batcher import LazyBatcher | |
from colbert.parameters import DEVICE | |
from colbert.modeling.colbert import ColBERT | |
from colbert.modeling.reranker.electra import ElectraReranker | |
from colbert.utils.utils import print_message | |
from colbert.training.utils import print_progress, manage_checkpoints | |
def train(config: ColBERTConfig, triples, queries=None, collection=None): | |
config.checkpoint = config.checkpoint or 'bert-base-uncased' | |
if config.rank < 1: | |
config.help() | |
random.seed(12345) | |
np.random.seed(12345) | |
torch.manual_seed(12345) | |
torch.cuda.manual_seed_all(12345) | |
assert config.bsize % config.nranks == 0, (config.bsize, config.nranks) | |
config.bsize = config.bsize // config.nranks | |
print("Using config.bsize =", config.bsize, "(per process) and config.accumsteps =", config.accumsteps) | |
if collection is not None: | |
if config.reranker: | |
reader = RerankBatcher(config, triples, queries, collection, (0 if config.rank == -1 else config.rank), config.nranks) | |
else: | |
reader = LazyBatcher(config, triples, queries, collection, (0 if config.rank == -1 else config.rank), config.nranks) | |
else: | |
raise NotImplementedError() | |
if not config.reranker: | |
colbert = ColBERT(name=config.checkpoint, colbert_config=config) | |
else: | |
colbert = ElectraReranker.from_pretrained(config.checkpoint) | |
colbert = colbert.to(DEVICE) | |
colbert.train() | |
colbert = torch.nn.parallel.DistributedDataParallel(colbert, device_ids=[config.rank], | |
output_device=config.rank, | |
find_unused_parameters=True) | |
optimizer = AdamW(filter(lambda p: p.requires_grad, colbert.parameters()), lr=config.lr, eps=1e-8) | |
optimizer.zero_grad() | |
scheduler = None | |
if config.warmup is not None: | |
print(f"#> LR will use {config.warmup} warmup steps and linear decay over {config.maxsteps} steps.") | |
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=config.warmup, | |
num_training_steps=config.maxsteps) | |
warmup_bert = config.warmup_bert | |
if warmup_bert is not None: | |
set_bert_grad(colbert, False) | |
amp = MixedPrecisionManager(config.amp) | |
labels = torch.zeros(config.bsize, dtype=torch.long, device=DEVICE) | |
start_time = time.time() | |
train_loss = None | |
train_loss_mu = 0.999 | |
start_batch_idx = 0 | |
# if config.resume: | |
# assert config.checkpoint is not None | |
# start_batch_idx = checkpoint['batch'] | |
# reader.skip_to_batch(start_batch_idx, checkpoint['arguments']['bsize']) | |
for batch_idx, BatchSteps in zip(range(start_batch_idx, config.maxsteps), reader): | |
if (warmup_bert is not None) and warmup_bert <= batch_idx: | |
set_bert_grad(colbert, True) | |
warmup_bert = None | |
this_batch_loss = 0.0 | |
for batch in BatchSteps: | |
with amp.context(): | |
try: | |
queries, passages, target_scores = batch | |
encoding = [queries, passages] | |
except: | |
encoding, target_scores = batch | |
encoding = [encoding.to(DEVICE)] | |
scores = colbert(*encoding) | |
if config.use_ib_negatives: | |
scores, ib_loss = scores | |
scores = scores.view(-1, config.nway) | |
if len(target_scores) and not config.ignore_scores: | |
target_scores = torch.tensor(target_scores).view(-1, config.nway).to(DEVICE) | |
target_scores = target_scores * config.distillation_alpha | |
target_scores = torch.nn.functional.log_softmax(target_scores, dim=-1) | |
log_scores = torch.nn.functional.log_softmax(scores, dim=-1) | |
loss = torch.nn.KLDivLoss(reduction='batchmean', log_target=True)(log_scores, target_scores) | |
else: | |
loss = nn.CrossEntropyLoss()(scores, labels[:scores.size(0)]) | |
if config.use_ib_negatives: | |
if config.rank < 1: | |
print('\t\t\t\t', loss.item(), ib_loss.item()) | |
loss += ib_loss | |
loss = loss / config.accumsteps | |
if config.rank < 1: | |
print_progress(scores) | |
amp.backward(loss) | |
this_batch_loss += loss.item() | |
train_loss = this_batch_loss if train_loss is None else train_loss | |
train_loss = train_loss_mu * train_loss + (1 - train_loss_mu) * this_batch_loss | |
amp.step(colbert, optimizer, scheduler) | |
if config.rank < 1: | |
print_message(batch_idx, train_loss) | |
manage_checkpoints(config, colbert, optimizer, batch_idx+1, savepath=None) | |
if config.rank < 1: | |
print_message("#> Done with all triples!") | |
ckpt_path = manage_checkpoints(config, colbert, optimizer, batch_idx+1, savepath=None, consumed_all_triples=True) | |
return ckpt_path # TODO: This should validate and return the best checkpoint, not just the last one. | |
def set_bert_grad(colbert, value): | |
try: | |
for p in colbert.bert.parameters(): | |
assert p.requires_grad is (not value) | |
p.requires_grad = value | |
except AttributeError: | |
set_bert_grad(colbert.module, value) | |