import copy import numpy as np import torch from text_encoder import T5Attention, T5Stack from torch import nn from transformers.models.t5.configuration_t5 import T5Config from transformers.models.t5.modeling_t5 import ( T5LayerNorm, T5DenseGatedActDense, ) ### # Code from zer0int/CLIP-fine-tune class GeometricLinear(nn.Module): def __init__(self, in_features, out_features): super().__init__() self.in_features = in_features self.r = nn.Parameter(torch.Tensor(out_features, 1)) self.theta = nn.Parameter(torch.Tensor(out_features, in_features)) def reset_parameters(self): nn.init.kaiming_uniform_(self.theta, a=np.sqrt(5)) bound = 1 / np.sqrt(self.in_features) nn.init.uniform_(self.r, -bound, bound) def forward(self, x): # Normalize theta to get unit vector U. u = torch.nn.functional.normalize(self.theta, p=2, dim=1) # Geometric parameterization. output = torch.nn.functional.linear(x, self.r * u) return output ### # Code from huggingface/twodgirl # License: apache-2.0 class T5EncoderGmpModel(nn.Module): def __init__(self, config: T5Config): super().__init__() config.is_encoder_decoder = False assert not config.tie_word_embeddings self.config = config self.model_dim = config.d_model self.shared = nn.Embedding(config.vocab_size, config.d_model) encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False self.encoder = T5Stack(encoder_config, self.shared) flux_dev_d_model = 4096 self.last_layer = nn.Sequential( GeometricLinear(config.d_model, flux_dev_d_model), nn.ReLU(), GeometricLinear(flux_dev_d_model, flux_dev_d_model) ) # Apply fn recursively to every submodule, children() and self. self.apply(self._init_weights) def forward(self, input_ids=None, attention_mask=None): encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, ) return self.last_layer(encoder_outputs.hidden_states) def get_input_embeddings(self): return self.shared def _init_weights(self, module): factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, T5LayerNorm): module.weight.data.fill_(factor * 1.0) elif isinstance(module, T5EncoderGmpModel): module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) elif isinstance(module, T5DenseGatedActDense): d_ff, d_model = module.wi_0.weight.data.size() module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) module.wo.weight.data.normal_(mean=0.0, std=factor * ((d_ff) ** -0.5)) elif isinstance(module, T5Attention): d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5)) module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5)) module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) if hasattr(module, 'relative_attention_bias'): module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) elif isinstance(module, GeometricLinear): module.reset_parameters()