|
from typing import Callable, Dict, Optional, Union, Tuple |
|
|
|
import copy |
|
import math |
|
import multiprocessing |
|
import os |
|
|
|
import torch |
|
import torch.nn as nn |
|
import transformers |
|
|
|
from .misc import ContextualModelConfig |
|
|
|
def load_embedder_and_tokenizer(name: str) -> Tuple[ |
|
transformers.PreTrainedModel, |
|
transformers.PreTrainedTokenizer |
|
]: |
|
if name.startswith("nomic") or (name == "bert-base-uncased"): |
|
model = transformers.AutoModelForMaskedLM.from_pretrained(name, trust_remote_code=True).bert |
|
tokenizer = transformers.AutoTokenizer.from_pretrained(name) |
|
elif name in ["gtr-base", "gtr_base"]: |
|
model = transformers.AutoModel.from_pretrained( |
|
"sentence-transformers/gtr-t5-base" |
|
).encoder |
|
tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
"sentence-transformers/gtr-t5-base" |
|
) |
|
elif name == "pile-t5-base-encoder": |
|
model = transformers.AutoModel.from_pretrained( |
|
"EleutherAI/pile-t5-base" |
|
).encoder |
|
tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
"EleutherAI/pile-t5-base" |
|
) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
elif name == "pile-t5-base-decoder": |
|
model = transformers.AutoModel.from_pretrained( |
|
"EleutherAI/pile-t5-base" |
|
).decoder |
|
tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
"EleutherAI/pile-t5-base" |
|
) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
elif name.startswith("gpt2") or name.startswith("meta-llama") or ("Llama" in name): |
|
model = transformers.AutoModelForCausalLM.from_pretrained( |
|
name, |
|
|
|
attn_implementation="flash_attention_2", |
|
low_cpu_mem_usage=True, |
|
|
|
) |
|
model.padding_side = "right" |
|
tokenizer = transformers.AutoTokenizer.from_pretrained(name) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.add_eos_token = True |
|
else: |
|
model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True) |
|
tokenizer = transformers.AutoTokenizer.from_pretrained(name) |
|
|
|
|
|
|
|
|
|
return model, tokenizer |
|
def get_world_size() -> int: |
|
try: |
|
return torch.distributed.get_world_size() |
|
except (RuntimeError, ValueError): |
|
return 1 |
|
|
|
|
|
def get_rank() -> int: |
|
try: |
|
return torch.distributed.get_rank() |
|
except (RuntimeError, ValueError): |
|
return 0 |
|
|
|
def gather(t: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
world_size = get_world_size() |
|
if world_size == 1: |
|
return t |
|
|
|
if t.ndim == 0: |
|
t = t.unsqueeze(0) |
|
|
|
gathered = [torch.empty_like(t) for _ in range(world_size)] |
|
torch.distributed.all_gather(gathered, t) |
|
gathered[get_rank()] = t |
|
return torch.cat(gathered, dim=0) |
|
|
|
|
|
def gather_sum(t: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
world_size = get_world_size() |
|
if world_size == 1: |
|
return t |
|
|
|
if t.ndim == 0: |
|
t = t.unsqueeze(0) |
|
|
|
gathered = [torch.empty_like(t) for _ in range(world_size)] |
|
torch.distributed.all_gather(gathered, t) |
|
gathered = torch.stack(gathered, dim=0) |
|
return gathered.sum(dim=0) |
|
|
|
|
|
def get_num_proc() -> int: |
|
world_size: int = get_world_size() |
|
try: |
|
|
|
|
|
return len(os.sched_getaffinity(0)) // world_size |
|
except AttributeError: |
|
return multiprocessing.cpu_count() // world_size |
|
|
|
|
|
def torch_main_worker_finish_first(func: Callable): |
|
def wrapper(*args, **kwargs): |
|
|
|
try: |
|
local_rank = torch.distributed.get_rank() |
|
ddp_enabled = True |
|
except (RuntimeError, ValueError): |
|
local_rank = -1 |
|
ddp_enabled = False |
|
is_main_worker = local_rank <= 0 |
|
|
|
if is_main_worker: |
|
result = func(*args, **kwargs) |
|
|
|
if ddp_enabled: |
|
torch.distributed.barrier() |
|
|
|
if not is_main_worker: |
|
result = func(*args, **kwargs) |
|
|
|
if ddp_enabled: |
|
torch.distributed.barrier() |
|
return result |
|
|
|
return wrapper |
|
|
|
|
|
def print0(*args, **kwargs) -> None: |
|
if get_rank() == 0: |
|
print(*args, **kwargs) |
|
|
|
|
|
def verify_ddp_weights_equal(model: torch.nn.Module, atol: float = 1e-5) -> None: |
|
if hasattr(model, "module"): |
|
model = model.module |
|
|
|
world_size = get_world_size() |
|
|
|
if world_size > 8: |
|
print0(f"[verify_ddp_weights_equal] Skipping with world_size={world_size} ⚠️") |
|
return |
|
|
|
for name, param in model.named_parameters(): |
|
if param is None: continue |
|
if param.grad is None: |
|
print0(f"[verify_ddp_weights_equal] Skipping param [{name}] with no grad") |
|
continue |
|
gathered_param = gather(param).reshape((world_size, -1)) |
|
absolute_diffs = (gathered_param[None, 0, :] - gathered_param).abs() |
|
rank_params_eq = (absolute_diffs < atol).all() |
|
assert rank_params_eq, f"❌ param [{name}] not equal - got max_absolute_diff={absolute_diffs.max()}" |
|
|
|
gathered_param_grad = gather(param.grad).reshape((world_size, -1)) |
|
absolute_grad_diffs = (gathered_param_grad[None, 0, :] - gathered_param_grad).abs() |
|
rank_grad_params_eq = (absolute_grad_diffs < atol).all() |
|
assert rank_grad_params_eq, f"❌ param [{name}] grad not equal - got max_absolute_diff={absolute_grad_diffs.max()}" |
|
|
|
|
|
|
|
print0("[verify_ddp_weights_equal] Verified DDP parameter correctness ✅") |
|
|
|
|
|
|
|
def mean_pool_3d( |
|
hidden_states: torch.Tensor, attention_mask: torch.Tensor |
|
) -> torch.Tensor: |
|
B, T, S, D = hidden_states.shape |
|
unmasked_outputs = hidden_states * attention_mask[..., None] |
|
pooled_outputs = unmasked_outputs.sum(dim=2) / (attention_mask.sum(dim=2)[..., None] + 1e-9) |
|
|
|
|
|
sequence_means = ( |
|
hidden_states.reshape((B, S * T, D)) |
|
.mean(dim=1, keepdim=True) |
|
.expand(-1, T, -1) |
|
) |
|
pooled_outputs = pooled_outputs.where( |
|
(attention_mask.sum(dim=2)[..., None] > 0), |
|
sequence_means |
|
) |
|
assert pooled_outputs.shape == (B, T, D) |
|
|
|
return pooled_outputs |
|
|
|
def mean_pool( |
|
hidden_states: torch.Tensor, attention_mask: torch.Tensor |
|
) -> torch.Tensor: |
|
B, _S, D = hidden_states.shape |
|
unmasked_outputs = hidden_states * attention_mask[..., None] |
|
pooled_outputs = unmasked_outputs.sum(dim=1) / (attention_mask.sum(dim=1)[:, None] + 1e-20) |
|
|
|
assert pooled_outputs.shape == (B, D) |
|
return pooled_outputs |
|
|
|
|
|
def mean_pool_weighted( |
|
hidden_states: torch.Tensor, attention_mask: torch.Tensor |
|
) -> torch.Tensor: |
|
B, _S, D = hidden_states.shape |
|
attention_mask *= attention_mask.cumsum(dim=1) |
|
s = torch.sum(hidden_states * attention_mask.unsqueeze(-1).float(), dim=1) |
|
d = attention_mask.sum(dim=1, keepdim=True).float() |
|
return s / d |
|
|
|
|
|
def slice_sparse_tensor_rows(t: torch.sparse.Tensor, min_row: int, max_row: int) -> torch.sparse.Tensor: |
|
assert min_row < max_row, f"can't slice from row {min_row} to {max_row}" |
|
t = t.coalesce() |
|
row_idxs = t.indices()[0] |
|
index_mask = (min_row <= row_idxs) & (row_idxs < max_row) |
|
|
|
num_rows = (max_row - min_row) |
|
num_cols = t.shape[1] |
|
|
|
idxs = t.indices()[:, index_mask] |
|
vals = t.values()[index_mask] |
|
return torch.sparse_coo_tensor(idxs, vals, size=(num_rows, num_cols)).coalesce() |
|
|
|
|
|
def slice_tensor_rows(t: torch.Tensor, min_row: int, max_row: int) -> torch.Tensor: |
|
if t.is_sparse: |
|
return slice_sparse_tensor_rows(t=t, min_row=min_row, max_row=max_row) |
|
else: |
|
return t[min_row:max_row] |
|
|
|
|
|
@torch.no_grad |
|
def maxsim( |
|
X: torch.Tensor, y: torch.Tensor, |
|
maximize: bool, chunk_size: int = 8_000, |
|
debug_mem_usage: bool = False) -> torch.Tensor: |
|
device = X.device |
|
n_samples = X.shape[0] |
|
|
|
max_sim_v = torch.zeros(n_samples, device=device, dtype=X.dtype) |
|
max_sim_i = torch.zeros(n_samples, device=device, dtype=torch.int64) |
|
|
|
|
|
|
|
rank = get_rank() |
|
world_size = get_world_size() |
|
|
|
worker_worklist_size = int(math.ceil(n_samples / world_size)) |
|
splits_start_idx = worker_worklist_size * rank |
|
splits_end_idx = worker_worklist_size * (rank + 1) |
|
|
|
for i in range(splits_start_idx, splits_end_idx, chunk_size): |
|
start, end = i, min(i + chunk_size, n_samples) |
|
sub_x = slice_tensor_rows(X, start, end) |
|
if debug_mem_usage: print(f"[maxsim] step {i} cuda mem free/total = {torch.cuda.mem_get_info()}") |
|
if debug_mem_usage: print("[maxsim] sub_x.shape:", sub_x.shape, "//", "y.shape:", y.shape) |
|
sub_sim = sub_x @ y |
|
sub_sim = sub_sim |
|
if maximize: |
|
sub_max_sim_v, sub_max_sim_i = sub_sim.to_dense().max(dim=-1) |
|
else: |
|
sub_max_sim_v, sub_max_sim_i = sub_sim.to_dense().min(dim=-1) |
|
del sub_sim |
|
del sub_x |
|
torch.cuda.empty_cache() |
|
max_sim_v[start: end] = sub_max_sim_v |
|
max_sim_i[start: end] = sub_max_sim_i |
|
|
|
|
|
max_sim_v = gather_sum(max_sim_v) |
|
max_sim_i = gather_sum(max_sim_i) |
|
k = y.shape[1] |
|
|
|
assert max_sim_v.shape == (n_samples,) |
|
assert max_sim_i.shape == (n_samples,) |
|
assert max_sim_i.min() >= 0 |
|
assert max_sim_i.max() <= k |
|
|
|
return max_sim_v, max_sim_i |
|
|
|
|
|
def forward_batched( |
|
model: torch.nn.Module, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
batch_size: int, |
|
dataset_input_ids: Optional[torch.Tensor] = None, |
|
dataset_attention_mask: Optional[torch.Tensor] = None, |
|
**second_stage_model_kwargs, |
|
) -> torch.Tensor: |
|
if hasattr(model, "module"): |
|
model = model.module |
|
|
|
if hasattr(model, "first_stage_model"): |
|
|
|
if len(dataset_input_ids.shape) == 2: |
|
dataset_input_ids = dataset_input_ids[None] |
|
dataset_attention_mask = dataset_attention_mask[None] |
|
|
|
dataset_embeddings = [] |
|
for j in range(len(dataset_input_ids)): |
|
i = 0 |
|
dataset_embeddings_batch = [] |
|
while i < dataset_input_ids.shape[1]: |
|
dataset_embeddings_batch.append( |
|
model.first_stage_model( |
|
input_ids=dataset_input_ids[j][i:i+batch_size], |
|
attention_mask=dataset_attention_mask[j][i:i+batch_size], |
|
) |
|
) |
|
i += batch_size |
|
dataset_embeddings.append( |
|
torch.cat(dataset_embeddings_batch, dim=0) |
|
) |
|
|
|
|
|
dataset_embeddings = torch.stack(dataset_embeddings, dim=0).mean(dim=0) |
|
|
|
j = 0 |
|
outputs = [] |
|
while j < len(input_ids): |
|
outputs.append( |
|
model.second_stage_model( |
|
input_ids=input_ids[j:j+batch_size], |
|
attention_mask=attention_mask[j:j+batch_size], |
|
dataset_embeddings=dataset_embeddings, |
|
**second_stage_model_kwargs, |
|
) |
|
) |
|
j += batch_size |
|
return torch.cat(outputs, dim=0) |
|
|
|
else: |
|
i = 0 |
|
outputs = [] |
|
while i < len(input_ids): |
|
|
|
outputs.append( |
|
model( |
|
input_ids=input_ids[i:i+batch_size], |
|
attention_mask=attention_mask[i:i+batch_size], |
|
**second_stage_model_kwargs, |
|
) |
|
) |
|
i += batch_size |
|
return torch.cat(outputs, dim=0) |
|
|
|
|
|
def last_token_pool(hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
|
|
|
b, n, d = hidden_state.size() |
|
|
|
|
|
|
|
reversed_mask = torch.flip(attention_mask, dims=(1,)) |
|
argmax_reverse = torch.argmax(reversed_mask, dim=1, keepdim=False) |
|
gather_indices = attention_mask.size(1) - argmax_reverse - 1 |
|
|
|
gather_indices = torch.clamp(gather_indices, min=0) |
|
|
|
gather_indices = gather_indices.unsqueeze(-1).repeat(1, d) |
|
gather_indices = gather_indices.unsqueeze(1) |
|
assert gather_indices.shape == (b, 1, d) |
|
|
|
|
|
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand((b, n, d)).float() |
|
return torch.gather(hidden_state * input_mask_expanded, 1, gather_indices).squeeze(dim=1) |
|
|
|
def print0(*args, **kwargs) -> None: |
|
if get_rank() == 0: |
|
print(*args, **kwargs) |
|
|
|
|
|
def limit_layers(model: transformers.PreTrainedModel, n_layers: int) -> None: |
|
if hasattr(model, 'transformer'): |
|
if hasattr(model.transformer, 'h'): |
|
|
|
model.transformer.h = model.transformer.h[:n_layers] |
|
else: |
|
model.transformer.layer = model.transformer.layer[:n_layers] |
|
elif hasattr(model, 'encoder'): |
|
if hasattr(model.encoder, 'layers'): |
|
model.encoder.layers = model.encoder.layers[:n_layers] |
|
else: |
|
model.encoder.layer = model.encoder.layer[:n_layers] |
|
else: |
|
raise RuntimeError(f"unknown how to limit layers of model {type(model)}") |
|
|
|
|
|
|
|
def disable_dropout(model: torch.nn.Module): |
|
dropout_modules = [m for m in model.modules() if isinstance(m, torch.nn.Dropout)] |
|
for m in dropout_modules: |
|
m.p = 0.0 |
|
print0( |
|
f"Disabled {len(dropout_modules)} dropout modules from model type {type(model)}" |
|
) |
|
|
|
|
|
def disable_causality(model: torch.nn.Module): |
|
disabled_modules = 0 |
|
for m in model.modules(): |
|
if hasattr(m, "is_causal"): |
|
m.is_causal = False |
|
disabled_modules += 1 |
|
print0( |
|
f"Set is_causal=False in {disabled_modules} modules from model type {type(model)}" |
|
) |
|
|
|
class ContextualModelMixin(nn.Module): |
|
@property |
|
def num_corpus_tokens(self) -> int: |
|
return self.transductive_corpus_size * self.transductive_tokens_per_document |
|
|
|
def contextual_init(self): |
|
self.n_soft_prompt = 8 |
|
self.prompt_projection = torch.nn.Sequential( |
|
torch.nn.Linear(self.hidden_size, self.hidden_size), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(self.hidden_size, self.hidden_size * self.n_soft_prompt) |
|
) |
|
self.transductive_corpus_size = vars(self.config).get("transductive_corpus_size", 1) |
|
self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1) |
|
self.randomize_dataset_sequence_order = True |
|
self.sequence_dropout_prob = vars(self.config).get("transductive_sequence_dropout_prob", 0.0) |
|
if self.sequence_dropout_prob > 0.0: |
|
self.sequence_dropout_null_embedding = torch.nn.Parameter( |
|
torch.randn(self.hidden_size) * 0.01, |
|
requires_grad = True |
|
) |
|
self.output_projection = torch.nn.Sequential( |
|
torch.nn.Linear(self.hidden_size, self.hidden_size), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(self.hidden_size, self.hidden_size) |
|
) |
|
|
|
def _prepare_dataset_embeddings( |
|
self, |
|
input_ids: torch.Tensor, dataset_embeddings: torch.Tensor, |
|
null_dataset_embedding: bool = False, |
|
) -> torch.Tensor: |
|
if not isinstance(dataset_embeddings, torch.Tensor): |
|
dataset_embeddings = torch.tensor(dataset_embeddings) |
|
|
|
if len(dataset_embeddings.shape) == 2: |
|
|
|
dataset_embeddings = dataset_embeddings[None, :, :] |
|
dataset_embeddings = dataset_embeddings.to(input_ids.device) |
|
|
|
batch_size = input_ids.shape[0] |
|
if (self.transductive_tokens_per_document > 1): |
|
if self.training: |
|
|
|
|
|
|
|
assert dataset_embeddings.shape[1] == self.transductive_tokens_per_document |
|
R = torch.randint( |
|
low=0, |
|
high=len(dataset_embeddings), |
|
size=(batch_size, self.config.transductive_corpus_size), |
|
device=dataset_embeddings.device |
|
) |
|
|
|
dataset_embeddings = dataset_embeddings[R].reshape((batch_size, self.num_corpus_tokens, self.hidden_size)) |
|
else: |
|
dataset_embeddings = dataset_embeddings.reshape((1, self.num_corpus_tokens, self.hidden_size)) |
|
|
|
|
|
if dataset_embeddings.shape[1] > self.num_corpus_tokens: |
|
|
|
|
|
dataset_embeddings = dataset_embeddings[:, :self.num_corpus_tokens, :] |
|
|
|
_, corpus_size, _hidden_size = dataset_embeddings.shape |
|
if _ == 1: |
|
|
|
dataset_embeddings = dataset_embeddings.expand((batch_size, -1, -1)) |
|
|
|
if self.training and self.sequence_dropout_prob > 0.0: |
|
sequence_dropout_mask = ( |
|
torch.rand((batch_size, corpus_size), device=dataset_embeddings.device) < self.sequence_dropout_prob |
|
) |
|
null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1) |
|
dataset_embeddings = torch.where( |
|
sequence_dropout_mask[..., None], null_embeddings, dataset_embeddings |
|
) |
|
elif null_dataset_embedding: |
|
null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1) |
|
dataset_embeddings = null_embeddings |
|
|
|
|
|
|
|
|
|
|
|
soft_prompt = torch.ones((1, self.hidden_size), device=dataset_embeddings.device, dtype=dataset_embeddings.dtype) |
|
soft_prompt = self.prompt_projection(soft_prompt).reshape((1, self.n_soft_prompt, self.hidden_size)) |
|
soft_prompt = soft_prompt.expand((len(dataset_embeddings), -1, -1)) |
|
soft_prompt = torch.cat((dataset_embeddings, soft_prompt), dim=1) |
|
|
|
|
|
|
|
if self.training and self.randomize_dataset_sequence_order: |
|
randomized_order = torch.stack( |
|
[ |
|
torch.cat( |
|
( |
|
torch.randperm(corpus_size, device=soft_prompt.device), |
|
torch.arange(self.n_soft_prompt, device=soft_prompt.device) + corpus_size |
|
), dim=0) |
|
for _ in range(batch_size)]) |
|
randomized_order = randomized_order.to(soft_prompt.device) |
|
soft_prompt = soft_prompt.gather(1, randomized_order[..., None].expand_as(soft_prompt)) |
|
|
|
return soft_prompt |
|
|
|
class BiEncoder(transformers.PreTrainedModel): |
|
embedder: transformers.PreTrainedModel |
|
def __init__( |
|
self, |
|
config, |
|
): |
|
super().__init__(config=config) |
|
embedder, _ = load_embedder_and_tokenizer( |
|
config.embedder, |
|
) |
|
|
|
if config.limit_layers: |
|
print0(f"Limiting layers to {config.limit_layers}") |
|
limit_layers(embedder, config.limit_layers) |
|
|
|
self.embedder = embedder |
|
|
|
|
|
|
|
self.hidden_size = self.embedder.config.hidden_size |
|
|
|
self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1) |
|
self.mlp = torch.nn.Sequential( |
|
torch.nn.Linear(self.hidden_size, self.hidden_size), |
|
torch.nn.GELU(), |
|
torch.nn.Linear(self.hidden_size, self.config.embedding_output_dim or self.hidden_size), |
|
) |
|
self.temp = config.logit_scale |
|
|
|
if config.disable_dropout: |
|
disable_dropout(self) |
|
self.pooling_strategy = vars(config).get("pooling_strategy", "mean") |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
dataset_input_ids: Optional[torch.Tensor] = None, |
|
dataset_attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids = None, |
|
output_hidden_states: bool = False, |
|
) -> torch.Tensor: |
|
""" |
|
query_embedding (float torch.Tensor) - shape (batch_size, embedding_dim) |
|
document_embeddings (float torch.Tensor) - shape (corpus_size, embedding_dim) |
|
where the corpus_size >= batch_size and is structured like this: |
|
[d1, d2, d3, hn1_1, hn1_2, hn2_1, hn2_2, hn3_1, hn3_2] |
|
for a corpus with three documents and two hard negatives per document |
|
""" |
|
|
|
|
|
del token_type_ids |
|
|
|
|
|
|
|
|
|
|
|
|
|
outputs = ( |
|
self.embedder( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
).last_hidden_state |
|
) |
|
|
|
if self.transductive_tokens_per_document > 1: |
|
document_embeddings = None |
|
batch_size, seq_length, output_dim = outputs.shape |
|
|
|
if seq_length % self.transductive_tokens_per_document != 0: |
|
|
|
n_extra_embeds = self.transductive_tokens_per_document - (seq_length % self.transductive_tokens_per_document) |
|
outputs = torch.cat( |
|
(outputs, torch.zeros((batch_size, n_extra_embeds, output_dim), device=outputs.device)), |
|
dim=1 |
|
) |
|
attention_mask = torch.cat( |
|
(attention_mask, torch.zeros((batch_size, n_extra_embeds), device=attention_mask.device)), |
|
dim=1 |
|
) |
|
seq_length += n_extra_embeds |
|
print(f"Added {n_extra_embeds} padding tokens to input_ids and attention_mask") |
|
|
|
|
|
|
|
outputs = outputs.reshape( |
|
(batch_size, self.transductive_tokens_per_document, seq_length // self.transductive_tokens_per_document, output_dim) |
|
) |
|
|
|
attention_mask = attention_mask.reshape((batch_size, self.transductive_tokens_per_document, -1)) |
|
document_embeddings = mean_pool_3d(outputs, attention_mask) |
|
|
|
document_embeddings = document_embeddings.reshape((batch_size, self.transductive_tokens_per_document, output_dim)) |
|
else: |
|
if self.pooling_strategy == "mean": |
|
document_embeddings = mean_pool(outputs, attention_mask) |
|
else: |
|
document_embeddings = document_embeddings.max(dim=1) |
|
output = self.mlp(document_embeddings) |
|
|
|
if output_hidden_states: |
|
return { |
|
"hidden_states": outputs, |
|
"pooled": output, |
|
} |
|
else: |
|
return output |
|
|
|
|
|
class DatasetConditionedAutoregressive(transformers.PreTrainedModel, ContextualModelMixin): |
|
def __init__( |
|
self, |
|
config, |
|
dataset_backbone: transformers.PreTrainedModel, |
|
first_stage_hidden_size: int, |
|
): |
|
super().__init__(config=config) |
|
self.backbone = dataset_backbone |
|
self.backbone_hidden_size = self.backbone.config.hidden_size |
|
self.hidden_size = first_stage_hidden_size |
|
self.contextual_init() |
|
disable_causality(self.backbone) |
|
|
|
self.input_ln = torch.nn.LayerNorm( |
|
self.backbone_hidden_size, |
|
eps=1e-5 |
|
) |
|
|
|
|
|
self.output_projection = torch.nn.Sequential( |
|
torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size) |
|
) |
|
self._shift_rotary_embedding() |
|
|
|
@property |
|
def num_corpus_tokens(self) -> int: |
|
return self.config.transductive_corpus_size * self.transductive_tokens_per_document |
|
|
|
@property |
|
def corpus_token_ratio(self) -> float: |
|
|
|
|
|
return self.backbone_hidden_size / self.hidden_size |
|
|
|
def corpus_token_pad_size(self, n_tokens: int) -> int: |
|
return self.hidden_size % self.backbone_hidden_size |
|
|
|
def _shift_rotary_embedding(self) -> None: |
|
disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True) |
|
|
|
print("Warning: Positional embedding disabling not implemented for LLAMA.") |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
dataset_embeddings: torch.Tensor, |
|
output_hidden_states: bool = False, |
|
null_dataset_embedding: bool = False, |
|
) -> torch.Tensor: |
|
soft_prompt = self._prepare_dataset_embeddings( |
|
input_ids=input_ids, |
|
dataset_embeddings=dataset_embeddings, |
|
null_dataset_embedding=null_dataset_embedding, |
|
) |
|
|
|
|
|
|
|
num_soft_elements = torch.prod(torch.tensor(soft_prompt.shape[1:])).item() |
|
soft_prompt = soft_prompt.reshape((soft_prompt.shape[0], num_soft_elements)) |
|
num_padding_elements = self.backbone_hidden_size - (num_soft_elements % self.backbone_hidden_size) |
|
padding = torch.ones((soft_prompt.shape[0], num_padding_elements), device=soft_prompt.device) |
|
soft_prompt = torch.cat((soft_prompt, padding), dim=1) |
|
soft_prompt = soft_prompt.reshape( |
|
(soft_prompt.shape[0], -1, self.backbone_hidden_size) |
|
) |
|
soft_prompt = self.input_ln(soft_prompt) |
|
|
|
|
|
backbone_attention_mask = torch.ones( |
|
soft_prompt.shape[0:2], |
|
dtype=torch.long, |
|
device=soft_prompt.device, |
|
) |
|
token_embeddings = self.backbone.get_input_embeddings() |
|
inputs_embeds = token_embeddings(input_ids) |
|
|
|
inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) |
|
|
|
input_attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1) |
|
|
|
|
|
output = self.backbone( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=input_attention_mask, |
|
output_hidden_states=True, |
|
) |
|
|
|
last_hidden_state = output.hidden_states[-1] |
|
n_soft_prompt_tokens = soft_prompt.shape[1] |
|
|
|
output_vectors = last_hidden_state[:, n_soft_prompt_tokens:, :] |
|
output_attention_mask = input_attention_mask[:, n_soft_prompt_tokens:] |
|
|
|
|
|
if vars(self.config).get("pooling_strategy") == "last_token": |
|
output_pooled = last_token_pool(output_vectors, output_attention_mask) |
|
elif vars(self.config).get("pooling_strategy") == "mean": |
|
output_pooled = mean_pool(output_vectors, output_attention_mask) |
|
else: |
|
output_pooled = mean_pool_weighted(output_vectors, output_attention_mask) |
|
|
|
|
|
|
|
output = self.output_projection(output_pooled) |
|
|
|
if output_hidden_states: |
|
return { |
|
"hidden_states": output_vectors, |
|
"pooled": output, |
|
} |
|
else: |
|
return output |
|
|
|
|
|
class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelMixin): |
|
def __init__( |
|
self, |
|
config, |
|
dataset_backbone: transformers.PreTrainedModel, |
|
): |
|
super().__init__(config=config) |
|
self.backbone = dataset_backbone |
|
self.hidden_size = self.backbone.config.hidden_size |
|
self.hidden_size = dataset_backbone.config.hidden_size |
|
|
|
|
|
|
|
|
|
self.contextual_init() |
|
self._shift_rotary_embedding() |
|
|
|
@property |
|
def num_corpus_tokens(self) -> int: |
|
return self.config.transductive_corpus_size * self.transductive_tokens_per_document |
|
|
|
def _shift_rotary_embedding(self) -> None: |
|
disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True) |
|
if self.backbone.config.model_type.startswith("nomic") and disable_transductive_rotary_embedding: |
|
|
|
|
|
self.backbone.config.rotary_start_pos = 0.0 |
|
rotary_disabled = 0 |
|
|
|
rotary_start_pos = self.num_corpus_tokens |
|
for module in self.backbone.modules(): |
|
if hasattr(module, "rotary_emb_dim"): |
|
module.rotary_start_pos = rotary_start_pos |
|
rotary_disabled += 1 |
|
print0(f"modified {rotary_disabled} rotary modules – set rotary_start_pos to {rotary_start_pos}") |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
dataset_embeddings: torch.Tensor, |
|
output_hidden_states: bool = False, |
|
null_dataset_embedding: bool = False, |
|
) -> torch.Tensor: |
|
|
|
soft_prompt = self._prepare_dataset_embeddings( |
|
input_ids=input_ids, |
|
dataset_embeddings=dataset_embeddings, |
|
null_dataset_embedding=null_dataset_embedding, |
|
) |
|
|
|
backbone_attention_mask = torch.ones( |
|
soft_prompt.shape[0:2], |
|
dtype=torch.long, |
|
device=soft_prompt.device, |
|
) |
|
inputs_embeds = self.backbone.embeddings(input_ids) |
|
|
|
inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) |
|
|
|
attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1) |
|
|
|
output = self.backbone( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
) |
|
|
|
output_vectors = output.last_hidden_state |
|
|
|
|
|
n_soft_prompt_tokens = soft_prompt.shape[1] |
|
|
|
|
|
output_vectors = output.last_hidden_state[:, n_soft_prompt_tokens:, :] |
|
output_attention_mask = attention_mask[:, n_soft_prompt_tokens:] |
|
|
|
|
|
output_pooled = mean_pool(output_vectors, output_attention_mask) |
|
|
|
|
|
|
|
|
|
|
|
output = self.output_projection(output_pooled) |
|
|
|
|
|
|
|
if output_hidden_states: |
|
return { |
|
"hidden_states": output_vectors, |
|
"pooled": output, |
|
} |
|
else: |
|
return output |
|
|
|
|
|
class DatasetPrefixBiencoder(transformers.PreTrainedModel, ContextualModelMixin): |
|
def __init__( |
|
self, |
|
config, |
|
embedder: transformers.PreTrainedModel, |
|
): |
|
super().__init__(config=config) |
|
self.embedder = embedder |
|
self.hidden_size = self.embedder.config.hidden_size |
|
self.contextual_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
dataset_input_ids: torch.Tensor, |
|
dataset_attention_mask: torch.Tensor, |
|
output_hidden_states: bool = False, |
|
) -> torch.Tensor: |
|
R = torch.randint(low=0, high=len(dataset_input_ids), size=(len(input_ids),), device=dataset_input_ids.device) |
|
|
|
dataset_input_ids = dataset_input_ids[R] |
|
input_ids = torch.cat((dataset_input_ids, input_ids), dim=1) |
|
|
|
dataset_attention_mask = torch.ones_like(dataset_attention_mask, device=dataset_attention_mask.device) |
|
input_attention_mask = torch.cat((dataset_attention_mask, attention_mask), dim=1) |
|
output_attention_mask = torch.cat( |
|
(torch.zeros_like(dataset_input_ids), attention_mask), dim=1 |
|
) |
|
|
|
output = self.embedder( |
|
input_ids=input_ids, |
|
attention_mask=input_attention_mask, |
|
) |
|
|
|
output_vectors = output.last_hidden_state |
|
output_pooled = mean_pool(output_vectors, output_attention_mask) |
|
output = self.output_projection(output_pooled) |
|
|
|
if output_hidden_states: |
|
S_d = dataset_attention_mask.shape[1] |
|
output_vectors = output_vectors[:, S_d:, :] |
|
return { |
|
"hidden_states": output_vectors, |
|
"pooled": output, |
|
} |
|
else: |
|
return output |
|
|
|
|
|
class DatasetTransformer(transformers.PreTrainedModel): |
|
config_class = ContextualModelConfig |
|
embedder: transformers.PreTrainedModel |
|
dataset_backbone: transformers.PreTrainedModel |
|
def __init__( |
|
self, |
|
config, |
|
): |
|
super().__init__(config=config) |
|
dataset_backbone, _ = load_embedder_and_tokenizer( |
|
vars(config).get("dataset_backbone", config.embedder) |
|
) |
|
|
|
if config.limit_layers: |
|
print0(f"Limiting layers to {config.limit_layers}") |
|
limit_layers(dataset_backbone, config.limit_layers) |
|
|
|
biencoder_config = copy.deepcopy(config) |
|
biencoder_config.embedding_output_dim = None |
|
biencoder_config.limit_layers = vars(self.config).get("limit_layers_first_stage", None) |
|
self.first_stage_model = BiEncoder( |
|
config=biencoder_config, |
|
) |
|
|
|
if vars(config).get("autoregressive_backbone", False): |
|
self.second_stage_model = DatasetConditionedAutoregressive( |
|
config=config, |
|
dataset_backbone=dataset_backbone, |
|
first_stage_hidden_size=self.first_stage_model.hidden_size, |
|
) |
|
else: |
|
self.second_stage_model = DatasetConditionedBiencoder( |
|
config=config, |
|
dataset_backbone=dataset_backbone |
|
) |
|
|
|
self.temp = config.logit_scale |
|
if config.disable_dropout: |
|
disable_dropout(self) |
|
|
|
transductive_tie_token_embeddings = vars(self.config).get("transductive_tie_token_embeddings", False) |
|
if transductive_tie_token_embeddings: |
|
self.second_stage_model.backbone.embeddings.word_embeddings.weight = ( |
|
self.first_stage_model.embedder.embeddings.word_embeddings.weight |
|
) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
dataset_input_ids: Optional[torch.Tensor], |
|
dataset_attention_mask: Optional[torch.Tensor], |
|
output_hidden_states: bool = False, |
|
) -> torch.Tensor: |
|
""" |
|
input_ids (long torch.Tensor) – ids of input tokens |
|
attention_mask (bool torch.Tensor) |
|
""" |
|
dataset_embeddings = self.first_stage_model( |
|
input_ids=dataset_input_ids, |
|
attention_mask=dataset_attention_mask |
|
) |
|
return self.second_stage_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
dataset_embeddings=dataset_embeddings, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
|
|
|
|
def get_model_class(name: str): |
|
if name in 'transductive': |
|
return DatasetTransformer |
|
elif name == 'biencoder': |
|
return BiEncoder |
|
elif name == "dataset_prefix_biencoder": |
|
return DatasetPrefixBiencoder |
|
else: |
|
raise ValueError(f'unknown model cls {name}') |
|
|