Spaces:
Runtime error
Runtime error
from utils import * | |
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union | |
class LoraInjectedLinear(nn.Module): | |
def __init__( | |
self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0 | |
): | |
super().__init__() | |
if r > min(in_features, out_features): | |
#raise ValueError( | |
# f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" | |
#) | |
print(f"LoRA rank {r} is too large. setting to: {min(in_features, out_features)}") | |
r = min(in_features, out_features) | |
self.r = r | |
self.linear = nn.Linear(in_features, out_features, bias) | |
self.lora_down = nn.Linear(in_features, r, bias=False) | |
self.dropout = nn.Dropout(dropout_p) | |
self.lora_up = nn.Linear(r, out_features, bias=False) | |
self.scale = scale | |
self.selector = nn.Identity() | |
nn.init.normal_(self.lora_down.weight, std=1 / r) | |
nn.init.zeros_(self.lora_up.weight) | |
def forward(self, input): | |
return ( | |
self.linear(input) | |
+ self.dropout(self.lora_up(self.selector(self.lora_down(input)))) | |
* self.scale | |
) | |
def realize_as_lora(self): | |
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data | |
def set_selector_from_diag(self, diag: torch.Tensor): | |
# diag is a 1D tensor of size (r,) | |
assert diag.shape == (self.r,) | |
self.selector = nn.Linear(self.r, self.r, bias=False) | |
self.selector.weight.data = torch.diag(diag) | |
self.selector.weight.data = self.selector.weight.data.to( | |
self.lora_up.weight.device | |
).to(self.lora_up.weight.dtype) | |
class LoraInjectedConv2d(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size, | |
stride=1, | |
padding=0, | |
dilation=1, | |
groups: int = 1, | |
bias: bool = True, | |
r: int = 4, | |
dropout_p: float = 0.1, | |
scale: float = 1.0, | |
): | |
super().__init__() | |
if r > min(in_channels, out_channels): | |
print(f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}") | |
r = min(in_channels, out_channels) | |
self.r = r | |
self.conv = nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=groups, | |
bias=bias, | |
) | |
self.lora_down = nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=r, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=groups, | |
bias=False, | |
) | |
self.dropout = nn.Dropout(dropout_p) | |
self.lora_up = nn.Conv2d( | |
in_channels=r, | |
out_channels=out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=False, | |
) | |
self.selector = nn.Identity() | |
self.scale = scale | |
nn.init.normal_(self.lora_down.weight, std=1 / r) | |
nn.init.zeros_(self.lora_up.weight) | |
def forward(self, input): | |
return ( | |
self.conv(input) | |
+ self.dropout(self.lora_up(self.selector(self.lora_down(input)))) | |
* self.scale | |
) | |
def realize_as_lora(self): | |
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data | |
def set_selector_from_diag(self, diag: torch.Tensor): | |
# diag is a 1D tensor of size (r,) | |
assert diag.shape == (self.r,) | |
self.selector = nn.Conv2d( | |
in_channels=self.r, | |
out_channels=self.r, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=False, | |
) | |
self.selector.weight.data = torch.diag(diag) | |
# same device + dtype as lora_up | |
self.selector.weight.data = self.selector.weight.data.to( | |
self.lora_up.weight.device | |
).to(self.lora_up.weight.dtype) | |
class LoraInjectedConv3d(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: (3, 1, 1), | |
padding: (1, 0, 0), | |
bias: bool = False, | |
r: int = 4, | |
dropout_p: float = 0, | |
scale: float = 1.0, | |
): | |
super().__init__() | |
if r > min(in_channels, out_channels): | |
print(f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}") | |
r = min(in_channels, out_channels) | |
self.r = r | |
self.kernel_size = kernel_size | |
self.padding = padding | |
self.conv = nn.Conv3d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
padding=padding, | |
) | |
self.lora_down = nn.Conv3d( | |
in_channels=in_channels, | |
out_channels=r, | |
kernel_size=kernel_size, | |
bias=False, | |
padding=padding | |
) | |
self.dropout = nn.Dropout(dropout_p) | |
self.lora_up = nn.Conv3d( | |
in_channels=r, | |
out_channels=out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=False, | |
) | |
self.selector = nn.Identity() | |
self.scale = scale | |
nn.init.normal_(self.lora_down.weight, std=1 / r) | |
nn.init.zeros_(self.lora_up.weight) | |
def forward(self, input): | |
return ( | |
self.conv(input) | |
+ self.dropout(self.lora_up(self.selector(self.lora_down(input)))) | |
* self.scale | |
) | |
def realize_as_lora(self): | |
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data | |
def set_selector_from_diag(self, diag: torch.Tensor): | |
# diag is a 1D tensor of size (r,) | |
assert diag.shape == (self.r,) | |
self.selector = nn.Conv3d( | |
in_channels=self.r, | |
out_channels=self.r, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=False, | |
) | |
self.selector.weight.data = torch.diag(diag) | |
# same device + dtype as lora_up | |
self.selector.weight.data = self.selector.weight.data.to( | |
self.lora_up.weight.device | |
).to(self.lora_up.weight.dtype) | |
def _find_modules( | |
model, | |
ancestor_class: Optional[Set[str]] = None, | |
search_class: List[Type[nn.Module]] = [nn.Linear], | |
exclude_children_of: Optional[List[Type[nn.Module]]] = [ | |
LoraInjectedLinear, | |
LoraInjectedConv2d, | |
LoraInjectedConv3d | |
], | |
): | |
""" | |
Find all modules of a certain class (or union of classes) that are direct or | |
indirect descendants of other modules of a certain class (or union of classes). | |
Returns all matching modules, along with the parent of those moduless and the | |
names they are referenced by. | |
""" | |
# Get the targets we should replace all linears under | |
if ancestor_class is not None: | |
ancestors = ( | |
module | |
for module in model.modules() | |
if module.__class__.__name__ in ancestor_class | |
) | |
else: | |
# this, incase you want to naively iterate over all modules. | |
ancestors = [module for module in model.modules()] | |
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear | |
for ancestor in ancestors: | |
for fullname, module in ancestor.named_modules(): | |
if any([isinstance(module, _class) for _class in search_class]): | |
# Find the direct parent if this is a descendant, not a child, of target | |
*path, name = fullname.split(".") | |
parent = ancestor | |
while path: | |
parent = parent.get_submodule(path.pop(0)) | |
# Skip this linear if it's a child of a LoraInjectedLinear | |
if exclude_children_of and any( | |
[isinstance(parent, _class) for _class in exclude_children_of] | |
): | |
continue | |
# Otherwise, yield it | |
yield parent, name, module | |
def inject_trainable_lora( | |
model: nn.Module, | |
target_replace_module, | |
r: int = 4, | |
loras=None, # path to lora .pt | |
verbose: bool = False, | |
dropout_p: float = 0.0, | |
scale: float = 1.0, | |
): | |
""" | |
inject lora into model, and returns lora parameter groups. | |
""" | |
require_grad_params = [] | |
names = [] | |
if loras != None: | |
loras = torch.load(loras) | |
for _module, name, _child_module in _find_modules( | |
model, target_replace_module, search_class=[nn.Linear] | |
): | |
weight = _child_module.weight | |
bias = _child_module.bias | |
if verbose: | |
print("LoRA Injection : injecting lora into ", name) | |
print("LoRA Injection : weight shape", weight.shape) | |
_tmp = LoraInjectedLinear( | |
_child_module.in_features, | |
_child_module.out_features, | |
_child_module.bias is not None, | |
r=r, | |
dropout_p=dropout_p, | |
scale=scale, | |
) | |
_tmp.linear.weight = weight | |
if bias is not None: | |
_tmp.linear.bias = bias | |
# switch the module | |
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype) | |
_module._modules[name] = _tmp | |
require_grad_params.append(_module._modules[name].lora_up.parameters()) | |
require_grad_params.append(_module._modules[name].lora_down.parameters()) | |
if loras != None: | |
_module._modules[name].lora_up.weight = loras.pop(0) | |
_module._modules[name].lora_down.weight = loras.pop(0) | |
_module._modules[name].lora_up.weight.requires_grad = True | |
_module._modules[name].lora_down.weight.requires_grad = True | |
names.append(name) | |
return require_grad_params, names | |
def inject_trainable_lora_extended( | |
model: nn.Module, | |
target_replace_module, | |
r: int = 4, | |
loras=None, # path to lora .pt | |
): | |
""" | |
inject lora into model, and returns lora parameter groups. | |
""" | |
require_grad_params = [] | |
names = [] | |
if loras != None: | |
loras = torch.load(loras) | |
for _module, name, _child_module in _find_modules( | |
model, target_replace_module, search_class=[nn.Linear, nn.Conv2d, nn.Conv3d] | |
): | |
if _child_module.__class__ == nn.Linear: | |
weight = _child_module.weight | |
bias = _child_module.bias | |
_tmp = LoraInjectedLinear( | |
_child_module.in_features, | |
_child_module.out_features, | |
_child_module.bias is not None, | |
r=r, | |
) | |
_tmp.linear.weight = weight | |
if bias is not None: | |
_tmp.linear.bias = bias | |
elif _child_module.__class__ == nn.Conv2d: | |
weight = _child_module.weight | |
bias = _child_module.bias | |
_tmp = LoraInjectedConv2d( | |
_child_module.in_channels, | |
_child_module.out_channels, | |
_child_module.kernel_size, | |
_child_module.stride, | |
_child_module.padding, | |
_child_module.dilation, | |
_child_module.groups, | |
_child_module.bias is not None, | |
r=r, | |
) | |
_tmp.conv.weight = weight | |
if bias is not None: | |
_tmp.conv.bias = bias | |
elif _child_module.__class__ == nn.Conv3d: | |
weight = _child_module.weight | |
bias = _child_module.bias | |
_tmp = LoraInjectedConv3d( | |
_child_module.in_channels, | |
_child_module.out_channels, | |
bias=_child_module.bias is not None, | |
kernel_size=_child_module.kernel_size, | |
padding=_child_module.padding, | |
r=r, | |
) | |
_tmp.conv.weight = weight | |
if bias is not None: | |
_tmp.conv.bias = bias | |
# switch the module | |
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype) | |
if bias is not None: | |
_tmp.to(_child_module.bias.device).to(_child_module.bias.dtype) | |
_module._modules[name] = _tmp | |
require_grad_params.append(_module._modules[name].lora_up.parameters()) | |
require_grad_params.append(_module._modules[name].lora_down.parameters()) | |
if loras != None: | |
_module._modules[name].lora_up.weight = loras.pop(0) | |
_module._modules[name].lora_down.weight = loras.pop(0) | |
_module._modules[name].lora_up.weight.requires_grad = True | |
_module._modules[name].lora_down.weight.requires_grad = True | |
names.append(name) | |
return require_grad_params, names | |
def extract_lora_ups_down(model, target_replace_module): | |
loras = [] | |
for _m, _n, _child_module in _find_modules( | |
model, | |
target_replace_module, | |
search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d], | |
): | |
loras.append((_child_module.lora_up, _child_module.lora_down)) | |
if len(loras) == 0: | |
raise ValueError("No lora injected.") | |
return loras |