Spaces:
Running
Running
File size: 5,832 Bytes
94f372a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import os
from typing import Any
import pytorch_lightning as L
import torch
import torch.nn as nn
from hydra.utils import instantiate
import copy
import pandas as pd
import numpy as np
class Geolocalizer(L.LightningModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.model = instantiate(cfg.network.instance)
if cfg.text_tuning:
self.text_model = instantiate(cfg.text_network.instance)
self.loss = instantiate(cfg.loss)
self.val_metrics = instantiate(cfg.val_metrics)
self.test_metrics = instantiate(cfg.test_metrics)
self.text_tuning = cfg.text_tuning
def training_step(self, batch, batch_idx):
pred = self.model(batch)
if self.text_tuning:
pred["text_features"] = self.text_model(batch)
loss = self.loss(pred, batch, average=True)
for metric_name, metric_value in loss.items():
self.log(
f"train/{metric_name}",
metric_value,
sync_dist=True,
on_step=True,
on_epoch=True,
)
return loss
@torch.no_grad()
def validation_step(self, batch, batch_idx):
pred = self.model(batch)
if self.text_tuning:
pred["text_features"] = self.text_model(batch)
loss = self.loss(pred, batch, average=True)["loss"]
self.val_metrics.update(pred, batch)
self.log("val/loss", loss, sync_dist=True, on_step=False, on_epoch=True)
def on_validation_epoch_end(self):
metrics = self.val_metrics.compute()
for metric_name, metric_value in metrics.items():
self.log(
f"val/{metric_name}",
metric_value,
sync_dist=True,
on_step=False,
on_epoch=True,
)
@torch.no_grad()
def test_step(self, batch, batch_idx):
pred = self.model(batch)
self.test_metrics.update(pred, batch)
def on_test_epoch_end(self):
metrics = self.test_metrics.compute()
for metric_name, metric_value in metrics.items():
self.log(
f"test/{metric_name}",
metric_value,
sync_dist=True,
on_step=False,
on_epoch=True,
)
def configure_optimizers(self):
lora_params = []
backbone_params = []
other_params = []
last_block_params = []
for name, param in self.model.named_parameters():
if "lora" in name:
lora_params.append(param)
elif "backbone" in name:
if self.cfg.optimizer.diff_backbone_last and ".11." in name:
last_block_params.append(param)
else:
backbone_params.append(param)
else:
other_params.append(param)
params_to_optimize = [{"params": other_params}]
if self.cfg.optimizer.unfreeze_lr:
params_to_optimize += [
{"params": backbone_params, "lr": self.cfg.optimizer.backbone_lr}
]
if self.cfg.optimizer.diff_backbone_last:
params_to_optimize += [
{
"params": last_block_params,
"lr": self.cfg.optimizer.last_block_lr,
}
]
if len(lora_params) > 0:
# LoRA params sometimes train better with a different lr (~1e-4 for CLIP)
params_to_optimize += [
{"params": lora_params, "lr": self.cfg.optimizer.lora_lr}
]
if self.cfg.optimizer.exclude_ln_and_biases_from_weight_decay:
parameters_names_wd = get_parameter_names(self.model, [nn.LayerNorm])
parameters_names_wd = [
name for name in parameters_names_wd if "bias" not in name
]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in self.model.named_parameters()
if n in parameters_names_wd
],
"weight_decay": self.cfg.optimizer.optim.weight_decay,
},
{
"params": [
p
for n, p in self.model.named_parameters()
if n not in parameters_names_wd
],
"weight_decay": 0.0,
},
]
optimizer = instantiate(
self.cfg.optimizer.optim, optimizer_grouped_parameters
)
else:
optimizer = instantiate(self.cfg.optimizer.optim, params_to_optimize)
scheduler = instantiate(self.cfg.lr_scheduler)(optimizer)
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
def lr_scheduler_step(self, scheduler, metric):
scheduler.step(self.global_step)
def get_parameter_names(model, forbidden_layer_types):
"""
Returns the names of the model parameters that are not inside a forbidden layer.
Taken from HuggingFace transformers.
"""
result = []
for name, child in model.named_children():
result += [
f"{name}.{n}"
for n in get_parameter_names(child, forbidden_layer_types)
if not isinstance(child, tuple(forbidden_layer_types))
]
# Add model specific parameters (defined with nn.Parameter) since they are not in any child.
result += list(model._parameters.keys())
return result
|