ABDALLALSWAITI commited on
Commit
7608036
1 Parent(s): 0225636

Upload convert_nf4_flux.py

Browse files
Files changed (1) hide show
  1. convert_nf4_flux.py +144 -0
convert_nf4_flux.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities adapted from
3
+
4
+ * https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/quantizer_bnb_4bit.py
5
+ * https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/bitsandbytes.py
6
+ """
7
+
8
+ import torch
9
+ import bitsandbytes as bnb
10
+ from transformers.quantizers.quantizers_utils import get_module_from_name
11
+ import torch.nn as nn
12
+ from accelerate import init_empty_weights
13
+
14
+
15
+ def _replace_with_bnb_linear(
16
+ model,
17
+ method="nf4",
18
+ has_been_replaced=False,
19
+ ):
20
+ """
21
+ Private method that wraps the recursion for module replacement.
22
+
23
+ Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
24
+ """
25
+ for name, module in model.named_children():
26
+ if isinstance(module, nn.Linear):
27
+ with init_empty_weights():
28
+ in_features = module.in_features
29
+ out_features = module.out_features
30
+
31
+ if method == "llm_int8":
32
+ model._modules[name] = bnb.nn.Linear8bitLt(
33
+ in_features,
34
+ out_features,
35
+ module.bias is not None,
36
+ has_fp16_weights=False,
37
+ threshold=6.0,
38
+ )
39
+ has_been_replaced = True
40
+ else:
41
+ model._modules[name] = bnb.nn.Linear4bit(
42
+ in_features,
43
+ out_features,
44
+ module.bias is not None,
45
+ compute_dtype=torch.bfloat16,
46
+ compress_statistics=False,
47
+ quant_type="nf4",
48
+ )
49
+ has_been_replaced = True
50
+ # Store the module class in case we need to transpose the weight later
51
+ model._modules[name].source_cls = type(module)
52
+ # Force requires grad to False to avoid unexpected errors
53
+ model._modules[name].requires_grad_(False)
54
+
55
+ if len(list(module.children())) > 0:
56
+ _, has_been_replaced = _replace_with_bnb_linear(
57
+ module,
58
+ has_been_replaced=has_been_replaced,
59
+ )
60
+ # Remove the last key for recursion
61
+ return model, has_been_replaced
62
+
63
+
64
+ def check_quantized_param(
65
+ model,
66
+ param_name: str,
67
+ ) -> bool:
68
+ module, tensor_name = get_module_from_name(model, param_name)
69
+ if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit):
70
+ # Add here check for loaded components' dtypes once serialization is implemented
71
+ return True
72
+ elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
73
+ # bias could be loaded by regular set_module_tensor_to_device() from accelerate,
74
+ # but it would wrongly use uninitialized weight there.
75
+ return True
76
+ else:
77
+ return False
78
+
79
+
80
+ def create_quantized_param(
81
+ model,
82
+ param_value: "torch.Tensor",
83
+ param_name: str,
84
+ target_device: "torch.device",
85
+ state_dict=None,
86
+ unexpected_keys=None,
87
+ pre_quantized=False
88
+ ):
89
+ module, tensor_name = get_module_from_name(model, param_name)
90
+
91
+ if tensor_name not in module._parameters:
92
+ raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
93
+
94
+ old_value = getattr(module, tensor_name)
95
+
96
+ if tensor_name == "bias":
97
+ if param_value is None:
98
+ new_value = old_value.to(target_device)
99
+ else:
100
+ new_value = param_value.to(target_device)
101
+
102
+ new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad)
103
+ module._parameters[tensor_name] = new_value
104
+ return
105
+
106
+ if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit):
107
+ raise ValueError("this function only loads `Linear4bit components`")
108
+ if (
109
+ old_value.device == torch.device("meta")
110
+ and target_device not in ["meta", torch.device("meta")]
111
+ and param_value is None
112
+ ):
113
+ raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")
114
+
115
+ if pre_quantized:
116
+ if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (
117
+ param_name + ".quant_state.bitsandbytes__nf4" not in state_dict
118
+ ):
119
+ raise ValueError(
120
+ f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components."
121
+ )
122
+
123
+ quantized_stats = {}
124
+ for k, v in state_dict.items():
125
+ # `startswith` to counter for edge cases where `param_name`
126
+ # substring can be present in multiple places in the `state_dict`
127
+ if param_name + "." in k and k.startswith(param_name):
128
+ quantized_stats[k] = v
129
+ if unexpected_keys is not None and k in unexpected_keys:
130
+ unexpected_keys.remove(k)
131
+
132
+ new_value = bnb.nn.Params4bit.from_prequantized(
133
+ data=param_value,
134
+ quantized_stats=quantized_stats,
135
+ requires_grad=False,
136
+ device=target_device,
137
+ )
138
+
139
+ else:
140
+ new_value = param_value.to("cpu")
141
+ kwargs = old_value.__dict__
142
+ new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device)
143
+
144
+ module._parameters[tensor_name] = new_value