diff --git a/Dockerfile b/Dockerfile index 923fc1e9be52defa858337ec5e24023e3e90a63a..29f78b93435ba9ca10dcb560493f39c0f17d29cc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,17 +6,19 @@ WORKDIR /usr/src/app COPY --link --chown=1000 ./ /usr/src/app COPY . . - # install dependcies -RUN conda install -y pandas numpy scikit-learn -RUN pip install --no-cache-dir -r requirements.txt +# RUN conda install -y pandas numpy scikit-learn +# RUN pip install --no-cache-dir -r requirements.txt -#if you need to download executable and run them switch to the default non-root user +# Create Conda environment from env.yaml +RUN conda env create -f env.yml +#if you need to download executable and run them switch to the default non-root user USER user #do not modify below EXPOSE 7860 ENV GRADIO_SERVER_NAME="0.0.0.0" -CMD ["python", "inference_app.py"] +# CMD ["python", "inference_app.py"] +CMD ["conda", "run", "-n", "dockformer-venv", "python", "inference_app.py"] diff --git a/dockformer/__init__.py b/dockformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..079e708b363112f6b3b9a1741d2c7325d215663c --- /dev/null +++ b/dockformer/__init__.py @@ -0,0 +1,6 @@ +from . import model +from . import utils +from . import data +from . import resources + +__all__ = ["model", "utils", "data", "resources"] diff --git a/dockformer/config.py b/dockformer/config.py new file mode 100644 index 0000000000000000000000000000000000000000..4cd1f38dc25e4acfde595431ab0079376bf0d083 --- /dev/null +++ b/dockformer/config.py @@ -0,0 +1,358 @@ +import copy +import ml_collections as mlc + +from dockformer.utils.config_tools import set_inf, enforce_config_constraints + + +def model_config( + name, + train=False, + low_prec=False, + long_sequence_inference=False +): + c = copy.deepcopy(config) + # TRAINING PRESETS + if name == "initial_training": + # AF2 Suppl. Table 4, "initial training" setting + + pass + elif name == "finetune_affinity": + c.loss.affinity2d.weight = 0.5 + c.loss.affinity1d.weight = 0.5 + c.loss.binding_site.weight = 0.5 + c.loss.positions_inter_distogram.weight = 0.5 # this is not essential given fape? + else: + raise ValueError("Invalid model name") + + c.globals.use_lma = False + + if long_sequence_inference: + assert(not train) + c.globals.use_lma = True + + if train: + c.globals.blocks_per_ckpt = 1 + c.globals.use_lma = False + + if low_prec: + c.globals.eps = 1e-4 + # If we want exact numerical parity with the original, inf can't be + # a global constant + set_inf(c, 1e4) + + enforce_config_constraints(c) + + return c + + +c_z = mlc.FieldReference(128, field_type=int) +c_m = mlc.FieldReference(256, field_type=int) +c_t = mlc.FieldReference(64, field_type=int) +c_e = mlc.FieldReference(64, field_type=int) +c_s = mlc.FieldReference(384, field_type=int) + + +blocks_per_ckpt = mlc.FieldReference(None, field_type=int) +aux_distogram_bins = mlc.FieldReference(64, field_type=int) +aux_affinity_bins = mlc.FieldReference(32, field_type=int) +eps = mlc.FieldReference(1e-8, field_type=float) + +NUM_RES = "num residues placeholder" +NUM_LIG_ATOMS = "num ligand atoms placeholder" +NUM_TOKEN = "num tokens placeholder" + + +config = mlc.ConfigDict( + { + "data": { + "common": { + "feat": { + "aatype": [NUM_TOKEN], + "all_atom_mask": [NUM_TOKEN, None], + "all_atom_positions": [NUM_TOKEN, None, None], + "atom14_alt_gt_exists": [NUM_TOKEN, None], + "atom14_alt_gt_positions": [NUM_TOKEN, None, None], + "atom14_atom_exists": [NUM_TOKEN, None], + "atom14_atom_is_ambiguous": [NUM_TOKEN, None], + "atom14_gt_exists": [NUM_TOKEN, None], + "atom14_gt_positions": [NUM_TOKEN, None, None], + "atom37_atom_exists": [NUM_TOKEN, None], + "backbone_rigid_mask": [NUM_TOKEN], + "backbone_rigid_tensor": [NUM_TOKEN, None, None], + "chi_angles_sin_cos": [NUM_TOKEN, None, None], + "chi_mask": [NUM_TOKEN, None], + "no_recycling_iters": [], + "pseudo_beta": [NUM_TOKEN, None], + "pseudo_beta_mask": [NUM_TOKEN], + "residue_index": [NUM_TOKEN], + "in_chain_residue_index": [NUM_TOKEN], + "chain_index": [NUM_TOKEN], + "residx_atom14_to_atom37": [NUM_TOKEN, None], + "residx_atom37_to_atom14": [NUM_TOKEN, None], + "resolution": [], + "rigidgroups_alt_gt_frames": [NUM_TOKEN, None, None, None], + "rigidgroups_group_exists": [NUM_TOKEN, None], + "rigidgroups_group_is_ambiguous": [NUM_TOKEN, None], + "rigidgroups_gt_exists": [NUM_TOKEN, None], + "rigidgroups_gt_frames": [NUM_TOKEN, None, None, None], + "seq_length": [], + "token_mask": [NUM_TOKEN], + "target_feat": [NUM_TOKEN, None], + "use_clamped_fape": [], + }, + "max_recycling_iters": 1, + "unsupervised_features": [ + "aatype", + "residue_index", + "in_chain_residue_index", + "chain_index", + "seq_length", + "no_recycling_iters", + "all_atom_mask", + "all_atom_positions", + ], + }, + "supervised": { + "clamp_prob": 0.9, + "supervised_features": [ + "resolution", + "use_clamped_fape", + ], + }, + "predict": { + "fixed_size": True, + "crop": False, + "crop_size": None, + "supervised": False, + "uniform_recycling": False, + }, + "eval": { + "fixed_size": True, + "crop": False, + "crop_size": None, + "supervised": True, + "uniform_recycling": False, + }, + "train": { + "fixed_size": True, + "crop": True, + "crop_size": 355, + "supervised": True, + "clamp_prob": 0.9, + "uniform_recycling": True, + "protein_distogram_mask_prob": 0.1, + }, + "data_module": { + "data_loaders": { + "batch_size": 1, + # "batch_size": 2, + "num_workers": 16, + "pin_memory": True, + "should_verify": False, + }, + }, + }, + # Recurring FieldReferences that can be changed globally here + "globals": { + "blocks_per_ckpt": blocks_per_ckpt, + # Use Staats & Rabe's low-memory attention algorithm. + "use_lma": False, + "max_lr": 1e-3, + "c_z": c_z, + "c_m": c_m, + "c_t": c_t, + "c_e": c_e, + "c_s": c_s, + "eps": eps, + }, + "model": { + "_mask_trans": False, + "structure_input_embedder": { + "protein_tf_dim": 20, + # len(POSSIBLE_ATOM_TYPES) + len(POSSIBLE_CHARGES) + len(POSSIBLE_CHIRALITIES) + "ligand_tf_dim": 34, + "additional_tf_dim": 3, # number of classes (prot, lig, aff) + "ligand_bond_dim": 6, + "c_z": c_z, + "c_m": c_m, + "relpos_k": 32, + "prot_min_bin": 3.25, + "prot_max_bin": 20.75, + "prot_no_bins": 15, + "lig_min_bin": 0.75, + "lig_max_bin": 9.75, + "lig_no_bins": 10, + "inf": 1e8, + }, + "recycling_embedder": { + "c_z": c_z, + "c_m": c_m, + "min_bin": 3.25, + "max_bin": 20.75, + "no_bins": 15, + "inf": 1e8, + }, + "evoformer_stack": { + "c_m": c_m, + "c_z": c_z, + "c_hidden_single_att": 32, + "c_hidden_mul": 128, + "c_hidden_pair_att": 32, + "c_s": c_s, + "no_heads_single": 8, + "no_heads_pair": 4, + # "no_blocks": 48, + "no_blocks": 2, + "transition_n": 4, + "single_dropout": 0.15, + "pair_dropout": 0.25, + "blocks_per_ckpt": blocks_per_ckpt, + "clear_cache_between_blocks": False, + "inf": 1e9, + "eps": eps, # 1e-10, + }, + "structure_module": { + "c_s": c_s, + "c_z": c_z, + "c_ipa": 16, + "c_resnet": 128, + "no_heads_ipa": 12, + "no_qk_points": 4, + "no_v_points": 8, + "dropout_rate": 0.1, + "no_blocks": 8, + "no_transition_layers": 1, + "no_resnet_blocks": 2, + "no_angles": 7, + "trans_scale_factor": 10, + "epsilon": eps, # 1e-12, + "inf": 1e5, + }, + "heads": { + "lddt": { + "no_bins": 50, + "c_in": c_s, + "c_hidden": 128, + }, + "distogram": { + "c_z": c_z, + "no_bins": aux_distogram_bins, + }, + "affinity_2d": { + "c_z": c_z, + "num_bins": aux_affinity_bins, + }, + "affinity_1d": { + "c_s": c_s, + "num_bins": aux_affinity_bins, + }, + "affinity_cls": { + "c_s": c_s, + "num_bins": aux_affinity_bins, + }, + "binding_site": { + "c_s": c_s, + "c_out": 1, + }, + "inter_contact": { + "c_s": c_s, + "c_z": c_z, + "c_out": 1, + }, + }, + # A negative value indicates that no early stopping will occur, i.e. + # the model will always run `max_recycling_iters` number of recycling + # iterations. A positive value will enable early stopping if the + # difference in pairwise distances is less than the tolerance between + # recycling steps. + "recycle_early_stop_tolerance": -1. + }, + "relax": { + "max_iterations": 0, # no max + "tolerance": 2.39, + "stiffness": 10.0, + "max_outer_iterations": 20, + "exclude_residues": [], + }, + "loss": { + "distogram": { + "min_bin": 2.3125, + "max_bin": 21.6875, + "no_bins": 64, + "eps": eps, # 1e-6, + "weight": 0.3, + }, + "positions_inter_distogram": { + "max_dist": 20.0, + "weight": 0.0, + }, + "positions_intra_distogram": { + "max_dist": 10.0, + "weight": 0.0, + }, + "binding_site": { + "weight": 0.0, + "pos_class_weight": 20.0, + }, + "inter_contact": { + "weight": 0.0, + "pos_class_weight": 200.0, + }, + "affinity2d": { + "min_bin": 0, + "max_bin": 15, + "no_bins": aux_affinity_bins, + "weight": 0.0, + }, + "affinity1d": { + "min_bin": 0, + "max_bin": 15, + "no_bins": aux_affinity_bins, + "weight": 0.0, + }, + "affinity_cls": { + "min_bin": 0, + "max_bin": 15, + "no_bins": aux_affinity_bins, + "weight": 0.0, + }, + "fape_backbone": { + "clamp_distance": 10.0, + "loss_unit_distance": 10.0, + "weight": 0.5, + }, + "fape_sidechain": { + "clamp_distance": 10.0, + "length_scale": 10.0, + "weight": 0.5, + }, + "fape_interface": { + "clamp_distance": 10.0, + "length_scale": 10.0, + "weight": 0.0, + }, + "plddt_loss": { + "min_resolution": 0.1, + "max_resolution": 3.0, + "cutoff": 15.0, + "no_bins": 50, + "eps": eps, # 1e-10, + "weight": 0.01, + }, + "supervised_chi": { + "chi_weight": 0.5, + "angle_norm_weight": 0.01, + "eps": eps, # 1e-6, + "weight": 1.0, + }, + "chain_center_of_mass": { + "clamp_distance": -4.0, + "weight": 0., + "eps": eps, + "enabled": False, + }, + "eps": eps, + }, + "ema": {"decay": 0.999}, + } +) diff --git a/dockformer/data/data_modules.py b/dockformer/data/data_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..4d86ab065c67005b04417876c66f1fc888035e24 --- /dev/null +++ b/dockformer/data/data_modules.py @@ -0,0 +1,643 @@ +import copy +import itertools +import time +import traceback +from collections import Counter +from functools import partial +import json +import os +import pickle +from typing import Optional, Sequence, Any + +import ml_collections as mlc +import lightning as L +import torch +from torch.utils.data import RandomSampler + +from dockformer.data.data_pipeline import parse_input_json +from dockformer.data import data_pipeline +from dockformer.utils.tensor_utils import dict_multimap +from dockformer.utils.tensor_utils import ( + tensor_tree_map, +) + + +class OpenFoldSingleDataset(torch.utils.data.Dataset): + def __init__(self, + data_dir: str, + config: mlc.ConfigDict, + mode: str = "train", + ): + """ + Args: + data_dir: + A path to a directory containing mmCIF files (in train + mode) or FASTA files (in inference mode). + config: + A dataset config object. See openfold.config + mode: + "train", "val", or "predict" + """ + super(OpenFoldSingleDataset, self).__init__() + self.data_dir = data_dir + + self.config = config + self.mode = mode + + valid_modes = ["train", "eval", "predict"] + if mode not in valid_modes: + raise ValueError(f'mode must be one of {valid_modes}') + + self._all_input_files = [i for i in os.listdir(data_dir) if i.endswith(".json")] + if self.config.data_module.data_loaders.should_verify: + self._all_input_files = [i for i in self._all_input_files if self._verify_json_input_file(i)] + + self.data_pipeline = data_pipeline.DataPipeline(config, mode) + + def _verify_json_input_file(self, file_name: str) -> bool: + with open(os.path.join(self.data_dir, file_name), "r") as f: + try: + loaded = json.load(f) + for i in ["input_structure"]: + if i not in loaded: + return False + if self.mode != "predict": + for i in ["gt_structure", "resolution"]: + if i not in loaded: + return False + except json.JSONDecodeError: + return False + return True + + def get_metadata_for_idx(self, idx: int) -> dict: + input_path = os.path.join(self.data_dir, self._all_input_files[idx]) + input_data = json.load(open(input_path, "r")) + metadata = { + "resolution": input_data.get("resolution", 99.0), + "input_path": input_path, + "input_name": os.path.basename(input_path).split(".json")[0], + } + return metadata + + def __getitem__(self, idx): + return parse_input_json( + input_path=os.path.join(self.data_dir, self._all_input_files[idx]), + mode=self.mode, + config=self.config, + data_pipeline=self.data_pipeline, + data_dir=os.path.dirname(self.data_dir), + idx=idx, + ) + + def __len__(self): + return len(self._all_input_files) + + +def resolution_filter(resolution: int, max_resolution: float) -> bool: + """Check that the resolution is <= max_resolution permitted""" + return resolution is not None and resolution <= max_resolution + + +def all_seq_len_filter(seqs: list, minimum_number_of_residues: int) -> bool: + """Check if the total combined sequence lengths are >= minimum_numer_of_residues""" + total_len = sum([len(i) for i in seqs]) + return total_len >= minimum_number_of_residues + + +class OpenFoldDataset(torch.utils.data.Dataset): + """ + Implements the stochastic filters applied during AlphaFold's training. + Because samples are selected from constituent datasets randomly, the + length of an OpenFoldFilteredDataset is arbitrary. Samples are selected + and filtered once at initialization. + """ + + def __init__(self, + datasets: Sequence[OpenFoldSingleDataset], + probabilities: Sequence[float], + epoch_len: int, + generator: torch.Generator = None, + _roll_at_init: bool = True, + ): + self.datasets = datasets + self.probabilities = probabilities + self.epoch_len = epoch_len + self.generator = generator + + self._samples = [self.looped_samples(i) for i in range(len(self.datasets))] + if _roll_at_init: + self.reroll() + + @staticmethod + def deterministic_train_filter( + cache_entry: Any, + max_resolution: float = 9., + max_single_aa_prop: float = 0.8, + *args, **kwargs + ) -> bool: + # Hard filters + resolution = cache_entry["resolution"] + + return all([ + resolution_filter(resolution=resolution, + max_resolution=max_resolution) + ]) + + @staticmethod + def get_stochastic_train_filter_prob( + cache_entry: Any, + *args, **kwargs + ) -> float: + # Stochastic filters + probabilities = [] + + cluster_size = cache_entry.get("cluster_size", None) + if cluster_size is not None and cluster_size > 0: + probabilities.append(1 / cluster_size) + + # Risk of underflow here? + out = 1 + for p in probabilities: + out *= p + + return out + + def looped_shuffled_dataset_idx(self, dataset_len): + while True: + # Uniformly shuffle each dataset's indices + weights = [1. for _ in range(dataset_len)] + shuf = torch.multinomial( + torch.tensor(weights), + num_samples=dataset_len, + replacement=False, + generator=self.generator, + ) + for idx in shuf: + yield idx + + def looped_samples(self, dataset_idx): + max_cache_len = int(self.epoch_len * self.probabilities[dataset_idx]) + dataset = self.datasets[dataset_idx] + idx_iter = self.looped_shuffled_dataset_idx(len(dataset)) + while True: + weights = [] + idx = [] + for _ in range(max_cache_len): + candidate_idx = next(idx_iter) + # chain_id = dataset.idx_to_chain_id(candidate_idx) + # chain_data_cache_entry = chain_data_cache[chain_id] + # data_entry = dataset[candidate_idx.item()] + entry_metadata_for_filter = dataset.get_metadata_for_idx(candidate_idx.item()) + if not self.deterministic_train_filter(entry_metadata_for_filter): + continue + + p = self.get_stochastic_train_filter_prob( + entry_metadata_for_filter, + ) + weights.append([1. - p, p]) + idx.append(candidate_idx) + + samples = torch.multinomial( + torch.tensor(weights), + num_samples=1, + generator=self.generator, + ) + samples = samples.squeeze() + + cache = [i for i, s in zip(idx, samples) if s] + + for datapoint_idx in cache: + yield datapoint_idx + + def __getitem__(self, idx): + dataset_idx, datapoint_idx = self.datapoints[idx] + return self.datasets[dataset_idx][datapoint_idx] + + def __len__(self): + return self.epoch_len + + def reroll(self): + # TODO bshor: I have removed support for filters (currently done in preprocess) and to weighting clusters + # now it is much faster, because it doesn't call looped_samples + dataset_choices = torch.multinomial( + torch.tensor(self.probabilities), + num_samples=self.epoch_len, + replacement=True, + generator=self.generator, + ) + self.datapoints = [] + counter_datasets = Counter(dataset_choices.tolist()) + for dataset_idx, num_samples in counter_datasets.items(): + dataset = self.datasets[dataset_idx] + sample_choices = torch.randint(0, len(dataset), (num_samples,), generator=self.generator) + for datapoint_idx in sample_choices: + self.datapoints.append((dataset_idx, datapoint_idx)) + + +class OpenFoldBatchCollator: + def __call__(self, prots): + stack_fn = partial(torch.stack, dim=0) + return dict_multimap(stack_fn, prots) + + +class OpenFoldDataLoader(torch.utils.data.DataLoader): + def __init__(self, *args, config, stage="train", generator=None, **kwargs): + super().__init__(*args, **kwargs) + self.config = config + self.stage = stage + self.generator = generator + self._prep_batch_properties_probs() + + def _prep_batch_properties_probs(self): + keyed_probs = [] + stage_cfg = self.config[self.stage] + + max_iters = self.config.common.max_recycling_iters + + if stage_cfg.uniform_recycling: + recycling_probs = [ + 1. / (max_iters + 1) for _ in range(max_iters + 1) + ] + else: + recycling_probs = [ + 0. for _ in range(max_iters + 1) + ] + recycling_probs[-1] = 1. + + keyed_probs.append( + ("no_recycling_iters", recycling_probs) + ) + + keys, probs = zip(*keyed_probs) + max_len = max([len(p) for p in probs]) + padding = [[0.] * (max_len - len(p)) for p in probs] + + self.prop_keys = keys + self.prop_probs_tensor = torch.tensor( + [p + pad for p, pad in zip(probs, padding)], + dtype=torch.float32, + ) + + def _add_batch_properties(self, batch): + # gt_features = batch.pop('gt_features', None) + samples = torch.multinomial( + self.prop_probs_tensor, + num_samples=1, # 1 per row + replacement=True, + generator=self.generator + ) + + aatype = batch["aatype"] + batch_dims = aatype.shape[:-2] + recycling_dim = aatype.shape[-1] + no_recycling = recycling_dim + for i, key in enumerate(self.prop_keys): + sample = int(samples[i][0]) + sample_tensor = torch.tensor( + sample, + device=aatype.device, + requires_grad=False + ) + orig_shape = sample_tensor.shape + sample_tensor = sample_tensor.view( + (1,) * len(batch_dims) + sample_tensor.shape + (1,) + ) + sample_tensor = sample_tensor.expand( + batch_dims + orig_shape + (recycling_dim,) + ) + batch[key] = sample_tensor + + if key == "no_recycling_iters": + no_recycling = sample + + resample_recycling = lambda t: t[..., :no_recycling + 1] + batch = tensor_tree_map(resample_recycling, batch) + # batch['gt_features'] = gt_features + + return batch + + def __iter__(self): + it = super().__iter__() + + def _batch_prop_gen(iterator): + for batch in iterator: + yield self._add_batch_properties(batch) + + return _batch_prop_gen(it) + + +class OpenFoldDataModule(L.LightningDataModule): + def __init__(self, + config: mlc.ConfigDict, + train_data_dir: Optional[str] = None, + val_data_dir: Optional[str] = None, + predict_data_dir: Optional[str] = None, + batch_seed: Optional[int] = None, + train_epoch_len: int = 50000, + **kwargs + ): + super(OpenFoldDataModule, self).__init__() + + self.config = config + self.train_data_dir = train_data_dir + self.val_data_dir = val_data_dir + self.predict_data_dir = predict_data_dir + self.batch_seed = batch_seed + self.train_epoch_len = train_epoch_len + + if self.train_data_dir is None and self.predict_data_dir is None: + raise ValueError( + 'At least one of train_data_dir or predict_data_dir must be ' + 'specified' + ) + + self.training_mode = self.train_data_dir is not None + + # if not self.training_mode and predict_alignment_dir is None: + # raise ValueError( + # 'In inference mode, predict_alignment_dir must be specified' + # ) + # elif val_data_dir is not None and val_alignment_dir is None: + # raise ValueError( + # 'If val_data_dir is specified, val_alignment_dir must ' + # 'be specified as well' + # ) + + def setup(self, stage): + # Most of the arguments are the same for the three datasets + dataset_gen = partial(OpenFoldSingleDataset, + config=self.config) + + if self.training_mode: + train_dataset = dataset_gen( + data_dir=self.train_data_dir, + mode="train", + ) + + datasets = [train_dataset] + probabilities = [1.] + + generator = None + if self.batch_seed is not None: + generator = torch.Generator() + generator = generator.manual_seed(self.batch_seed + 1) + + self.train_dataset = OpenFoldDataset( + datasets=datasets, + probabilities=probabilities, + epoch_len=self.train_epoch_len, + generator=generator, + _roll_at_init=False, + ) + + if self.val_data_dir is not None: + self.eval_dataset = dataset_gen( + data_dir=self.val_data_dir, + mode="eval", + ) + else: + self.eval_dataset = None + else: + self.predict_dataset = dataset_gen( + data_dir=self.predict_data_dir, + mode="predict", + ) + + def _gen_dataloader(self, stage): + generator = None + if self.batch_seed is not None: + generator = torch.Generator() + generator = generator.manual_seed(self.batch_seed) + + if stage == "train": + dataset = self.train_dataset + # Filter the dataset, if necessary + dataset.reroll() + elif stage == "eval": + dataset = self.eval_dataset + elif stage == "predict": + dataset = self.predict_dataset + else: + raise ValueError("Invalid stage") + + batch_collator = OpenFoldBatchCollator() + + dl = OpenFoldDataLoader( + dataset, + config=self.config, + stage=stage, + generator=generator, + batch_size=self.config.data_module.data_loaders.batch_size, + # num_workers=self.config.data_module.data_loaders.num_workers, + num_workers=0, # TODO bshor: solve generator pickling issue and then bring back num_workers, or just remove generator + collate_fn=batch_collator, + ) + + return dl + + def train_dataloader(self): + return self._gen_dataloader("train") + + def val_dataloader(self): + if self.eval_dataset is not None: + return self._gen_dataloader("eval") + return None + + def predict_dataloader(self): + return self._gen_dataloader("predict") + + +class DummyDataset(torch.utils.data.Dataset): + def __init__(self, batch_path): + with open(batch_path, "rb") as f: + self.batch = pickle.load(f) + + def __getitem__(self, idx): + return copy.deepcopy(self.batch) + + def __len__(self): + return 1000 + + +class DummyDataLoader(L.LightningDataModule): + def __init__(self, batch_path): + super().__init__() + self.dataset = DummyDataset(batch_path) + + def train_dataloader(self): + return torch.utils.data.DataLoader(self.dataset) + + +class DockFormerSimpleDataset(torch.utils.data.Dataset): + def __init__(self, clusters_json: str, config: mlc.ConfigDict, mode: str = "train"): + clusters = json.load(open(clusters_json, "r")) + self.config = config + self.mode = mode + self._data_dir = os.path.dirname(clusters_json) + print("Data dir", self._data_dir) + self._clusters = clusters + self._all_input_files = sum(clusters.values(), []) + self.data_pipeline = data_pipeline.DataPipeline(config, mode) + + def __getitem__(self, idx): + return parse_input_json( + input_path=os.path.join(self._data_dir, self._all_input_files[idx]), + mode=self.mode, + config=self.config, + data_pipeline=self.data_pipeline, + data_dir=self._data_dir, + idx=idx, + ) + + def __len__(self): + return len(self._all_input_files) + + +class DockFormerClusteredDataset(torch.utils.data.Dataset): + def __init__(self, clusters_json: str, config: mlc.ConfigDict, mode: str = "train", generator=None): + clusters = json.load(open(clusters_json, "r")) + self.config = config + self.mode = mode + self._data_dir = os.path.dirname(clusters_json) + self._clusters = list(clusters.values()) + self.data_pipeline = data_pipeline.DataPipeline(config, mode) + self._generator = generator + + def __getitem__(self, idx): + try: + cluster = self._clusters[idx] + # choose random from cluster + input_file = cluster[torch.randint(0, len(cluster), (1,), generator=self._generator).item()] + + return parse_input_json( + input_path=os.path.join(self._data_dir, input_file), + mode=self.mode, + config=self.config, + data_pipeline=self.data_pipeline, + data_dir=self._data_dir, + idx=idx, + ) + except Exception as e: + print("ERROR in loading", e) + traceback.print_exc() + return parse_input_json( + input_path=os.path.join(self._data_dir, self._clusters[0][0]), + mode=self.mode, + config=self.config, + data_pipeline=self.data_pipeline, + data_dir=self._data_dir, + idx=idx, + ) + + + def __len__(self): + return len(self._clusters) + + +class DockFormerDataLoader(torch.utils.data.DataLoader): + def __init__(self, *args, config, stage="train", generator=None, **kwargs): + super().__init__(*args, **kwargs) + self.config = config + self.stage = stage + # self.generator = generator + + def _add_batch_properties(self, batch): + if self.config[self.stage].uniform_recycling: + aatype = batch["aatype"] + max_recycling_dim = aatype.shape[-1] + + # num_recycles = torch.randint(0, max_recycling_dim, (1,), generator=self.generator) + num_recycles = torch.randint(0, max_recycling_dim, (1,)).item() + + resample_recycling = lambda t: t[..., :num_recycles + 1] + batch = tensor_tree_map(resample_recycling, batch) + + return batch + + def __iter__(self): + it = super().__iter__() + + def _batch_prop_gen(iterator): + for batch in iterator: + yield self._add_batch_properties(batch) + + return _batch_prop_gen(it) + + +class DockFormerDataModule(L.LightningDataModule): + def __init__(self, + config: mlc.ConfigDict, + train_data_file: Optional[str] = None, + val_data_file: Optional[str] = None, + batch_seed: Optional[int] = None, + **kwargs + ): + super(DockFormerDataModule, self).__init__() + + self.config = config + self.train_data_file = train_data_file + self.val_data_file = val_data_file + self.batch_seed = batch_seed + + assert self.train_data_file is not None, "train_data_file must be specified" + assert self.val_data_file is not None, "val_data_file must be specified" + + self.train_dataset = None + self.val_dataset = None + + def setup(self, stage): + generator = None + if self.batch_seed is not None: + generator = torch.Generator() + generator = generator.manual_seed(self.batch_seed + 1) + + self.train_dataset = DockFormerClusteredDataset( + clusters_json=self.train_data_file, + config=self.config, + mode="train", + generator=generator, + ) + + self.val_dataset = DockFormerSimpleDataset( + clusters_json=self.val_data_file, + config=self.config, + mode="eval", + ) + + def _gen_dataloader(self, stage): + generator = None + if self.batch_seed is not None: + generator = torch.Generator() + generator = generator.manual_seed(self.batch_seed) + + should_shuffle = stage == "train" + if stage == "train": + dataset = self.train_dataset + elif stage == "eval": + dataset = self.val_dataset + else: + raise ValueError("Invalid stage") + + batch_collator = OpenFoldBatchCollator() + + dl = DockFormerDataLoader( + dataset, + config=self.config, + stage=stage, + # generator=generator, + batch_size=self.config.data_module.data_loaders.batch_size, + # num_workers=self.config.data_module.data_loaders.num_workers, + num_workers=0, # TODO bshor: solve generator pickling issue and then bring back num_workers, or just remove generator + collate_fn=batch_collator, + shuffle=should_shuffle, + ) + + return dl + + def train_dataloader(self): + return self._gen_dataloader("train") + + def val_dataloader(self): + if self.val_dataset is not None: + return self._gen_dataloader("eval") + return None diff --git a/dockformer/data/data_pipeline.py b/dockformer/data/data_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..6b2a1adebeb41ec01af94319723da5d97f49f54f --- /dev/null +++ b/dockformer/data/data_pipeline.py @@ -0,0 +1,503 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import time +from typing import List + +import numpy as np +import torch +import ml_collections as mlc +from rdkit import Chem + +from dockformer.data import data_transforms +from dockformer.data.data_transforms import get_restype_atom37_mask, get_restypes +from dockformer.data.ligand_features import make_ligand_features +from dockformer.data.protein_features import make_protein_features +from dockformer.data.utils import FeatureTensorDict, FeatureDict +from dockformer.utils import protein + + +def _np_filter_and_to_tensor_dict(np_example: FeatureDict, features_to_keep: List[str]) -> FeatureTensorDict: + """Creates dict of tensors from a dict of NumPy arrays. + + Args: + np_example: A dict of NumPy feature arrays. + features: A list of strings of feature names to be returned in the dataset. + + Returns: + A dictionary of features mapping feature names to features. Only the given + features are returned, all other ones are filtered out. + """ + # torch generates warnings if feature is already a torch Tensor + to_tensor = lambda t: torch.tensor(t) if type(t) != torch.Tensor else t.clone().detach() + tensor_dict = { + k: to_tensor(v) for k, v in np_example.items() if k in features_to_keep + } + + return tensor_dict + + +def _add_protein_probablistic_features(features: FeatureDict, cfg: mlc.ConfigDict, mode: str) -> FeatureDict: + if mode == "train": + p = torch.rand(1).item() + use_clamped_fape_value = float(p < cfg.supervised.clamp_prob) + features["use_clamped_fape"] = np.float32(use_clamped_fape_value) + else: + features["use_clamped_fape"] = np.float32(0.0) + return features + + +@data_transforms.curry1 +def compose(x, fs): + for f in fs: + x = f(x) + return x + + +def _apply_protein_transforms(tensors: FeatureTensorDict) -> FeatureTensorDict: + transforms = [ + data_transforms.cast_to_64bit_ints, + data_transforms.squeeze_features, + data_transforms.make_atom14_masks, + data_transforms.make_atom14_positions, + data_transforms.atom37_to_frames, + data_transforms.atom37_to_torsion_angles(""), + data_transforms.make_pseudo_beta(), + data_transforms.get_backbone_frames, + data_transforms.get_chi_angles, + ] + + tensors = compose(transforms)(tensors) + + return tensors + + +def _apply_protein_probablistic_transforms(tensors: FeatureTensorDict, cfg: mlc.ConfigDict, mode: str) \ + -> FeatureTensorDict: + transforms = [data_transforms.make_target_feat()] + + crop_feats = dict(cfg.common.feat) + + if cfg[mode].fixed_size: + transforms.append(data_transforms.select_feat(list(crop_feats))) + # TODO bshor: restore transforms for training on cropped proteins, need to handle pocket somehow + # if so, look for random_crop_to_size and make_fixed_size in data_transforms.py + + compose(transforms)(tensors) + + return tensors + + +class DataPipeline: + """Assembles input features.""" + def __init__(self, config: mlc.ConfigDict, mode: str): + self.config = config + self.mode = mode + + self.feature_names = config.common.unsupervised_features + if config[mode].supervised: + self.feature_names += config.supervised.supervised_features + + def process_pdb(self, pdb_path: str) -> FeatureTensorDict: + """ + Assembles features for a protein in a PDB file. + """ + with open(pdb_path, 'r') as f: + pdb_str = f.read() + + protein_object = protein.from_pdb_string(pdb_str) + description = os.path.splitext(os.path.basename(pdb_path))[0].upper() + pdb_feats = make_protein_features(protein_object, description) + pdb_feats = _add_protein_probablistic_features(pdb_feats, self.config, self.mode) + + tensor_feats = _np_filter_and_to_tensor_dict(pdb_feats, self.feature_names) + + tensor_feats = _apply_protein_transforms(tensor_feats) + tensor_feats = _apply_protein_probablistic_transforms(tensor_feats, self.config, self.mode) + + return tensor_feats + + def process_smiles(self, smiles: str) -> FeatureTensorDict: + ligand = Chem.MolFromSmiles(smiles) + return make_ligand_features(ligand) + + def process_mol2(self, mol2_path: str) -> FeatureTensorDict: + """ + Assembles features for a ligand in a mol2 file. + """ + ligand = Chem.MolFromMol2File(mol2_path) + assert ligand is not None, f"Failed to parse ligand from {mol2_path}" + + conf = ligand.GetConformer() + positions = torch.tensor(conf.GetPositions()) + + return { + **make_ligand_features(ligand), + "gt_ligand_positions": positions.float() + } + + def process_sdf(self, sdf_path: str) -> FeatureTensorDict: + """ + Assembles features for a ligand in a mol2 file. + """ + ligand = Chem.MolFromMolFile(sdf_path) + assert ligand is not None, f"Failed to parse ligand from {sdf_path}" + + conf = ligand.GetConformer(0) + positions = torch.tensor(conf.GetPositions()) + + return { + **make_ligand_features(ligand), + "ligand_positions": positions.float() + } + + def process_sdf_list(self, sdf_path_list: List[str]) -> FeatureTensorDict: + all_sdf_feats = [self.process_sdf(sdf_path) for sdf_path in sdf_path_list] + + all_sizes = [sdf_feats["ligand_target_feat"].shape[0] for sdf_feats in all_sdf_feats] + + joined_ligand_feats = {} + for k in all_sdf_feats[0].keys(): + if k == "ligand_positions": + joined_positions = all_sdf_feats[0][k] + prev_offset = joined_positions.max(dim=0).values + 100 + + for i, sdf_feats in enumerate(all_sdf_feats[1:]): + offset = prev_offset - sdf_feats[k].min(dim=0).values + joined_positions = torch.cat([joined_positions, sdf_feats[k] + offset], dim=0) + prev_offset = joined_positions.max(dim=0).values + 100 + joined_ligand_feats[k] = joined_positions + elif k in ["ligand_target_feat", "ligand_atype", "ligand_charge", "ligand_chirality", "ligand_bonds"]: + joined_ligand_feats[k] = torch.cat([sdf_feats[k] for sdf_feats in all_sdf_feats], dim=0) + if k == "ligand_target_feat": + joined_ligand_feats["ligand_idx"] = torch.cat([torch.full((sdf_feats[k].shape[0],), i) + for i, sdf_feats in enumerate(all_sdf_feats)], dim=0) + elif k == "ligand_bonds": + joined_ligand_feats["ligand_bonds_idx"] = torch.cat([torch.full((sdf_feats[k].shape[0],), i) + for i, sdf_feats in enumerate(all_sdf_feats)], + dim=0) + elif k == "ligand_bonds_feat": + joined_feature = torch.zeros((sum(all_sizes), sum(all_sizes), all_sdf_feats[0][k].shape[2])) + for i, sdf_feats in enumerate(all_sdf_feats): + start_idx = sum(all_sizes[:i]) + end_idx = sum(all_sizes[:i + 1]) + joined_feature[start_idx:end_idx, start_idx:end_idx, :] = sdf_feats[k] + joined_ligand_feats[k] = joined_feature + else: + raise ValueError(f"Unknown key in sdf list features {k}") + return joined_ligand_feats + + def get_matching_positions_list(self, ref_path_list: List[str], gt_path_list: List[str]): + joined_gt_positions = [] + + for ref_ligand_path, gt_ligand_path in zip(ref_path_list, gt_path_list): + ref_ligand = Chem.MolFromMolFile(ref_ligand_path) + gt_ligand = Chem.MolFromMolFile(gt_ligand_path) + + gt_original_positions = gt_ligand.GetConformer(0).GetPositions() + + gt_positions = [gt_original_positions[idx] for idx in gt_ligand.GetSubstructMatch(ref_ligand)] + + joined_gt_positions.extend(gt_positions) + + return torch.tensor(np.array(joined_gt_positions)).float() + + def get_matching_positions(self, ref_ligand_path: str, gt_ligand_path: str): + ref_ligand = Chem.MolFromMolFile(ref_ligand_path) + gt_ligand = Chem.MolFromMolFile(gt_ligand_path) + + gt_original_positions = gt_ligand.GetConformer(0).GetPositions() + + gt_positions = [gt_original_positions[idx] for idx in gt_ligand.GetSubstructMatch(ref_ligand)] + + # ref_positions = ref_ligand.GetConformer(0).GetPositions() + # for i in range(len(ref_positions)): + # for j in range(i + 1, len(ref_positions)): + # dist_ref = np.linalg.norm(ref_positions[i] - ref_positions[j]) + # dist_gt = np.linalg.norm(gt_positions[i] - gt_positions[j]) + # dist_gt = np.linalg.norm(gt_original_positions[i] - gt_original_positions[j]) + # if abs(dist_ref - dist_gt) > 1.0: + # print(f"Distance mismatch {i} {j} {dist_ref} {dist_gt}") + + return torch.tensor(np.array(gt_positions)) .float() + + +def _prepare_recycles(feat: torch.Tensor, num_recycles: int) -> torch.Tensor: + return feat.unsqueeze(-1).repeat(*([1] * len(feat.shape)), num_recycles) + + +def _fit_to_crop(target_tensor: torch.Tensor, crop_size: int, start_ind: int) -> torch.Tensor: + if len(target_tensor.shape) == 1: + ret = torch.zeros((crop_size, ), dtype=target_tensor.dtype) + ret[start_ind:start_ind + target_tensor.shape[0]] = target_tensor + return ret + elif len(target_tensor.shape) == 2: + ret = torch.zeros((crop_size, target_tensor.shape[-1]), dtype=target_tensor.dtype) + ret[start_ind:start_ind + target_tensor.shape[0], :] = target_tensor + return ret + else: + ret = torch.zeros((crop_size, *target_tensor.shape[1:]), dtype=target_tensor.dtype) + ret[start_ind:start_ind + target_tensor.shape[0], ...] = target_tensor + return ret + + +def parse_input_json(input_path: str, mode: str, config: mlc.ConfigDict, data_pipeline: DataPipeline, + data_dir: str, idx: int) -> FeatureTensorDict: + start_load_time = time.time() + input_data = json.load(open(input_path, "r")) + if mode == "train" or mode == "eval": + print("loading", input_data["pdb_id"], end=" ") + + num_recycles = config.common.max_recycling_iters + 1 + + input_pdb_path = os.path.join(data_dir, input_data["input_structure"]) + input_protein_feats = data_pipeline.process_pdb(pdb_path=input_pdb_path) + + # load ref sdf + if "ref_sdf" in input_data: + ref_sdf_path = os.path.join(data_dir, input_data["ref_sdf"]) + ref_ligand_feats = data_pipeline.process_sdf(sdf_path=ref_sdf_path) + ref_ligand_feats["ligand_idx"] = torch.zeros((ref_ligand_feats["ligand_target_feat"].shape[0],)) + ref_ligand_feats["ligand_bonds_idx"] = torch.zeros((ref_ligand_feats["ligand_bonds"].shape[0],)) + elif "ref_sdf_list" in input_data: + sdf_path_list = [os.path.join(data_dir, i) for i in input_data["ref_sdf_list"]] + ref_ligand_feats = data_pipeline.process_sdf_list(sdf_path_list=sdf_path_list) + else: + raise ValueError("ref_sdf or ref_sdf_list must be in input_data") + + n_res = input_protein_feats["protein_target_feat"].shape[0] + n_lig = ref_ligand_feats["ligand_target_feat"].shape[0] + n_affinity = 1 + + # add 1 for affinity token + crop_size = n_res + n_lig + n_affinity + if (mode == "train" or mode == "eval") and config.train.fixed_size: + crop_size = config.train.crop_size + + assert crop_size >= n_res + n_lig + n_affinity, f"crop_size: {crop_size}, n_res: {n_res}, n_lig: {n_lig}" + + token_mask = torch.zeros((crop_size,), dtype=torch.float32) + token_mask[:n_res + n_lig + n_affinity] = 1 + + protein_mask = torch.zeros((crop_size,), dtype=torch.float32) + protein_mask[:n_res] = 1 + + ligand_mask = torch.zeros((crop_size,), dtype=torch.float32) + ligand_mask[n_res:n_res + n_lig] = 1 + + affinity_mask = torch.zeros((crop_size,), dtype=torch.float32) + affinity_mask[n_res + n_lig] = 1 + + structural_mask = torch.zeros((crop_size,), dtype=torch.float32) + structural_mask[:n_res + n_lig] = 1 + + inter_pair_mask = torch.zeros((crop_size, crop_size), dtype=torch.float32) + inter_pair_mask[:n_res, n_res:n_res + n_lig] = 1 + inter_pair_mask[n_res:n_res + n_lig, :n_res] = 1 + + protein_tf_dim = input_protein_feats["protein_target_feat"].shape[-1] + ligand_tf_dim = ref_ligand_feats["ligand_target_feat"].shape[-1] + joined_tf_dim = protein_tf_dim + ligand_tf_dim + + target_feat = torch.zeros((crop_size, joined_tf_dim + 3), dtype=torch.float32) + target_feat[:n_res, :protein_tf_dim] = input_protein_feats["protein_target_feat"] + target_feat[n_res:n_res + n_lig, protein_tf_dim:joined_tf_dim] = ref_ligand_feats["ligand_target_feat"] + + target_feat[:n_res, joined_tf_dim] = 1 # Set "is_protein" flag for protein rows + target_feat[n_res:n_res + n_lig, joined_tf_dim + 1] = 1 # Set "is_ligand" flag for ligand rows + target_feat[n_res + n_lig, joined_tf_dim + 2] = 1 # Set "is_affinity" flag for affinity row + + ligand_bonds_feat = torch.zeros((crop_size, crop_size, ref_ligand_feats["ligand_bonds_feat"].shape[-1]), + dtype=torch.float32) + ligand_bonds_feat[n_res:n_res + n_lig, n_res:n_res + n_lig] = ref_ligand_feats["ligand_bonds_feat"] + + input_positions = torch.zeros((crop_size, 3), dtype=torch.float32) + input_positions[:n_res] = input_protein_feats["pseudo_beta"] + input_positions[n_res:n_res + n_lig] = ref_ligand_feats["ligand_positions"] + + protein_distogram_mask = torch.zeros(crop_size) + if mode == "train": + ones_indices = torch.randperm(n_res)[:int(n_res * config.train.protein_distogram_mask_prob)] + # print(ones_indices) + protein_distogram_mask[ones_indices] = 1 + input_positions = input_positions * (1 - protein_distogram_mask).unsqueeze(-1) + elif mode == "predict": + # ignore all positions where pseudo_beta is 0, 0, 0 + protein_distogram_mask = (input_positions == 0).all(dim=-1).float() + # print("Ignoring residues", torch.nonzero(distogram_mask).flatten()) + + # Implement ligand as amino acid type 20 + ligand_aatype = 20 * torch.ones((n_lig,), dtype=input_protein_feats["aatype"].dtype) + aatype = torch.cat([input_protein_feats["aatype"], ligand_aatype], dim=0) + + restype_atom14_to_atom37, restype_atom37_to_atom14, restype_atom14_mask = get_restypes(target_feat.device) + lig_residx_atom37_to_atom14 = restype_atom37_to_atom14[20].repeat(n_lig, 1) + residx_atom37_to_atom14 = torch.cat([input_protein_feats["residx_atom37_to_atom14"], lig_residx_atom37_to_atom14], + dim=0) + + restype_atom37_mask = get_restype_atom37_mask(target_feat.device) + lig_atom37_atom_exists = restype_atom37_mask[20].repeat(n_lig, 1) + atom37_atom_exists = torch.cat([input_protein_feats["atom37_atom_exists"], lig_atom37_atom_exists], dim=0) + + feats = { + "token_mask": token_mask, + "protein_mask": protein_mask, + "ligand_mask": ligand_mask, + "affinity_mask": affinity_mask, + "structural_mask": structural_mask, + "inter_pair_mask": inter_pair_mask, + + "target_feat": target_feat, + "ligand_bonds_feat": ligand_bonds_feat, + "input_positions": input_positions, + "protein_distogram_mask": protein_distogram_mask, + "protein_residue_index": _fit_to_crop(input_protein_feats["residue_index"], crop_size, 0), + "aatype": _fit_to_crop(aatype, crop_size, 0), + "residx_atom37_to_atom14": _fit_to_crop(residx_atom37_to_atom14, crop_size, 0), + "atom37_atom_exists": _fit_to_crop(atom37_atom_exists, crop_size, 0), + } + + if mode == "predict": + feats.update({ + "in_chain_residue_index": input_protein_feats["in_chain_residue_index"], + "chain_index": input_protein_feats["chain_index"], + "ligand_atype": ref_ligand_feats["ligand_atype"], + "ligand_chirality": ref_ligand_feats["ligand_chirality"], + "ligand_charge": ref_ligand_feats["ligand_charge"], + "ligand_bonds": ref_ligand_feats["ligand_bonds"], + "ligand_idx": ref_ligand_feats["ligand_idx"], + "ligand_bonds_idx": ref_ligand_feats["ligand_bonds_idx"], + }) + + if mode == 'train' or mode == 'eval': + gt_pdb_path = os.path.join(data_dir, input_data["gt_structure"]) + gt_protein_feats = data_pipeline.process_pdb(pdb_path=gt_pdb_path) + + if "gt_sdf" in input_data: + gt_ligand_positions = data_pipeline.get_matching_positions( + os.path.join(data_dir, input_data["ref_sdf"]), + os.path.join(data_dir, input_data["gt_sdf"]), + ) + elif "gt_sdf_list" in input_data: + gt_ligand_positions = data_pipeline.get_matching_positions_list( + [os.path.join(data_dir, i) for i in input_data["ref_sdf_list"]], + [os.path.join(data_dir, i) for i in input_data["gt_sdf_list"]], + ) + else: + raise ValueError("gt_sdf or gt_sdf_list must be in input_data") + + affinity_loss_factor = torch.tensor([1.0], dtype=torch.float32) + if input_data["affinity"] is None: + eps = 1e-6 + affinity_loss_factor = torch.tensor([eps], dtype=torch.float32) + affinity = torch.tensor([0.0], dtype=torch.float32) + else: + affinity = torch.tensor([input_data["affinity"]], dtype=torch.float32) + + resolution = torch.tensor(input_data["resolution"], dtype=torch.float32) + + # prepare inter_contacts + expanded_prot_pos = gt_protein_feats["pseudo_beta"].unsqueeze(1) # Shape: (N_prot, 1, 3) + expanded_lig_pos = gt_ligand_positions.unsqueeze(0) # Shape: (1, N_lig, 3) + distances = torch.sqrt(torch.sum((expanded_prot_pos - expanded_lig_pos) ** 2, dim=-1)) + inter_contact = (distances < 5.0).float() + binding_site_mask = inter_contact.any(dim=1).float() + + inter_contact_reshaped_to_crop = torch.zeros((crop_size, crop_size), dtype=torch.float32) + inter_contact_reshaped_to_crop[:n_res, n_res:n_res + n_lig] = inter_contact + inter_contact_reshaped_to_crop[n_res:n_res + n_lig, :n_res] = inter_contact.T + + # Use CA positions only + lig_single_res_atom37_mask = torch.zeros((37,), dtype=torch.float32) + lig_single_res_atom37_mask[1] = 1 + lig_atom37_mask = lig_single_res_atom37_mask.unsqueeze(0).expand(n_lig, -1) + lig_single_res_atom14_mask = torch.zeros((14,), dtype=torch.float32) + lig_single_res_atom14_mask[1] = 1 + lig_atom14_mask = lig_single_res_atom14_mask.unsqueeze(0).expand(n_lig, -1) + + lig_atom37_positions = gt_ligand_positions.unsqueeze(1).expand(-1, 37, -1) + lig_atom37_positions = lig_atom37_positions * lig_single_res_atom37_mask.view(1, 37, 1).expand(n_lig, -1, 3) + + lig_atom14_positions = gt_ligand_positions.unsqueeze(1).expand(-1, 14, -1) + lig_atom14_positions = lig_atom14_positions * lig_single_res_atom14_mask.view(1, 14, 1).expand(n_lig, -1, 3) + + atom37_gt_positions = torch.cat([gt_protein_feats["all_atom_positions"], lig_atom37_positions], dim=0) + atom37_atom_exists_in_res = torch.cat([gt_protein_feats["atom37_atom_exists"], lig_atom37_mask], dim=0) + atom37_atom_exists_in_gt = torch.cat([gt_protein_feats["all_atom_mask"], lig_atom37_mask], dim=0) + + atom14_gt_positions = torch.cat([gt_protein_feats["atom14_gt_positions"], lig_atom14_positions], dim=0) + atom14_atom_exists_in_res = torch.cat([gt_protein_feats["atom14_atom_exists"], lig_atom14_mask], dim=0) + atom14_atom_exists_in_gt = torch.cat([gt_protein_feats["atom14_gt_exists"], lig_atom14_mask], dim=0) + + gt_pseudo_beta_with_lig = torch.cat([gt_protein_feats["pseudo_beta"], gt_ligand_positions], dim=0) + gt_pseudo_beta_with_lig_mask = torch.cat( + [gt_protein_feats["pseudo_beta_mask"], + torch.ones((n_lig,), dtype=gt_protein_feats["pseudo_beta_mask"].dtype)], + dim=0) + + # IGNORES: residx_atom14_to_atom37, rigidgroups_group_exists, + # rigidgroups_group_is_ambiguous, pseudo_beta_mask, backbone_rigid_mask, protein_target_feat + gt_protein_feats = { + "atom37_gt_positions": atom37_gt_positions, # torch.Size([n_struct, 37, 3]) + "atom37_atom_exists_in_res": atom37_atom_exists_in_res, # torch.Size([n_struct, 37]) + "atom37_atom_exists_in_gt": atom37_atom_exists_in_gt, # torch.Size([n_struct, 37]) + + "atom14_gt_positions": atom14_gt_positions, # torch.Size([n_struct, 14, 3]) + "atom14_atom_exists_in_res": atom14_atom_exists_in_res, # torch.Size([n_struct, 14]) + "atom14_atom_exists_in_gt": atom14_atom_exists_in_gt, # torch.Size([n_struct, 14]) + + "gt_pseudo_beta_with_lig": gt_pseudo_beta_with_lig, # torch.Size([n_struct, 3]) + "gt_pseudo_beta_with_lig_mask": gt_pseudo_beta_with_lig_mask, # torch.Size([n_struct]) + + # These we don't need to add the ligand to, because padding is sufficient (everything should be 0) + "atom14_alt_gt_positions": gt_protein_feats["atom14_alt_gt_positions"], # torch.Size([n_res, 14, 3]) + "atom14_alt_gt_exists": gt_protein_feats["atom14_alt_gt_exists"], # torch.Size([n_res, 14]) + "atom14_atom_is_ambiguous": gt_protein_feats["atom14_atom_is_ambiguous"], # torch.Size([n_res, 14]) + "rigidgroups_gt_frames": gt_protein_feats["rigidgroups_gt_frames"], # torch.Size([n_res, 8, 4, 4]) + "rigidgroups_gt_exists": gt_protein_feats["rigidgroups_gt_exists"], # torch.Size([n_res, 8]) + "rigidgroups_alt_gt_frames": gt_protein_feats["rigidgroups_alt_gt_frames"], # torch.Size([n_res, 8, 4, 4]) + "backbone_rigid_tensor": gt_protein_feats["backbone_rigid_tensor"], # torch.Size([n_res, 4, 4]) + "backbone_rigid_mask": gt_protein_feats["backbone_rigid_mask"], # torch.Size([n_res]) + "chi_angles_sin_cos": gt_protein_feats["chi_angles_sin_cos"], + "chi_mask": gt_protein_feats["chi_mask"], + } + + for k, v in gt_protein_feats.items(): + gt_protein_feats[k] = _fit_to_crop(v, crop_size, 0) + + feats = { + **feats, + **gt_protein_feats, + "gt_ligand_positions": _fit_to_crop(gt_ligand_positions, crop_size, n_res), + "resolution": resolution, + "affinity": affinity, + "affinity_loss_factor": affinity_loss_factor, + "seq_length": torch.tensor(n_res + n_lig), + "binding_site_mask": _fit_to_crop(binding_site_mask, crop_size, 0), + "gt_inter_contacts": inter_contact_reshaped_to_crop, + } + + for k, v in feats.items(): + # print(k, v.shape) + feats[k] = _prepare_recycles(v, num_recycles) + + feats["batch_idx"] = torch.tensor( + [idx for _ in range(crop_size)], dtype=torch.int64, device=feats["aatype"].device + ) + + print("load time", round(time.time() - start_load_time, 4)) + + return feats diff --git a/dockformer/data/data_transforms.py b/dockformer/data/data_transforms.py new file mode 100755 index 0000000000000000000000000000000000000000..13ec4cc1eeb9e7ccc2434431e8ff14074a971762 --- /dev/null +++ b/dockformer/data/data_transforms.py @@ -0,0 +1,731 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +from functools import reduce, wraps +from operator import add + +import numpy as np +import torch + +from dockformer.config import NUM_RES +from dockformer.utils import residue_constants as rc +from dockformer.utils.residue_constants import restypes +from dockformer.utils.rigid_utils import Rotation, Rigid +from dockformer.utils.geometry.rigid_matrix_vector import Rigid3Array +from dockformer.utils.geometry.rotation_matrix import Rot3Array +from dockformer.utils.geometry.vector import Vec3Array +from dockformer.utils.tensor_utils import ( + tree_map, + tensor_tree_map, + batched_gather, +) + + +def cast_to_64bit_ints(protein): + # We keep all ints as int64 + for k, v in protein.items(): + if v.dtype == torch.int32: + protein[k] = v.type(torch.int64) + + return protein + + +def make_one_hot(x, num_classes): + x_one_hot = torch.zeros(*x.shape, num_classes, device=x.device) + x_one_hot.scatter_(-1, x.unsqueeze(-1), 1) + return x_one_hot + + +def curry1(f): + """Supply all arguments but the first.""" + @wraps(f) + def fc(*args, **kwargs): + return lambda x: f(x, *args, **kwargs) + + return fc + + +def squeeze_features(protein): + """Remove singleton and repeated dimensions in protein features.""" + protein["aatype"] = torch.argmax(protein["aatype"], dim=-1) + for k in [ + "domain_name", + "seq_length", + "sequence", + "resolution", + "residue_index", + ]: + if k in protein: + final_dim = protein[k].shape[-1] + if isinstance(final_dim, int) and final_dim == 1: + if torch.is_tensor(protein[k]): + protein[k] = torch.squeeze(protein[k], dim=-1) + else: + protein[k] = np.squeeze(protein[k], axis=-1) + + for k in ["seq_length"]: + if k in protein: + protein[k] = protein[k][0] + + return protein + + +def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask): + """Create pseudo beta features.""" + is_gly = torch.eq(aatype, rc.restype_order["G"]) + ca_idx = rc.atom_order["CA"] + cb_idx = rc.atom_order["CB"] + pseudo_beta = torch.where( + torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]), + all_atom_positions[..., ca_idx, :], + all_atom_positions[..., cb_idx, :], + ) + + if all_atom_mask is not None: + pseudo_beta_mask = torch.where( + is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx] + ) + return pseudo_beta, pseudo_beta_mask + else: + return pseudo_beta + + +@curry1 +def make_pseudo_beta(protein): + """Create pseudo-beta (alpha for glycine) position and mask.""" + (protein["pseudo_beta"], protein["pseudo_beta_mask"]) = pseudo_beta_fn( + protein["aatype"], + protein["all_atom_positions"], + protein["all_atom_mask"], + ) + return protein + + +@curry1 +def make_target_feat(protein): + """Create and concatenate protein features.""" + # Whether there is a domain break. Always zero for chains, but keeping for + # compatibility with domain datasets. + aatype_1hot = make_one_hot(protein["aatype"], 20) + + protein["protein_target_feat"] = aatype_1hot + + return protein + + + +@curry1 +def select_feat(protein, feature_list): + return {k: v for k, v in protein.items() if k in feature_list} + + +def get_restypes(device): + restype_atom14_to_atom37 = [] + restype_atom37_to_atom14 = [] + restype_atom14_mask = [] + + for rt in rc.restypes: + atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]] + restype_atom14_to_atom37.append( + [(rc.atom_order[name] if name else 0) for name in atom_names] + ) + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + restype_atom37_to_atom14.append( + [ + (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) + for name in rc.atom_types + ] + ) + + restype_atom14_mask.append( + [(1.0 if name else 0.0) for name in atom_names] + ) + + # Add dummy mapping for restype 'UNK' + restype_atom14_to_atom37.append([0] * 14) + restype_atom37_to_atom14.append([0] * 37) + restype_atom14_mask.append([0.0] * 14) + + restype_atom14_to_atom37 = torch.tensor( + restype_atom14_to_atom37, + dtype=torch.int32, + device=device, + ) + restype_atom37_to_atom14 = torch.tensor( + restype_atom37_to_atom14, + dtype=torch.int32, + device=device, + ) + restype_atom14_mask = torch.tensor( + restype_atom14_mask, + dtype=torch.float32, + device=device, + ) + + return restype_atom14_to_atom37, restype_atom37_to_atom14, restype_atom14_mask + + +def get_restype_atom37_mask(device): + # create the corresponding mask + restype_atom37_mask = torch.zeros( + [len(restypes) + 1, 37], dtype=torch.float32, device=device + ) + for restype, restype_letter in enumerate(rc.restypes): + restype_name = rc.restype_1to3[restype_letter] + atom_names = rc.residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = rc.atom_order[atom_name] + restype_atom37_mask[restype, atom_type] = 1 + return restype_atom37_mask + + +def make_atom14_masks(protein): + """Construct denser atom positions (14 dimensions instead of 37).""" + restype_atom14_to_atom37, restype_atom37_to_atom14, restype_atom14_mask = get_restypes(protein["aatype"].device) + + protein_aatype = protein['aatype'].to(torch.long) + + # create the mapping for (residx, atom14) --> atom37, i.e. an array + # with shape (num_res, 14) containing the atom37 indices for this protein + residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype] + residx_atom14_mask = restype_atom14_mask[protein_aatype] + + protein["atom14_atom_exists"] = residx_atom14_mask + protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long() + + # create the gather indices for mapping back + residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype] + protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long() + + restype_atom37_mask = get_restype_atom37_mask(protein["aatype"].device) + + residx_atom37_mask = restype_atom37_mask[protein_aatype] + protein["atom37_atom_exists"] = residx_atom37_mask + + return protein + + +def make_atom14_positions(protein): + """Constructs denser atom positions (14 dimensions instead of 37).""" + residx_atom14_mask = protein["atom14_atom_exists"] + residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"] + + # Create a mask for known ground truth positions. + residx_atom14_gt_mask = residx_atom14_mask * batched_gather( + protein["all_atom_mask"], + residx_atom14_to_atom37, + dim=-1, + no_batch_dims=len(protein["all_atom_mask"].shape[:-1]), + ) + + # Gather the ground truth positions. + residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * ( + batched_gather( + protein["all_atom_positions"], + residx_atom14_to_atom37, + dim=-2, + no_batch_dims=len(protein["all_atom_positions"].shape[:-2]), + ) + ) + + protein["atom14_atom_exists"] = residx_atom14_mask + protein["atom14_gt_exists"] = residx_atom14_gt_mask + protein["atom14_gt_positions"] = residx_atom14_gt_positions + + # As the atom naming is ambiguous for 7 of the 20 amino acids, provide + # alternative ground truth coordinates where the naming is swapped + restype_3 = [rc.restype_1to3[res] for res in rc.restypes] + restype_3 += ["UNK"] + + # Matrices for renaming ambiguous atoms. + all_matrices = { + res: torch.eye( + 14, + dtype=protein["all_atom_mask"].dtype, + device=protein["all_atom_mask"].device, + ) + for res in restype_3 + } + for resname, swap in rc.residue_atom_renaming_swaps.items(): + correspondences = torch.arange( + 14, device=protein["all_atom_mask"].device + ) + for source_atom_swap, target_atom_swap in swap.items(): + source_index = rc.restype_name_to_atom14_names[resname].index( + source_atom_swap + ) + target_index = rc.restype_name_to_atom14_names[resname].index( + target_atom_swap + ) + correspondences[source_index] = target_index + correspondences[target_index] = source_index + renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14)) + for index, correspondence in enumerate(correspondences): + renaming_matrix[index, correspondence] = 1.0 + all_matrices[resname] = renaming_matrix + + renaming_matrices = torch.stack( + [all_matrices[restype] for restype in restype_3] + ) + + # Pick the transformation matrices for the given residue sequence + # shape (num_res, 14, 14). + renaming_transform = renaming_matrices[protein["aatype"]] + + # Apply it to the ground truth positions. shape (num_res, 14, 3). + alternative_gt_positions = torch.einsum( + "...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform + ) + protein["atom14_alt_gt_positions"] = alternative_gt_positions + + # Create the mask for the alternative ground truth (differs from the + # ground truth mask, if only one of the atoms in an ambiguous pair has a + # ground truth position). + alternative_gt_mask = torch.einsum( + "...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform + ) + protein["atom14_alt_gt_exists"] = alternative_gt_mask + + # Create an ambiguous atoms mask. shape: (21, 14). + restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14)) + for resname, swap in rc.residue_atom_renaming_swaps.items(): + for atom_name1, atom_name2 in swap.items(): + restype = rc.restype_order[rc.restype_3to1[resname]] + atom_idx1 = rc.restype_name_to_atom14_names[resname].index( + atom_name1 + ) + atom_idx2 = rc.restype_name_to_atom14_names[resname].index( + atom_name2 + ) + restype_atom14_is_ambiguous[restype, atom_idx1] = 1 + restype_atom14_is_ambiguous[restype, atom_idx2] = 1 + + # From this create an ambiguous_mask for the given sequence. + protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[ + protein["aatype"] + ] + + return protein + + +def atom37_to_frames(protein, eps=1e-8): + aatype = protein["aatype"] + all_atom_positions = protein["all_atom_positions"] + all_atom_mask = protein["all_atom_mask"] + + batch_dims = len(aatype.shape[:-1]) + + restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object) + restype_rigidgroup_base_atom_names[:, 0, :] = ["C", "CA", "N"] + restype_rigidgroup_base_atom_names[:, 3, :] = ["CA", "C", "O"] + + for restype, restype_letter in enumerate(rc.restypes): + resname = rc.restype_1to3[restype_letter] + for chi_idx in range(4): + if rc.chi_angles_mask[restype][chi_idx]: + names = rc.chi_angles_atoms[resname][chi_idx] + restype_rigidgroup_base_atom_names[ + restype, chi_idx + 4, : + ] = names[1:] + + restype_rigidgroup_mask = all_atom_mask.new_zeros( + (*aatype.shape[:-1], 21, 8), + ) + restype_rigidgroup_mask[..., 0] = 1 + restype_rigidgroup_mask[..., 3] = 1 + restype_rigidgroup_mask[..., :len(restypes), 4:] = all_atom_mask.new_tensor( + rc.chi_angles_mask + ) + + lookuptable = rc.atom_order.copy() + lookuptable[""] = 0 + lookup = np.vectorize(lambda x: lookuptable[x]) + restype_rigidgroup_base_atom37_idx = lookup( + restype_rigidgroup_base_atom_names, + ) + restype_rigidgroup_base_atom37_idx = aatype.new_tensor( + restype_rigidgroup_base_atom37_idx, + ) + restype_rigidgroup_base_atom37_idx = ( + restype_rigidgroup_base_atom37_idx.view( + *((1,) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape + ) + ) + + residx_rigidgroup_base_atom37_idx = batched_gather( + restype_rigidgroup_base_atom37_idx, + aatype, + dim=-3, + no_batch_dims=batch_dims, + ) + + base_atom_pos = batched_gather( + all_atom_positions, + residx_rigidgroup_base_atom37_idx, + dim=-2, + no_batch_dims=len(all_atom_positions.shape[:-2]), + ) + + gt_frames = Rigid.from_3_points( + p_neg_x_axis=base_atom_pos[..., 0, :], + origin=base_atom_pos[..., 1, :], + p_xy_plane=base_atom_pos[..., 2, :], + eps=eps, + ) + + group_exists = batched_gather( + restype_rigidgroup_mask, + aatype, + dim=-2, + no_batch_dims=batch_dims, + ) + + gt_atoms_exist = batched_gather( + all_atom_mask, + residx_rigidgroup_base_atom37_idx, + dim=-1, + no_batch_dims=len(all_atom_mask.shape[:-1]), + ) + gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists + + rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device) + rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1)) + rots[..., 0, 0, 0] = -1 + rots[..., 0, 2, 2] = -1 + + rots = Rotation(rot_mats=rots) + gt_frames = gt_frames.compose(Rigid(rots, None)) + + restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros( + *((1,) * batch_dims), 21, 8 + ) + restype_rigidgroup_rots = torch.eye( + 3, dtype=all_atom_mask.dtype, device=aatype.device + ) + restype_rigidgroup_rots = torch.tile( + restype_rigidgroup_rots, + (*((1,) * batch_dims), 21, 8, 1, 1), + ) + + for resname, _ in rc.residue_atom_renaming_swaps.items(): + restype = rc.restype_order[rc.restype_3to1[resname]] + chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1) + restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1 + restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1 + restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1 + + residx_rigidgroup_is_ambiguous = batched_gather( + restype_rigidgroup_is_ambiguous, + aatype, + dim=-2, + no_batch_dims=batch_dims, + ) + + residx_rigidgroup_ambiguity_rot = batched_gather( + restype_rigidgroup_rots, + aatype, + dim=-4, + no_batch_dims=batch_dims, + ) + + residx_rigidgroup_ambiguity_rot = Rotation( + rot_mats=residx_rigidgroup_ambiguity_rot + ) + alt_gt_frames = gt_frames.compose( + Rigid(residx_rigidgroup_ambiguity_rot, None) + ) + + gt_frames_tensor = gt_frames.to_tensor_4x4() + alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4() + + protein["rigidgroups_gt_frames"] = gt_frames_tensor + protein["rigidgroups_gt_exists"] = gt_exists + protein["rigidgroups_group_exists"] = group_exists + protein["rigidgroups_group_is_ambiguous"] = residx_rigidgroup_is_ambiguous + protein["rigidgroups_alt_gt_frames"] = alt_gt_frames_tensor + + return protein + + +def get_chi_atom_indices(): + """Returns atom indices needed to compute chi angles for all residue types. + + Returns: + A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are + in the order specified in rc.restypes + unknown residue type + at the end. For chi angles which are not defined on the residue, the + positions indices are by default set to 0. + """ + chi_atom_indices = [] + for residue_name in rc.restypes: + residue_name = rc.restype_1to3[residue_name] + residue_chi_angles = rc.chi_angles_atoms[residue_name] + atom_indices = [] + for chi_angle in residue_chi_angles: + atom_indices.append([rc.atom_order[atom] for atom in chi_angle]) + for _ in range(4 - len(atom_indices)): + atom_indices.append( + [0, 0, 0, 0] + ) # For chi angles not defined on the AA. + chi_atom_indices.append(atom_indices) + + chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. + + return chi_atom_indices + + +@curry1 +def atom37_to_torsion_angles( + protein, + prefix="", +): + """ + Convert coordinates to torsion angles. + + This function is extremely sensitive to floating point imprecisions + and should be run with double precision whenever possible. + + Args: + Dict containing: + * (prefix)aatype: + [*, N_res] residue indices + * (prefix)all_atom_positions: + [*, N_res, 37, 3] atom positions (in atom37 + format) + * (prefix)all_atom_mask: + [*, N_res, 37] atom position mask + Returns: + The same dictionary updated with the following features: + + "(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2]) + Torsion angles + "(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2]) + Alternate torsion angles (accounting for 180-degree symmetry) + "(prefix)torsion_angles_mask" ([*, N_res, 7]) + Torsion angles mask + """ + aatype = protein[prefix + "aatype"] + all_atom_positions = protein[prefix + "all_atom_positions"] + all_atom_mask = protein[prefix + "all_atom_mask"] + + aatype = torch.clamp(aatype, max=20) + + pad = all_atom_positions.new_zeros( + [*all_atom_positions.shape[:-3], 1, 37, 3] + ) + prev_all_atom_positions = torch.cat( + [pad, all_atom_positions[..., :-1, :, :]], dim=-3 + ) + + pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37]) + prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2) + + pre_omega_atom_pos = torch.cat( + [prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]], + dim=-2, + ) + phi_atom_pos = torch.cat( + [prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]], + dim=-2, + ) + psi_atom_pos = torch.cat( + [all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]], + dim=-2, + ) + + pre_omega_mask = torch.prod( + prev_all_atom_mask[..., 1:3], dim=-1 + ) * torch.prod(all_atom_mask[..., :2], dim=-1) + phi_mask = prev_all_atom_mask[..., 2] * torch.prod( + all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype + ) + psi_mask = ( + torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype) + * all_atom_mask[..., 4] + ) + + chi_atom_indices = torch.as_tensor( + get_chi_atom_indices(), device=aatype.device + ) + + atom_indices = chi_atom_indices[..., aatype, :, :] + chis_atom_pos = batched_gather( + all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2]) + ) + + chi_angles_mask = list(rc.chi_angles_mask) + chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) + chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask) + + chis_mask = chi_angles_mask[aatype, :] + + chi_angle_atoms_mask = batched_gather( + all_atom_mask, + atom_indices, + dim=-1, + no_batch_dims=len(atom_indices.shape[:-2]), + ) + chi_angle_atoms_mask = torch.prod( + chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype + ) + chis_mask = chis_mask * chi_angle_atoms_mask + + torsions_atom_pos = torch.cat( + [ + pre_omega_atom_pos[..., None, :, :], + phi_atom_pos[..., None, :, :], + psi_atom_pos[..., None, :, :], + chis_atom_pos, + ], + dim=-3, + ) + + torsion_angles_mask = torch.cat( + [ + pre_omega_mask[..., None], + phi_mask[..., None], + psi_mask[..., None], + chis_mask, + ], + dim=-1, + ) + + torsion_frames = Rigid.from_3_points( + torsions_atom_pos[..., 1, :], + torsions_atom_pos[..., 2, :], + torsions_atom_pos[..., 0, :], + eps=1e-8, + ) + + fourth_atom_rel_pos = torsion_frames.invert().apply( + torsions_atom_pos[..., 3, :] + ) + + torsion_angles_sin_cos = torch.stack( + [fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1 + ) + + denom = torch.sqrt( + torch.sum( + torch.square(torsion_angles_sin_cos), + dim=-1, + dtype=torsion_angles_sin_cos.dtype, + keepdims=True, + ) + + 1e-8 + ) + torsion_angles_sin_cos = torsion_angles_sin_cos / denom + + torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor( + [1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0], + )[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)] + + chi_is_ambiguous = torsion_angles_sin_cos.new_tensor( + rc.chi_pi_periodic, + )[aatype, ...] + + mirror_torsion_angles = torch.cat( + [ + all_atom_mask.new_ones(*aatype.shape, 3), + 1.0 - 2.0 * chi_is_ambiguous, + ], + dim=-1, + ) + + alt_torsion_angles_sin_cos = ( + torsion_angles_sin_cos * mirror_torsion_angles[..., None] + ) + + protein[prefix + "torsion_angles_sin_cos"] = torsion_angles_sin_cos + protein[prefix + "alt_torsion_angles_sin_cos"] = alt_torsion_angles_sin_cos + protein[prefix + "torsion_angles_mask"] = torsion_angles_mask + + return protein + + +def get_backbone_frames(protein): + # DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why. + protein["backbone_rigid_tensor"] = protein["rigidgroups_gt_frames"][ + ..., 0, :, : + ] + protein["backbone_rigid_mask"] = protein["rigidgroups_gt_exists"][..., 0] + + return protein + + +def get_chi_angles(protein): + dtype = protein["all_atom_mask"].dtype + protein["chi_angles_sin_cos"] = ( + protein["torsion_angles_sin_cos"][..., 3:, :] + ).to(dtype) + protein["chi_mask"] = protein["torsion_angles_mask"][..., 3:].to(dtype) + + return protein + + +@curry1 +def random_crop_to_size( + protein, + crop_size, + shape_schema, + seed=None, +): + """Crop randomly to `crop_size`, or keep as is if shorter than that.""" + # We want each ensemble to be cropped the same way + + g = None + if seed is not None: + g = torch.Generator(device=protein["seq_length"].device) + g.manual_seed(seed) + + seq_length = protein["seq_length"] + + num_res_crop_size = min(int(seq_length), crop_size) + + def _randint(lower, upper): + return int(torch.randint( + lower, + upper + 1, + (1,), + device=protein["seq_length"].device, + generator=g, + )[0]) + + n = seq_length - num_res_crop_size + if "use_clamped_fape" in protein and protein["use_clamped_fape"] == 1.: + right_anchor = n + else: + x = _randint(0, n) + right_anchor = n - x + + num_res_crop_start = _randint(0, right_anchor) + + for k, v in protein.items(): + if k not in shape_schema or (NUM_RES not in shape_schema[k]): + continue + + slices = [] + for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)): + is_num_res = dim_size == NUM_RES + crop_start = num_res_crop_start if is_num_res else 0 + crop_size = num_res_crop_size if is_num_res else dim + slices.append(slice(crop_start, crop_start + crop_size)) + protein[k] = v[slices] + + protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size) + + return protein + diff --git a/dockformer/data/errors.py b/dockformer/data/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..7809d5ae29a736ccc7688d555a06a5301a3ab46f --- /dev/null +++ b/dockformer/data/errors.py @@ -0,0 +1,22 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""General-purpose errors used throughout the data pipeline""" +class Error(Exception): + """Base class for exceptions.""" + + +class MultipleChainsError(Error): + """An error indicating that multiple chains were found for a given ID.""" diff --git a/dockformer/data/ligand_features.py b/dockformer/data/ligand_features.py new file mode 100644 index 0000000000000000000000000000000000000000..49e34e408d8ff599ed51dff5ee57fbff465f308f --- /dev/null +++ b/dockformer/data/ligand_features.py @@ -0,0 +1,66 @@ +import os +import numpy as np +import torch +from torch import nn +from rdkit import Chem + +from dockformer.data.utils import FeatureTensorDict +from dockformer.utils.consts import POSSIBLE_BOND_TYPES, POSSIBLE_ATOM_TYPES, POSSIBLE_CHARGES, POSSIBLE_CHIRALITIES + + +def get_atom_features(atom: Chem.Atom): + # TODO: this is temporary, we need to add more features, for example for Zn + if atom.GetSymbol() not in POSSIBLE_ATOM_TYPES: + print(f"********Unknown atom type {atom.GetSymbol()}") + atom_type = POSSIBLE_ATOM_TYPES.index("Ni") + else: + atom_type = POSSIBLE_ATOM_TYPES.index(atom.GetSymbol()) + atom_charge = POSSIBLE_CHARGES.index(max(min(atom.GetFormalCharge(), 1), -1)) + atom_chirality = POSSIBLE_CHIRALITIES.index(atom.GetChiralTag()) + + return {"atom_type": atom_type, "atom_charge": atom_charge, "atom_chirality": atom_chirality} + + +def get_bond_features(bond: Chem.Bond): + bond_type = POSSIBLE_BOND_TYPES.index(bond.GetBondType()) + return {"bond_type": bond_type} + + +def make_ligand_features(ligand: Chem.Mol) -> FeatureTensorDict: + atoms_features = [] + atom_idx_to_atom_pos_idx = {} + for atom in ligand.GetAtoms(): + atom_idx_to_atom_pos_idx[atom.GetIdx()] = len(atoms_features) + atoms_features.append(get_atom_features(atom)) + + atom_types = torch.tensor(np.array([atom["atom_type"] for atom in atoms_features], dtype=np.int64)) + atom_types_one_hot = nn.functional.one_hot(atom_types, num_classes=len(POSSIBLE_ATOM_TYPES), ) + atom_charges = torch.tensor(np.array([atom["atom_charge"] for atom in atoms_features], dtype=np.int64)) + atom_charges_one_hot = nn.functional.one_hot(atom_charges, num_classes=len(POSSIBLE_CHARGES)) + atom_chiralities = torch.tensor(np.array([atom["atom_chirality"] for atom in atoms_features], dtype=np.int64)) + atom_chiralities_one_hot = nn.functional.one_hot(atom_chiralities, num_classes=len(POSSIBLE_CHIRALITIES)) + + ligand_target_feat = torch.cat([atom_types_one_hot.float(), atom_charges_one_hot.float(), + atom_chiralities_one_hot.float()], dim=1) + + # create one-hot matrix encoding for bonds + ligand_bonds_feat = torch.zeros((len(atoms_features), len(atoms_features), len(POSSIBLE_BOND_TYPES))) + ligand_bonds = [] + for bond in ligand.GetBonds(): + atom1_idx = atom_idx_to_atom_pos_idx[bond.GetBeginAtomIdx()] + atom2_idx = atom_idx_to_atom_pos_idx[bond.GetEndAtomIdx()] + bond_features = get_bond_features(bond) + ligand_bonds.append((atom1_idx, atom2_idx, bond_features["bond_type"])) + ligand_bonds_feat[atom1_idx, atom2_idx, bond_features["bond_type"]] = 1 + + return { + # These are used for reconstruction at the end of the pipeline + "ligand_atype": atom_types, + "ligand_charge": atom_charges, + "ligand_chirality": atom_chiralities, + "ligand_bonds": torch.tensor(ligand_bonds, dtype=torch.int64), + # these are the actual features + "ligand_target_feat": ligand_target_feat.float(), + "ligand_bonds_feat": ligand_bonds_feat.float(), + } + diff --git a/dockformer/data/parsers.py b/dockformer/data/parsers.py new file mode 100644 index 0000000000000000000000000000000000000000..c7db4a6dc27d9148d2473900ea9c01834b8b59f4 --- /dev/null +++ b/dockformer/data/parsers.py @@ -0,0 +1,53 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for parsing various file formats.""" +import collections +import dataclasses +import itertools +import re +import string +from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set + + +def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: + """Parses FASTA string and returns list of strings with amino-acid sequences. + + Arguments: + fasta_string: The string contents of a FASTA file. + + Returns: + A tuple of two lists: + * A list of sequences. + * A list of sequence descriptions taken from the comment lines. In the + same order as the sequences. + """ + sequences = [] + descriptions = [] + index = -1 + for line in fasta_string.splitlines(): + line = line.strip() + if line.startswith(">"): + index += 1 + descriptions.append(line[1:]) # Remove the '>' at the beginning. + sequences.append("") + continue + elif line.startswith("#"): + continue + elif not line: + continue # Skip blank lines. + sequences[index] += line + + return sequences, descriptions diff --git a/dockformer/data/protein_features.py b/dockformer/data/protein_features.py new file mode 100644 index 0000000000000000000000000000000000000000..bc5ba1b8f4c756ca2a2d77c12c7c2da99ea3170a --- /dev/null +++ b/dockformer/data/protein_features.py @@ -0,0 +1,71 @@ +import numpy as np + +from dockformer.data.utils import FeatureDict +from dockformer.utils import residue_constants, protein + + +def _make_sequence_features(sequence: str, description: str, num_res: int) -> FeatureDict: + """Construct a feature dict of sequence features.""" + features = {} + features["aatype"] = residue_constants.sequence_to_onehot( + sequence=sequence, + mapping=residue_constants.restype_order_with_x, + map_unknown_to_x=True, + ) + features["domain_name"] = np.array( + [description.encode("utf-8")], dtype=object + ) + # features["residue_index"] = np.array(range(num_res), dtype=np.int32) + features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32) + features["sequence"] = np.array( + [sequence.encode("utf-8")], dtype=object + ) + return features + + +def _aatype_to_str_sequence(aatype): + return ''.join([ + residue_constants.restypes_with_x[aatype[i]] + for i in range(len(aatype)) + ]) + + +def _make_protein_structure_features(protein_object: protein.Protein) -> FeatureDict: + pdb_feats = {} + + all_atom_positions = protein_object.atom_positions + all_atom_mask = protein_object.atom_mask + + pdb_feats["all_atom_positions"] = all_atom_positions.astype(np.float32) + pdb_feats["all_atom_mask"] = all_atom_mask.astype(np.float32) + pdb_feats["in_chain_residue_index"] = protein_object.residue_index.astype(np.int32) + + gapped_res_indexes = [] + prev_chain_index = protein_object.chain_index[0] + chain_start_res_ind = 0 + for relative_res_ind, chain_index in zip(protein_object.residue_index, protein_object.chain_index): + if chain_index != prev_chain_index: + chain_start_res_ind = gapped_res_indexes[-1] + 50 + prev_chain_index = chain_index + gapped_res_indexes.append(relative_res_ind + chain_start_res_ind) + + pdb_feats["residue_index"] = np.array(gapped_res_indexes).astype(np.int32) + pdb_feats["chain_index"] = np.array(protein_object.chain_index).astype(np.int32) + pdb_feats["resolution"] = np.array([0.]).astype(np.float32) + + return pdb_feats + + +def make_protein_features(protein_object: protein.Protein, description: str) -> FeatureDict: + feats = {} + aatype = protein_object.aatype + sequence = _aatype_to_str_sequence(aatype) + feats.update( + _make_sequence_features(sequence=sequence, description=description, num_res=len(protein_object.aatype)) + ) + + feats.update( + _make_protein_structure_features(protein_object=protein_object) + ) + + return feats diff --git a/dockformer/data/utils.py b/dockformer/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ba4cf1534b7e605f1233447fa11a5e7d374f65b6 --- /dev/null +++ b/dockformer/data/utils.py @@ -0,0 +1,54 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common utilities for data pipeline tools.""" +import contextlib +import datetime +import logging +import shutil +import tempfile +import time +from typing import Optional, Mapping, Dict + +import numpy as np +import torch + +FeatureDict = Dict[str, np.ndarray] +FeatureTensorDict = Dict[str, torch.Tensor] + + +@contextlib.contextmanager +def tmpdir_manager(base_dir: Optional[str] = None): + """Context manager that deletes a temporary directory on exit.""" + tmpdir = tempfile.mkdtemp(dir=base_dir) + try: + yield tmpdir + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +@contextlib.contextmanager +def timing(msg: str): + logging.info("Started %s", msg) + tic = time.perf_counter() + yield + toc = time.perf_counter() + logging.info("Finished %s in %.3f seconds", msg, toc - tic) + + +def to_date(s: str): + return datetime.datetime( + year=int(s[:4]), month=int(s[5:7]), day=int(s[8:10]) + ) diff --git a/dockformer/model/__init__.py b/dockformer/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dockformer/model/dropout.py b/dockformer/model/dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..c89d0a835c736ae73c08cd34403f3dee586bd9c0 --- /dev/null +++ b/dockformer/model/dropout.py @@ -0,0 +1,69 @@ +# Copyright 2021 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +from functools import partialmethod +from typing import Union, List + + +class Dropout(nn.Module): + """ + Implementation of dropout with the ability to share the dropout mask + along a particular dimension. + + If not in training mode, this module computes the identity function. + """ + + def __init__(self, r: float, batch_dim: Union[int, List[int]]): + """ + Args: + r: + Dropout rate + batch_dim: + Dimension(s) along which the dropout mask is shared + """ + super(Dropout, self).__init__() + + self.r = r + if type(batch_dim) == int: + batch_dim = [batch_dim] + self.batch_dim = batch_dim + self.dropout = nn.Dropout(self.r) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + Tensor to which dropout is applied. Can have any shape + compatible with self.batch_dim + """ + shape = list(x.shape) + if self.batch_dim is not None: + for bd in self.batch_dim: + shape[bd] = 1 + mask = x.new_ones(shape) + mask = self.dropout(mask) + x *= mask + return x + + +class DropoutRowwise(Dropout): + """ + Convenience class for rowwise dropout as described in subsection + 1.11.6. + """ + + __init__ = partialmethod(Dropout.__init__, batch_dim=-3) diff --git a/dockformer/model/embedders.py b/dockformer/model/embedders.py new file mode 100644 index 0000000000000000000000000000000000000000..a287bf9023983f7ad804c09b4192b24ad69ec182 --- /dev/null +++ b/dockformer/model/embedders.py @@ -0,0 +1,346 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import torch +import torch.nn as nn +from typing import Tuple, Optional + +from dockformer.model.primitives import Linear, LayerNorm +from dockformer.utils.tensor_utils import add + + +class StructureInputEmbedder(nn.Module): + """ + Embeds a subset of the input features. + + Implements a merge of Algorithms 3 and Algorithm 32. + """ + + def __init__( + self, + protein_tf_dim: int, + ligand_tf_dim: int, + additional_tf_dim: int, + ligand_bond_dim: int, + c_z: int, + c_m: int, + relpos_k: int, + prot_min_bin: float, + prot_max_bin: float, + prot_no_bins: int, + lig_min_bin: float, + lig_max_bin: float, + lig_no_bins: int, + inf: float = 1e8, + **kwargs, + ): + """ + Args: + tf_dim: + Final dimension of the target features + c_z: + Pair embedding dimension + c_m: + Single embedding dimension + relpos_k: + Window size used in relative positional encoding + """ + super(StructureInputEmbedder, self).__init__() + + self.tf_dim = protein_tf_dim + ligand_tf_dim + additional_tf_dim + self.pair_tf_dim = ligand_bond_dim + + self.c_z = c_z + self.c_m = c_m + + self.linear_tf_z_i = Linear(self.tf_dim, c_z) + self.linear_tf_z_j = Linear(self.tf_dim, c_z) + self.linear_tf_m = Linear(self.tf_dim, c_m) + + self.ligand_linear_bond_z = Linear(ligand_bond_dim, c_z) + + # RPE stuff + self.relpos_k = relpos_k + self.no_bins = 2 * relpos_k + 1 + self.linear_relpos = Linear(self.no_bins, c_z) + + # Recycling stuff + self.prot_min_bin = prot_min_bin + self.prot_max_bin = prot_max_bin + self.prot_no_bins = prot_no_bins + self.lig_min_bin = lig_min_bin + self.lig_max_bin = lig_max_bin + self.lig_no_bins = lig_no_bins + self.inf = inf + + self.prot_recycling_linear = Linear(self.prot_no_bins + 1, self.c_z) + self.lig_recycling_linear = Linear(self.lig_no_bins, self.c_z) + self.layer_norm_m = LayerNorm(self.c_m) + self.layer_norm_z = LayerNorm(self.c_z) + + def relpos(self, ri: torch.Tensor): + """ + Computes relative positional encodings + + Implements Algorithm 4. + + Args: + ri: + "residue_index" features of shape [*, N] + """ + d = ri[..., None] - ri[..., None, :] + boundaries = torch.arange( + start=-self.relpos_k, end=self.relpos_k + 1, device=d.device + ) + reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),)) + d = d[..., None] - reshaped_bins + d = torch.abs(d) + d = torch.argmin(d, dim=-1) + d = nn.functional.one_hot(d, num_classes=len(boundaries)).float() + d = d.to(ri.dtype) + return self.linear_relpos(d) + + def _get_binned_distogram(self, x, min_bin, max_bin, no_bins, recycling_linear, prot_distogram_mask=None): + # This squared method might become problematic in FP16 mode. + bins = torch.linspace( + min_bin, + max_bin, + no_bins, + dtype=x.dtype, + device=x.device, + requires_grad=False, + ) + squared_bins = bins ** 2 + upper = torch.cat( + [squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1 + ) + d = torch.sum((x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True) + + # [*, N, N, no_bins] + d = ((d > squared_bins) * (d < upper)).type(x.dtype) + # print("d shape", d.shape, d[0][0][:10]) + + if prot_distogram_mask is not None: + expanded_d = torch.cat([d, torch.zeros(*d.shape[:-1], 1, device=d.device)], dim=-1) + + # Step 2: Create a mask where `input_positions_masked` is 0 + # Use broadcasting and tensor operations directly without additional variables + input_positions_mask = (prot_distogram_mask == 1).float() # Shape [N, crop_size] + mask_i = input_positions_mask.unsqueeze(2) # Shape [N, crop_size, 1] + mask_j = input_positions_mask.unsqueeze(1) # Shape [N, 1, crop_size] + + # Step 3: Combine masks for both [N, :, i, :] and [N, i, :, :] + combined_mask = mask_i + mask_j # Shape [N, crop_size, crop_size] + combined_mask = combined_mask.clamp(max=1) # Ensure binary mask + + # Step 4: Apply the mask + # a. Set all but the last position in the `no_bins + 1` dimension to 0 where the mask is 1 + expanded_d[..., :-1] *= (1 - combined_mask).unsqueeze(-1) # Shape [N, crop_size, crop_size, no_bins] + + # print("expanded_d shape1", expanded_d.shape, expanded_d[0][0][:10]) + + # b. Set the last position in the `no_bins + 1` dimension to 1 where the mask is 1 + expanded_d[..., -1] += combined_mask # Shape [N, crop_size, crop_size, 1] + d = expanded_d + # print("expanded_d shape2", d.shape, d[0][0][:10]) + + return recycling_linear(d) + + def forward( + self, + token_mask: torch.Tensor, + protein_mask: torch.Tensor, + ligand_mask: torch.Tensor, + target_feat: torch.Tensor, + ligand_bonds_feat: torch.Tensor, + input_positions: torch.Tensor, + protein_residue_index: torch.Tensor, + protein_distogram_mask: torch.Tensor, + inplace_safe: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + batch: Dict containing + "protein_target_feat": + Features of shape [*, N_res + N_lig_atoms, tf_dim] + "residue_index": + Features of shape [*, N_res] + input_protein_coords: + [*, N_res, 3] AF predicted C_beta coordinates supplied as input + ligand_bonds_feat: + [*, N_lig_atoms, N_lig_atoms, tf_dim] ligand bonds features + Returns: + single_emb: + [*, N_res + N_lig_atoms, C_m] single embedding + pair_emb: + [*, N_res + N_lig_atoms, N_res + N_lig_atoms, C_z] pair embedding + + """ + device = token_mask.device + pair_protein_mask = protein_mask[..., None] * protein_mask[..., None, :] + pair_ligand_mask = ligand_mask[..., None] * ligand_mask[..., None, :] + + # Single representation embedding - Algorithm 3 + tf_m = self.linear_tf_m(target_feat) + tf_m = self.layer_norm_m(tf_m) # previously this happend in the do_recycle function + + # Pair representation + # protein pair embedding - Algorithm 3 + # [*, N_res, c_z] + tf_emb_i = self.linear_tf_z_i(target_feat) + tf_emb_j = self.linear_tf_z_j(target_feat) + + pair_emb = torch.zeros(*pair_protein_mask.shape, self.c_z, device=device) + pair_emb = add(pair_emb, tf_emb_i[..., None, :], inplace=inplace_safe) + pair_emb = add(pair_emb, tf_emb_j[..., None, :, :], inplace=inplace_safe) + + # Apply relpos + relpos = self.relpos(protein_residue_index.type(tf_emb_i.dtype)) + pair_emb += pair_protein_mask[..., None] * relpos + + del relpos + + # apply ligand bonds + ligand_bonds = self.ligand_linear_bond_z(ligand_bonds_feat) + pair_emb += pair_ligand_mask[..., None] * ligand_bonds + + del ligand_bonds + + # before recycles, do z_norm, this previously was a part of the recycles + pair_emb = self.layer_norm_z(pair_emb) + + # apply protein recycle + prot_distogram_embed = self._get_binned_distogram(input_positions, self.prot_min_bin, self.prot_max_bin, + self.prot_no_bins, self.prot_recycling_linear, + protein_distogram_mask) + + + pair_emb = add(pair_emb, prot_distogram_embed * pair_protein_mask.unsqueeze(-1), inplace_safe) + + del prot_distogram_embed + + # apply ligand recycle + lig_distogram_embed = self._get_binned_distogram(input_positions, self.lig_min_bin, self.lig_max_bin, + self.lig_no_bins, self.lig_recycling_linear) + pair_emb = add(pair_emb, lig_distogram_embed * pair_ligand_mask.unsqueeze(-1), inplace_safe) + + del lig_distogram_embed + + return tf_m, pair_emb + + +class RecyclingEmbedder(nn.Module): + """ + Embeds the output of an iteration of the model for recycling. + + Implements Algorithm 32. + """ + def __init__( + self, + c_m: int, + c_z: int, + min_bin: float, + max_bin: float, + no_bins: int, + inf: float = 1e8, + **kwargs, + ): + """ + Args: + c_m: + Single channel dimension + c_z: + Pair embedding channel dimension + min_bin: + Smallest distogram bin (Angstroms) + max_bin: + Largest distogram bin (Angstroms) + no_bins: + Number of distogram bins + """ + super(RecyclingEmbedder, self).__init__() + + self.c_m = c_m + self.c_z = c_z + self.min_bin = min_bin + self.max_bin = max_bin + self.no_bins = no_bins + self.inf = inf + + self.linear = Linear(self.no_bins, self.c_z) + self.layer_norm_m = LayerNorm(self.c_m) + self.layer_norm_z = LayerNorm(self.c_z) + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + x: torch.Tensor, + inplace_safe: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + m: + First row of the single embedding. [*, N_res, C_m] + z: + [*, N_res, N_res, C_z] pair embedding + x: + [*, N_res, 3] predicted C_beta coordinates + Returns: + m: + [*, N_res, C_m] single embedding update + z: + [*, N_res, N_res, C_z] pair embedding update + """ + # [*, N, C_m] + m_update = self.layer_norm_m(m) + if(inplace_safe): + m.copy_(m_update) + m_update = m + + # [*, N, N, C_z] + z_update = self.layer_norm_z(z) + if(inplace_safe): + z.copy_(z_update) + z_update = z + + # This squared method might become problematic in FP16 mode. + bins = torch.linspace( + self.min_bin, + self.max_bin, + self.no_bins, + dtype=x.dtype, + device=x.device, + requires_grad=False, + ) + squared_bins = bins ** 2 + upper = torch.cat( + [squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1 + ) + d = torch.sum( + (x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True + ) + + # [*, N, N, no_bins] + d = ((d > squared_bins) * (d < upper)).type(x.dtype) + + # [*, N, N, C_z] + d = self.linear(d) + z_update = add(z_update, d, inplace_safe) + + return m_update, z_update + diff --git a/dockformer/model/evoformer.py b/dockformer/model/evoformer.py new file mode 100644 index 0000000000000000000000000000000000000000..bfe2aa732c792834fc4790311e0fc122d5af53dd --- /dev/null +++ b/dockformer/model/evoformer.py @@ -0,0 +1,468 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import sys +import torch +import torch.nn as nn +from typing import Tuple, Sequence, Optional +from functools import partial +from abc import ABC, abstractmethod + +from dockformer.model.primitives import Linear, LayerNorm +from dockformer.model.dropout import DropoutRowwise +from dockformer.model.single_attention import SingleRowAttentionWithPairBias + +from dockformer.model.pair_transition import PairTransition +from dockformer.model.triangular_attention import ( + TriangleAttention, +) +from dockformer.model.triangular_multiplicative_update import ( + TriangleMultiplicationOutgoing, + TriangleMultiplicationIncoming, +) +from dockformer.utils.checkpointing import checkpoint_blocks +from dockformer.utils.tensor_utils import add + + +class SingleRepTransition(nn.Module): + """ + Feed-forward network applied to single representation activations after attention. + + Implements Algorithm 9 + """ + def __init__(self, c_m, n): + """ + Args: + c_m: + channel dimension + n: + Factor multiplied to c_m to obtain the hidden channel dimension + """ + super(SingleRepTransition, self).__init__() + + self.c_m = c_m + self.n = n + + self.layer_norm = LayerNorm(self.c_m) + self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu") + self.relu = nn.ReLU() + self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final") + + def _transition(self, m, mask): + m = self.layer_norm(m) + m = self.linear_1(m) + m = self.relu(m) + m = self.linear_2(m) * mask + return m + + def forward( + self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + m: + [*, N_res, C_m] activation after attention + mask: + [*, N_res, C_m] mask + Returns: + m: + [*, N_res, C_m] activation update + """ + # DISCREPANCY: DeepMind forgets to apply the mask here. + if mask is None: + mask = m.new_ones(m.shape[:-1]) + + mask = mask.unsqueeze(-1) + + m = self._transition(m, mask) + + return m + + +class PairStack(nn.Module): + def __init__( + self, + c_z: int, + c_hidden_mul: int, + c_hidden_pair_att: int, + no_heads_pair: int, + transition_n: int, + pair_dropout: float, + inf: float, + eps: float + ): + super(PairStack, self).__init__() + + self.tri_mul_out = TriangleMultiplicationOutgoing( + c_z, + c_hidden_mul, + ) + self.tri_mul_in = TriangleMultiplicationIncoming( + c_z, + c_hidden_mul, + ) + + self.tri_att_start = TriangleAttention( + c_z, + c_hidden_pair_att, + no_heads_pair, + inf=inf, + ) + self.tri_att_end = TriangleAttention( + c_z, + c_hidden_pair_att, + no_heads_pair, + inf=inf, + ) + + self.pair_transition = PairTransition( + c_z, + transition_n, + ) + + self.ps_dropout_row_layer = DropoutRowwise(pair_dropout) + + def forward(self, + z: torch.Tensor, + pair_mask: torch.Tensor, + use_lma: bool = False, + inplace_safe: bool = False, + _mask_trans: bool = True, + ) -> torch.Tensor: + # DeepMind doesn't mask these transitions in the source, so _mask_trans + # should be disabled to better approximate the exact activations of + # the original. + pair_trans_mask = pair_mask if _mask_trans else None + + tmu_update = self.tri_mul_out( + z, + mask=pair_mask, + inplace_safe=inplace_safe, + _add_with_inplace=True, + ) + if (not inplace_safe): + z = z + self.ps_dropout_row_layer(tmu_update) + else: + z = tmu_update + + del tmu_update + + tmu_update = self.tri_mul_in( + z, + mask=pair_mask, + inplace_safe=inplace_safe, + _add_with_inplace=True, + ) + if (not inplace_safe): + z = z + self.ps_dropout_row_layer(tmu_update) + else: + z = tmu_update + + del tmu_update + + z = add(z, + self.ps_dropout_row_layer( + self.tri_att_start( + z, + mask=pair_mask, + use_memory_efficient_kernel=False, + use_lma=use_lma, + ) + ), + inplace=inplace_safe, + ) + + z = z.transpose(-2, -3) + if (inplace_safe): + z = z.contiguous() + + z = add(z, + self.ps_dropout_row_layer( + self.tri_att_end( + z, + mask=pair_mask.transpose(-1, -2), + use_memory_efficient_kernel=False, + use_lma=use_lma, + ) + ), + inplace=inplace_safe, + ) + + z = z.transpose(-2, -3) + if (inplace_safe): + z = z.contiguous() + + z = add(z, + self.pair_transition( + z, mask=pair_trans_mask, + ), + inplace=inplace_safe, + ) + + return z + + +class EvoformerBlock(nn.Module, ABC): + def __init__(self, + c_m: int, + c_z: int, + c_hidden_single_att: int, + c_hidden_mul: int, + c_hidden_pair_att: int, + no_heads_single: int, + no_heads_pair: int, + transition_n: int, + single_dropout: float, + pair_dropout: float, + inf: float, + eps: float, + ): + super(EvoformerBlock, self).__init__() + + self.single_att_row = SingleRowAttentionWithPairBias( + c_m=c_m, + c_z=c_z, + c_hidden=c_hidden_single_att, + no_heads=no_heads_single, + inf=inf, + ) + + self.single_dropout_layer = DropoutRowwise(single_dropout) + + self.single_transition = SingleRepTransition( + c_m=c_m, + n=transition_n, + ) + + self.pair_stack = PairStack( + c_z=c_z, + c_hidden_mul=c_hidden_mul, + c_hidden_pair_att=c_hidden_pair_att, + no_heads_pair=no_heads_pair, + transition_n=transition_n, + pair_dropout=pair_dropout, + inf=inf, + eps=eps + ) + + def forward(self, + m: Optional[torch.Tensor], + z: Optional[torch.Tensor], + single_mask: torch.Tensor, + pair_mask: torch.Tensor, + use_lma: bool = False, + inplace_safe: bool = False, + _mask_trans: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + single_trans_mask = single_mask if _mask_trans else None + + input_tensors = [m, z] + + m, z = input_tensors + + z = self.pair_stack( + z=z, + pair_mask=pair_mask, + use_lma=use_lma, + inplace_safe=inplace_safe, + _mask_trans=_mask_trans, + ) + + m = add(m, + self.single_dropout_layer( + self.single_att_row( + m, + z=z, + mask=single_mask, + use_memory_efficient_kernel=False, + use_lma=use_lma, + ) + ), + inplace=inplace_safe, + ) + + m = add(m, self.single_transition(m, mask=single_mask), inplace=inplace_safe) + + return m, z + + +class EvoformerStack(nn.Module): + """ + Main Evoformer trunk. + + Implements Algorithm 6. + """ + + def __init__( + self, + c_m: int, + c_z: int, + c_hidden_single_att: int, + c_hidden_mul: int, + c_hidden_pair_att: int, + c_s: int, + no_heads_single: int, + no_heads_pair: int, + no_blocks: int, + transition_n: int, + single_dropout: float, + pair_dropout: float, + blocks_per_ckpt: int, + inf: float, + eps: float, + clear_cache_between_blocks: bool = False, + **kwargs, + ): + """ + Args: + c_m: + single channel dimension + c_z: + Pair channel dimension + c_hidden_single_att: + Hidden dimension in single representation attention + c_hidden_mul: + Hidden dimension in multiplicative updates + c_hidden_pair_att: + Hidden dimension in triangular attention + c_s: + Channel dimension of the output "single" embedding + no_heads_single: + Number of heads used for single attention + no_heads_pair: + Number of heads used for pair attention + no_blocks: + Number of Evoformer blocks in the stack + transition_n: + Factor by which to multiply c_m to obtain the SingleTransition + hidden dimension + single_dropout: + Dropout rate for single activations + pair_dropout: + Dropout used for pair activations + blocks_per_ckpt: + Number of Evoformer blocks in each activation checkpoint + clear_cache_between_blocks: + Whether to clear CUDA's GPU memory cache between blocks of the + stack. Slows down each block but can reduce fragmentation + """ + super(EvoformerStack, self).__init__() + + self.blocks_per_ckpt = blocks_per_ckpt + self.clear_cache_between_blocks = clear_cache_between_blocks + + self.blocks = nn.ModuleList() + + for _ in range(no_blocks): + block = EvoformerBlock( + c_m=c_m, + c_z=c_z, + c_hidden_single_att=c_hidden_single_att, + c_hidden_mul=c_hidden_mul, + c_hidden_pair_att=c_hidden_pair_att, + no_heads_single=no_heads_single, + no_heads_pair=no_heads_pair, + transition_n=transition_n, + single_dropout=single_dropout, + pair_dropout=pair_dropout, + inf=inf, + eps=eps, + ) + self.blocks.append(block) + + self.linear = Linear(c_m, c_s) + + def _prep_blocks(self, + use_lma: bool, + single_mask: Optional[torch.Tensor], + pair_mask: Optional[torch.Tensor], + inplace_safe: bool, + _mask_trans: bool, + ): + blocks = [ + partial( + b, + single_mask=single_mask, + pair_mask=pair_mask, + use_lma=use_lma, + inplace_safe=inplace_safe, + _mask_trans=_mask_trans, + ) + for b in self.blocks + ] + + if self.clear_cache_between_blocks: + def block_with_cache_clear(block, *args, **kwargs): + torch.cuda.empty_cache() + return block(*args, **kwargs) + + blocks = [partial(block_with_cache_clear, b) for b in blocks] + + return blocks + + def forward(self, + m: torch.Tensor, + z: torch.Tensor, + single_mask: torch.Tensor, + pair_mask: torch.Tensor, + use_lma: bool = False, + inplace_safe: bool = False, + _mask_trans: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + m: + [*, N_res, C_m] single embedding + z: + [*, N_res, N_res, C_z] pair embedding + single_mask: + [*, N_res] single mask + pair_mask: + [*, N_res, N_res] pair mask + use_lma: + Whether to use low-memory attention during inference. + + Returns: + m: + [*, N_res, C_m] single embedding + z: + [*, N_res, N_res, C_z] pair embedding + s: + [*, N_res, C_s] single embedding after linear layer + """ + blocks = self._prep_blocks( + use_lma=use_lma, + single_mask=single_mask, + pair_mask=pair_mask, + inplace_safe=inplace_safe, + _mask_trans=_mask_trans, + ) + + blocks_per_ckpt = self.blocks_per_ckpt + if(not torch.is_grad_enabled()): + blocks_per_ckpt = None + + m, z = checkpoint_blocks( + blocks, + args=(m, z), + blocks_per_ckpt=blocks_per_ckpt, + ) + + s = self.linear(m) + + return m, z, s diff --git a/dockformer/model/heads.py b/dockformer/model/heads.py new file mode 100644 index 0000000000000000000000000000000000000000..0d1a73cf140fa24301243851a997de0224f08e3c --- /dev/null +++ b/dockformer/model/heads.py @@ -0,0 +1,260 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from torch.nn import Parameter + +from dockformer.model.primitives import Linear, LayerNorm +from dockformer.utils.loss import ( + compute_plddt, + compute_tm, + compute_predicted_aligned_error, +) +from dockformer.utils.precision_utils import is_fp16_enabled + + +class AuxiliaryHeads(nn.Module): + def __init__(self, config): + super(AuxiliaryHeads, self).__init__() + + self.plddt = PerResidueLDDTCaPredictor( + **config["lddt"], + ) + + self.distogram = DistogramHead( + **config["distogram"], + ) + + self.affinity_2d = Affinity2DPredictor( + **config["affinity_2d"], + ) + + self.affinity_1d = Affinity1DPredictor( + **config["affinity_1d"], + ) + + self.affinity_cls = AffinityClsTokenPredictor( + **config["affinity_cls"], + ) + + self.binding_site = BindingSitePredictor( + **config["binding_site"], + ) + + self.inter_contact = InterContactHead( + **config["inter_contact"], + ) + + self.config = config + + def forward(self, outputs, inter_mask, affinity_mask): + aux_out = {} + lddt_logits = self.plddt(outputs["sm"]["single"]) + aux_out["lddt_logits"] = lddt_logits + + # Required for relaxation later on + aux_out["plddt"] = compute_plddt(lddt_logits) + + distogram_logits = self.distogram(outputs["pair"]) + aux_out["distogram_logits"] = distogram_logits + + aux_out["inter_contact_logits"] = self.inter_contact(outputs["single"], outputs["pair"]) + + aux_out["affinity_2d_logits"] = self.affinity_2d(outputs["pair"], aux_out["inter_contact_logits"], inter_mask) + + aux_out["affinity_1d_logits"] = self.affinity_1d(outputs["single"]) + + aux_out["affinity_cls_logits"] = self.affinity_cls(outputs["single"], affinity_mask) + + aux_out["binding_site_logits"] = self.binding_site(outputs["single"]) + + return aux_out + + +class Affinity2DPredictor(nn.Module): + def __init__(self, c_z, num_bins): + super(Affinity2DPredictor, self).__init__() + + self.c_z = c_z + + self.weight_linear = Linear(self.c_z + 1, 1) + self.embed_linear = Linear(self.c_z, self.c_z) + self.bins_linear = Linear(self.c_z, num_bins) + + def forward(self, z, inter_contacts_logits, inter_pair_mask): + z_with_inter_contacts = torch.cat((z, inter_contacts_logits), dim=-1) # [*, N, N, c_z + 1] + weights = self.weight_linear(z_with_inter_contacts) # [*, N, N, 1] + + x = self.embed_linear(z) # [*, N, N, c_z] + batch_size, N, M, _ = x.shape + + flat_weights = weights.reshape(batch_size, N*M, -1) # [*, N*M, 1] + flat_x = x.reshape(batch_size, N*M, -1) # [*, N*M, c_z] + flat_inter_pair_mask = inter_pair_mask.reshape(batch_size, N*M, 1) + + flat_weights = flat_weights.masked_fill(~(flat_inter_pair_mask.bool()), float('-inf')) # [*, N*N, 1] + flat_weights = torch.nn.functional.softmax(flat_weights, dim=1) # [*, N*N, 1] + flat_weights = torch.nan_to_num(flat_weights, nan=0.0) # [*, N*N, 1] + weighted_sum = torch.sum((flat_weights * flat_x).reshape(batch_size, N*M, -1), dim=1) # [*, c_z] + + return self.bins_linear(weighted_sum) + + +class Affinity1DPredictor(nn.Module): + def __init__(self, c_s, num_bins, **kwargs): + super(Affinity1DPredictor, self).__init__() + + self.c_s = c_s + + self.linear1 = Linear(self.c_s, self.c_s, init="final") + + self.linear2 = Linear(self.c_s, num_bins, init="final") + + def forward(self, s): + # [*, N, C_out] + s = self.linear1(s) + + # get an average over the sequence + s = torch.mean(s, dim=1) + + logits = self.linear2(s) + return logits + + +class AffinityClsTokenPredictor(nn.Module): + def __init__(self, c_s, num_bins, **kwargs): + super(AffinityClsTokenPredictor, self).__init__() + + self.c_s = c_s + self.linear = Linear(self.c_s, num_bins, init="final") + + def forward(self, s, affinity_mask): + affinity_tokens = (s * affinity_mask.unsqueeze(-1)).sum(dim=1) + return self.linear(affinity_tokens) + + +class BindingSitePredictor(nn.Module): + def __init__(self, c_s, c_out, **kwargs): + super(BindingSitePredictor, self).__init__() + + self.c_s = c_s + self.c_out = c_out + + self.linear = Linear(self.c_s, self.c_out, init="final") + + def forward(self, s): + # [*, N, C_out] + return self.linear(s) + + +class InterContactHead(nn.Module): + def __init__(self, c_s, c_z, c_out, **kwargs): + """ + Args: + c_z: + Input channel dimension + c_out: + Number of bins, but since boolean should be 1 + """ + super(InterContactHead, self).__init__() + + self.c_s = c_s + self.c_z = c_z + self.c_out = c_out + + self.linear = Linear(2 * self.c_s + self.c_z, self.c_out, init="final") + + def forward(self, s, z): # [*, N, N, C_z] + # [*, N, N, no_bins] + batch_size, n, s_dim = s.shape + + s_i = s.unsqueeze(2).expand(batch_size, n, n, s_dim) + s_j = s.unsqueeze(1).expand(batch_size, n, n, s_dim) + joined = torch.cat((s_i, s_j, z), dim=-1) + + logits = self.linear(joined) + + return logits + + +class PerResidueLDDTCaPredictor(nn.Module): + def __init__(self, no_bins, c_in, c_hidden): + super(PerResidueLDDTCaPredictor, self).__init__() + + self.no_bins = no_bins + self.c_in = c_in + self.c_hidden = c_hidden + + self.layer_norm = LayerNorm(self.c_in) + + self.linear_1 = Linear(self.c_in, self.c_hidden, init="relu") + self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="relu") + self.linear_3 = Linear(self.c_hidden, self.no_bins, init="final") + + self.relu = nn.ReLU() + + def forward(self, s): + s = self.layer_norm(s) + s = self.linear_1(s) + s = self.relu(s) + s = self.linear_2(s) + s = self.relu(s) + s = self.linear_3(s) + + return s + + +class DistogramHead(nn.Module): + """ + Computes a distogram probability distribution. + + For use in computation of distogram loss, subsection 1.9.8 + """ + + def __init__(self, c_z, no_bins, **kwargs): + """ + Args: + c_z: + Input channel dimension + no_bins: + Number of distogram bins + """ + super(DistogramHead, self).__init__() + + self.c_z = c_z + self.no_bins = no_bins + + self.linear = Linear(self.c_z, self.no_bins, init="final") + + def _forward(self, z): # [*, N, N, C_z] + """ + Args: + z: + [*, N_res, N_res, C_z] pair embedding + Returns: + [*, N, N, no_bins] distogram probability distribution + """ + # [*, N, N, no_bins] + logits = self.linear(z) + logits = logits + logits.transpose(-2, -3) + return logits + + def forward(self, z): + if(is_fp16_enabled()): + with torch.cuda.amp.autocast(enabled=False): + return self._forward(z.float()) + else: + return self._forward(z) diff --git a/dockformer/model/model.py b/dockformer/model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..7d666cad1202ed88322d7dcbd304a2a5d7db6e31 --- /dev/null +++ b/dockformer/model/model.py @@ -0,0 +1,318 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +import weakref + +import torch +import torch.nn as nn + +from dockformer.utils.tensor_utils import masked_mean +from dockformer.model.embedders import ( + StructureInputEmbedder, + RecyclingEmbedder, +) +from dockformer.model.evoformer import EvoformerStack +from dockformer.model.heads import AuxiliaryHeads +from dockformer.model.structure_module import StructureModule +import dockformer.utils.residue_constants as residue_constants +from dockformer.utils.feats import ( + pseudo_beta_fn, + atom14_to_atom37, +) +from dockformer.utils.tensor_utils import ( + add, + tensor_tree_map, +) + + +class AlphaFold(nn.Module): + """ + Alphafold 2. + + Implements Algorithm 2 (but with training). + """ + + def __init__(self, config): + """ + Args: + config: + A dict-like config object (like the one in config.py) + """ + super(AlphaFold, self).__init__() + + self.globals = config.globals + self.config = config.model + + # Main trunk + structure module + self.input_embedder = StructureInputEmbedder( + **self.config["structure_input_embedder"], + ) + + self.recycling_embedder = RecyclingEmbedder( + **self.config["recycling_embedder"], + ) + + self.evoformer = EvoformerStack( + **self.config["evoformer_stack"], + ) + + self.structure_module = StructureModule( + **self.config["structure_module"], + ) + self.aux_heads = AuxiliaryHeads( + self.config["heads"], + ) + + def tolerance_reached(self, prev_pos, next_pos, mask, eps=1e-8) -> bool: + """ + Early stopping criteria based on criteria used in + AF2Complex: https://www.nature.com/articles/s41467-022-29394-2 + Args: + prev_pos: Previous atom positions in atom37/14 representation + next_pos: Current atom positions in atom37/14 representation + mask: 1-D sequence mask + eps: Epsilon used in square root calculation + Returns: + Whether to stop recycling early based on the desired tolerance. + """ + + def distances(points): + """Compute all pairwise distances for a set of points.""" + d = points[..., None, :] - points[..., None, :, :] + return torch.sqrt(torch.sum(d ** 2, dim=-1)) + + if self.config.recycle_early_stop_tolerance < 0: + return False + + ca_idx = residue_constants.atom_order['CA'] + sq_diff = (distances(prev_pos[..., ca_idx, :]) - distances(next_pos[..., ca_idx, :])) ** 2 + mask = mask[..., None] * mask[..., None, :] + sq_diff = masked_mean(mask=mask, value=sq_diff, dim=list(range(len(mask.shape)))) + diff = torch.sqrt(sq_diff + eps).item() + return diff <= self.config.recycle_early_stop_tolerance + + def iteration(self, feats, prevs, _recycle=True): + # Primary output dictionary + outputs = {} + + # This needs to be done manually for DeepSpeed's sake + dtype = next(self.parameters()).dtype + for k in feats: + if feats[k].dtype == torch.float32: + feats[k] = feats[k].to(dtype=dtype) + + # Grab some data about the input + batch_dims, n_total = feats["token_mask"].shape + device = feats["token_mask"].device + + print("doing sample of size", feats["token_mask"].shape, + feats["protein_mask"].sum(dim=1), feats["ligand_mask"].sum(dim=1)) + + # Controls whether the model uses in-place operations throughout + # The dual condition accounts for activation checkpoints + # inplace_safe = not (self.training or torch.is_grad_enabled()) + inplace_safe = False # so we don't need attn_core_inplace_cuda + + # Prep some features + token_mask = feats["token_mask"] + pair_mask = token_mask[..., None] * token_mask[..., None, :] + + # Initialize the single and pair representations + # m: [*, 1, n_total, C_m] + # z: [*, n_total, n_total, C_z] + m, z = self.input_embedder( + feats["token_mask"], + feats["protein_mask"], + feats["ligand_mask"], + feats["target_feat"], + feats["ligand_bonds_feat"], + feats["input_positions"], + feats["protein_residue_index"], + feats["protein_distogram_mask"], + inplace_safe=inplace_safe, + ) + + # Unpack the recycling embeddings. Removing them from the list allows + # them to be freed further down in this function, saving memory + m_1_prev, z_prev, x_prev = reversed([prevs.pop() for _ in range(3)]) + + # Initialize the recycling embeddings, if needs be + if None in [m_1_prev, z_prev, x_prev]: + # [*, N, C_m] + m_1_prev = m.new_zeros( + (batch_dims, n_total, self.config.structure_input_embedder.c_m), + requires_grad=False, + ) + + # [*, N, N, C_z] + z_prev = z.new_zeros( + (batch_dims, n_total, n_total, self.config.structure_input_embedder.c_z), + requires_grad=False, + ) + + # [*, N, 3] + x_prev = z.new_zeros( + (batch_dims, n_total, residue_constants.atom_type_num, 3), + requires_grad=False, + ) + + # shape == [1, n_total, 37, 3] + pseudo_beta_or_lig_x_prev = pseudo_beta_fn(feats["aatype"], x_prev, None).to(dtype=z.dtype) + + # m_1_prev_emb: [*, N, C_m] + # z_prev_emb: [*, N, N, C_z] + m_1_prev_emb, z_prev_emb = self.recycling_embedder( + m_1_prev, + z_prev, + pseudo_beta_or_lig_x_prev, + inplace_safe=inplace_safe, + ) + + del pseudo_beta_or_lig_x_prev + + # [*, S_c, N, C_m] + m += m_1_prev_emb + + # [*, N, N, C_z] + z = add(z, z_prev_emb, inplace=inplace_safe) + + # Deletions like these become significant for inference with large N, + # where they free unused tensors and remove references to others such + # that they can be offloaded later + del m_1_prev, z_prev, m_1_prev_emb, z_prev_emb + + # Run single + pair embeddings through the trunk of the network + # m: [*, N, C_m] + # z: [*, N, N, C_z] + # s: [*, N, C_s] + m, z, s = self.evoformer( + m, + z, + single_mask=token_mask.to(dtype=m.dtype), + pair_mask=pair_mask.to(dtype=z.dtype), + use_lma=self.globals.use_lma, + inplace_safe=inplace_safe, + _mask_trans=self.config._mask_trans, + ) + + outputs["pair"] = z + outputs["single"] = s + + del z + + # Predict 3D structure + outputs["sm"] = self.structure_module( + outputs, + feats["aatype"], + mask=token_mask.to(dtype=s.dtype), + inplace_safe=inplace_safe, + ) + outputs["final_atom_positions"] = atom14_to_atom37( + outputs["sm"]["positions"][-1], feats + ) + outputs["final_atom_mask"] = feats["atom37_atom_exists"] + + # Save embeddings for use during the next recycling iteration + + # [*, N, C_m] + m_1_prev = m[..., 0, :, :] + + # [*, N, N, C_z] + z_prev = outputs["pair"] + + # TODO bshor: early stop depends on is_multimer, but I don't think it must + early_stop = False + # if self.globals.is_multimer: + # early_stop = self.tolerance_reached(x_prev, outputs["final_atom_positions"], seq_mask) + + del x_prev + + # [*, N, 3] + x_prev = outputs["final_atom_positions"] + + return outputs, m_1_prev, z_prev, x_prev, early_stop + + def forward(self, batch): + """ + Args: + batch: + Dictionary of arguments outlined in Algorithm 2. Keys must + include the official names of the features in the + supplement subsection 1.2.9. + + The final dimension of each input must have length equal to + the number of recycling iterations. + + Features (without the recycling dimension): + + "aatype" ([*, N_res]): + Contrary to the supplement, this tensor of residue + indices is not one-hot. + "protein_target_feat" ([*, N_res, C_tf]) + One-hot encoding of the target sequence. C_tf is + config.model.input_embedder.tf_dim. + "residue_index" ([*, N_res]) + Tensor whose final dimension consists of + consecutive indices from 0 to N_res. + "token_mask" ([*, N_token]) + 1-D token mask + "pair_mask" ([*, N_token, N_token]) + 2-D pair mask + """ + # Initialize recycling embeddings + m_1_prev, z_prev, x_prev = None, None, None + prevs = [m_1_prev, z_prev, x_prev] + + is_grad_enabled = torch.is_grad_enabled() + + # Main recycling loop + num_iters = batch["aatype"].shape[-1] + early_stop = False + num_recycles = 0 + for cycle_no in range(num_iters): + # Select the features for the current recycling cycle + fetch_cur_batch = lambda t: t[..., cycle_no] + feats = tensor_tree_map(fetch_cur_batch, batch) + + # Enable grad iff we're training and it's the final recycling layer + is_final_iter = cycle_no == (num_iters - 1) or early_stop + with torch.set_grad_enabled(is_grad_enabled and is_final_iter): + if is_final_iter: + # Sidestep AMP bug (PyTorch issue #65766) + if torch.is_autocast_enabled(): + torch.clear_autocast_cache() + + # Run the next iteration of the model + outputs, m_1_prev, z_prev, x_prev, early_stop = self.iteration( + feats, + prevs, + _recycle=(num_iters > 1) + ) + + num_recycles += 1 + + if not is_final_iter: + del outputs + prevs = [m_1_prev, z_prev, x_prev] + del m_1_prev, z_prev, x_prev + else: + break + + outputs["num_recycles"] = torch.tensor(num_recycles, device=feats["aatype"].device) + + # Run auxiliary heads, remove the recycling dimension batch properties + outputs.update(self.aux_heads(outputs, batch["inter_pair_mask"][..., 0], batch["affinity_mask"][..., 0])) + + return outputs diff --git a/dockformer/model/pair_transition.py b/dockformer/model/pair_transition.py new file mode 100644 index 0000000000000000000000000000000000000000..263370718904de4c7756f54e3b73dfa9c2fe76d0 --- /dev/null +++ b/dockformer/model/pair_transition.py @@ -0,0 +1,81 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +import torch.nn as nn + +from dockformer.model.primitives import Linear, LayerNorm + + +class PairTransition(nn.Module): + """ + Implements Algorithm 15. + """ + + def __init__(self, c_z, n): + """ + Args: + c_z: + Pair transition channel dimension + n: + Factor by which c_z is multiplied to obtain hidden channel + dimension + """ + super(PairTransition, self).__init__() + + self.c_z = c_z + self.n = n + + self.layer_norm = LayerNorm(self.c_z) + self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu") + self.relu = nn.ReLU() + self.linear_2 = Linear(self.n * self.c_z, c_z, init="final") + + def _transition(self, z, mask): + # [*, N_res, N_res, C_z] + z = self.layer_norm(z) + + # [*, N_res, N_res, C_hidden] + z = self.linear_1(z) + z = self.relu(z) + + # [*, N_res, N_res, C_z] + z = self.linear_2(z) + z = z * mask + + return z + + def forward(self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + z: + [*, N_res, N_res, C_z] pair embedding + Returns: + [*, N_res, N_res, C_z] pair embedding update + """ + # DISCREPANCY: DeepMind forgets to apply the mask in this module. + if mask is None: + mask = z.new_ones(z.shape[:-1]) + + # [*, N_res, N_res, 1] + mask = mask.unsqueeze(-1) + + z = self._transition(z=z, mask=mask) + + return z diff --git a/dockformer/model/primitives.py b/dockformer/model/primitives.py new file mode 100644 index 0000000000000000000000000000000000000000..b9b4131fc1460880891f59b2269776772bea88b5 --- /dev/null +++ b/dockformer/model/primitives.py @@ -0,0 +1,598 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import math +from typing import Optional, Callable, List, Tuple +import numpy as np + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from scipy.stats import truncnorm + +from dockformer.utils.kernel.attention_core import attention_core +from dockformer.utils.precision_utils import is_fp16_enabled +from dockformer.utils.tensor_utils import ( + permute_final_dims, + flatten_final_dims, +) + + +# Suited for 40gb GPU +# DEFAULT_LMA_Q_CHUNK_SIZE = 1024 +# DEFAULT_LMA_KV_CHUNK_SIZE = 4096 +# Suited for 10gb GPU +DEFAULT_LMA_Q_CHUNK_SIZE = 64 +DEFAULT_LMA_KV_CHUNK_SIZE = 256 + + +def _prod(nums): + out = 1 + for n in nums: + out = out * n + return out + + +def _calculate_fan(linear_weight_shape, fan="fan_in"): + fan_out, fan_in = linear_weight_shape + + if fan == "fan_in": + f = fan_in + elif fan == "fan_out": + f = fan_out + elif fan == "fan_avg": + f = (fan_in + fan_out) / 2 + else: + raise ValueError("Invalid fan option") + + return f + + +def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): + shape = weights.shape + f = _calculate_fan(shape, fan) + scale = scale / max(1, f) + a = -2 + b = 2 + std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1) + size = _prod(shape) + samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size) + samples = np.reshape(samples, shape) + with torch.no_grad(): + weights.copy_(torch.tensor(samples, device=weights.device)) + + +def lecun_normal_init_(weights): + trunc_normal_init_(weights, scale=1.0) + + +def he_normal_init_(weights): + trunc_normal_init_(weights, scale=2.0) + + +def glorot_uniform_init_(weights): + nn.init.xavier_uniform_(weights, gain=1) + + +def final_init_(weights): + with torch.no_grad(): + weights.fill_(0.0) + + +def gating_init_(weights): + with torch.no_grad(): + weights.fill_(0.0) + + +def normal_init_(weights): + torch.nn.init.kaiming_normal_(weights, nonlinearity="linear") + + +def ipa_point_weights_init_(weights): + with torch.no_grad(): + softplus_inverse_1 = 0.541324854612918 + weights.fill_(softplus_inverse_1) + + +class Linear(nn.Linear): + """ + A Linear layer with built-in nonstandard initializations. Called just + like torch.nn.Linear. + + Implements the initializers in 1.11.4, plus some additional ones found + in the code. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + bias: bool = True, + init: str = "default", + init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None, + precision=None + ): + """ + Args: + in_dim: + The final dimension of inputs to the layer + out_dim: + The final dimension of layer outputs + bias: + Whether to learn an additive bias. True by default + init: + The initializer to use. Choose from: + + "default": LeCun fan-in truncated normal initialization + "relu": He initialization w/ truncated normal distribution + "glorot": Fan-average Glorot uniform initialization + "gating": Weights=0, Bias=1 + "normal": Normal initialization with std=1/sqrt(fan_in) + "final": Weights=0, Bias=0 + + Overridden by init_fn if the latter is not None. + init_fn: + A custom initializer taking weight and bias as inputs. + Overrides init if not None. + """ + super(Linear, self).__init__(in_dim, out_dim, bias=bias) + + if bias: + with torch.no_grad(): + self.bias.fill_(0) + + with torch.no_grad(): + if init_fn is not None: + init_fn(self.weight, self.bias) + else: + if init == "default": + lecun_normal_init_(self.weight) + elif init == "relu": + he_normal_init_(self.weight) + elif init == "glorot": + glorot_uniform_init_(self.weight) + elif init == "gating": + gating_init_(self.weight) + if bias: + self.bias.fill_(1.0) + elif init == "normal": + normal_init_(self.weight) + elif init == "final": + final_init_(self.weight) + else: + raise ValueError("Invalid init string.") + + self.precision = precision + + def forward(self, input: torch.Tensor) -> torch.Tensor: + d = input.dtype + if self.precision is not None: + with torch.cuda.amp.autocast(enabled=False): + bias = self.bias.to(dtype=self.precision) if self.bias is not None else None + return nn.functional.linear(input.to(dtype=self.precision), + self.weight.to(dtype=self.precision), + bias).to(dtype=d) + + if d is torch.bfloat16: + with torch.cuda.amp.autocast(enabled=False): + bias = self.bias.to(dtype=d) if self.bias is not None else None + return nn.functional.linear(input, self.weight.to(dtype=d), bias) + + return nn.functional.linear(input, self.weight, self.bias) + + +class LayerNorm(nn.Module): + def __init__(self, c_in, eps=1e-5): + super(LayerNorm, self).__init__() + + self.c_in = (c_in,) + self.eps = eps + + self.weight = nn.Parameter(torch.ones(c_in)) + self.bias = nn.Parameter(torch.zeros(c_in)) + + def forward(self, x): + d = x.dtype + if d is torch.bfloat16: + with torch.cuda.amp.autocast(enabled=False): + out = nn.functional.layer_norm( + x, + self.c_in, + self.weight.to(dtype=d), + self.bias.to(dtype=d), + self.eps + ) + else: + out = nn.functional.layer_norm( + x, + self.c_in, + self.weight, + self.bias, + self.eps, + ) + + return out + + +@torch.jit.ignore +def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Softmax, but without automatic casting to fp32 when the input is of + type bfloat16 + """ + d = t.dtype + if d is torch.bfloat16: + with torch.cuda.amp.autocast(enabled=False): + s = torch.nn.functional.softmax(t, dim=dim) + else: + s = torch.nn.functional.softmax(t, dim=dim) + + return s + + +#@torch.jit.script +def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor]) -> torch.Tensor: + # [*, H, C_hidden, K] + key = permute_final_dims(key, (1, 0)) + + # [*, H, Q, K] + a = torch.matmul(query, key) + + for b in biases: + a += b + + a = softmax_no_cast(a, -1) + + # [*, H, Q, C_hidden] + a = torch.matmul(a, value) + + return a + + +class Attention(nn.Module): + """ + Standard multi-head attention using AlphaFold's default layer + initialization. Allows multiple bias vectors. + """ + def __init__( + self, + c_q: int, + c_k: int, + c_v: int, + c_hidden: int, + no_heads: int, + gating: bool = True, + ): + """ + Args: + c_q: + Input dimension of query data + c_k: + Input dimension of key data + c_v: + Input dimension of value data + c_hidden: + Per-head hidden dimension + no_heads: + Number of attention heads + gating: + Whether the output should be gated using query data + """ + super(Attention, self).__init__() + + self.c_q = c_q + self.c_k = c_k + self.c_v = c_v + self.c_hidden = c_hidden + self.no_heads = no_heads + self.gating = gating + + # DISCREPANCY: c_hidden is not the per-head channel dimension, as + # stated in the supplement, but the overall channel dimension. + + self.linear_q = Linear( + self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot" + ) + self.linear_k = Linear( + self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot" + ) + self.linear_v = Linear( + self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot" + ) + self.linear_o = Linear( + self.c_hidden * self.no_heads, self.c_q, init="final" + ) + + self.linear_g = None + if self.gating: + self.linear_g = Linear( + self.c_q, self.c_hidden * self.no_heads, init="gating" + ) + + self.sigmoid = nn.Sigmoid() + + def _prep_qkv(self, + q_x: torch.Tensor, + kv_x: torch.Tensor, + apply_scale: bool = True + ) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor + ]: + # [*, Q/K/V, H * C_hidden] + q = self.linear_q(q_x) + k = self.linear_k(kv_x) + v = self.linear_v(kv_x) + + # [*, Q/K, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + k = k.view(k.shape[:-1] + (self.no_heads, -1)) + v = v.view(v.shape[:-1] + (self.no_heads, -1)) + + # [*, H, Q/K, C_hidden] + q = q.transpose(-2, -3) + k = k.transpose(-2, -3) + v = v.transpose(-2, -3) + + if apply_scale: + q /= math.sqrt(self.c_hidden) + + return q, k, v + + def _wrap_up(self, + o: torch.Tensor, + q_x: torch.Tensor + ) -> torch.Tensor: + if self.linear_g is not None: + g = self.sigmoid(self.linear_g(q_x)) + + # [*, Q, H, C_hidden] + g = g.view(g.shape[:-1] + (self.no_heads, -1)) + o = o * g + + # [*, Q, H * C_hidden] + o = flatten_final_dims(o, 2) + + # [*, Q, C_q] + o = self.linear_o(o) + + return o + + def forward( + self, + q_x: torch.Tensor, + kv_x: torch.Tensor, + biases: Optional[List[torch.Tensor]] = None, + use_memory_efficient_kernel: bool = False, + use_lma: bool = False, + lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE, + lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE, + ) -> torch.Tensor: + """ + Args: + q_x: + [*, Q, C_q] query data + kv_x: + [*, K, C_k] key data + biases: + List of biases that broadcast to [*, H, Q, K] + use_memory_efficient_kernel: + Whether to use a custom memory-efficient attention kernel. + This should be the default choice for most. If none of the + "use_<...>" flags are True, a stock PyTorch implementation + is used instead + use_lma: + Whether to use low-memory attention (Staats & Rabe 2021). If + none of the "use_<...>" flags are True, a stock PyTorch + implementation is used instead + lma_q_chunk_size: + Query chunk size (for LMA) + lma_kv_chunk_size: + Key/Value chunk size (for LMA) + Returns + [*, Q, C_q] attention update + """ + if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None): + raise ValueError( + "If use_lma is specified, lma_q_chunk_size and " + "lma_kv_chunk_size must be provided" + ) + + attn_options = [use_memory_efficient_kernel, use_lma] + if sum(attn_options) > 1: + raise ValueError( + "Choose at most one alternative attention algorithm" + ) + + if biases is None: + biases = [] + + q, k, v = self._prep_qkv(q_x, kv_x, apply_scale=True) + + if is_fp16_enabled(): + use_memory_efficient_kernel = False + + if use_memory_efficient_kernel: + if len(biases) > 2: + raise ValueError( + "If use_memory_efficient_kernel is True, you may only " + "provide up to two bias terms" + ) + o = attention_core(q, k, v, *((biases + [None] * 2)[:2])) + o = o.transpose(-2, -3) + elif use_lma: + biases = [ + b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) + for b in biases + ] + o = _lma(q, k, v, biases, lma_q_chunk_size, lma_kv_chunk_size) + o = o.transpose(-2, -3) + else: + o = _attention(q, k, v, biases) + o = o.transpose(-2, -3) + + o = self._wrap_up(o, q_x) + + return o + + +class GlobalAttention(nn.Module): + def __init__(self, c_in, c_hidden, no_heads, inf, eps): + super(GlobalAttention, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.inf = inf + self.eps = eps + + self.linear_q = Linear( + c_in, c_hidden * no_heads, bias=False, init="glorot" + ) + + self.linear_k = Linear( + c_in, c_hidden, bias=False, init="glorot", + ) + self.linear_v = Linear( + c_in, c_hidden, bias=False, init="glorot", + ) + self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating") + self.linear_o = Linear(c_hidden * no_heads, c_in, init="final") + + self.sigmoid = nn.Sigmoid() + + def forward(self, + m: torch.Tensor, + mask: torch.Tensor, + use_lma: bool = False, + ) -> torch.Tensor: + # [*, N_res, C_in] + q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / ( + torch.sum(mask, dim=-1)[..., None] + self.eps + ) + + # [*, N_res, H * C_hidden] + q = self.linear_q(q) + q *= (self.c_hidden ** (-0.5)) + + # [*, N_res, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, C_hidden] + k = self.linear_k(m) + v = self.linear_v(m) + + bias = (self.inf * (mask - 1))[..., :, None, :] + if not use_lma: + # [*, N_res, H, N_seq] + a = torch.matmul( + q, + k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq] + ) + a += bias + a = softmax_no_cast(a) + + # [*, N_res, H, C_hidden] + o = torch.matmul( + a, + v, + ) + else: + o = _lma( + q, + k, + v, + [bias], + DEFAULT_LMA_Q_CHUNK_SIZE, + DEFAULT_LMA_KV_CHUNK_SIZE + ) + + # [*, N_res, C_hidden] + g = self.sigmoid(self.linear_g(m)) + + # [*, N_res, H, C_hidden] + g = g.view(g.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, H, C_hidden] + o = o.unsqueeze(-3) * g + + # [*, N_res, H * C_hidden] + o = o.reshape(o.shape[:-2] + (-1,)) + + # [*, N_res, C_in] + m = self.linear_o(o) + + return m + + +def _lma( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + biases: List[torch.Tensor], + q_chunk_size: int, + kv_chunk_size: int, +): + no_q, no_kv = q.shape[-2], k.shape[-2] + + # [*, H, Q, C_hidden] + o = q.new_zeros(q.shape) + for q_s in range(0, no_q, q_chunk_size): + q_chunk = q[..., q_s: q_s + q_chunk_size, :] + large_bias_chunks = [ + b[..., q_s: q_s + q_chunk_size, :] for b in biases + ] + + maxes = [] + weights = [] + values = [] + for kv_s in range(0, no_kv, kv_chunk_size): + k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :] + v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :] + small_bias_chunks = [ + b[..., kv_s: kv_s + kv_chunk_size] for b in large_bias_chunks + ] + + a = torch.einsum( + "...hqd,...hkd->...hqk", q_chunk, k_chunk, + ) + + for b in small_bias_chunks: + a += b + + max_a = torch.max(a, dim=-1, keepdim=True)[0] + exp_a = torch.exp(a - max_a) + exp_v = torch.einsum("...hvf,...hqv->...hqf", v_chunk, exp_a) + + maxes.append(max_a.detach().squeeze(-1)) + weights.append(torch.sum(exp_a, dim=-1)) + values.append(exp_v) + + chunk_max = torch.stack(maxes, dim=-3) + chunk_weights = torch.stack(weights, dim=-3) + chunk_values = torch.stack(values, dim=-4) + + global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0] + max_diffs = torch.exp(chunk_max - global_max) + chunk_values = chunk_values * max_diffs.unsqueeze(-1) + chunk_weights = chunk_weights * max_diffs + + all_values = torch.sum(chunk_values, dim=-4) + all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4) + + q_chunk_out = all_values / all_weights + + o[..., q_s: q_s + q_chunk_size, :] = q_chunk_out + + return o diff --git a/dockformer/model/single_attention.py b/dockformer/model/single_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..27660da8ab8e6e9a6ce2152a0d15dee3af0395b8 --- /dev/null +++ b/dockformer/model/single_attention.py @@ -0,0 +1,184 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +import math +import torch +import torch.nn as nn +import torch.utils.checkpoint +from typing import Optional, List, Tuple + +from dockformer.model.primitives import ( + Linear, + LayerNorm, + Attention, +) +from dockformer.utils.tensor_utils import permute_final_dims + + +class SingleAttention(nn.Module): + def __init__( + self, + c_in, + c_hidden, + no_heads, + pair_bias=False, + c_z=None, + inf=1e9, + ): + """ + Args: + c_in: + Input channel dimension + c_hidden: + Per-head hidden channel dimension + no_heads: + Number of attention heads + pair_bias: + Whether to use pair embedding bias + c_z: + Pair embedding channel dimension. Ignored unless pair_bias + is true + inf: + A large number to be used in computing the attention mask + """ + super(SingleAttention, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.pair_bias = pair_bias + self.c_z = c_z + self.inf = inf + + self.layer_norm_m = LayerNorm(self.c_in) + + self.layer_norm_z = None + self.linear_z = None + if self.pair_bias: + self.layer_norm_z = LayerNorm(self.c_z) + self.linear_z = Linear( + self.c_z, self.no_heads, bias=False, init="normal" + ) + + self.mha = Attention( + self.c_in, + self.c_in, + self.c_in, + self.c_hidden, + self.no_heads, + ) + + def _prep_inputs(self, + m: torch.Tensor, + z: Optional[torch.Tensor], + mask: Optional[torch.Tensor], + inplace_safe: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if mask is None: + # [*, N_res] + mask = m.new_ones(m.shape[:-1]) + + # [*, 1, 1, N_res] + mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] + + if (self.pair_bias and + z is not None and # For the + self.layer_norm_z is not None and # benefit of + self.linear_z is not None # TorchScript + ): + chunks = [] + + for i in range(0, z.shape[-3], 256): + z_chunk = z[..., i: i + 256, :, :] + + # [*, N_res, N_res, C_z] + z_chunk = self.layer_norm_z(z_chunk) + + # [*, N_res, N_res, no_heads] + z_chunk = self.linear_z(z_chunk) + + chunks.append(z_chunk) + + z = torch.cat(chunks, dim=-3) + + # [*, no_heads, N_res, N_res] + z = permute_final_dims(z, (2, 0, 1)) + + return m, mask_bias, z + + def forward(self, + m: torch.Tensor, + z: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + use_memory_efficient_kernel: bool = False, + use_lma: bool = False, + inplace_safe: bool = False, + ) -> torch.Tensor: + """ + Args: + m: + [*, N_res, C_m] single embedding + z: + [*, N_res, N_res, C_z] pair embedding. Required only if pair_bias is True + mask: + [*, N_res] single mask + """ + m, mask_bias, z = self._prep_inputs( + m, z, mask, inplace_safe=inplace_safe + ) + + biases = [mask_bias] + if(z is not None): + biases.append(z) + + m = self.layer_norm_m(m) + m = self.mha( + q_x=m, + kv_x=m, + biases=biases, + use_memory_efficient_kernel=use_memory_efficient_kernel, + use_lma=use_lma, + ) + + return m + + +class SingleRowAttentionWithPairBias(SingleAttention): + """ + Implements Algorithm 7. + """ + + def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9): + """ + Args: + c_m: + Input channel dimension + c_z: + Pair embedding channel dimension + c_hidden: + Per-head hidden channel dimension + no_heads: + Number of attention heads + inf: + Large number used to construct attention masks + """ + super(SingleRowAttentionWithPairBias, self).__init__( + c_m, + c_hidden, + no_heads, + pair_bias=True, + c_z=c_z, + inf=inf, + ) diff --git a/dockformer/model/structure_module.py b/dockformer/model/structure_module.py new file mode 100644 index 0000000000000000000000000000000000000000..7b777dcbd80f99e51984d4c02dc1fbd505972274 --- /dev/null +++ b/dockformer/model/structure_module.py @@ -0,0 +1,837 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import reduce +import importlib +import math +import sys +from operator import mul + +import torch +import torch.nn as nn +from typing import Optional, Tuple, Sequence, Union + +from dockformer.model.primitives import Linear, LayerNorm, ipa_point_weights_init_ +from dockformer.utils.residue_constants import ( + restype_rigid_group_default_frame, + restype_atom14_to_rigid_group, + restype_atom14_mask, + restype_atom14_rigid_group_positions, +) +from dockformer.utils.geometry.quat_rigid import QuatRigid +from dockformer.utils.geometry.rigid_matrix_vector import Rigid3Array +from dockformer.utils.geometry.vector import Vec3Array, square_euclidean_distance +from dockformer.utils.feats import ( + frames_and_literature_positions_to_atom14_pos, + torsion_angles_to_frames, +) +from dockformer.utils.precision_utils import is_fp16_enabled +from dockformer.utils.rigid_utils import Rotation, Rigid +from dockformer.utils.tensor_utils import ( + dict_multimap, + permute_final_dims, + flatten_final_dims, +) + +import importlib.util +attn_core_is_installed = importlib.util.find_spec("attn_core_inplace_cuda") is not None +attn_core_inplace_cuda = None +if attn_core_is_installed: + attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda") + + +class AngleResnetBlock(nn.Module): + def __init__(self, c_hidden): + """ + Args: + c_hidden: + Hidden channel dimension + """ + super(AngleResnetBlock, self).__init__() + + self.c_hidden = c_hidden + + self.linear_1 = Linear(self.c_hidden, self.c_hidden, init="relu") + self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="final") + + self.relu = nn.ReLU() + + def forward(self, a: torch.Tensor) -> torch.Tensor: + + s_initial = a + + a = self.relu(a) + a = self.linear_1(a) + a = self.relu(a) + a = self.linear_2(a) + + return a + s_initial + + +class AngleResnet(nn.Module): + """ + Implements Algorithm 20, lines 11-14 + """ + + def __init__(self, c_in, c_hidden, no_blocks, no_angles, epsilon): + """ + Args: + c_in: + Input channel dimension + c_hidden: + Hidden channel dimension + no_blocks: + Number of resnet blocks + no_angles: + Number of torsion angles to generate + epsilon: + Small constant for normalization + """ + super(AngleResnet, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_blocks = no_blocks + self.no_angles = no_angles + self.eps = epsilon + + self.linear_in = Linear(self.c_in, self.c_hidden) + self.linear_initial = Linear(self.c_in, self.c_hidden) + + self.layers = nn.ModuleList() + for _ in range(self.no_blocks): + layer = AngleResnetBlock(c_hidden=self.c_hidden) + self.layers.append(layer) + + self.linear_out = Linear(self.c_hidden, self.no_angles * 2) + + self.relu = nn.ReLU() + + def forward( + self, s: torch.Tensor, s_initial: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + s: + [*, C_hidden] single embedding + s_initial: + [*, C_hidden] single embedding as of the start of the + StructureModule + Returns: + [*, no_angles, 2] predicted angles + """ + # NOTE: The ReLU's applied to the inputs are absent from the supplement + # pseudocode but present in the source. For maximal compatibility with + # the pretrained weights, I'm going with the source. + + # [*, C_hidden] + s_initial = self.relu(s_initial) + s_initial = self.linear_initial(s_initial) + s = self.relu(s) + s = self.linear_in(s) + s = s + s_initial + + for l in self.layers: + s = l(s) + + s = self.relu(s) + + # [*, no_angles * 2] + s = self.linear_out(s) + + # [*, no_angles, 2] + s = s.view(s.shape[:-1] + (-1, 2)) + + unnormalized_s = s + norm_denom = torch.sqrt( + torch.clamp( + torch.sum(s ** 2, dim=-1, keepdim=True), + min=self.eps, + ) + ) + s = s / norm_denom + + return unnormalized_s, s + + +class PointProjection(nn.Module): + def __init__(self, + c_hidden: int, + num_points: int, + no_heads: int, + return_local_points: bool = False, + ): + super().__init__() + self.return_local_points = return_local_points + self.no_heads = no_heads + self.num_points = num_points + + # Multimer requires this to be run with fp32 precision during training + precision = None + self.linear = Linear(c_hidden, no_heads * 3 * num_points, precision=precision) + + def forward(self, + activations: torch.Tensor, + rigids: Union[Rigid, Rigid3Array], + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + # TODO: Needs to run in high precision during training + points_local = self.linear(activations) + out_shape = points_local.shape[:-1] + (self.no_heads, self.num_points, 3) + + points_local = torch.split( + points_local, points_local.shape[-1] // 3, dim=-1 + ) + + points_local = torch.stack(points_local, dim=-1).view(out_shape) + + points_global = rigids[..., None, None].apply(points_local) + + if(self.return_local_points): + return points_global, points_local + + return points_global + + +class InvariantPointAttention(nn.Module): + """ + Implements Algorithm 22. + """ + def __init__( + self, + c_s: int, + c_z: int, + c_hidden: int, + no_heads: int, + no_qk_points: int, + no_v_points: int, + inf: float = 1e5, + eps: float = 1e-8, + ): + """ + Args: + c_s: + Single representation channel dimension + c_z: + Pair representation channel dimension + c_hidden: + Hidden channel dimension + no_heads: + Number of attention heads + no_qk_points: + Number of query/key points to generate + no_v_points: + Number of value points to generate + """ + super(InvariantPointAttention, self).__init__() + + self.c_s = c_s + self.c_z = c_z + self.c_hidden = c_hidden + self.no_heads = no_heads + self.no_qk_points = no_qk_points + self.no_v_points = no_v_points + self.inf = inf + self.eps = eps + + # These linear layers differ from their specifications in the + # supplement. There, they lack bias and use Glorot initialization. + # Here as in the official source, they have bias and use the default + # Lecun initialization. + hc = self.c_hidden * self.no_heads + self.linear_q = Linear(self.c_s, hc, bias=True) + + self.linear_q_points = PointProjection( + self.c_s, + self.no_qk_points, + self.no_heads, + ) + + + self.linear_kv = Linear(self.c_s, 2 * hc) + self.linear_kv_points = PointProjection( + self.c_s, + self.no_qk_points + self.no_v_points, + self.no_heads, + ) + + self.linear_b = Linear(self.c_z, self.no_heads) + + self.head_weights = nn.Parameter(torch.zeros((no_heads))) + ipa_point_weights_init_(self.head_weights) + + concat_out_dim = self.no_heads * ( + self.c_z + self.c_hidden + self.no_v_points * 4 + ) + self.linear_out = Linear(concat_out_dim, self.c_s, init="final") + + self.softmax = nn.Softmax(dim=-1) + self.softplus = nn.Softplus() + + def forward( + self, + s: torch.Tensor, + z: torch.Tensor, + r: Union[Rigid, Rigid3Array], + mask: torch.Tensor, + inplace_safe: bool = False, + ) -> torch.Tensor: + """ + Args: + s: + [*, N_res, C_s] single representation + z: + [*, N_res, N_res, C_z] pair representation + r: + [*, N_res] transformation object + mask: + [*, N_res] mask + Returns: + [*, N_res, C_s] single representation update + """ + z = [z] + + ####################################### + # Generate scalar and point activations + ####################################### + # [*, N_res, H * C_hidden] + q = self.linear_q(s) + + # [*, N_res, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, H, P_qk] + q_pts = self.linear_q_points(s, r) + + # The following two blocks are equivalent + # They're separated only to preserve compatibility with old AF weights + + # [*, N_res, H * 2 * C_hidden] + kv = self.linear_kv(s) + + # [*, N_res, H, 2 * C_hidden] + kv = kv.view(kv.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, H, C_hidden] + k, v = torch.split(kv, self.c_hidden, dim=-1) + + kv_pts = self.linear_kv_points(s, r) + + # [*, N_res, H, P_q/P_v, 3] + k_pts, v_pts = torch.split( + kv_pts, [self.no_qk_points, self.no_v_points], dim=-2 + ) + + ########################## + # Compute attention scores + ########################## + # [*, N_res, N_res, H] + b = self.linear_b(z[0]) + + # [*, H, N_res, N_res] + if (is_fp16_enabled()): + with torch.cuda.amp.autocast(enabled=False): + a = torch.matmul( + permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden] + permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res] + ) + else: + a = torch.matmul( + permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden] + permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res] + ) + + a *= math.sqrt(1.0 / (3 * self.c_hidden)) + a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))) + + # [*, N_res, N_res, H, P_q, 3] + pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) + + if (inplace_safe): + pt_att *= pt_att + else: + pt_att = pt_att ** 2 + + pt_att = sum(torch.unbind(pt_att, dim=-1)) + + head_weights = self.softplus(self.head_weights).view( + *((1,) * len(pt_att.shape[:-2]) + (-1, 1)) + ) + head_weights = head_weights * math.sqrt( + 1.0 / (3 * (self.no_qk_points * 9.0 / 2)) + ) + + if (inplace_safe): + pt_att *= head_weights + else: + pt_att = pt_att * head_weights + + # [*, N_res, N_res, H] + pt_att = torch.sum(pt_att, dim=-1) * (-0.5) + + # [*, N_res, N_res] + square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) + square_mask = self.inf * (square_mask - 1) + + # [*, H, N_res, N_res] + pt_att = permute_final_dims(pt_att, (2, 0, 1)) + + if (inplace_safe): + a += pt_att + del pt_att + a += square_mask.unsqueeze(-3) + # in-place softmax + attn_core_inplace_cuda.forward_( + a, + reduce(mul, a.shape[:-1]), + a.shape[-1], + ) + else: + a = a + pt_att + a = a + square_mask.unsqueeze(-3) + a = self.softmax(a) + + ################ + # Compute output + ################ + # [*, N_res, H, C_hidden] + o = torch.matmul( + a, v.transpose(-2, -3).to(dtype=a.dtype) + ).transpose(-2, -3) + + # [*, N_res, H * C_hidden] + o = flatten_final_dims(o, 2) + + # [*, H, 3, N_res, P_v] + if (inplace_safe): + v_pts = permute_final_dims(v_pts, (1, 3, 0, 2)) + o_pt = [ + torch.matmul(a, v.to(a.dtype)) + for v in torch.unbind(v_pts, dim=-3) + ] + o_pt = torch.stack(o_pt, dim=-3) + else: + o_pt = torch.sum( + ( + a[..., None, :, :, None] + * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :] + ), + dim=-2, + ) + + # [*, N_res, H, P_v, 3] + o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) + o_pt = r[..., None, None].invert_apply(o_pt) + + # [*, N_res, H * P_v] + o_pt_norm = flatten_final_dims( + torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2 + ) + + # [*, N_res, H * P_v, 3] + o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) + o_pt = torch.unbind(o_pt, dim=-1) + + # [*, N_res, H, C_z] + o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype)) + + # [*, N_res, H * C_z] + o_pair = flatten_final_dims(o_pair, 2) + + # [*, N_res, C_s] + s = self.linear_out( + torch.cat( + (o, *o_pt, o_pt_norm, o_pair), dim=-1 + ).to(dtype=z[0].dtype) + ) + + return s + + +class BackboneUpdate(nn.Module): + """ + Implements part of Algorithm 23. + """ + + def __init__(self, c_s): + """ + Args: + c_s: + Single representation channel dimension + """ + super(BackboneUpdate, self).__init__() + + self.c_s = c_s + + self.linear = Linear(self.c_s, 6, init="final") + + def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + [*, N_res, C_s] single representation + Returns: + [*, N_res, 6] update vector + """ + # [*, 6] + update = self.linear(s) + + return update + + +class StructureModuleTransitionLayer(nn.Module): + def __init__(self, c): + super(StructureModuleTransitionLayer, self).__init__() + + self.c = c + + self.linear_1 = Linear(self.c, self.c, init="relu") + self.linear_2 = Linear(self.c, self.c, init="relu") + self.linear_3 = Linear(self.c, self.c, init="final") + + self.relu = nn.ReLU() + + def forward(self, s): + s_initial = s + s = self.linear_1(s) + s = self.relu(s) + s = self.linear_2(s) + s = self.relu(s) + s = self.linear_3(s) + + s = s + s_initial + + return s + + +class StructureModuleTransition(nn.Module): + def __init__(self, c, num_layers, dropout_rate): + super(StructureModuleTransition, self).__init__() + + self.c = c + self.num_layers = num_layers + self.dropout_rate = dropout_rate + + self.layers = nn.ModuleList() + for _ in range(self.num_layers): + l = StructureModuleTransitionLayer(self.c) + self.layers.append(l) + + self.dropout = nn.Dropout(self.dropout_rate) + self.layer_norm = LayerNorm(self.c) + + def forward(self, s): + for l in self.layers: + s = l(s) + + s = self.dropout(s) + s = self.layer_norm(s) + + return s + + +class StructureModule(nn.Module): + def __init__( + self, + c_s, + c_z, + c_ipa, + c_resnet, + no_heads_ipa, + no_qk_points, + no_v_points, + dropout_rate, + no_blocks, + no_transition_layers, + no_resnet_blocks, + no_angles, + trans_scale_factor, + epsilon, + inf, + **kwargs, + ): + """ + Args: + c_s: + Single representation channel dimension + c_z: + Pair representation channel dimension + c_ipa: + IPA hidden channel dimension + c_resnet: + Angle resnet (Alg. 23 lines 11-14) hidden channel dimension + no_heads_ipa: + Number of IPA heads + no_qk_points: + Number of query/key points to generate during IPA + no_v_points: + Number of value points to generate during IPA + dropout_rate: + Dropout rate used throughout the layer + no_blocks: + Number of structure module blocks + no_transition_layers: + Number of layers in the single representation transition + (Alg. 23 lines 8-9) + no_resnet_blocks: + Number of blocks in the angle resnet + no_angles: + Number of angles to generate in the angle resnet + trans_scale_factor: + Scale of single representation transition hidden dimension + epsilon: + Small number used in angle resnet normalization + inf: + Large number used for attention masking + """ + super(StructureModule, self).__init__() + + self.c_s = c_s + self.c_z = c_z + self.c_ipa = c_ipa + self.c_resnet = c_resnet + self.no_heads_ipa = no_heads_ipa + self.no_qk_points = no_qk_points + self.no_v_points = no_v_points + self.dropout_rate = dropout_rate + self.no_blocks = no_blocks + self.no_transition_layers = no_transition_layers + self.no_resnet_blocks = no_resnet_blocks + self.no_angles = no_angles + self.trans_scale_factor = trans_scale_factor + self.epsilon = epsilon + self.inf = inf + + # Buffers to be lazily initialized later + # self.default_frames + # self.group_idx + # self.atom_mask + # self.lit_positions + + self.layer_norm_s = LayerNorm(self.c_s) + self.layer_norm_z = LayerNorm(self.c_z) + + self.linear_in = Linear(self.c_s, self.c_s) + + self.ipa = InvariantPointAttention( + self.c_s, + self.c_z, + self.c_ipa, + self.no_heads_ipa, + self.no_qk_points, + self.no_v_points, + inf=self.inf, + eps=self.epsilon, + ) + + self.ipa_dropout = nn.Dropout(self.dropout_rate) + self.layer_norm_ipa = LayerNorm(self.c_s) + + self.transition = StructureModuleTransition( + self.c_s, + self.no_transition_layers, + self.dropout_rate, + ) + + self.bb_update = BackboneUpdate(self.c_s) + + self.angle_resnet = AngleResnet( + self.c_s, + self.c_resnet, + self.no_resnet_blocks, + self.no_angles, + self.epsilon, + ) + + def forward( + self, + evoformer_output_dict, + aatype, + mask=None, + inplace_safe=False, + ): + """ + Args: + evoformer_output_dict: + Dictionary containing: + "single": + [*, N_res, C_s] single representation + "pair": + [*, N_res, N_res, C_z] pair representation + aatype: + [*, N_res] amino acid indices + mask: + Optional [*, N_res] sequence mask + Returns: + A dictionary of outputs + """ + s = evoformer_output_dict["single"] + + if mask is None: + # [*, N] + mask = s.new_ones(s.shape[:-1]) + + # [*, N, C_s] + s = self.layer_norm_s(s) + + # [*, N, N, C_z] + z = self.layer_norm_z(evoformer_output_dict["pair"]) + + # [*, N, C_s] + s_initial = s + s = self.linear_in(s) + + # [*, N] + rigids = Rigid.identity( + s.shape[:-1], + s.dtype, + s.device, + self.training, + fmt="quat", + ) + outputs = [] + for i in range(self.no_blocks): + # [*, N, C_s] + s = s + self.ipa( + s, + z, + rigids, + mask, + inplace_safe=inplace_safe, + ) + s = self.ipa_dropout(s) + s = self.layer_norm_ipa(s) + s = self.transition(s) + + # [*, N] + + # [*, N_res, 6] vector of translations and rotations + bb_update_output = self.bb_update(s) + + rigids = rigids.compose_q_update_vec(bb_update_output) + + + # To hew as closely as possible to AlphaFold, we convert our + # quaternion-based transformations to rotation-matrix ones + # here + backb_to_global = Rigid( + Rotation( + rot_mats=rigids.get_rots().get_rot_mats(), + quats=None + ), + rigids.get_trans(), + ) + + backb_to_global = backb_to_global.scale_translation( + self.trans_scale_factor + ) + + # [*, N, 7, 2] + unnormalized_angles, angles = self.angle_resnet(s, s_initial) + + all_frames_to_global = self.torsion_angles_to_frames( + backb_to_global, + angles, + aatype, + ) + + pred_xyz = self.frames_and_literature_positions_to_atom14_pos( + all_frames_to_global, + aatype, + ) + + scaled_rigids = rigids.scale_translation(self.trans_scale_factor) + + preds = { + "frames": scaled_rigids.to_tensor_7(), + "sidechain_frames": all_frames_to_global.to_tensor_4x4(), + "unnormalized_angles": unnormalized_angles, + "angles": angles, + "positions": pred_xyz, + "states": s, + } + + outputs.append(preds) + + rigids = rigids.stop_rot_gradient() + + del z + + outputs = dict_multimap(torch.stack, outputs) + outputs["single"] = s + + return outputs + + def _init_residue_constants(self, float_dtype, device): + if not hasattr(self, "default_frames"): + self.register_buffer( + "default_frames", + torch.tensor( + restype_rigid_group_default_frame, + dtype=float_dtype, + device=device, + requires_grad=False, + ), + persistent=False, + ) + if not hasattr(self, "group_idx"): + self.register_buffer( + "group_idx", + torch.tensor( + restype_atom14_to_rigid_group, + device=device, + requires_grad=False, + ), + persistent=False, + ) + if not hasattr(self, "atom_mask"): + self.register_buffer( + "atom_mask", + torch.tensor( + restype_atom14_mask, + dtype=float_dtype, + device=device, + requires_grad=False, + ), + persistent=False, + ) + if not hasattr(self, "lit_positions"): + self.register_buffer( + "lit_positions", + torch.tensor( + restype_atom14_rigid_group_positions, + dtype=float_dtype, + device=device, + requires_grad=False, + ), + persistent=False, + ) + + def torsion_angles_to_frames(self, r, alpha, f): + # Lazily initialize the residue constants on the correct device + self._init_residue_constants(alpha.dtype, alpha.device) + # Separated purely to make testing less annoying + return torsion_angles_to_frames(r, alpha, f, self.default_frames) + + def frames_and_literature_positions_to_atom14_pos( + self, r, f # [*, N, 8] # [*, N] + ): + # Lazily initialize the residue constants on the correct device + self._init_residue_constants(r.dtype, r.device) + return frames_and_literature_positions_to_atom14_pos( + r, + f, + self.default_frames, + self.group_idx, + self.atom_mask, + self.lit_positions, + ) diff --git a/dockformer/model/torchscript.py b/dockformer/model/torchscript.py new file mode 100644 index 0000000000000000000000000000000000000000..19fb30337745b11615ff0a439accc306bde60e3d --- /dev/null +++ b/dockformer/model/torchscript.py @@ -0,0 +1,171 @@ +# Copyright 2021 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Sequence, Tuple + +import torch +import torch.nn as nn + +from dockformer.model.evoformer import ( + EvoformerBlock, + EvoformerStack, +) +from dockformer.model.single_attention import SingleRowAttentionWithPairBias +from dockformer.model.primitives import Attention, GlobalAttention + + +def script_preset_(model: torch.nn.Module): + """ + TorchScript a handful of low-level but frequently used submodule types + that are known to be scriptable. + + Args: + model: + A torch.nn.Module. It should contain at least some modules from + this repository, or this function won't do anything. + """ + script_submodules_( + model, + [ + nn.Dropout, + Attention, + GlobalAttention, + EvoformerBlock, + ], + attempt_trace=False, + batch_dims=None, + ) + + +def _get_module_device(module: torch.nn.Module) -> torch.device: + """ + Fetches the device of a module, assuming that all of the module's + parameters reside on a single device + + Args: + module: A torch.nn.Module + Returns: + The module's device + """ + return next(module.parameters()).device + + +def _trace_module(module, batch_dims=None): + if(batch_dims is None): + batch_dims = () + + # Stand-in values + n_seq = 10 + n_res = 10 + + device = _get_module_device(module) + + def msa(channel_dim): + return torch.rand( + (*batch_dims, n_seq, n_res, channel_dim), + device=device, + ) + + def pair(channel_dim): + return torch.rand( + (*batch_dims, n_res, n_res, channel_dim), + device=device, + ) + + if(isinstance(module, SingleRowAttentionWithPairBias)): + inputs = { + "forward": ( + msa(module.c_in), # m + pair(module.c_z), # z + torch.randint( + 0, 2, + (*batch_dims, n_seq, n_res) + ), # mask + ), + } + else: + raise TypeError( + f"tracing is not supported for modules of type {type(module)}" + ) + + return torch.jit.trace_module(module, inputs) + + +def _script_submodules_helper_( + model, + types, + attempt_trace, + to_trace, +): + for name, child in model.named_children(): + if(types is None or any(isinstance(child, t) for t in types)): + try: + scripted = torch.jit.script(child) + setattr(model, name, scripted) + continue + except (RuntimeError, torch.jit.frontend.NotSupportedError) as e: + if(attempt_trace): + to_trace.add(type(child)) + else: + raise e + + _script_submodules_helper_(child, types, attempt_trace, to_trace) + + +def _trace_submodules_( + model, + types, + batch_dims=None, +): + for name, child in model.named_children(): + if(any(isinstance(child, t) for t in types)): + traced = _trace_module(child, batch_dims=batch_dims) + setattr(model, name, traced) + else: + _trace_submodules_(child, types, batch_dims=batch_dims) + + +def script_submodules_( + model: nn.Module, + types: Optional[Sequence[type]] = None, + attempt_trace: Optional[bool] = True, + batch_dims: Optional[Tuple[int]] = None, +): + """ + Convert all submodules whose types match one of those in the input + list to recursively scripted equivalents in place. To script the entire + model, just call torch.jit.script on it directly. + + When types is None, all submodules are scripted. + + Args: + model: + A torch.nn.Module + types: + A list of types of submodules to script + attempt_trace: + Whether to attempt to trace specified modules if scripting + fails. Recall that tracing eliminates all conditional + logic---with great tracing comes the mild responsibility of + having to remember to ensure that the modules in question + perform the same computations no matter what. + """ + to_trace = set() + + # Aggressively script as much as possible first... + _script_submodules_helper_(model, types, attempt_trace, to_trace) + + # ... and then trace stragglers. + if(attempt_trace and len(to_trace) > 0): + _trace_submodules_(model, to_trace, batch_dims=batch_dims) diff --git a/dockformer/model/triangular_attention.py b/dockformer/model/triangular_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..a4d4413adcbf2afc304754bf78c84f293b942e9d --- /dev/null +++ b/dockformer/model/triangular_attention.py @@ -0,0 +1,104 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partialmethod, partial +import math +from typing import Optional, List + +import torch +import torch.nn as nn + +from dockformer.model.primitives import Linear, LayerNorm, Attention +from dockformer.utils.tensor_utils import permute_final_dims + + +class TriangleAttention(nn.Module): + def __init__( + self, c_in, c_hidden, no_heads, starting=True, inf=1e9 + ): + """ + Args: + c_in: + Input channel dimension + c_hidden: + Overall hidden channel dimension (not per-head) + no_heads: + Number of attention heads + """ + super(TriangleAttention, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.starting = starting + self.inf = inf + + self.layer_norm = LayerNorm(self.c_in) + + self.linear = Linear(c_in, self.no_heads, bias=False, init="normal") + + self.mha = Attention( + self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads + ) + + def forward(self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + use_memory_efficient_kernel: bool = False, + use_lma: bool = False, + ) -> torch.Tensor: + """ + Args: + x: + [*, I, J, C_in] input tensor (e.g. the pair representation) + Returns: + [*, I, J, C_in] output tensor + """ + if mask is None: + # [*, I, J] + mask = x.new_ones( + x.shape[:-1], + ) + + if(not self.starting): + x = x.transpose(-2, -3) + mask = mask.transpose(-1, -2) + + # [*, I, J, C_in] + x = self.layer_norm(x) + + # [*, I, 1, 1, J] + mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] + + # [*, H, I, J] + triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1)) + + # [*, 1, H, I, J] + triangle_bias = triangle_bias.unsqueeze(-4) + + biases = [mask_bias, triangle_bias] + + x = self.mha( + q_x=x, + kv_x=x, + biases=biases, + use_memory_efficient_kernel=use_memory_efficient_kernel, + use_lma=use_lma + ) + + if(not self.starting): + x = x.transpose(-2, -3) + + return x diff --git a/dockformer/model/triangular_multiplicative_update.py b/dockformer/model/triangular_multiplicative_update.py new file mode 100644 index 0000000000000000000000000000000000000000..d6d2c59594845ae5cc8599251172446469f1a08a --- /dev/null +++ b/dockformer/model/triangular_multiplicative_update.py @@ -0,0 +1,173 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partialmethod +from typing import Optional +from abc import ABC, abstractmethod + +import torch +import torch.nn as nn + +from dockformer.model.primitives import Linear, LayerNorm +from dockformer.utils.precision_utils import is_fp16_enabled +from dockformer.utils.tensor_utils import permute_final_dims + + +class BaseTriangleMultiplicativeUpdate(nn.Module, ABC): + """ + Implements Algorithms 11 and 12. + """ + @abstractmethod + def __init__(self, c_z, c_hidden, _outgoing): + """ + Args: + c_z: + Input channel dimension + c: + Hidden channel dimension + """ + super(BaseTriangleMultiplicativeUpdate, self).__init__() + self.c_z = c_z + self.c_hidden = c_hidden + self._outgoing = _outgoing + + self.linear_g = Linear(self.c_z, self.c_z, init="gating") + self.linear_z = Linear(self.c_hidden, self.c_z, init="final") + + self.layer_norm_in = LayerNorm(self.c_z) + self.layer_norm_out = LayerNorm(self.c_hidden) + + self.sigmoid = nn.Sigmoid() + + def _combine_projections(self, + a: torch.Tensor, + b: torch.Tensor, + ) -> torch.Tensor: + if(self._outgoing): + a = permute_final_dims(a, (2, 0, 1)) + b = permute_final_dims(b, (2, 1, 0)) + else: + a = permute_final_dims(a, (2, 1, 0)) + b = permute_final_dims(b, (2, 0, 1)) + + p = torch.matmul(a, b) + + return permute_final_dims(p, (1, 2, 0)) + + @abstractmethod + def forward(self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + inplace_safe: bool = False, + _add_with_inplace: bool = False + ) -> torch.Tensor: + """ + Args: + x: + [*, N_res, N_res, C_z] input tensor + mask: + [*, N_res, N_res] input mask + Returns: + [*, N_res, N_res, C_z] output tensor + """ + pass + + +class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate): + """ + Implements Algorithms 11 and 12. + """ + def __init__(self, c_z, c_hidden, _outgoing=True): + """ + Args: + c_z: + Input channel dimension + c: + Hidden channel dimension + """ + super(TriangleMultiplicativeUpdate, self).__init__(c_z=c_z, + c_hidden=c_hidden, + _outgoing=_outgoing) + + self.linear_a_p = Linear(self.c_z, self.c_hidden) + self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating") + self.linear_b_p = Linear(self.c_z, self.c_hidden) + self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating") + + def forward(self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + inplace_safe: bool = False, + _add_with_inplace: bool = False, + ) -> torch.Tensor: + """ + Args: + x: + [*, N_res, N_res, C_z] input tensor + mask: + [*, N_res, N_res] input mask + Returns: + [*, N_res, N_res, C_z] output tensor + """ + + if mask is None: + mask = z.new_ones(z.shape[:-1]) + + mask = mask.unsqueeze(-1) + + z = self.layer_norm_in(z) + a = mask + a = a * self.sigmoid(self.linear_a_g(z)) + a = a * self.linear_a_p(z) + b = mask + b = b * self.sigmoid(self.linear_b_g(z)) + b = b * self.linear_b_p(z) + + # Prevents overflow of torch.matmul in combine projections in + # reduced-precision modes + a_std = a.std() + b_std = b.std() + if(is_fp16_enabled() and a_std != 0. and b_std != 0.): + a = a / a.std() + b = b / b.std() + + if(is_fp16_enabled()): + with torch.cuda.amp.autocast(enabled=False): + x = self._combine_projections(a.float(), b.float()) + else: + x = self._combine_projections(a, b) + + del a, b + x = self.layer_norm_out(x) + x = self.linear_z(x) + g = self.sigmoid(self.linear_g(z)) + x = x * g + + return x + + +class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate): + """ + Implements Algorithm 11. + """ + __init__ = partialmethod(TriangleMultiplicativeUpdate.__init__, _outgoing=True) + + +class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate): + """ + Implements Algorithm 12. + """ + __init__ = partialmethod(TriangleMultiplicativeUpdate.__init__, _outgoing=False) + diff --git a/dockformer/resources/__init__.py b/dockformer/resources/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dockformer/resources/stereo_chemical_props.txt b/dockformer/resources/stereo_chemical_props.txt new file mode 100644 index 0000000000000000000000000000000000000000..25262efd7689fc76f4d6ea1a7c75e342e840ea1d --- /dev/null +++ b/dockformer/resources/stereo_chemical_props.txt @@ -0,0 +1,345 @@ +Bond Residue Mean StdDev +CA-CB ALA 1.520 0.021 +N-CA ALA 1.459 0.020 +CA-C ALA 1.525 0.026 +C-O ALA 1.229 0.019 +CA-CB ARG 1.535 0.022 +CB-CG ARG 1.521 0.027 +CG-CD ARG 1.515 0.025 +CD-NE ARG 1.460 0.017 +NE-CZ ARG 1.326 0.013 +CZ-NH1 ARG 1.326 0.013 +CZ-NH2 ARG 1.326 0.013 +N-CA ARG 1.459 0.020 +CA-C ARG 1.525 0.026 +C-O ARG 1.229 0.019 +CA-CB ASN 1.527 0.026 +CB-CG ASN 1.506 0.023 +CG-OD1 ASN 1.235 0.022 +CG-ND2 ASN 1.324 0.025 +N-CA ASN 1.459 0.020 +CA-C ASN 1.525 0.026 +C-O ASN 1.229 0.019 +CA-CB ASP 1.535 0.022 +CB-CG ASP 1.513 0.021 +CG-OD1 ASP 1.249 0.023 +CG-OD2 ASP 1.249 0.023 +N-CA ASP 1.459 0.020 +CA-C ASP 1.525 0.026 +C-O ASP 1.229 0.019 +CA-CB CYS 1.526 0.013 +CB-SG CYS 1.812 0.016 +N-CA CYS 1.459 0.020 +CA-C CYS 1.525 0.026 +C-O CYS 1.229 0.019 +CA-CB GLU 1.535 0.022 +CB-CG GLU 1.517 0.019 +CG-CD GLU 1.515 0.015 +CD-OE1 GLU 1.252 0.011 +CD-OE2 GLU 1.252 0.011 +N-CA GLU 1.459 0.020 +CA-C GLU 1.525 0.026 +C-O GLU 1.229 0.019 +CA-CB GLN 1.535 0.022 +CB-CG GLN 1.521 0.027 +CG-CD GLN 1.506 0.023 +CD-OE1 GLN 1.235 0.022 +CD-NE2 GLN 1.324 0.025 +N-CA GLN 1.459 0.020 +CA-C GLN 1.525 0.026 +C-O GLN 1.229 0.019 +N-CA GLY 1.456 0.015 +CA-C GLY 1.514 0.016 +C-O GLY 1.232 0.016 +CA-CB HIS 1.535 0.022 +CB-CG HIS 1.492 0.016 +CG-ND1 HIS 1.369 0.015 +CG-CD2 HIS 1.353 0.017 +ND1-CE1 HIS 1.343 0.025 +CD2-NE2 HIS 1.415 0.021 +CE1-NE2 HIS 1.322 0.023 +N-CA HIS 1.459 0.020 +CA-C HIS 1.525 0.026 +C-O HIS 1.229 0.019 +CA-CB ILE 1.544 0.023 +CB-CG1 ILE 1.536 0.028 +CB-CG2 ILE 1.524 0.031 +CG1-CD1 ILE 1.500 0.069 +N-CA ILE 1.459 0.020 +CA-C ILE 1.525 0.026 +C-O ILE 1.229 0.019 +CA-CB LEU 1.533 0.023 +CB-CG LEU 1.521 0.029 +CG-CD1 LEU 1.514 0.037 +CG-CD2 LEU 1.514 0.037 +N-CA LEU 1.459 0.020 +CA-C LEU 1.525 0.026 +C-O LEU 1.229 0.019 +CA-CB LYS 1.535 0.022 +CB-CG LYS 1.521 0.027 +CG-CD LYS 1.520 0.034 +CD-CE LYS 1.508 0.025 +CE-NZ LYS 1.486 0.025 +N-CA LYS 1.459 0.020 +CA-C LYS 1.525 0.026 +C-O LYS 1.229 0.019 +CA-CB MET 1.535 0.022 +CB-CG MET 1.509 0.032 +CG-SD MET 1.807 0.026 +SD-CE MET 1.774 0.056 +N-CA MET 1.459 0.020 +CA-C MET 1.525 0.026 +C-O MET 1.229 0.019 +CA-CB PHE 1.535 0.022 +CB-CG PHE 1.509 0.017 +CG-CD1 PHE 1.383 0.015 +CG-CD2 PHE 1.383 0.015 +CD1-CE1 PHE 1.388 0.020 +CD2-CE2 PHE 1.388 0.020 +CE1-CZ PHE 1.369 0.019 +CE2-CZ PHE 1.369 0.019 +N-CA PHE 1.459 0.020 +CA-C PHE 1.525 0.026 +C-O PHE 1.229 0.019 +CA-CB PRO 1.531 0.020 +CB-CG PRO 1.495 0.050 +CG-CD PRO 1.502 0.033 +CD-N PRO 1.474 0.014 +N-CA PRO 1.468 0.017 +CA-C PRO 1.524 0.020 +C-O PRO 1.228 0.020 +CA-CB SER 1.525 0.015 +CB-OG SER 1.418 0.013 +N-CA SER 1.459 0.020 +CA-C SER 1.525 0.026 +C-O SER 1.229 0.019 +CA-CB THR 1.529 0.026 +CB-OG1 THR 1.428 0.020 +CB-CG2 THR 1.519 0.033 +N-CA THR 1.459 0.020 +CA-C THR 1.525 0.026 +C-O THR 1.229 0.019 +CA-CB TRP 1.535 0.022 +CB-CG TRP 1.498 0.018 +CG-CD1 TRP 1.363 0.014 +CG-CD2 TRP 1.432 0.017 +CD1-NE1 TRP 1.375 0.017 +NE1-CE2 TRP 1.371 0.013 +CD2-CE2 TRP 1.409 0.012 +CD2-CE3 TRP 1.399 0.015 +CE2-CZ2 TRP 1.393 0.017 +CE3-CZ3 TRP 1.380 0.017 +CZ2-CH2 TRP 1.369 0.019 +CZ3-CH2 TRP 1.396 0.016 +N-CA TRP 1.459 0.020 +CA-C TRP 1.525 0.026 +C-O TRP 1.229 0.019 +CA-CB TYR 1.535 0.022 +CB-CG TYR 1.512 0.015 +CG-CD1 TYR 1.387 0.013 +CG-CD2 TYR 1.387 0.013 +CD1-CE1 TYR 1.389 0.015 +CD2-CE2 TYR 1.389 0.015 +CE1-CZ TYR 1.381 0.013 +CE2-CZ TYR 1.381 0.013 +CZ-OH TYR 1.374 0.017 +N-CA TYR 1.459 0.020 +CA-C TYR 1.525 0.026 +C-O TYR 1.229 0.019 +CA-CB VAL 1.543 0.021 +CB-CG1 VAL 1.524 0.021 +CB-CG2 VAL 1.524 0.021 +N-CA VAL 1.459 0.020 +CA-C VAL 1.525 0.026 +C-O VAL 1.229 0.019 +- + +Angle Residue Mean StdDev +N-CA-CB ALA 110.1 1.4 +CB-CA-C ALA 110.1 1.5 +N-CA-C ALA 111.0 2.7 +CA-C-O ALA 120.1 2.1 +N-CA-CB ARG 110.6 1.8 +CB-CA-C ARG 110.4 2.0 +CA-CB-CG ARG 113.4 2.2 +CB-CG-CD ARG 111.6 2.6 +CG-CD-NE ARG 111.8 2.1 +CD-NE-CZ ARG 123.6 1.4 +NE-CZ-NH1 ARG 120.3 0.5 +NE-CZ-NH2 ARG 120.3 0.5 +NH1-CZ-NH2 ARG 119.4 1.1 +N-CA-C ARG 111.0 2.7 +CA-C-O ARG 120.1 2.1 +N-CA-CB ASN 110.6 1.8 +CB-CA-C ASN 110.4 2.0 +CA-CB-CG ASN 113.4 2.2 +CB-CG-ND2 ASN 116.7 2.4 +CB-CG-OD1 ASN 121.6 2.0 +ND2-CG-OD1 ASN 121.9 2.3 +N-CA-C ASN 111.0 2.7 +CA-C-O ASN 120.1 2.1 +N-CA-CB ASP 110.6 1.8 +CB-CA-C ASP 110.4 2.0 +CA-CB-CG ASP 113.4 2.2 +CB-CG-OD1 ASP 118.3 0.9 +CB-CG-OD2 ASP 118.3 0.9 +OD1-CG-OD2 ASP 123.3 1.9 +N-CA-C ASP 111.0 2.7 +CA-C-O ASP 120.1 2.1 +N-CA-CB CYS 110.8 1.5 +CB-CA-C CYS 111.5 1.2 +CA-CB-SG CYS 114.2 1.1 +N-CA-C CYS 111.0 2.7 +CA-C-O CYS 120.1 2.1 +N-CA-CB GLU 110.6 1.8 +CB-CA-C GLU 110.4 2.0 +CA-CB-CG GLU 113.4 2.2 +CB-CG-CD GLU 114.2 2.7 +CG-CD-OE1 GLU 118.3 2.0 +CG-CD-OE2 GLU 118.3 2.0 +OE1-CD-OE2 GLU 123.3 1.2 +N-CA-C GLU 111.0 2.7 +CA-C-O GLU 120.1 2.1 +N-CA-CB GLN 110.6 1.8 +CB-CA-C GLN 110.4 2.0 +CA-CB-CG GLN 113.4 2.2 +CB-CG-CD GLN 111.6 2.6 +CG-CD-OE1 GLN 121.6 2.0 +CG-CD-NE2 GLN 116.7 2.4 +OE1-CD-NE2 GLN 121.9 2.3 +N-CA-C GLN 111.0 2.7 +CA-C-O GLN 120.1 2.1 +N-CA-C GLY 113.1 2.5 +CA-C-O GLY 120.6 1.8 +N-CA-CB HIS 110.6 1.8 +CB-CA-C HIS 110.4 2.0 +CA-CB-CG HIS 113.6 1.7 +CB-CG-ND1 HIS 123.2 2.5 +CB-CG-CD2 HIS 130.8 3.1 +CG-ND1-CE1 HIS 108.2 1.4 +ND1-CE1-NE2 HIS 109.9 2.2 +CE1-NE2-CD2 HIS 106.6 2.5 +NE2-CD2-CG HIS 109.2 1.9 +CD2-CG-ND1 HIS 106.0 1.4 +N-CA-C HIS 111.0 2.7 +CA-C-O HIS 120.1 2.1 +N-CA-CB ILE 110.8 2.3 +CB-CA-C ILE 111.6 2.0 +CA-CB-CG1 ILE 111.0 1.9 +CB-CG1-CD1 ILE 113.9 2.8 +CA-CB-CG2 ILE 110.9 2.0 +CG1-CB-CG2 ILE 111.4 2.2 +N-CA-C ILE 111.0 2.7 +CA-C-O ILE 120.1 2.1 +N-CA-CB LEU 110.4 2.0 +CB-CA-C LEU 110.2 1.9 +CA-CB-CG LEU 115.3 2.3 +CB-CG-CD1 LEU 111.0 1.7 +CB-CG-CD2 LEU 111.0 1.7 +CD1-CG-CD2 LEU 110.5 3.0 +N-CA-C LEU 111.0 2.7 +CA-C-O LEU 120.1 2.1 +N-CA-CB LYS 110.6 1.8 +CB-CA-C LYS 110.4 2.0 +CA-CB-CG LYS 113.4 2.2 +CB-CG-CD LYS 111.6 2.6 +CG-CD-CE LYS 111.9 3.0 +CD-CE-NZ LYS 111.7 2.3 +N-CA-C LYS 111.0 2.7 +CA-C-O LYS 120.1 2.1 +N-CA-CB MET 110.6 1.8 +CB-CA-C MET 110.4 2.0 +CA-CB-CG MET 113.3 1.7 +CB-CG-SD MET 112.4 3.0 +CG-SD-CE MET 100.2 1.6 +N-CA-C MET 111.0 2.7 +CA-C-O MET 120.1 2.1 +N-CA-CB PHE 110.6 1.8 +CB-CA-C PHE 110.4 2.0 +CA-CB-CG PHE 113.9 2.4 +CB-CG-CD1 PHE 120.8 0.7 +CB-CG-CD2 PHE 120.8 0.7 +CD1-CG-CD2 PHE 118.3 1.3 +CG-CD1-CE1 PHE 120.8 1.1 +CG-CD2-CE2 PHE 120.8 1.1 +CD1-CE1-CZ PHE 120.1 1.2 +CD2-CE2-CZ PHE 120.1 1.2 +CE1-CZ-CE2 PHE 120.0 1.8 +N-CA-C PHE 111.0 2.7 +CA-C-O PHE 120.1 2.1 +N-CA-CB PRO 103.3 1.2 +CB-CA-C PRO 111.7 2.1 +CA-CB-CG PRO 104.8 1.9 +CB-CG-CD PRO 106.5 3.9 +CG-CD-N PRO 103.2 1.5 +CA-N-CD PRO 111.7 1.4 +N-CA-C PRO 112.1 2.6 +CA-C-O PRO 120.2 2.4 +N-CA-CB SER 110.5 1.5 +CB-CA-C SER 110.1 1.9 +CA-CB-OG SER 111.2 2.7 +N-CA-C SER 111.0 2.7 +CA-C-O SER 120.1 2.1 +N-CA-CB THR 110.3 1.9 +CB-CA-C THR 111.6 2.7 +CA-CB-OG1 THR 109.0 2.1 +CA-CB-CG2 THR 112.4 1.4 +OG1-CB-CG2 THR 110.0 2.3 +N-CA-C THR 111.0 2.7 +CA-C-O THR 120.1 2.1 +N-CA-CB TRP 110.6 1.8 +CB-CA-C TRP 110.4 2.0 +CA-CB-CG TRP 113.7 1.9 +CB-CG-CD1 TRP 127.0 1.3 +CB-CG-CD2 TRP 126.6 1.3 +CD1-CG-CD2 TRP 106.3 0.8 +CG-CD1-NE1 TRP 110.1 1.0 +CD1-NE1-CE2 TRP 109.0 0.9 +NE1-CE2-CD2 TRP 107.3 1.0 +CE2-CD2-CG TRP 107.3 0.8 +CG-CD2-CE3 TRP 133.9 0.9 +NE1-CE2-CZ2 TRP 130.4 1.1 +CE3-CD2-CE2 TRP 118.7 1.2 +CD2-CE2-CZ2 TRP 122.3 1.2 +CE2-CZ2-CH2 TRP 117.4 1.0 +CZ2-CH2-CZ3 TRP 121.6 1.2 +CH2-CZ3-CE3 TRP 121.2 1.1 +CZ3-CE3-CD2 TRP 118.8 1.3 +N-CA-C TRP 111.0 2.7 +CA-C-O TRP 120.1 2.1 +N-CA-CB TYR 110.6 1.8 +CB-CA-C TYR 110.4 2.0 +CA-CB-CG TYR 113.4 1.9 +CB-CG-CD1 TYR 121.0 0.6 +CB-CG-CD2 TYR 121.0 0.6 +CD1-CG-CD2 TYR 117.9 1.1 +CG-CD1-CE1 TYR 121.3 0.8 +CG-CD2-CE2 TYR 121.3 0.8 +CD1-CE1-CZ TYR 119.8 0.9 +CD2-CE2-CZ TYR 119.8 0.9 +CE1-CZ-CE2 TYR 119.8 1.6 +CE1-CZ-OH TYR 120.1 2.7 +CE2-CZ-OH TYR 120.1 2.7 +N-CA-C TYR 111.0 2.7 +CA-C-O TYR 120.1 2.1 +N-CA-CB VAL 111.5 2.2 +CB-CA-C VAL 111.4 1.9 +CA-CB-CG1 VAL 110.9 1.5 +CA-CB-CG2 VAL 110.9 1.5 +CG1-CB-CG2 VAL 110.9 1.6 +N-CA-C VAL 111.0 2.7 +CA-C-O VAL 120.1 2.1 +- + +Non-bonded distance Minimum Dist Tolerance +C-C 3.4 1.5 +C-N 3.25 1.5 +C-S 3.5 1.5 +C-O 3.22 1.5 +N-N 3.1 1.5 +N-S 3.35 1.5 +N-O 3.07 1.5 +O-S 3.32 1.5 +O-O 3.04 1.5 +S-S 2.03 1.0 +- diff --git a/dockformer/utils/__init__.py b/dockformer/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dockformer/utils/callbacks.py b/dockformer/utils/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..1db267346dc8ae8ccdd3ed9526626e9b45f2c211 --- /dev/null +++ b/dockformer/utils/callbacks.py @@ -0,0 +1,15 @@ +from lightning.pytorch.callbacks import EarlyStopping +from lightning_utilities.core.rank_zero import rank_zero_info + + +class EarlyStoppingVerbose(EarlyStopping): + """ + The default EarlyStopping callback's verbose mode is too verbose. + This class outputs a message only when it's getting ready to stop. + """ + def _evalute_stopping_criteria(self, *args, **kwargs): + should_stop, reason = super()._evalute_stopping_criteria(*args, **kwargs) + if(should_stop): + rank_zero_info(f"{reason}\n") + + return should_stop, reason diff --git a/dockformer/utils/checkpointing.py b/dockformer/utils/checkpointing.py new file mode 100644 index 0000000000000000000000000000000000000000..61886a24fd097835db9fe24ae01e85e100d03143 --- /dev/null +++ b/dockformer/utils/checkpointing.py @@ -0,0 +1,78 @@ +# Copyright 2021 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +from typing import Any, Tuple, List, Callable, Optional + + +import torch +import torch.utils.checkpoint + + +BLOCK_ARG = Any +BLOCK_ARGS = List[BLOCK_ARG] + + +@torch.jit.ignore +def checkpoint_blocks( + blocks: List[Callable], + args: BLOCK_ARGS, + blocks_per_ckpt: Optional[int], +) -> BLOCK_ARGS: + """ + Chunk a list of blocks and run each chunk with activation + checkpointing. We define a "block" as a callable whose only inputs are + the outputs of the previous block. + + Implements Subsection 1.11.8 + + Args: + blocks: + List of blocks + args: + Tuple of arguments for the first block. + blocks_per_ckpt: + Size of each chunk. A higher value corresponds to fewer + checkpoints, and trades memory for speed. If None, no checkpointing + is performed. + Returns: + The output of the final block + """ + def wrap(a): + return (a,) if type(a) is not tuple else a + + def exec(b, a): + for block in b: + a = wrap(block(*a)) + return a + + def chunker(s, e): + def exec_sliced(*a): + return exec(blocks[s:e], a) + + return exec_sliced + + # Avoids mishaps when the blocks take just one argument + args = wrap(args) + + if blocks_per_ckpt is None or not torch.is_grad_enabled(): + return exec(blocks, args) + elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks): + raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)") + + for s in range(0, len(blocks), blocks_per_ckpt): + e = s + blocks_per_ckpt + args = torch.utils.checkpoint.checkpoint(chunker(s, e), *args, use_reentrant=True) + args = wrap(args) + + return args diff --git a/dockformer/utils/config_tools.py b/dockformer/utils/config_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..fde076dd87f90bdab864e267d6c9f0d9fd97c874 --- /dev/null +++ b/dockformer/utils/config_tools.py @@ -0,0 +1,32 @@ +import importlib.util + +import ml_collections as mlc + + +def set_inf(c, inf): + for k, v in c.items(): + if isinstance(v, mlc.ConfigDict): + set_inf(v, inf) + elif k == "inf": + c[k] = inf + + +def enforce_config_constraints(config): + def string_to_setting(s): + path = s.split('.') + setting = config + for p in path: + setting = setting.get(p) + + return setting + + mutually_exclusive_bools = [ + ( + "globals.use_lma", + ), + ] + + for options in mutually_exclusive_bools: + option_settings = [string_to_setting(o) for o in options] + if sum(option_settings) > 1: + raise ValueError(f"Only one of {', '.join(options)} may be set at a time") diff --git a/dockformer/utils/consts.py b/dockformer/utils/consts.py new file mode 100644 index 0000000000000000000000000000000000000000..0c376f5196d12f81b9fe48f04d877a6cc691fcc3 --- /dev/null +++ b/dockformer/utils/consts.py @@ -0,0 +1,25 @@ +from rdkit.Chem.rdchem import ChiralType, BondType + +# Survey of atom types in the PDBBind +# {'C': 403253, 'O': 101283, 'N': 81325, 'S': 6262, 'F': 5256, 'P': 3378, 'Cl': 2920, 'Br': 552, 'B': 237, 'I': 185, +# 'H': 181, 'Fe': 19, 'Se': 15, 'Ru': 10, 'Si': 5, 'Co': 4, 'Ir': 4, 'As': 2, 'Pt': 2, 'V': 1, 'Mg': 1, 'Be': 1, +# 'Rh': 1, 'Cu': 1, 'Re': 1} +# I have changed the uncommon types to common ions for the plinder dataset +# {'As': "Zn", 'Pt': "Mn", 'V': "Ca", 'Mg': "Mg", 'Be': "Na", 'Rh': "Al", 'Cu': "K", 'Re': "Ni"} + +POSSIBLE_ATOM_TYPES = ['C', 'O', 'N', 'S', 'F', 'P', 'Cl', 'Br', 'B', 'I', 'H', 'Fe', 'Se', 'Ru', 'Si', 'Co', 'Ir', + 'Zn', 'Mn', 'Ca', 'Mg', 'Na', 'Al', 'K', 'Ni'] + +# bonds Counter({BondType.SINGLE: 366857, BondType.AROMATIC: 214238, BondType.DOUBLE: 59725, BondType.TRIPLE: 866, +# BondType.UNSPECIFIED: 18, BondType.DATIVE: 8}) +POSSIBLE_BOND_TYPES = [BondType.SINGLE, BondType.DOUBLE, BondType.TRIPLE, BondType.AROMATIC, BondType.UNSPECIFIED, + BondType.DATIVE] + +# {0: 580061, 1: 13273, -1: 11473, 2: 44, 7: 17, -2: 8, 9: 7, 10: 7, 5: 3, 3: 3, 4: 1, 6: 1, 8: 1} +POSSIBLE_CHARGES = [-1, 0, 1] + +# {ChiralType.CHI_UNSPECIFIED: 551374, ChiralType.CHI_TETRAHEDRAL_CCW: 27328, ChiralType.CHI_TETRAHEDRAL_CW: 26178, +# ChiralType.CHI_OCTAHEDRAL: 13, ChiralType.CHI_SQUAREPLANAR: 3, ChiralType.CHI_TRIGONALBIPYRAMIDAL: 3} +POSSIBLE_CHIRALITIES = [ChiralType.CHI_UNSPECIFIED, ChiralType.CHI_TETRAHEDRAL_CCW, ChiralType.CHI_TETRAHEDRAL_CW, + ChiralType.CHI_OCTAHEDRAL, ChiralType.CHI_SQUAREPLANAR, ChiralType.CHI_TRIGONALBIPYRAMIDAL] + diff --git a/dockformer/utils/exponential_moving_average.py b/dockformer/utils/exponential_moving_average.py new file mode 100644 index 0000000000000000000000000000000000000000..5b7c2ec3d022e7a82181c32d475d0dd22df3c9df --- /dev/null +++ b/dockformer/utils/exponential_moving_average.py @@ -0,0 +1,71 @@ +from collections import OrderedDict +import copy +import torch +import torch.nn as nn + +from dockformer.utils.tensor_utils import tensor_tree_map + + +class ExponentialMovingAverage: + """ + Maintains moving averages of parameters with exponential decay + + At each step, the stored copy `copy` of each parameter `param` is + updated as follows: + + `copy = decay * copy + (1 - decay) * param` + + where `decay` is an attribute of the ExponentialMovingAverage object. + """ + + def __init__(self, model: nn.Module, decay: float): + """ + Args: + model: + A torch.nn.Module whose parameters are to be tracked + decay: + A value (usually close to 1.) by which updates are + weighted as part of the above formula + """ + super(ExponentialMovingAverage, self).__init__() + + clone_param = lambda t: t.clone().detach() + self.params = tensor_tree_map(clone_param, model.state_dict()) + self.decay = decay + self.device = next(model.parameters()).device + + def to(self, device): + self.params = tensor_tree_map(lambda t: t.to(device), self.params) + self.device = device + + def _update_state_dict_(self, update, state_dict): + with torch.no_grad(): + for k, v in update.items(): + stored = state_dict[k] + if not isinstance(v, torch.Tensor): + self._update_state_dict_(v, stored) + else: + diff = stored - v + diff *= 1 - self.decay + stored -= diff + + def update(self, model: torch.nn.Module) -> None: + """ + Updates the stored parameters using the state dict of the provided + module. The module should have the same structure as that used to + initialize the ExponentialMovingAverage object. + """ + self._update_state_dict_(model.state_dict(), self.params) + + def load_state_dict(self, state_dict: OrderedDict) -> None: + for k in state_dict["params"].keys(): + self.params[k] = state_dict["params"][k].clone() + self.decay = state_dict["decay"] + + def state_dict(self) -> OrderedDict: + return OrderedDict( + { + "params": self.params, + "decay": self.decay, + } + ) diff --git a/dockformer/utils/feats.py b/dockformer/utils/feats.py new file mode 100644 index 0000000000000000000000000000000000000000..790b211a0dc7a5fecac74af35397011eacf99ec2 --- /dev/null +++ b/dockformer/utils/feats.py @@ -0,0 +1,174 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import torch +import torch.nn as nn +from typing import Dict, Union + +from dockformer.utils import protein +import dockformer.utils.residue_constants as rc +from dockformer.utils.geometry import rigid_matrix_vector, rotation_matrix, vector +from dockformer.utils.rigid_utils import Rotation, Rigid +from dockformer.utils.tensor_utils import ( + batched_gather, + one_hot, + tree_map, + tensor_tree_map, +) + + +def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): + # rc.restype_order["X"] defines a ligand, and the atom position used is the CA + is_gly_or_lig = (aatype == rc.restype_order["G"]) | (aatype == rc.restype_order["Z"]) + ca_idx = rc.atom_order["CA"] + cb_idx = rc.atom_order["CB"] + pseudo_beta = torch.where( + is_gly_or_lig[..., None].expand(*((-1,) * len(is_gly_or_lig.shape)), 3), + all_atom_positions[..., ca_idx, :], + all_atom_positions[..., cb_idx, :], + ) + + if all_atom_masks is not None: + pseudo_beta_mask = torch.where( + is_gly_or_lig, + all_atom_masks[..., ca_idx], + all_atom_masks[..., cb_idx], + ) + return pseudo_beta, pseudo_beta_mask + else: + return pseudo_beta + + +def atom14_to_atom37(atom14, batch): + atom37_data = batched_gather( + atom14, + batch["residx_atom37_to_atom14"], + dim=-2, + no_batch_dims=len(atom14.shape[:-2]), + ) + + atom37_data = atom37_data * batch["atom37_atom_exists"][..., None] + + return atom37_data + + +def torsion_angles_to_frames( + r: Union[Rigid, rigid_matrix_vector.Rigid3Array], + alpha: torch.Tensor, + aatype: torch.Tensor, + rrgdf: torch.Tensor, +): + + rigid_type = type(r) + + # [*, N, 8, 4, 4] + default_4x4 = rrgdf[aatype, ...] + + # [*, N, 8] transformations, i.e. + # One [*, N, 8, 3, 3] rotation matrix and + # One [*, N, 8, 3] translation matrix + default_r = rigid_type.from_tensor_4x4(default_4x4) + + bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2)) + bb_rot[..., 1] = 1 + + # [*, N, 8, 2] + alpha = torch.cat( + [bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2 + ) + + # [*, N, 8, 3, 3] + # Produces rotation matrices of the form: + # [ + # [1, 0 , 0 ], + # [0, a_2,-a_1], + # [0, a_1, a_2] + # ] + # This follows the original code rather than the supplement, which uses + # different indices. + + all_rots = alpha.new_zeros(default_r.shape + (4, 4)) + all_rots[..., 0, 0] = 1 + all_rots[..., 1, 1] = alpha[..., 1] + all_rots[..., 1, 2] = -alpha[..., 0] + all_rots[..., 2, 1:3] = alpha + + all_rots = rigid_type.from_tensor_4x4(all_rots) + all_frames = default_r.compose(all_rots) + + chi2_frame_to_frame = all_frames[..., 5] + chi3_frame_to_frame = all_frames[..., 6] + chi4_frame_to_frame = all_frames[..., 7] + + chi1_frame_to_bb = all_frames[..., 4] + chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame) + chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) + chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) + + all_frames_to_bb = rigid_type.cat( + [ + all_frames[..., :5], + chi2_frame_to_bb.unsqueeze(-1), + chi3_frame_to_bb.unsqueeze(-1), + chi4_frame_to_bb.unsqueeze(-1), + ], + dim=-1, + ) + + all_frames_to_global = r[..., None].compose(all_frames_to_bb) + + return all_frames_to_global + + +def frames_and_literature_positions_to_atom14_pos( + r: Union[Rigid, rigid_matrix_vector.Rigid3Array], + aatype: torch.Tensor, + default_frames, + group_idx, + atom_mask, + lit_positions, +): + # [*, N, 14, 4, 4] + default_4x4 = default_frames[aatype, ...] + + # [*, N, 14] + group_mask = group_idx[aatype, ...] + + # [*, N, 14, 8] + group_mask = nn.functional.one_hot( + group_mask, + num_classes=default_frames.shape[-3], + ) + + # [*, N, 14, 8] + t_atoms_to_global = r[..., None, :] * group_mask + + # [*, N, 14] + t_atoms_to_global = t_atoms_to_global.map_tensor_fn( + lambda x: torch.sum(x, dim=-1) + ) + + # [*, N, 14] + atom_mask = atom_mask[aatype, ...].unsqueeze(-1) + + # [*, N, 14, 3] + lit_positions = lit_positions[aatype, ...] + pred_positions = t_atoms_to_global.apply(lit_positions) + pred_positions = pred_positions * atom_mask + + return pred_positions diff --git a/dockformer/utils/geometry/__init__.py b/dockformer/utils/geometry/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3064f53e04ed5b4c25604ee97a13437e7372f413 --- /dev/null +++ b/dockformer/utils/geometry/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Geometry Module.""" + +from dockformer.utils.geometry import rigid_matrix_vector +from dockformer.utils.geometry import rotation_matrix +from dockformer.utils.geometry import vector + +Rot3Array = rotation_matrix.Rot3Array +Rigid3Array = rigid_matrix_vector.Rigid3Array + +Vec3Array = vector.Vec3Array +square_euclidean_distance = vector.square_euclidean_distance +euclidean_distance = vector.euclidean_distance +dihedral_angle = vector.dihedral_angle +dot = vector.dot +cross = vector.cross diff --git a/dockformer/utils/geometry/quat_rigid.py b/dockformer/utils/geometry/quat_rigid.py new file mode 100644 index 0000000000000000000000000000000000000000..e771fd70aa032a91dca6d9c3363b4442647c260e --- /dev/null +++ b/dockformer/utils/geometry/quat_rigid.py @@ -0,0 +1,38 @@ +import torch +import torch.nn as nn + +from dockformer.model.primitives import Linear +from dockformer.utils.geometry.rigid_matrix_vector import Rigid3Array +from dockformer.utils.geometry.rotation_matrix import Rot3Array +from dockformer.utils.geometry.vector import Vec3Array + + +class QuatRigid(nn.Module): + def __init__(self, c_hidden, full_quat): + super().__init__() + self.full_quat = full_quat + if self.full_quat: + rigid_dim = 7 + else: + rigid_dim = 6 + + self.linear = Linear(c_hidden, rigid_dim, init="final", precision=torch.float32) + + def forward(self, activations: torch.Tensor) -> Rigid3Array: + # NOTE: During training, this needs to be run in higher precision + rigid_flat = self.linear(activations) + + rigid_flat = torch.unbind(rigid_flat, dim=-1) + if(self.full_quat): + qw, qx, qy, qz = rigid_flat[:4] + translation = rigid_flat[4:] + else: + qx, qy, qz = rigid_flat[:3] + qw = torch.ones_like(qx) + translation = rigid_flat[3:] + + rotation = Rot3Array.from_quaternion( + qw, qx, qy, qz, normalize=True, + ) + translation = Vec3Array(*translation) + return Rigid3Array(rotation, translation) diff --git a/dockformer/utils/geometry/rigid_matrix_vector.py b/dockformer/utils/geometry/rigid_matrix_vector.py new file mode 100644 index 0000000000000000000000000000000000000000..aded26976d780c0ff4421402e44df57a50c775c5 --- /dev/null +++ b/dockformer/utils/geometry/rigid_matrix_vector.py @@ -0,0 +1,181 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rigid3Array Transformations represented by a Matrix and a Vector.""" + +from __future__ import annotations +import dataclasses +from typing import Union, List + +import torch + +from dockformer.utils.geometry import rotation_matrix +from dockformer.utils.geometry import vector + + +Float = Union[float, torch.Tensor] + + +@dataclasses.dataclass(frozen=True) +class Rigid3Array: + """Rigid Transformation, i.e. element of special euclidean group.""" + + rotation: rotation_matrix.Rot3Array + translation: vector.Vec3Array + + def __matmul__(self, other: Rigid3Array) -> Rigid3Array: + new_rotation = self.rotation @ other.rotation # __matmul__ + new_translation = self.apply_to_point(other.translation) + return Rigid3Array(new_rotation, new_translation) + + def __getitem__(self, index) -> Rigid3Array: + return Rigid3Array( + self.rotation[index], + self.translation[index], + ) + + def __mul__(self, other: torch.Tensor) -> Rigid3Array: + return Rigid3Array( + self.rotation * other, + self.translation * other, + ) + + def map_tensor_fn(self, fn) -> Rigid3Array: + return Rigid3Array( + self.rotation.map_tensor_fn(fn), + self.translation.map_tensor_fn(fn), + ) + + def inverse(self) -> Rigid3Array: + """Return Rigid3Array corresponding to inverse transform.""" + inv_rotation = self.rotation.inverse() + inv_translation = inv_rotation.apply_to_point(-self.translation) + return Rigid3Array(inv_rotation, inv_translation) + + def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Apply Rigid3Array transform to point.""" + return self.rotation.apply_to_point(point) + self.translation + + def apply(self, point: torch.Tensor) -> torch.Tensor: + return self.apply_to_point(vector.Vec3Array.from_array(point)).to_tensor() + + def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Apply inverse Rigid3Array transform to point.""" + new_point = point - self.translation + return self.rotation.apply_inverse_to_point(new_point) + + def invert_apply(self, point: torch.Tensor) -> torch.Tensor: + return self.apply_inverse_to_point(vector.Vec3Array.from_array(point)).to_tensor() + + def compose_rotation(self, other_rotation): + rot = self.rotation @ other_rotation + return Rigid3Array(rot, self.translation.clone()) + + def compose(self, other_rigid): + return self @ other_rigid + + def unsqueeze(self, dim: int): + return Rigid3Array( + self.rotation.unsqueeze(dim), + self.translation.unsqueeze(dim), + ) + + @property + def shape(self) -> torch.Size: + return self.rotation.xx.shape + + @property + def dtype(self) -> torch.dtype: + return self.rotation.xx.dtype + + @property + def device(self) -> torch.device: + return self.rotation.xx.device + + @classmethod + def identity(cls, shape, device) -> Rigid3Array: + """Return identity Rigid3Array of given shape.""" + return cls( + rotation_matrix.Rot3Array.identity(shape, device), + vector.Vec3Array.zeros(shape, device) + ) + + @classmethod + def cat(cls, rigids: List[Rigid3Array], dim: int) -> Rigid3Array: + return cls( + rotation_matrix.Rot3Array.cat( + [r.rotation for r in rigids], dim=dim + ), + vector.Vec3Array.cat( + [r.translation for r in rigids], dim=dim + ), + ) + + def scale_translation(self, factor: Float) -> Rigid3Array: + """Scale translation in Rigid3Array by 'factor'.""" + return Rigid3Array(self.rotation, self.translation * factor) + + def to_tensor(self) -> torch.Tensor: + rot_array = self.rotation.to_tensor() + vec_array = self.translation.to_tensor() + array = torch.zeros( + rot_array.shape[:-2] + (4, 4), + device=rot_array.device, + dtype=rot_array.dtype + ) + array[..., :3, :3] = rot_array + array[..., :3, 3] = vec_array + array[..., 3, 3] = 1. + return array + + def to_tensor_4x4(self) -> torch.Tensor: + return self.to_tensor() + + def reshape(self, new_shape) -> Rigid3Array: + rots = self.rotation.reshape(new_shape) + trans = self.translation.reshape(new_shape) + return Rigid3Array(rots, trans) + + def stop_rot_gradient(self) -> Rigid3Array: + return Rigid3Array( + self.rotation.stop_gradient(), + self.translation, + ) + + @classmethod + def from_array(cls, array): + rot = rotation_matrix.Rot3Array.from_array( + array[..., :3, :3], + ) + vec = vector.Vec3Array.from_array(array[..., :3, 3]) + return cls(rot, vec) + + @classmethod + def from_tensor_4x4(cls, array): + return cls.from_array(array) + + @classmethod + def from_array4x4(cls, array: torch.tensor) -> Rigid3Array: + """Construct Rigid3Array from homogeneous 4x4 array.""" + rotation = rotation_matrix.Rot3Array( + array[..., 0, 0], array[..., 0, 1], array[..., 0, 2], + array[..., 1, 0], array[..., 1, 1], array[..., 1, 2], + array[..., 2, 0], array[..., 2, 1], array[..., 2, 2] + ) + translation = vector.Vec3Array( + array[..., 0, 3], array[..., 1, 3], array[..., 2, 3] + ) + return cls(rotation, translation) + + def cuda(self) -> Rigid3Array: + return Rigid3Array.from_tensor_4x4(self.to_tensor_4x4().cuda()) diff --git a/dockformer/utils/geometry/rotation_matrix.py b/dockformer/utils/geometry/rotation_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..76d1aa2700eec90d3512c18628b2e1b917f13975 --- /dev/null +++ b/dockformer/utils/geometry/rotation_matrix.py @@ -0,0 +1,208 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rot3Array Matrix Class.""" + +from __future__ import annotations +import dataclasses +from typing import List + +import torch + +from dockformer.utils.geometry import utils +from dockformer.utils.geometry import vector +from dockformer.utils.tensor_utils import tensor_tree_map + + +COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz'] + +@dataclasses.dataclass(frozen=True) +class Rot3Array: + """Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays.""" + xx: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32}) + xy: torch.Tensor + xz: torch.Tensor + yx: torch.Tensor + yy: torch.Tensor + yz: torch.Tensor + zx: torch.Tensor + zy: torch.Tensor + zz: torch.Tensor + + __array_ufunc__ = None + + def __getitem__(self, index): + field_names = utils.get_field_names(Rot3Array) + return Rot3Array( + **{ + name: getattr(self, name)[index] + for name in field_names + } + ) + + def __mul__(self, other: torch.Tensor): + field_names = utils.get_field_names(Rot3Array) + return Rot3Array( + **{ + name: getattr(self, name) * other + for name in field_names + } + ) + + def __matmul__(self, other: Rot3Array) -> Rot3Array: + """Composes two Rot3Arrays.""" + c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx)) + c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy)) + c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz)) + return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z) + + def map_tensor_fn(self, fn) -> Rot3Array: + field_names = utils.get_field_names(Rot3Array) + return Rot3Array( + **{ + name: fn(getattr(self, name)) + for name in field_names + } + ) + + def inverse(self) -> Rot3Array: + """Returns inverse of Rot3Array.""" + return Rot3Array( + self.xx, self.yx, self.zx, + self.xy, self.yy, self.zy, + self.xz, self.yz, self.zz + ) + + def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Applies Rot3Array to point.""" + return vector.Vec3Array( + self.xx * point.x + self.xy * point.y + self.xz * point.z, + self.yx * point.x + self.yy * point.y + self.yz * point.z, + self.zx * point.x + self.zy * point.y + self.zz * point.z + ) + + def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Applies inverse Rot3Array to point.""" + return self.inverse().apply_to_point(point) + + + def unsqueeze(self, dim: int): + return Rot3Array( + *tensor_tree_map( + lambda t: t.unsqueeze(dim), + [getattr(self, c) for c in COMPONENTS] + ) + ) + + def stop_gradient(self) -> Rot3Array: + return Rot3Array( + *[getattr(self, c).detach() for c in COMPONENTS] + ) + + @classmethod + def identity(cls, shape, device) -> Rot3Array: + """Returns identity of given shape.""" + ones = torch.ones(shape, dtype=torch.float32, device=device) + zeros = torch.zeros(shape, dtype=torch.float32, device=device) + return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones) + + @classmethod + def from_two_vectors( + cls, e0: vector.Vec3Array, + e1: vector.Vec3Array + ) -> Rot3Array: + """Construct Rot3Array from two Vectors. + + Rot3Array is constructed such that in the corresponding frame 'e0' lies on + the positive x-Axis and 'e1' lies in the xy plane with positive sign of y. + + Args: + e0: Vector + e1: Vector + Returns: + Rot3Array + """ + # Normalize the unit vector for the x-axis, e0. + e0 = e0.normalized() + # make e1 perpendicular to e0. + c = e1.dot(e0) + e1 = (e1 - c * e0).normalized() + # Compute e2 as cross product of e0 and e1. + e2 = e0.cross(e1) + return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) + + @classmethod + def from_array(cls, array: torch.Tensor) -> Rot3Array: + """Construct Rot3Array Matrix from array of shape. [..., 3, 3].""" + rows = torch.unbind(array, dim=-2) + rc = [torch.unbind(e, dim=-1) for e in rows] + return cls(*[e for row in rc for e in row]) + + def to_tensor(self) -> torch.Tensor: + """Convert Rot3Array to array of shape [..., 3, 3].""" + return torch.stack( + [ + torch.stack([self.xx, self.xy, self.xz], dim=-1), + torch.stack([self.yx, self.yy, self.yz], dim=-1), + torch.stack([self.zx, self.zy, self.zz], dim=-1) + ], + dim=-2 + ) + + @classmethod + def from_quaternion(cls, + w: torch.Tensor, + x: torch.Tensor, + y: torch.Tensor, + z: torch.Tensor, + normalize: bool = True, + eps: float = 1e-6 + ) -> Rot3Array: + """Construct Rot3Array from components of quaternion.""" + if normalize: + inv_norm = torch.rsqrt(torch.clamp(w**2 + x**2 + y**2 + z**2, min=eps)) + w = w * inv_norm + x = x * inv_norm + y = y * inv_norm + z = z * inv_norm + xx = 1.0 - 2.0 * (y ** 2 + z ** 2) + xy = 2.0 * (x * y - w * z) + xz = 2.0 * (x * z + w * y) + yx = 2.0 * (x * y + w * z) + yy = 1.0 - 2.0 * (x ** 2 + z ** 2) + yz = 2.0 * (y * z - w * x) + zx = 2.0 * (x * z - w * y) + zy = 2.0 * (y * z + w * x) + zz = 1.0 - 2.0 * (x ** 2 + y ** 2) + return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) + + def reshape(self, new_shape): + field_names = utils.get_field_names(Rot3Array) + reshape_fn = lambda t: t.reshape(new_shape) + return Rot3Array( + **{ + name: reshape_fn(getattr(self, name)) + for name in field_names + } + ) + + @classmethod + def cat(cls, rots: List[Rot3Array], dim: int) -> Rot3Array: + field_names = utils.get_field_names(Rot3Array) + cat_fn = lambda l: torch.cat(l, dim=dim) + return cls( + **{ + name: cat_fn([getattr(r, name) for r in rots]) + for name in field_names + } + ) diff --git a/dockformer/utils/geometry/test_utils.py b/dockformer/utils/geometry/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c6d61cb70278afd3eb42a73ed1d34b6b7c768d80 --- /dev/null +++ b/dockformer/utils/geometry/test_utils.py @@ -0,0 +1,97 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared utils for tests.""" + +import dataclasses +import torch + +from dockformer.utils.geometry import rigid_matrix_vector +from dockformer.utils.geometry import rotation_matrix +from dockformer.utils.geometry import vector + + +def assert_rotation_matrix_equal(matrix1: rotation_matrix.Rot3Array, + matrix2: rotation_matrix.Rot3Array): + for field in dataclasses.fields(rotation_matrix.Rot3Array): + field = field.name + assert torch.equal( + getattr(matrix1, field), getattr(matrix2, field)) + + +def assert_rotation_matrix_close(mat1: rotation_matrix.Rot3Array, + mat2: rotation_matrix.Rot3Array): + assert torch.allclose(mat1.to_tensor(), mat2.to_tensor(), atol=1e-6) + + +def assert_array_equal_to_rotation_matrix(array: torch.Tensor, + matrix: rotation_matrix.Rot3Array): + """Check that array and Matrix match.""" + assert torch.equal(matrix.xx, array[..., 0, 0]) + assert torch.equal(matrix.xy, array[..., 0, 1]) + assert torch.equal(matrix.xz, array[..., 0, 2]) + assert torch.equal(matrix.yx, array[..., 1, 0]) + assert torch.equal(matrix.yy, array[..., 1, 1]) + assert torch.equal(matrix.yz, array[..., 1, 2]) + assert torch.equal(matrix.zx, array[..., 2, 0]) + assert torch.equal(matrix.zy, array[..., 2, 1]) + assert torch.equal(matrix.zz, array[..., 2, 2]) + + +def assert_array_close_to_rotation_matrix(array: torch.Tensor, + matrix: rotation_matrix.Rot3Array): + assert torch.allclose(matrix.to_tensor(), array, atol=1e-6) + + +def assert_vectors_equal(vec1: vector.Vec3Array, vec2: vector.Vec3Array): + assert torch.equal(vec1.x, vec2.x) + assert torch.equal(vec1.y, vec2.y) + assert torch.equal(vec1.z, vec2.z) + + +def assert_vectors_close(vec1: vector.Vec3Array, vec2: vector.Vec3Array): + assert torch.allclose(vec1.x, vec2.x, atol=1e-6, rtol=0.) + assert torch.allclose(vec1.y, vec2.y, atol=1e-6, rtol=0.) + assert torch.allclose(vec1.z, vec2.z, atol=1e-6, rtol=0.) + + +def assert_array_close_to_vector(array: torch.Tensor, vec: vector.Vec3Array): + assert torch.allclose(vec.to_tensor(), array, atol=1e-6, rtol=0.) + + +def assert_array_equal_to_vector(array: torch.Tensor, vec: vector.Vec3Array): + assert torch.equal(vec.to_tensor(), array) + + +def assert_rigid_equal_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array, + rigid2: rigid_matrix_vector.Rigid3Array): + assert_rot_trans_equal_to_rigid(rigid1.rotation, rigid1.translation, rigid2) + + +def assert_rigid_close_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array, + rigid2: rigid_matrix_vector.Rigid3Array): + assert_rot_trans_close_to_rigid(rigid1.rotation, rigid1.translation, rigid2) + + +def assert_rot_trans_equal_to_rigid(rot: rotation_matrix.Rot3Array, + trans: vector.Vec3Array, + rigid: rigid_matrix_vector.Rigid3Array): + assert_rotation_matrix_equal(rot, rigid.rotation) + assert_vectors_equal(trans, rigid.translation) + + +def assert_rot_trans_close_to_rigid(rot: rotation_matrix.Rot3Array, + trans: vector.Vec3Array, + rigid: rigid_matrix_vector.Rigid3Array): + assert_rotation_matrix_close(rot, rigid.rotation) + assert_vectors_close(trans, rigid.translation) diff --git a/dockformer/utils/geometry/utils.py b/dockformer/utils/geometry/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6c4d52ba9969b50fa2945720b30effb8d35cc683 --- /dev/null +++ b/dockformer/utils/geometry/utils.py @@ -0,0 +1,22 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utils for geometry library.""" + +import dataclasses + + +def get_field_names(cls): + fields = dataclasses.fields(cls) + field_names = [f.name for f in fields] + return field_names diff --git a/dockformer/utils/geometry/vector.py b/dockformer/utils/geometry/vector.py new file mode 100644 index 0000000000000000000000000000000000000000..c25fad8fdaf388110cb5e40f7f3e3082ac554884 --- /dev/null +++ b/dockformer/utils/geometry/vector.py @@ -0,0 +1,261 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Vec3Array Class.""" + +from __future__ import annotations +import dataclasses +from typing import Union, List + +import torch + +Float = Union[float, torch.Tensor] + +@dataclasses.dataclass(frozen=True) +class Vec3Array: + x: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32}) + y: torch.Tensor + z: torch.Tensor + + def __post_init__(self): + if hasattr(self.x, 'dtype'): + assert self.x.dtype == self.y.dtype + assert self.x.dtype == self.z.dtype + assert all([x == y for x, y in zip(self.x.shape, self.y.shape)]) + assert all([x == z for x, z in zip(self.x.shape, self.z.shape)]) + + def __add__(self, other: Vec3Array) -> Vec3Array: + return Vec3Array( + self.x + other.x, + self.y + other.y, + self.z + other.z, + ) + + def __sub__(self, other: Vec3Array) -> Vec3Array: + return Vec3Array( + self.x - other.x, + self.y - other.y, + self.z - other.z, + ) + + def __mul__(self, other: Float) -> Vec3Array: + return Vec3Array( + self.x * other, + self.y * other, + self.z * other, + ) + + def __rmul__(self, other: Float) -> Vec3Array: + return self * other + + def __truediv__(self, other: Float) -> Vec3Array: + return Vec3Array( + self.x / other, + self.y / other, + self.z / other, + ) + + def __neg__(self) -> Vec3Array: + return self * -1 + + def __pos__(self) -> Vec3Array: + return self * 1 + + def __getitem__(self, index) -> Vec3Array: + return Vec3Array( + self.x[index], + self.y[index], + self.z[index], + ) + + def __iter__(self): + return iter((self.x, self.y, self.z)) + + @property + def shape(self): + return self.x.shape + + def map_tensor_fn(self, fn) -> Vec3Array: + return Vec3Array( + fn(self.x), + fn(self.y), + fn(self.z), + ) + + def cross(self, other: Vec3Array) -> Vec3Array: + """Compute cross product between 'self' and 'other'.""" + new_x = self.y * other.z - self.z * other.y + new_y = self.z * other.x - self.x * other.z + new_z = self.x * other.y - self.y * other.x + return Vec3Array(new_x, new_y, new_z) + + def dot(self, other: Vec3Array) -> Float: + """Compute dot product between 'self' and 'other'.""" + return self.x * other.x + self.y * other.y + self.z * other.z + + def norm(self, epsilon: float = 1e-6) -> Float: + """Compute Norm of Vec3Array, clipped to epsilon.""" + # To avoid NaN on the backward pass, we must use maximum before the sqrt + norm2 = self.dot(self) + if epsilon: + norm2 = torch.clamp(norm2, min=epsilon**2) + return torch.sqrt(norm2) + + def norm2(self): + return self.dot(self) + + def normalized(self, epsilon: float = 1e-6) -> Vec3Array: + """Return unit vector with optional clipping.""" + return self / self.norm(epsilon) + + def clone(self) -> Vec3Array: + return Vec3Array( + self.x.clone(), + self.y.clone(), + self.z.clone(), + ) + + def reshape(self, new_shape) -> Vec3Array: + x = self.x.reshape(new_shape) + y = self.y.reshape(new_shape) + z = self.z.reshape(new_shape) + + return Vec3Array(x, y, z) + + def sum(self, dim: int) -> Vec3Array: + return Vec3Array( + torch.sum(self.x, dim=dim), + torch.sum(self.y, dim=dim), + torch.sum(self.z, dim=dim), + ) + + def unsqueeze(self, dim: int): + return Vec3Array( + self.x.unsqueeze(dim), + self.y.unsqueeze(dim), + self.z.unsqueeze(dim), + ) + + @classmethod + def zeros(cls, shape, device="cpu"): + """Return Vec3Array corresponding to zeros of given shape.""" + return cls( + torch.zeros(shape, dtype=torch.float32, device=device), + torch.zeros(shape, dtype=torch.float32, device=device), + torch.zeros(shape, dtype=torch.float32, device=device) + ) + + def to_tensor(self) -> torch.Tensor: + return torch.stack([self.x, self.y, self.z], dim=-1) + + @classmethod + def from_array(cls, tensor): + return cls(*torch.unbind(tensor, dim=-1)) + + @classmethod + def cat(cls, vecs: List[Vec3Array], dim: int) -> Vec3Array: + return cls( + torch.cat([v.x for v in vecs], dim=dim), + torch.cat([v.y for v in vecs], dim=dim), + torch.cat([v.z for v in vecs], dim=dim), + ) + + +def square_euclidean_distance( + vec1: Vec3Array, + vec2: Vec3Array, + epsilon: float = 1e-6 +) -> Float: + """Computes square of euclidean distance between 'vec1' and 'vec2'. + + Args: + vec1: Vec3Array to compute distance to + vec2: Vec3Array to compute distance from, should be + broadcast compatible with 'vec1' + epsilon: distance is clipped from below to be at least epsilon + + Returns: + Array of square euclidean distances; + shape will be result of broadcasting 'vec1' and 'vec2' + """ + difference = vec1 - vec2 + distance = difference.dot(difference) + if epsilon: + distance = torch.clamp(distance, min=epsilon) + return distance + + +def dot(vector1: Vec3Array, vector2: Vec3Array) -> Float: + return vector1.dot(vector2) + + +def cross(vector1: Vec3Array, vector2: Vec3Array) -> Float: + return vector1.cross(vector2) + + +def norm(vector: Vec3Array, epsilon: float = 1e-6) -> Float: + return vector.norm(epsilon) + + +def normalized(vector: Vec3Array, epsilon: float = 1e-6) -> Vec3Array: + return vector.normalized(epsilon) + + +def euclidean_distance( + vec1: Vec3Array, + vec2: Vec3Array, + epsilon: float = 1e-6 +) -> Float: + """Computes euclidean distance between 'vec1' and 'vec2'. + + Args: + vec1: Vec3Array to compute euclidean distance to + vec2: Vec3Array to compute euclidean distance from, should be + broadcast compatible with 'vec1' + epsilon: distance is clipped from below to be at least epsilon + + Returns: + Array of euclidean distances; + shape will be result of broadcasting 'vec1' and 'vec2' + """ + distance_sq = square_euclidean_distance(vec1, vec2, epsilon**2) + distance = torch.sqrt(distance_sq) + return distance + + +def dihedral_angle(a: Vec3Array, b: Vec3Array, c: Vec3Array, + d: Vec3Array) -> Float: + """Computes torsion angle for a quadruple of points. + + For points (a, b, c, d), this is the angle between the planes defined by + points (a, b, c) and (b, c, d). It is also known as the dihedral angle. + + Arguments: + a: A Vec3Array of coordinates. + b: A Vec3Array of coordinates. + c: A Vec3Array of coordinates. + d: A Vec3Array of coordinates. + + Returns: + A tensor of angles in radians: [-pi, pi]. + """ + v1 = a - b + v2 = b - c + v3 = d - c + + c1 = v1.cross(v2) + c2 = v3.cross(v2) + c3 = c2.cross(c1) + + v2_mag = v2.norm() + return torch.atan2(c3.dot(v2), v2_mag * c1.dot(c2)) diff --git a/dockformer/utils/kernel/__init__.py b/dockformer/utils/kernel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dockformer/utils/kernel/attention_core.py b/dockformer/utils/kernel/attention_core.py new file mode 100644 index 0000000000000000000000000000000000000000..0577479d274c631ceb79856084ae99376e99e9e0 --- /dev/null +++ b/dockformer/utils/kernel/attention_core.py @@ -0,0 +1,107 @@ +# Copyright 2021 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +from functools import reduce +from operator import mul + +import torch + +# TODO bshor: solve attn_core_is_installed in mac +attn_core_is_installed = importlib.util.find_spec("attn_core_inplace_cuda") is not None +attn_core_inplace_cuda = None +if attn_core_is_installed: + attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda") + + +SUPPORTED_DTYPES = [torch.float32, torch.bfloat16] + + +class AttentionCoreFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, bias_1=None, bias_2=None): + if(bias_1 is None and bias_2 is not None): + raise ValueError("bias_1 must be specified before bias_2") + if(q.dtype not in SUPPORTED_DTYPES): + raise ValueError("Unsupported datatype") + + q = q.contiguous() + k = k.contiguous() + + # [*, H, Q, K] + attention_logits = torch.matmul( + q, k.transpose(-1, -2), + ) + + if(bias_1 is not None): + attention_logits += bias_1 + if(bias_2 is not None): + attention_logits += bias_2 + + attn_core_inplace_cuda.forward_( + attention_logits, + reduce(mul, attention_logits.shape[:-1]), + attention_logits.shape[-1], + ) + + o = torch.matmul(attention_logits, v) + + ctx.bias_1_shape = bias_1.shape if bias_1 is not None else None + ctx.bias_2_shape = bias_2.shape if bias_2 is not None else None + ctx.save_for_backward(q, k, v, attention_logits) + + return o + + @staticmethod + def backward(ctx, grad_output): + q, k, v, attention_logits = ctx.saved_tensors + grad_q = grad_k = grad_v = grad_bias_1 = grad_bias_2 = None + + grad_v = torch.matmul( + attention_logits.transpose(-1, -2), + grad_output + ) + + attn_core_inplace_cuda.backward_( + attention_logits, + grad_output.contiguous(), + v.contiguous(), # v is implicitly transposed in the kernel + reduce(mul, attention_logits.shape[:-1]), + attention_logits.shape[-1], + grad_output.shape[-1], + ) + + if(ctx.bias_1_shape is not None): + grad_bias_1 = torch.sum( + attention_logits, + dim=tuple(i for i,d in enumerate(ctx.bias_1_shape) if d == 1), + keepdim=True, + ) + + if(ctx.bias_2_shape is not None): + grad_bias_2 = torch.sum( + attention_logits, + dim=tuple(i for i,d in enumerate(ctx.bias_2_shape) if d == 1), + keepdim=True, + ) + + grad_q = torch.matmul( + attention_logits, k + ) + grad_k = torch.matmul( + q.transpose(-1, -2), attention_logits, + ).transpose(-1, -2) + + return grad_q, grad_k, grad_v, grad_bias_1, grad_bias_2 + +attention_core = AttentionCoreFunction.apply diff --git a/dockformer/utils/kernel/csrc/compat.h b/dockformer/utils/kernel/csrc/compat.h new file mode 100644 index 0000000000000000000000000000000000000000..bfab6aa5234de9b77669b3aa23b0cac0125b1b27 --- /dev/null +++ b/dockformer/utils/kernel/csrc/compat.h @@ -0,0 +1,11 @@ +// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h + +#ifndef TORCH_CHECK +#define TORCH_CHECK AT_CHECK +#endif + +#ifdef VERSION_GE_1_3 +#define DATA_PTR data_ptr +#else +#define DATA_PTR data +#endif diff --git a/dockformer/utils/kernel/csrc/softmax_cuda.cpp b/dockformer/utils/kernel/csrc/softmax_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f31eeec6154ee22b3d03ce86d28dbb2be55fa0f9 --- /dev/null +++ b/dockformer/utils/kernel/csrc/softmax_cuda.cpp @@ -0,0 +1,44 @@ +// Copyright 2021 AlQuraishi Laboratory +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp + +#include + +void attn_softmax_inplace_forward_( + at::Tensor input, + long long rows, int cols +); +void attn_softmax_inplace_backward_( + at::Tensor output, + at::Tensor d_ov, + at::Tensor values, + long long rows, + int cols_output, + int cols_values +); + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "forward_", + &attn_softmax_inplace_forward_, + "Softmax forward (CUDA)" + ); + m.def( + "backward_", + &attn_softmax_inplace_backward_, + "Softmax backward (CUDA)" + ); +} diff --git a/dockformer/utils/kernel/csrc/softmax_cuda_kernel.cu b/dockformer/utils/kernel/csrc/softmax_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..985093654d90180578cdd25108ef9049a0c8a5dd --- /dev/null +++ b/dockformer/utils/kernel/csrc/softmax_cuda_kernel.cu @@ -0,0 +1,241 @@ +// Copyright 2021 AlQuraishi Laboratory +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu + +#include +#include +#include + +#include + +#include "ATen/ATen.h" +#include "ATen/cuda/CUDAContext.h" +#include "compat.h" + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +__inline__ __device__ float WarpAllReduceMax(float val) { + for (int mask = 1; mask < 32; mask *= 2) { + val = max(val, __shfl_xor_sync(0xffffffff, val, mask)); + } + return val; +} + +__inline__ __device__ float WarpAllReduceSum(float val) { + for (int mask = 1; mask < 32; mask *= 2) { + val += __shfl_xor_sync(0xffffffff, val, mask); + } + return val; +} + + +template +__global__ void attn_softmax_inplace_( + T *input, + long long rows, int cols +) { + int threadidx_x = threadIdx.x / 32; + int threadidx_y = threadIdx.x % 32; + long long row_offset = (long long)(blockIdx.x * 4 + threadidx_x); + int cols_per_thread = (cols + 31) / 32; + int cols_this_thread = cols_per_thread; + + int last_y = (cols / cols_per_thread); + + if (threadidx_y == last_y) { + cols_this_thread = cols - cols_per_thread * last_y; + } + else if (threadidx_y > last_y) { + cols_this_thread = 0; + } + + float buf[32]; + + int lane_id = threadidx_y; + + if (row_offset < rows) { + T *row_input = input + row_offset * cols; + T *row_output = row_input; + + #pragma unroll + for (int i = 0; i < cols_this_thread; i++) { + int idx = lane_id * cols_per_thread + i; + buf[i] = static_cast(row_input[idx]); + } + + float thread_max = -1 * CUDART_INF_F; + #pragma unroll + for (int i = 0; i < cols_this_thread; i++) { + thread_max = max(thread_max, buf[i]); + } + + float warp_max = WarpAllReduceMax(thread_max); + + float thread_sum = 0.f; + #pragma unroll + for (int i = 0; i < cols_this_thread; i++) { + buf[i] = __expf(buf[i] - warp_max); + thread_sum += buf[i]; + } + + float warp_sum = WarpAllReduceSum(thread_sum); + #pragma unroll + for (int i = 0; i < cols_this_thread; i++) { + row_output[lane_id * cols_per_thread + i] = + static_cast(__fdividef(buf[i], warp_sum)); + } + } +} + + +void attn_softmax_inplace_forward_( + at::Tensor input, + long long rows, int cols +) { + CHECK_INPUT(input); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + + int grid = (rows + 3) / 4; + dim3 block(128); + + if (input.dtype() == torch::kFloat32) { + attn_softmax_inplace_<<>>( + (float *)input.data_ptr(), + rows, cols + ); + } + else { + attn_softmax_inplace_<<>>( + (at::BFloat16 *)input.data_ptr(), + rows, cols + ); + } +} + + +template +__global__ void attn_softmax_inplace_grad_( + T *output, + T *d_ov, + T *values, + long long rows, + int cols_output, + int cols_values +) { + int threadidx_x = threadIdx.x / 32; + int threadidx_y = threadIdx.x % 32; + long long row_offset = (long long)(blockIdx.x * 4 + threadidx_x); + int cols_per_thread = (cols_output + 31) / 32; + int cols_this_thread = cols_per_thread; + int rows_values = cols_output; + // values are set to the beginning of the current + // rows_values x cols_values leaf matrix + long long value_row_offset = row_offset - row_offset % rows_values; + int last_y = (cols_output / cols_per_thread); + + if (threadidx_y == last_y) { + cols_this_thread = cols_output - cols_per_thread * last_y; + } + else if (threadidx_y > last_y) { + cols_this_thread = 0; + } + + float y_buf[32]; + float dy_buf[32]; + + int lane_id = threadidx_y; + + if (row_offset < rows) { + T *row_output = output + row_offset * cols_output; + T *row_d_ov = d_ov + row_offset * cols_values; + T *row_values = values + value_row_offset * cols_values; + + float thread_max = -1 * CUDART_INF_F; + + // Compute a chunk of the output gradient on the fly + int value_row_idx = 0; + int value_idx = 0; + #pragma unroll + for (int i = 0; i < cols_this_thread; i++) { + T sum = 0.; + #pragma unroll + for (int j = 0; j < cols_values; j++) { + value_row_idx = ((lane_id * cols_per_thread) + i); + value_idx = value_row_idx * cols_values + j; + sum += row_d_ov[j] * row_values[value_idx]; + } + dy_buf[i] = static_cast(sum); + } + + #pragma unroll + for (int i = 0; i < cols_this_thread; i++) { + y_buf[i] = static_cast(row_output[lane_id * cols_per_thread + i]); + } + + float thread_sum = 0.; + + #pragma unroll + for (int i = 0; i < cols_this_thread; i++) { + thread_sum += y_buf[i] * dy_buf[i]; + } + + float warp_sum = WarpAllReduceSum(thread_sum); + + #pragma unroll + for (int i = 0; i < cols_this_thread; i++) { + row_output[lane_id * cols_per_thread + i] = static_cast( + (dy_buf[i] - warp_sum) * y_buf[i] + ); + } + } +} + + +void attn_softmax_inplace_backward_( + at::Tensor output, + at::Tensor d_ov, + at::Tensor values, + long long rows, + int cols_output, + int cols_values +) { + CHECK_INPUT(output); + CHECK_INPUT(d_ov); + CHECK_INPUT(values); + const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); + + int grid = (rows + 3) / 4; + dim3 block(128); + + if (output.dtype() == torch::kFloat32) { + attn_softmax_inplace_grad_<<>>( + (float *)output.data_ptr(), + (float *)d_ov.data_ptr(), + (float *)values.data_ptr(), + rows, cols_output, cols_values + ); + } else { + attn_softmax_inplace_grad_<<>>( + (at::BFloat16 *)output.data_ptr(), + (at::BFloat16 *)d_ov.data_ptr(), + (at::BFloat16 *)values.data_ptr(), + rows, cols_output, cols_values + ); + } +} diff --git a/dockformer/utils/kernel/csrc/softmax_cuda_stub.cpp b/dockformer/utils/kernel/csrc/softmax_cuda_stub.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4539c19fb617ea56a86e4b08efec02b8ea490ff9 --- /dev/null +++ b/dockformer/utils/kernel/csrc/softmax_cuda_stub.cpp @@ -0,0 +1,36 @@ +// Copyright 2021 AlQuraishi Laboratory +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp + +#include + +void attn_softmax_inplace_forward_( + at::Tensor input, + long long rows, int cols +) +{ + throw std::runtime_error("attn_softmax_inplace_forward_ not implemented on CPU"); +}; +void attn_softmax_inplace_backward_( + at::Tensor output, + at::Tensor d_ov, + at::Tensor values, + long long rows, + int cols_output, + int cols_values +) +{ + throw std::runtime_error("attn_softmax_inplace_backward_ not implemented on CPU"); +}; \ No newline at end of file diff --git a/dockformer/utils/logger.py b/dockformer/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..ec3d3b7552930b68efcc96f4ce8392dbeb3da0f7 --- /dev/null +++ b/dockformer/utils/logger.py @@ -0,0 +1,80 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import operator +import time + +import dllogger as logger +from dllogger import JSONStreamBackend, StdOutBackend, Verbosity +import numpy as np +from lightning import Callback +import torch.cuda.profiler as profiler + + +def is_main_process(): + return int(os.getenv("LOCAL_RANK", "0")) == 0 + + +class PerformanceLoggingCallback(Callback): + def __init__(self, log_file, global_batch_size, warmup_steps: int = 0, profile: bool = False): + logger.init(backends=[JSONStreamBackend(Verbosity.VERBOSE, log_file), StdOutBackend(Verbosity.VERBOSE)]) + self.warmup_steps = warmup_steps + self.global_batch_size = global_batch_size + self.step = 0 + self.profile = profile + self.timestamps = [] + + def do_step(self): + self.step += 1 + if self.profile and self.step == self.warmup_steps: + profiler.start() + if self.step > self.warmup_steps: + self.timestamps.append(time.time()) + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + self.do_step() + + def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx: int = 0): + self.do_step() + + def process_performance_stats(self, deltas): + def _round3(val): + return round(val, 3) + + throughput_imgps = _round3(self.global_batch_size / np.mean(deltas)) + timestamps_ms = 1000 * deltas + stats = { + f"throughput": throughput_imgps, + f"latency_mean": _round3(timestamps_ms.mean()), + } + for level in [90, 95, 99]: + stats.update({f"latency_{level}": _round3(np.percentile(timestamps_ms, level))}) + + return stats + + def _log(self): + if is_main_process(): + diffs = list(map(operator.sub, self.timestamps[1:], self.timestamps[:-1])) + deltas = np.array(diffs) + stats = self.process_performance_stats(deltas) + logger.log(step=(), data=stats) + logger.flush() + + def on_train_end(self, trainer, pl_module): + if self.profile: + profiler.stop() + self._log() + + def on_epoch_end(self, trainer, pl_module): + self._log() diff --git a/dockformer/utils/loss.py b/dockformer/utils/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a9ecf2d1dac3cc04302ac9788fe1aeadfd6c6764 --- /dev/null +++ b/dockformer/utils/loss.py @@ -0,0 +1,1171 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time + +import ml_collections +import numpy as np +import torch +import torch.nn as nn +from typing import Dict, Optional, Tuple + +from dockformer.utils import residue_constants +from dockformer.utils.feats import pseudo_beta_fn +from dockformer.utils.rigid_utils import Rotation, Rigid +from dockformer.utils.geometry.vector import Vec3Array, euclidean_distance +from dockformer.utils.tensor_utils import ( + tree_map, + masked_mean, + permute_final_dims, +) +import logging +from dockformer.utils.tensor_utils import tensor_tree_map + +logger = logging.getLogger(__name__) + + +def softmax_cross_entropy(logits, labels): + loss = -1 * torch.sum( + labels * torch.nn.functional.log_softmax(logits, dim=-1), + dim=-1, + ) + return loss + + +def sigmoid_cross_entropy(logits, labels): + logits_dtype = logits.dtype + try: + logits = logits.double() + labels = labels.double() + except: + logits = logits.to(dtype=torch.float32) + labels = labels.to(dtype=torch.float32) + + log_p = torch.nn.functional.logsigmoid(logits) + # log_p = torch.log(torch.sigmoid(logits)) + log_not_p = torch.nn.functional.logsigmoid(-1 * logits) + # log_not_p = torch.log(torch.sigmoid(-logits)) + loss = (-1. * labels) * log_p - (1. - labels) * log_not_p + loss = loss.to(dtype=logits_dtype) + return loss + + +def torsion_angle_loss( + a, # [*, N, 7, 2] + a_gt, # [*, N, 7, 2] + a_alt_gt, # [*, N, 7, 2] +): + # [*, N, 7] + norm = torch.norm(a, dim=-1) + + # [*, N, 7, 2] + a = a / norm.unsqueeze(-1) + + # [*, N, 7] + diff_norm_gt = torch.norm(a - a_gt, dim=-1) + diff_norm_alt_gt = torch.norm(a - a_alt_gt, dim=-1) + min_diff = torch.minimum(diff_norm_gt ** 2, diff_norm_alt_gt ** 2) + + # [*] + l_torsion = torch.mean(min_diff, dim=(-1, -2)) + l_angle_norm = torch.mean(torch.abs(norm - 1), dim=(-1, -2)) + + an_weight = 0.02 + return l_torsion + an_weight * l_angle_norm + + +def compute_fape( + pred_frames: Rigid, + target_frames: Rigid, + frames_mask: torch.Tensor, + pred_positions: torch.Tensor, + target_positions: torch.Tensor, + positions_mask: torch.Tensor, + length_scale: float, + pair_mask: Optional[torch.Tensor] = None, + l1_clamp_distance: Optional[float] = None, + eps=1e-8, +) -> torch.Tensor: + """ + Computes FAPE loss. + + Args: + pred_frames: + [*, N_frames] Rigid object of predicted frames + target_frames: + [*, N_frames] Rigid object of ground truth frames + frames_mask: + [*, N_frames] binary mask for the frames + pred_positions: + [*, N_pts, 3] predicted atom positions + target_positions: + [*, N_pts, 3] ground truth positions + positions_mask: + [*, N_pts] positions mask + length_scale: + Length scale by which the loss is divided + pair_mask: + [*, N_frames, N_pts] mask to use for + separating intra- from inter-chain losses. + l1_clamp_distance: + Cutoff above which distance errors are disregarded + eps: + Small value used to regularize denominators + Returns: + [*] loss tensor + """ + # [*, N_frames, N_pts, 3] + local_pred_pos = pred_frames.invert()[..., None].apply( + pred_positions[..., None, :, :], + ) + local_target_pos = target_frames.invert()[..., None].apply( + target_positions[..., None, :, :], + ) + + error_dist = torch.sqrt( + torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + eps + ) + + if l1_clamp_distance is not None: + error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance) + + normed_error = error_dist / length_scale + normed_error = normed_error * frames_mask[..., None] + normed_error = normed_error * positions_mask[..., None, :] + + if pair_mask is not None: + normed_error = normed_error * pair_mask + normed_error = torch.sum(normed_error, dim=(-1, -2)) + + mask = frames_mask[..., None] * positions_mask[..., None, :] * pair_mask + norm_factor = torch.sum(mask, dim=(-2, -1)) + + normed_error = normed_error / (eps + norm_factor) + else: + # FP16-friendly averaging. Roughly equivalent to: + # + # norm_factor = ( + # torch.sum(frames_mask, dim=-1) * + # torch.sum(positions_mask, dim=-1) + # ) + # normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor) + # + # ("roughly" because eps is necessarily duplicated in the latter) + normed_error = torch.sum(normed_error, dim=-1) + normed_error = ( + normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None] + ) + normed_error = torch.sum(normed_error, dim=-1) + normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1)) + + return normed_error + + +def backbone_loss( + backbone_rigid_tensor: torch.Tensor, + backbone_rigid_mask: torch.Tensor, + traj: torch.Tensor, + pair_mask: Optional[torch.Tensor] = None, + use_clamped_fape: Optional[torch.Tensor] = None, + clamp_distance: float = 10.0, + loss_unit_distance: float = 10.0, + eps: float = 1e-4, + **kwargs, +) -> torch.Tensor: + ### need to check if the traj belongs to 4*4 matrix or a tensor_7 + if traj.shape[-1] == 7: + pred_aff = Rigid.from_tensor_7(traj) + elif traj.shape[-1] == 4: + pred_aff = Rigid.from_tensor_4x4(traj) + + pred_aff = Rigid( + Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None), + pred_aff.get_trans(), + ) + + # DISCREPANCY: DeepMind somehow gets a hold of a tensor_7 version of + # backbone tensor, normalizes it, and then turns it back to a rotation + # matrix. To avoid a potentially numerically unstable rotation matrix + # to quaternion conversion, we just use the original rotation matrix + # outright. This one hasn't been composed a bunch of times, though, so + # it might be fine. + gt_aff = Rigid.from_tensor_4x4(backbone_rigid_tensor) + + fape_loss = compute_fape( + pred_aff, + gt_aff[None], + backbone_rigid_mask[None], + pred_aff.get_trans(), + gt_aff[None].get_trans(), + backbone_rigid_mask[None], + pair_mask=pair_mask, + l1_clamp_distance=clamp_distance, + length_scale=loss_unit_distance, + eps=eps, + ) + if use_clamped_fape is not None: + unclamped_fape_loss = compute_fape( + pred_aff, + gt_aff[None], + backbone_rigid_mask[None], + pred_aff.get_trans(), + gt_aff[None].get_trans(), + backbone_rigid_mask[None], + pair_mask=pair_mask, + l1_clamp_distance=None, + length_scale=loss_unit_distance, + eps=eps, + ) + + fape_loss = fape_loss * use_clamped_fape + unclamped_fape_loss * ( + 1 - use_clamped_fape + ) + + # Average over the batch dimension + fape_loss = torch.mean(fape_loss) + + return fape_loss + + +def sidechain_loss( + pred_sidechain_frames: torch.Tensor, + pred_sidechain_atom_pos: torch.Tensor, + rigidgroups_gt_frames: torch.Tensor, + rigidgroups_alt_gt_frames: torch.Tensor, + rigidgroups_gt_exists: torch.Tensor, + renamed_atom14_gt_positions: torch.Tensor, + renamed_atom14_gt_exists: torch.Tensor, + alt_naming_is_better: torch.Tensor, + ligand_mask: torch.Tensor, + clamp_distance: float = 10.0, + length_scale: float = 10.0, + eps: float = 1e-4, + only_include_ligand_atoms: bool = False, + **kwargs, +) -> torch.Tensor: + renamed_gt_frames = ( + 1.0 - alt_naming_is_better[..., None, None, None] + ) * rigidgroups_gt_frames + alt_naming_is_better[ + ..., None, None, None + ] * rigidgroups_alt_gt_frames + + # Steamroll the inputs + pred_sidechain_frames = pred_sidechain_frames[-1] # get only the last layer of the strcuture module + batch_dims = pred_sidechain_frames.shape[:-4] + pred_sidechain_frames = pred_sidechain_frames.view(*batch_dims, -1, 4, 4) + pred_sidechain_frames = Rigid.from_tensor_4x4(pred_sidechain_frames) + renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4) + renamed_gt_frames = Rigid.from_tensor_4x4(renamed_gt_frames) + rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1) + pred_sidechain_atom_pos = pred_sidechain_atom_pos[-1] + pred_sidechain_atom_pos = pred_sidechain_atom_pos.view(*batch_dims, -1, 3) + renamed_atom14_gt_positions = renamed_atom14_gt_positions.view( + *batch_dims, -1, 3 + ) + renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(*batch_dims, -1) + + atom_mask_to_apply = renamed_atom14_gt_exists + if only_include_ligand_atoms: + ligand_atom14_mask = torch.repeat_interleave(ligand_mask, 14, dim=-1) + atom_mask_to_apply = atom_mask_to_apply * ligand_atom14_mask + + fape = compute_fape( + pred_sidechain_frames, + renamed_gt_frames, + rigidgroups_gt_exists, + pred_sidechain_atom_pos, + renamed_atom14_gt_positions, + atom_mask_to_apply, + pair_mask=None, + l1_clamp_distance=clamp_distance, + length_scale=length_scale, + eps=eps, + ) + + return fape + + +def fape_bb( + out: Dict[str, torch.Tensor], + batch: Dict[str, torch.Tensor], + config: ml_collections.ConfigDict, +) -> torch.Tensor: + traj = out["sm"]["frames"] + bb_loss = backbone_loss( + traj=traj, + **{**batch, **config}, + ) + loss = torch.mean(bb_loss) + return loss + + +def fape_sidechain( + out: Dict[str, torch.Tensor], + batch: Dict[str, torch.Tensor], + config: ml_collections.ConfigDict, +) -> torch.Tensor: + sc_loss = sidechain_loss( + out["sm"]["sidechain_frames"], + out["sm"]["positions"], + **{**batch, **config}, + ) + loss = torch.mean(sc_loss) + return loss + + +def fape_interface( + out: Dict[str, torch.Tensor], + batch: Dict[str, torch.Tensor], + config: ml_collections.ConfigDict, +) -> torch.Tensor: + sc_loss = sidechain_loss( + out["sm"]["sidechain_frames"], + out["sm"]["positions"], + only_include_ligand_atoms=True, + **{**batch, **config}, + ) + loss = torch.mean(sc_loss) + return loss + + +def supervised_chi_loss( + angles_sin_cos: torch.Tensor, + unnormalized_angles_sin_cos: torch.Tensor, + aatype: torch.Tensor, + protein_mask: torch.Tensor, + chi_mask: torch.Tensor, + chi_angles_sin_cos: torch.Tensor, + chi_weight: float, + angle_norm_weight: float, + eps=1e-6, + **kwargs, +) -> torch.Tensor: + """ + Implements Algorithm 27 (torsionAngleLoss) + + Args: + angles_sin_cos: + [*, N, 7, 2] predicted angles + unnormalized_angles_sin_cos: + The same angles, but unnormalized + aatype: + [*, N] residue indices + protein_mask: + [*, N] protein mask + chi_mask: + [*, N, 7] angle mask + chi_angles_sin_cos: + [*, N, 7, 2] ground truth angles + chi_weight: + Weight for the angle component of the loss + angle_norm_weight: + Weight for the normalization component of the loss + Returns: + [*] loss tensor + """ + pred_angles = angles_sin_cos[..., 3:, :] + residue_type_one_hot = torch.nn.functional.one_hot( + aatype, + residue_constants.restype_num + 1, + ) + chi_pi_periodic = torch.einsum( + "...ij,jk->ik", + residue_type_one_hot.type(angles_sin_cos.dtype), + angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic), + ) + + true_chi = chi_angles_sin_cos[None] + + shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1) + true_chi_shifted = shifted_mask * true_chi + sq_chi_error = torch.sum((true_chi - pred_angles) ** 2, dim=-1) + sq_chi_error_shifted = torch.sum( + (true_chi_shifted - pred_angles) ** 2, dim=-1 + ) + sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted) + + # The ol' switcheroo + sq_chi_error = sq_chi_error.permute( + *range(len(sq_chi_error.shape))[1:-2], 0, -2, -1 + ) + + sq_chi_loss = masked_mean( + chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3) + ) + + loss = chi_weight * sq_chi_loss + + angle_norm = torch.sqrt( + torch.sum(unnormalized_angles_sin_cos ** 2, dim=-1) + eps + ) + norm_error = torch.abs(angle_norm - 1.0) + norm_error = norm_error.permute( + *range(len(norm_error.shape))[1:-2], 0, -2, -1 + ) + angle_norm_loss = masked_mean( + protein_mask[..., None, :, None], norm_error, dim=(-1, -2, -3) + ) + + loss = loss + angle_norm_weight * angle_norm_loss + + # Average over the batch dimension + loss = torch.mean(loss) + + return loss + + +def compute_plddt(logits: torch.Tensor) -> torch.Tensor: + num_bins = logits.shape[-1] + bin_width = 1.0 / num_bins + bounds = torch.arange( + start=0.5 * bin_width, end=1.0, step=bin_width, device=logits.device + ) + probs = torch.nn.functional.softmax(logits, dim=-1) + pred_lddt_ca = torch.sum( + probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape), + dim=-1, + ) + return pred_lddt_ca * 100 + + +def lddt( + all_atom_pred_pos: torch.Tensor, + all_atom_positions: torch.Tensor, + all_atom_mask: torch.Tensor, + cutoff: float = 15.0, + eps: float = 1e-10, + per_residue: bool = True, +) -> torch.Tensor: + n = all_atom_mask.shape[-2] + dmat_true = torch.sqrt( + eps + + torch.sum( + ( + all_atom_positions[..., None, :] + - all_atom_positions[..., None, :, :] + ) + ** 2, + dim=-1, + ) + ) + + dmat_pred = torch.sqrt( + eps + + torch.sum( + ( + all_atom_pred_pos[..., None, :] + - all_atom_pred_pos[..., None, :, :] + ) + ** 2, + dim=-1, + ) + ) + dists_to_score = ( + (dmat_true < cutoff) + * all_atom_mask + * permute_final_dims(all_atom_mask, (1, 0)) + * (1.0 - torch.eye(n, device=all_atom_mask.device)) + ) + + dist_l1 = torch.abs(dmat_true - dmat_pred) + + score = ( + (dist_l1 < 0.5).type(dist_l1.dtype) + + (dist_l1 < 1.0).type(dist_l1.dtype) + + (dist_l1 < 2.0).type(dist_l1.dtype) + + (dist_l1 < 4.0).type(dist_l1.dtype) + ) + score = score * 0.25 + + dims = (-1,) if per_residue else (-2, -1) + norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims)) + score = norm * (eps + torch.sum(dists_to_score * score, dim=dims)) + + return score + + +def lddt_ca( + all_atom_pred_pos: torch.Tensor, + all_atom_positions: torch.Tensor, + all_atom_mask: torch.Tensor, + cutoff: float = 15.0, + eps: float = 1e-10, + per_residue: bool = True, +) -> torch.Tensor: + ca_pos = residue_constants.atom_order["CA"] + all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] + all_atom_positions = all_atom_positions[..., ca_pos, :] + all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim + + return lddt( + all_atom_pred_pos, + all_atom_positions, + all_atom_mask, + cutoff=cutoff, + eps=eps, + per_residue=per_residue, + ) + + +def lddt_loss( + logits: torch.Tensor, + all_atom_pred_pos: torch.Tensor, + atom37_gt_positions: torch.Tensor, + atom37_atom_exists_in_gt: torch.Tensor, + resolution: torch.Tensor, + cutoff: float = 15.0, + no_bins: int = 50, + min_resolution: float = 0.1, + max_resolution: float = 3.0, + eps: float = 1e-10, + **kwargs, +) -> torch.Tensor: + # remove ligand + logits = logits[:, :atom37_atom_exists_in_gt.shape[1], :] + + ca_pos = residue_constants.atom_order["CA"] + all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] + atom37_gt_positions = atom37_gt_positions[..., ca_pos, :] + atom37_atom_exists_in_gt = atom37_atom_exists_in_gt[..., ca_pos: (ca_pos + 1)] # keep dim + + score = lddt( + all_atom_pred_pos, + atom37_gt_positions, + atom37_atom_exists_in_gt, + cutoff=cutoff, + eps=eps + ) + + # TODO: Remove after initial pipeline testing + score = torch.nan_to_num(score, nan=torch.nanmean(score)) + score[score < 0] = 0 + + score = score.detach() + bin_index = torch.floor(score * no_bins).long() + bin_index = torch.clamp(bin_index, max=(no_bins - 1)) + lddt_ca_one_hot = torch.nn.functional.one_hot( + bin_index, num_classes=no_bins + ) + + errors = softmax_cross_entropy(logits, lddt_ca_one_hot) + atom37_atom_exists_in_gt = atom37_atom_exists_in_gt.squeeze(-1) + loss = torch.sum(errors * atom37_atom_exists_in_gt, dim=-1) / ( + eps + torch.sum(atom37_atom_exists_in_gt, dim=-1) + ) + + loss = loss * ( + (resolution >= min_resolution) & (resolution <= max_resolution) + ) + + # Average over the batch dimension + loss = torch.mean(loss) + + return loss + + +def distogram_loss( + logits, + gt_pseudo_beta_with_lig, + gt_pseudo_beta_with_lig_mask, + min_bin=2.3125, + max_bin=21.6875, + no_bins=64, + eps=1e-6, + **kwargs, +): + boundaries = torch.linspace( + min_bin, + max_bin, + no_bins - 1, + device=logits.device, + ) + boundaries = boundaries ** 2 + + dists = torch.sum( + (gt_pseudo_beta_with_lig[..., None, :] - gt_pseudo_beta_with_lig[..., None, :, :]) ** 2, + dim=-1, + keepdims=True, + ) + + true_bins = torch.sum(dists > boundaries, dim=-1) + errors = softmax_cross_entropy( + logits, + torch.nn.functional.one_hot(true_bins, no_bins), + ) + + square_mask = gt_pseudo_beta_with_lig_mask[..., None] * gt_pseudo_beta_with_lig_mask[..., None, :] + + # FP16-friendly sum. Equivalent to: + # mean = (torch.sum(errors * square_mask, dim=(-1, -2)) / + # (eps + torch.sum(square_mask, dim=(-1, -2)))) + denom = eps + torch.sum(square_mask, dim=(-1, -2)) + mean = errors * square_mask + mean = torch.sum(mean, dim=-1) + mean = mean / denom[..., None] + mean = torch.sum(mean, dim=-1) + + # Average over the batch dimensions + mean = torch.mean(mean) + + return mean + + +def inter_contact_loss( + logits: torch.Tensor, + gt_inter_contacts: torch.Tensor, + inter_pair_mask: torch.Tensor, + pos_class_weight: float = 200.0, + contact_distance: float = 5.0, + **kwargs, +): + logits = logits.squeeze(-1) + bce_loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, gt_inter_contacts, reduction='none', + pos_weight=logits.new_tensor([pos_class_weight])) + masked_loss = bce_loss * inter_pair_mask + final_loss = masked_loss.sum() / inter_pair_mask.sum() + + return final_loss + + +def affinity_loss( + logits, + affinity, + affinity_loss_factor, + min_bin=0, + max_bin=15, + no_bins=32, + **kwargs, +): + boundaries = torch.linspace( + min_bin, + max_bin, + no_bins - 1, + device=logits.device, + ) + + true_bins = torch.sum(affinity > boundaries, dim=-1) + errors = softmax_cross_entropy( + logits, + torch.nn.functional.one_hot(true_bins, no_bins), + ) + + # print("errors dim", errors.shape, affinity_loss_factor.shape, errors) + after_factor = errors * affinity_loss_factor.squeeze() + if affinity_loss_factor.sum() > 0.1: + mean_val = after_factor.sum() / affinity_loss_factor.sum() + else: + # If no affinity in batch - get a very small loss. the factor also makes the loss small + mean_val = after_factor.sum() * 1e-3 + # print("after factor", after_factor.shape, after_factor, affinity_loss_factor.sum(), mean_val) + return mean_val + + +def positions_inter_distogram_loss( + out, + aatype: torch.Tensor, + inter_pair_mask: torch.Tensor, + gt_pseudo_beta_with_lig: torch.Tensor, + max_dist=20., + length_scale=10., + eps: float = 1e-10, + **kwargs, +): + + predicted_atoms = pseudo_beta_fn(aatype, out["final_atom_positions"], None) + pred_dists = torch.sum( + (predicted_atoms[..., None, :] - predicted_atoms[..., None, :, :]) ** 2, + dim=-1, + keepdims=True, + ) + + gt_dists = torch.sum( + (gt_pseudo_beta_with_lig[..., None, :] - gt_pseudo_beta_with_lig[..., None, :, :]) ** 2, + dim=-1, + keepdims=True, + ) + + pred_dists = pred_dists.clamp(max=max_dist ** 2) + gt_dists = gt_dists.clamp(max=max_dist ** 2) + + dists_diff = torch.abs(pred_dists - gt_dists) / (length_scale ** 2) + dists_diff = dists_diff * inter_pair_mask.unsqueeze(-1) + + dists_diff_sum_per_batch = torch.sum(torch.sqrt(eps + dists_diff), dim=(-1, -2, -3)) + mask_size_per_batch = torch.sum(inter_pair_mask, dim=(-1, -2)) + inter_loss = torch.mean(dists_diff_sum_per_batch / (eps + mask_size_per_batch)) + + return inter_loss + + +def positions_intra_ligand_distogram_loss( + out, + aatype: torch.Tensor, + ligand_mask: torch.Tensor, + gt_pseudo_beta_with_lig: torch.Tensor, + max_dist=20., + length_scale=4., # similar to RosettaFoldAA + eps=1e-10, + **kwargs, +): + intra_ligand_pair_mask = ligand_mask[..., None] * ligand_mask[..., None, :] + predicted_atoms = pseudo_beta_fn(aatype, out["final_atom_positions"], None) + pred_dists = torch.sum( + (predicted_atoms[..., None, :] - predicted_atoms[..., None, :, :]) ** 2, + dim=-1, + keepdims=True, + ) + + gt_dists = torch.sum( + (gt_pseudo_beta_with_lig[..., None, :] - gt_pseudo_beta_with_lig[..., None, :, :]) ** 2, + dim=-1, + keepdims=True, + ) + + pred_dists = torch.sqrt(eps + pred_dists.clamp(max=max_dist ** 2)) / length_scale + gt_dists = torch.sqrt(eps + gt_dists.clamp(max=max_dist ** 2)) / length_scale + + # Apply L2 loss + dists_diff = (pred_dists - gt_dists) ** 2 + + dists_diff = dists_diff * intra_ligand_pair_mask.unsqueeze(-1) + + dists_diff_sum_per_batch = torch.sum(dists_diff, dim=(-1, -2, -3)) + mask_size_per_batch = torch.sum(intra_ligand_pair_mask, dim=(-1, -2)) + intra_ligand_loss = torch.mean(dists_diff_sum_per_batch / (eps + mask_size_per_batch)) + + return intra_ligand_loss + + +def _calculate_bin_centers(boundaries: torch.Tensor): + step = boundaries[1] - boundaries[0] + bin_centers = boundaries + step / 2 + bin_centers = torch.cat( + [bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0 + ) + return bin_centers + + +def _calculate_expected_aligned_error( + alignment_confidence_breaks: torch.Tensor, + aligned_distance_error_probs: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + bin_centers = _calculate_bin_centers(alignment_confidence_breaks) + return ( + torch.sum(aligned_distance_error_probs * bin_centers, dim=-1), + bin_centers[-1], + ) + + +def compute_predicted_aligned_error( + logits: torch.Tensor, + max_bin: int = 31, + no_bins: int = 64, + **kwargs, +) -> Dict[str, torch.Tensor]: + """Computes aligned confidence metrics from logits. + + Args: + logits: [*, num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + max_bin: Maximum bin value + no_bins: Number of bins + Returns: + aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted + aligned error probabilities over bins for each residue pair. + predicted_aligned_error: [*, num_res, num_res] the expected aligned distance + error for each pair of residues. + max_predicted_aligned_error: [*] the maximum predicted error possible. + """ + boundaries = torch.linspace( + 0, max_bin, steps=(no_bins - 1), device=logits.device + ) + + aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1) + ( + predicted_aligned_error, + max_predicted_aligned_error, + ) = _calculate_expected_aligned_error( + alignment_confidence_breaks=boundaries, + aligned_distance_error_probs=aligned_confidence_probs, + ) + + return { + "aligned_confidence_probs": aligned_confidence_probs, + "predicted_aligned_error": predicted_aligned_error, + "max_predicted_aligned_error": max_predicted_aligned_error, + } + + +def compute_tm( + logits: torch.Tensor, + residue_weights: Optional[torch.Tensor] = None, + asym_id: Optional[torch.Tensor] = None, + interface: bool = False, + max_bin: int = 31, + no_bins: int = 64, + eps: float = 1e-8, + **kwargs, +) -> torch.Tensor: + if residue_weights is None: + residue_weights = logits.new_ones(logits.shape[-2]) + + boundaries = torch.linspace( + 0, max_bin, steps=(no_bins - 1), device=logits.device + ) + + bin_centers = _calculate_bin_centers(boundaries) + clipped_n = max(torch.sum(residue_weights), 19) + + d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8 + + probs = torch.nn.functional.softmax(logits, dim=-1) + + tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2)) + predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1) + + n = residue_weights.shape[-1] + pair_mask = residue_weights.new_ones((n, n), dtype=torch.int32) + if interface and (asym_id is not None): + if len(asym_id.shape) > 1: + assert len(asym_id.shape) <= 2 + batch_size = asym_id.shape[0] + pair_mask = residue_weights.new_ones((batch_size, n, n), dtype=torch.int32) + pair_mask *= (asym_id[..., None] != asym_id[..., None, :]).to(dtype=pair_mask.dtype) + + predicted_tm_term *= pair_mask + + pair_residue_weights = pair_mask * ( + residue_weights[..., None, :] * residue_weights[..., :, None] + ) + denom = eps + torch.sum(pair_residue_weights, dim=-1, keepdims=True) + normed_residue_mask = pair_residue_weights / denom + per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1) + + weighted = per_alignment * residue_weights + + argmax = (weighted == torch.max(weighted)).nonzero()[0] + return per_alignment[tuple(argmax)] + + +def compute_renamed_ground_truth( + batch: Dict[str, torch.Tensor], + atom14_pred_positions: torch.Tensor, + eps=1e-10, +) -> Dict[str, torch.Tensor]: + """ + Find optimal renaming of ground truth based on the predicted positions. + + Alg. 26 "renameSymmetricGroundTruthAtoms" + + This renamed ground truth is then used for all losses, + such that each loss moves the atoms in the same direction. + + Args: + batch: Dictionary containing: + * atom14_gt_positions: Ground truth positions. + * atom14_alt_gt_positions: Ground truth positions with renaming swaps. + * atom14_atom_is_ambiguous: 1.0 for atoms that are affected by + renaming swaps. + * atom14_gt_exists: Mask for which atoms exist in ground truth. + * atom14_alt_gt_exists: Mask for which atoms exist in ground truth + after renaming. + * atom14_atom_exists: Mask for whether each atom is part of the given + amino acid type. + atom14_pred_positions: Array of atom positions in global frame with shape + Returns: + Dictionary containing: + alt_naming_is_better: Array with 1.0 where alternative swap is better. + renamed_atom14_gt_positions: Array of optimal ground truth positions + after renaming swaps are performed. + renamed_atom14_gt_exists: Mask after renaming swap is performed. + """ + + pred_dists = torch.sqrt( + eps + + torch.sum( + ( + atom14_pred_positions[..., None, :, None, :] + - atom14_pred_positions[..., None, :, None, :, :] + ) + ** 2, + dim=-1, + ) + ) + + atom14_gt_positions = batch["atom14_gt_positions"] + gt_dists = torch.sqrt( + eps + + torch.sum( + ( + atom14_gt_positions[..., None, :, None, :] + - atom14_gt_positions[..., None, :, None, :, :] + ) + ** 2, + dim=-1, + ) + ) + + atom14_alt_gt_positions = batch["atom14_alt_gt_positions"] + alt_gt_dists = torch.sqrt( + eps + + torch.sum( + ( + atom14_alt_gt_positions[..., None, :, None, :] + - atom14_alt_gt_positions[..., None, :, None, :, :] + ) + ** 2, + dim=-1, + ) + ) + + lddt = torch.sqrt(eps + (pred_dists - gt_dists) ** 2) + alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists) ** 2) + + atom14_gt_exists = batch["atom14_atom_exists_in_gt"] + atom14_atom_is_ambiguous = batch["atom14_atom_is_ambiguous"] + mask = ( + atom14_gt_exists[..., None, :, None] + * atom14_atom_is_ambiguous[..., None, :, None] + * atom14_gt_exists[..., None, :, None, :] + * (1.0 - atom14_atom_is_ambiguous[..., None, :, None, :]) + ) + + per_res_lddt = torch.sum(mask * lddt, dim=(-1, -2, -3)) + alt_per_res_lddt = torch.sum(mask * alt_lddt, dim=(-1, -2, -3)) + + fp_type = atom14_pred_positions.dtype + alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).type(fp_type) + + renamed_atom14_gt_positions = ( + 1.0 - alt_naming_is_better[..., None, None] + ) * atom14_gt_positions + alt_naming_is_better[ + ..., None, None + ] * atom14_alt_gt_positions + + renamed_atom14_gt_mask = ( + 1.0 - alt_naming_is_better[..., None] + ) * atom14_gt_exists + alt_naming_is_better[..., None] * batch[ + "atom14_alt_gt_exists" + ] + + return { + "alt_naming_is_better": alt_naming_is_better, + "renamed_atom14_gt_positions": renamed_atom14_gt_positions, + "renamed_atom14_gt_exists": renamed_atom14_gt_mask, + } + + +def binding_site_loss( + logits: torch.Tensor, + binding_site_mask: torch.Tensor, + protein_mask: torch.Tensor, + pos_class_weight: float, + **kwargs, +) -> torch.Tensor: + logits = logits.squeeze(-1) + bce_loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, binding_site_mask, reduction='none', + pos_weight=logits.new_tensor([pos_class_weight])) + masked_loss = bce_loss * protein_mask + final_loss = masked_loss.sum() / protein_mask.sum() + + return final_loss + + +def chain_center_of_mass_loss( + all_atom_pred_pos: torch.Tensor, + all_atom_positions: torch.Tensor, + all_atom_mask: torch.Tensor, + asym_id: torch.Tensor, + clamp_distance: float = -4.0, + weight: float = 0.05, + eps: float = 1e-10, **kwargs +) -> torch.Tensor: + """ + Computes chain centre-of-mass loss. Implements section 2.5, eqn 1 in the Multimer paper. + + Args: + all_atom_pred_pos: + [*, N_pts, 37, 3] All-atom predicted atom positions + all_atom_positions: + [*, N_pts, 37, 3] Ground truth all-atom positions + all_atom_mask: + [*, N_pts, 37] All-atom positions mask + asym_id: + [*, N_pts] Chain asym IDs + clamp_distance: + Cutoff above which distance errors are disregarded + weight: + Weight for loss + eps: + Small value used to regularize denominators + Returns: + [*] loss tensor + """ + ca_pos = residue_constants.atom_order["CA"] + all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] + all_atom_positions = all_atom_positions[..., ca_pos, :] + all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim + + one_hot = torch.nn.functional.one_hot(asym_id.long()).to(dtype=all_atom_mask.dtype) + one_hot = one_hot * all_atom_mask + chain_pos_mask = one_hot.transpose(-2, -1) + chain_exists = torch.any(chain_pos_mask, dim=-1).to(dtype=all_atom_positions.dtype) + + def get_chain_center_of_mass(pos): + center_sum = (chain_pos_mask[..., None] * pos[..., None, :, :]).sum(dim=-2) + centers = center_sum / (torch.sum(chain_pos_mask, dim=-1, keepdim=True) + eps) + return Vec3Array.from_array(centers) + + pred_centers = get_chain_center_of_mass(all_atom_pred_pos) # [B, NC, 3] + true_centers = get_chain_center_of_mass(all_atom_positions) # [B, NC, 3] + + pred_dists = euclidean_distance(pred_centers[..., None, :], pred_centers[..., :, None], epsilon=eps) + true_dists = euclidean_distance(true_centers[..., None, :], true_centers[..., :, None], epsilon=eps) + losses = torch.clamp((weight * (pred_dists - true_dists - clamp_distance)), max=0) ** 2 + loss_mask = chain_exists[..., :, None] * chain_exists[..., None, :] + + loss = masked_mean(loss_mask, losses, dim=(-1, -2)) + return loss + + +class AlphaFoldLoss(nn.Module): + """Aggregation of the various losses described in the supplement""" + + def __init__(self, config): + super(AlphaFoldLoss, self).__init__() + self.config = config + + def loss(self, out, batch, _return_breakdown=False): + """ + Rename previous forward() as loss() + so that can be reused in the subclass + """ + if "renamed_atom14_gt_positions" not in out.keys(): + batch.update( + compute_renamed_ground_truth( + batch, + out["sm"]["positions"][-1], + ) + ) + + loss_fns = { + "distogram": lambda: distogram_loss( + logits=out["distogram_logits"], + **{**batch, **self.config.distogram}, + ), + "positions_inter_distogram": lambda: positions_inter_distogram_loss( + out, + **{**batch, **self.config.positions_inter_distogram}, + ), + "positions_intra_distogram": lambda: positions_intra_ligand_distogram_loss( + out, + **{**batch, **self.config.positions_intra_distogram}, + ), + + "affinity1d": lambda: affinity_loss( + logits=out["affinity_1d_logits"], + **{**batch, **self.config.affinity1d}, + ), + "affinity2d": lambda: affinity_loss( + logits=out["affinity_2d_logits"], + **{**batch, **self.config.affinity2d}, + ), + "affinity_cls": lambda: affinity_loss( + logits=out["affinity_cls_logits"], + **{**batch, **self.config.affinity_cls}, + ), + "binding_site": lambda: binding_site_loss( + logits=out["binding_site_logits"], + **{**batch, **self.config.binding_site}, + ), + "inter_contact": lambda: inter_contact_loss( + logits=out["inter_contact_logits"], + **{**batch, **self.config.inter_contact}, + ), + # backbone is based on frames so only works on protein + "fape_backbone": lambda: fape_bb( + out, + batch, + self.config.fape_backbone, + ), + "fape_sidechain": lambda: fape_sidechain( + out, + batch, + self.config.fape_sidechain, + ), + "fape_interface": lambda: fape_interface( + out, + batch, + self.config.fape_interface, + ), + "plddt_loss": lambda: lddt_loss( + logits=out["lddt_logits"], + all_atom_pred_pos=out["final_atom_positions"], + **{**batch, **self.config.plddt_loss}, + ), + "supervised_chi": lambda: supervised_chi_loss( + out["sm"]["angles"], + out["sm"]["unnormalized_angles"], + **{**batch, **self.config.supervised_chi}, + ), + } + + if self.config.chain_center_of_mass.enabled: + loss_fns["chain_center_of_mass"] = lambda: chain_center_of_mass_loss( + all_atom_pred_pos=out["final_atom_positions"], + **{**batch, **self.config.chain_center_of_mass}, + ) + + cum_loss = 0. + losses = {} + loss_time_took = {} + for loss_name, loss_fn in loss_fns.items(): + start_time = time.time() + weight = self.config[loss_name].weight + loss = loss_fn() + if torch.isnan(loss) or torch.isinf(loss): + # for k,v in batch.items(): + # if torch.any(torch.isnan(v)) or torch.any(torch.isinf(v)): + # logging.warning(f"{k}: is nan") + # logging.warning(f"{loss_name}: {loss}") + logging.warning(f"{loss_name} loss is NaN. Skipping...") + loss = loss.new_tensor(0., requires_grad=True) + # else: + cum_loss = cum_loss + weight * loss + losses[loss_name] = loss.detach().clone() + loss_time_took[loss_name] = time.time() - start_time + losses["unscaled_loss"] = cum_loss.detach().clone() + # print("loss took: ", round(time.time() % 10000, 3), + # sorted(loss_time_took.items(), key=lambda x: x[1], reverse=True)) + + # Scale the loss by the square root of the minimum of the crop size and + # the (average) sequence length. See subsection 1.9. + seq_len = torch.mean(batch["seq_length"].float()) + crop_len = batch["aatype"].shape[-1] + cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len)) + + losses["loss"] = cum_loss.detach().clone() + + if not _return_breakdown: + return cum_loss + + return cum_loss, losses + + def forward(self, out, batch, _return_breakdown=False): + if not _return_breakdown: + cum_loss = self.loss(out, batch, _return_breakdown) + return cum_loss + else: + cum_loss, losses = self.loss(out, batch, _return_breakdown) + return cum_loss, losses diff --git a/dockformer/utils/lr_schedulers.py b/dockformer/utils/lr_schedulers.py new file mode 100644 index 0000000000000000000000000000000000000000..b4107ecabddd406a25b21f94bf7d0447150e5525 --- /dev/null +++ b/dockformer/utils/lr_schedulers.py @@ -0,0 +1,82 @@ +import torch + + +class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler): + """ Implements the learning rate schedule defined in the AlphaFold 2 + supplement. A linear warmup is followed by a plateau at the maximum + learning rate and then exponential decay. + + Note that the initial learning rate of the optimizer in question is + ignored; use this class' base_lr parameter to specify the starting + point of the warmup. + """ + def __init__(self, + optimizer, + last_epoch: int = -1, + verbose: bool = False, + base_lr: float = 0., + max_lr: float = 0.001, + warmup_no_steps: int = 1000, + start_decay_after_n_steps: int = 50000, + decay_every_n_steps: int = 50000, + decay_factor: float = 0.95, + ): + step_counts = { + "warmup_no_steps": warmup_no_steps, + "start_decay_after_n_steps": start_decay_after_n_steps, + } + + for k,v in step_counts.items(): + if(v < 0): + raise ValueError(f"{k} must be nonnegative") + + if(warmup_no_steps > start_decay_after_n_steps): + raise ValueError( + "warmup_no_steps must not exceed start_decay_after_n_steps" + ) + + self.optimizer = optimizer + self.last_epoch = last_epoch + self.verbose = verbose + self.base_lr = base_lr + self.max_lr = max_lr + self.warmup_no_steps = warmup_no_steps + self.start_decay_after_n_steps = start_decay_after_n_steps + self.decay_every_n_steps = decay_every_n_steps + self.decay_factor = decay_factor + + super(AlphaFoldLRScheduler, self).__init__( + optimizer, + last_epoch=last_epoch, + verbose=verbose, + ) + + def state_dict(self): + state_dict = { + k:v for k,v in self.__dict__.items() if k not in ["optimizer"] + } + + return state_dict + + def load_state_dict(self, state_dict): + self.__dict__.update(state_dict) + + def get_lr(self): + if(not self._get_lr_called_within_step): + raise RuntimeError( + "To get the last learning rate computed by the scheduler, use " + "get_last_lr()" + ) + + step_no = self.last_epoch + + if(step_no <= self.warmup_no_steps): + lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr + elif(step_no > self.start_decay_after_n_steps): + steps_since_decay = step_no - self.start_decay_after_n_steps + exp = (steps_since_decay // self.decay_every_n_steps) + 1 + lr = self.max_lr * (self.decay_factor ** exp) + else: # plateau + lr = self.max_lr + + return [lr for group in self.optimizer.param_groups] diff --git a/dockformer/utils/precision_utils.py b/dockformer/utils/precision_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..81ef568637a1f6ee30a8fc295edbb60a70f4c602 --- /dev/null +++ b/dockformer/utils/precision_utils.py @@ -0,0 +1,23 @@ +# Copyright 2022 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib + +import torch + +def is_fp16_enabled(): + # Autocast world + fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16 + fp16_enabled = fp16_enabled and torch.is_autocast_enabled() + + return fp16_enabled diff --git a/dockformer/utils/protein.py b/dockformer/utils/protein.py new file mode 100644 index 0000000000000000000000000000000000000000..56a70863007ac2b0854ead32e24751e92996594e --- /dev/null +++ b/dockformer/utils/protein.py @@ -0,0 +1,638 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Protein data type.""" +import dataclasses +import io +from typing import Any, Sequence, Mapping, Optional +import re +import string + +from dockformer.utils import residue_constants +from Bio.PDB import PDBParser +import numpy as np +import modelcif +import modelcif.model +import modelcif.dumper +import modelcif.reference +import modelcif.protocol +import modelcif.alignment +import modelcif.qa_metric + + +FeatureDict = Mapping[str, np.ndarray] +PICO_TO_ANGSTROM = 0.01 + +PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" +PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) +assert(PDB_MAX_CHAINS == 62) + + +@dataclasses.dataclass(frozen=True) +class Protein: + """Protein structure representation.""" + + # Cartesian coordinates of atoms in angstroms. The atom types correspond to + # residue_constants.atom_types, i.e. the first three are N, CA, CB. + atom_positions: np.ndarray # [num_res, num_atom_type, 3] + + # Amino-acid type for each residue represented as an integer between 0 and + # 20, where 20 is 'X'. + aatype: np.ndarray # [num_res] + + # Binary float mask to indicate presence of a particular atom. 1.0 if an atom + # is present and 0.0 if not. This should be used for loss masking. + atom_mask: np.ndarray # [num_res, num_atom_type] + + # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. + residue_index: np.ndarray # [num_res] + + # B-factors, or temperature factors, of each residue (in sq. angstroms units), + # representing the displacement of the residue from its ground truth mean + # value. + b_factors: np.ndarray # [num_res, num_atom_type] + + # Chain indices for multi-chain predictions + chain_index: Optional[np.ndarray] = None + + # Optional remark about the protein. Included as a comment in output PDB + # files + remark: Optional[str] = None + + # Templates used to generate this protein (prediction-only) + parents: Optional[Sequence[str]] = None + + # Chain corresponding to each parent + parents_chain_index: Optional[Sequence[int]] = None + + def __post_init__(self): + if(len(np.unique(self.chain_index)) > PDB_MAX_CHAINS): + raise ValueError( + f"Cannot build an instance with more than {PDB_MAX_CHAINS} " + "chains because these cannot be written to PDB format" + ) + + +def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: + """Takes a PDB string and constructs a Protein object. + + WARNING: All non-standard residue types will be converted into UNK. All + non-standard atoms will be ignored. + + Args: + pdb_str: The contents of the pdb file + chain_id: If None, then the whole pdb file is parsed. If chain_id is specified (e.g. A), then only that chain + is parsed. + + Returns: + A new `Protein` parsed from the pdb contents. + """ + pdb_fh = io.StringIO(pdb_str) + parser = PDBParser(QUIET=True) + structure = parser.get_structure("none", pdb_fh) + models = list(structure.get_models()) + if len(models) != 1: + raise ValueError( + f"Only single model PDBs are supported. Found {len(models)} models." + ) + model = models[0] + + atom_positions = [] + aatype = [] + atom_mask = [] + residue_index = [] + chain_ids = [] + b_factors = [] + + for chain in model: + if(chain_id is not None and chain.id != chain_id): + continue + + for res in chain: + if res.id[2] != " ": + raise ValueError( + f"PDB contains an insertion code at chain {chain.id} and residue " + f"index {res.id[1]}. These are not supported." + ) + + res_shortname = residue_constants.restype_3to1.get(res.resname, "X") + if res_shortname not in residue_constants.restypes: + print("Unknown residue type, skipping", res.resname) + continue + + restype_idx = residue_constants.restype_order.get( + res_shortname, residue_constants.restype_num + ) + pos = np.zeros((residue_constants.atom_type_num, 3)) + mask = np.zeros((residue_constants.atom_type_num,)) + res_b_factors = np.zeros((residue_constants.atom_type_num,)) + for atom in res: + if atom.name not in residue_constants.atom_types: + continue + pos[residue_constants.atom_order[atom.name]] = atom.coord + mask[residue_constants.atom_order[atom.name]] = 1.0 + res_b_factors[ + residue_constants.atom_order[atom.name] + ] = atom.bfactor + if np.sum(mask) < 0.5: + # If no known atom positions are reported for the residue then skip it. + continue + + aatype.append(restype_idx) + atom_positions.append(pos) + atom_mask.append(mask) + residue_index.append(res.id[1]) + chain_ids.append(chain.id) + b_factors.append(res_b_factors) + + parents = None + parents_chain_index = None + if("PARENT" in pdb_str): + parents = [] + parents_chain_index = [] + chain_id = 0 + for l in pdb_str.split("\n"): + if("PARENT" in l): + if(not "N/A" in l): + parent_names = l.split()[1:] + parents.extend(parent_names) + parents_chain_index.extend([ + chain_id for _ in parent_names + ]) + chain_id += 1 + + unique_chain_ids = np.unique(chain_ids) + chain_id_mapping = {cid: n for n, cid in enumerate(string.ascii_uppercase + string.digits + string.ascii_lowercase)} + chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids]) + + return Protein( + atom_positions=np.array(atom_positions), + atom_mask=np.array(atom_mask), + aatype=np.array(aatype), + residue_index=np.array(residue_index), + chain_index=chain_index, + b_factors=np.array(b_factors), + parents=parents, + parents_chain_index=parents_chain_index, + ) + + +def from_proteinnet_string(proteinnet_str: str) -> Protein: + tag_re = r'(\[[A-Z]+\]\n)' + tags = [ + tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0 + ] + groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]]) + + atoms = ['N', 'CA', 'C'] + aatype = None + atom_positions = None + atom_mask = None + for g in groups: + if("[PRIMARY]" == g[0]): + seq = g[1][0].strip() + for i in range(len(seq)): + if(seq[i] not in residue_constants.restypes): + seq[i] = 'X' + aatype = np.array([ + residue_constants.restype_order.get( + res_symbol, residue_constants.restype_num + ) for res_symbol in seq + ]) + elif("[TERTIARY]" == g[0]): + tertiary = [] + for axis in range(3): + tertiary.append(list(map(float, g[1][axis].split()))) + tertiary_np = np.array(tertiary) + atom_positions = np.zeros( + (len(tertiary[0])//3, residue_constants.atom_type_num, 3) + ).astype(np.float32) + for i, atom in enumerate(atoms): + atom_positions[:, residue_constants.atom_order[atom], :] = ( + np.transpose(tertiary_np[:, i::3]) + ) + atom_positions *= PICO_TO_ANGSTROM + elif("[MASK]" == g[0]): + mask = np.array(list(map({'-': 0, '+': 1}.get, g[1][0].strip()))) + atom_mask = np.zeros( + (len(mask), residue_constants.atom_type_num,) + ).astype(np.float32) + for i, atom in enumerate(atoms): + atom_mask[:, residue_constants.atom_order[atom]] = 1 + atom_mask *= mask[..., None] + + return Protein( + atom_positions=atom_positions, + atom_mask=atom_mask, + aatype=aatype, + residue_index=np.arange(len(aatype)), + b_factors=None, + ) + + +def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str: + chain_end = 'TER' + return( + f'{chain_end:<6}{atom_index:>5} {end_resname:>3} ' + f'{chain_name:>1}{residue_index:>4}' + ) + + +def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]: + pdb_headers = [] + + remark = prot.remark + if(remark is not None): + pdb_headers.append(f"REMARK {remark}") + + parents = prot.parents + parents_chain_index = prot.parents_chain_index + if(parents_chain_index is not None): + parents = [ + p for i, p in zip(parents_chain_index, parents) if i == chain_id + ] + + if(parents is None or len(parents) == 0): + parents = ["N/A"] + + pdb_headers.append(f"PARENT {' '.join(parents)}") + + return pdb_headers + + +def add_pdb_headers(prot: Protein, pdb_str: str) -> str: + """ Add pdb headers to an existing PDB string. Useful during multi-chain + recycling + """ + out_pdb_lines = [] + lines = pdb_str.split('\n') + + remark = prot.remark + if(remark is not None): + out_pdb_lines.append(f"REMARK {remark}") + + parents_per_chain = None + if(prot.parents is not None and len(prot.parents) > 0): + parents_per_chain = [] + if(prot.parents_chain_index is not None): + cur_chain = prot.parents_chain_index[0] + parent_dict = {} + for p, i in zip(prot.parents, prot.parents_chain_index): + parent_dict.setdefault(str(i), []) + parent_dict[str(i)].append(p) + + max_idx = max([int(chain_idx) for chain_idx in parent_dict]) + for i in range(max_idx + 1): + chain_parents = parent_dict.get(str(i), ["N/A"]) + parents_per_chain.append(chain_parents) + else: + parents_per_chain.append(prot.parents) + else: + parents_per_chain = [["N/A"]] + + make_parent_line = lambda p: f"PARENT {' '.join(p)}" + + out_pdb_lines.append(make_parent_line(parents_per_chain[0])) + + chain_counter = 0 + for i, l in enumerate(lines): + if("PARENT" not in l and "REMARK" not in l): + out_pdb_lines.append(l) + if("TER" in l and not "END" in lines[i + 1]): + chain_counter += 1 + if(not chain_counter >= len(parents_per_chain)): + chain_parents = parents_per_chain[chain_counter] + else: + chain_parents = ["N/A"] + + out_pdb_lines.append(make_parent_line(chain_parents)) + + return '\n'.join(out_pdb_lines) + + +def to_pdb(prot: Protein) -> str: + """Converts a `Protein` instance to a PDB string. + + Args: + prot: The protein to convert to PDB. + + Returns: + PDB string. + """ + restypes = residue_constants.restypes + ["X"] + res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], "UNK") + atom_types = residue_constants.atom_types + + pdb_lines = [] + + atom_mask = prot.atom_mask + aatype = prot.aatype + atom_positions = prot.atom_positions + residue_index = prot.residue_index.astype(np.int32) + b_factors = prot.b_factors + chain_index = prot.chain_index.astype(np.int32) + + if np.any(aatype > residue_constants.restype_num): + raise ValueError("Invalid aatypes.") + + # Construct a mapping from chain integer indices to chain ID strings. + chain_ids = {} + for i in np.unique(chain_index): # np.unique gives sorted output. + if i >= PDB_MAX_CHAINS: + raise ValueError( + f"The PDB format supports at most {PDB_MAX_CHAINS} chains." + ) + chain_ids[i] = PDB_CHAIN_IDS[i] + + headers = get_pdb_headers(prot) + if (len(headers) > 0): + pdb_lines.extend(headers) + + pdb_lines.append("MODEL 1") + n = aatype.shape[0] + atom_index = 1 + last_chain_index = chain_index[0] + prev_chain_index = 0 + chain_tags = string.ascii_uppercase + + # Add all atom sites. + for i in range(aatype.shape[0]): + # Close the previous chain if in a multichain PDB. + if last_chain_index != chain_index[i]: + pdb_lines.append( + _chain_end( + atom_index, + res_1to3(aatype[i - 1]), + chain_ids[chain_index[i - 1]], + residue_index[i - 1] + ) + ) + last_chain_index = chain_index[i] + atom_index += 1 # Atom index increases at the TER symbol. + + res_name_3 = res_1to3(aatype[i]) + for atom_name, pos, mask, b_factor in zip( + atom_types, atom_positions[i], atom_mask[i], b_factors[i] + ): + if mask < 0.5: + continue + + record_type = "ATOM" + name = atom_name if len(atom_name) == 4 else f" {atom_name}" + alt_loc = "" + insertion_code = "" + occupancy = 1.00 + element = atom_name[ + 0 + ] # Protein supports only C, N, O, S, this works. + charge = "" + + chain_tag = "A" + if(chain_index is not None): + chain_tag = chain_tags[chain_index[i]] + + # PDB is a columnar format, every space matters here! + atom_line = ( + f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}" + #TODO: check this refactor, chose main branch version + #f"{res_name_3:>3} {chain_ids[chain_index[i]]:>1}" + f"{res_name_3:>3} {chain_tag:>1}" + f"{residue_index[i]:>4}{insertion_code:>1} " + f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}" + f"{occupancy:>6.2f}{b_factor:>6.2f} " + f"{element:>2}{charge:>2}" + ) + pdb_lines.append(atom_line) + atom_index += 1 + + should_terminate = (i == n - 1) + if(chain_index is not None): + if(i != n - 1 and chain_index[i + 1] != prev_chain_index): + should_terminate = True + prev_chain_index = chain_index[i + 1] + + if(should_terminate): + # Close the chain. + chain_end = "TER" + chain_termination_line = ( + f"{chain_end:<6}{atom_index:>5} " + f"{res_1to3(aatype[i]):>3} " + f"{chain_tag:>1}{residue_index[i]:>4}" + ) + pdb_lines.append(chain_termination_line) + atom_index += 1 + + if(i != n - 1): + # "prev" is a misnomer here. This happens at the beginning of + # each new chain. + pdb_lines.extend(get_pdb_headers(prot, prev_chain_index)) + + pdb_lines.append("ENDMDL") + pdb_lines.append("END") + + # Pad all lines to 80 characters + pdb_lines = [line.ljust(80) for line in pdb_lines] + return '\n'.join(pdb_lines) + '\n' # Add terminating newline. + + +def to_modelcif(prot: Protein) -> str: + """ + Converts a `Protein` instance to a ModelCIF string. Chains with identical modelled coordinates + will be treated as the same polymer entity. But note that if chains differ in modelled regions, + no attempt is made at identifying them as a single polymer entity. + + Args: + prot: The protein to convert to PDB. + + Returns: + ModelCIF string. + """ + + restypes = residue_constants.restypes + ["X"] + atom_types = residue_constants.atom_types + + atom_mask = prot.atom_mask + aatype = prot.aatype + atom_positions = prot.atom_positions + residue_index = prot.residue_index.astype(np.int32) + b_factors = prot.b_factors + chain_index = prot.chain_index + + n = aatype.shape[0] + if chain_index is None: + chain_index = [0 for i in range(n)] + + system = modelcif.System(title='Prediction') + + # Finding chains and creating entities + seqs = {} + seq = [] + last_chain_idx = None + for i in range(n): + if last_chain_idx is not None and last_chain_idx != chain_index[i]: + seqs[last_chain_idx] = seq + seq = [] + seq.append(restypes[aatype[i]]) + last_chain_idx = chain_index[i] + # finally add the last chain + seqs[last_chain_idx] = seq + + # now reduce sequences to unique ones (note this won't work if different asyms have different unmodelled regions) + unique_seqs = {} + for chain_idx, seq_list in seqs.items(): + seq = "".join(seq_list) + if seq in unique_seqs: + unique_seqs[seq].append(chain_idx) + else: + unique_seqs[seq] = [chain_idx] + + # adding 1 entity per unique sequence + entities_map = {} + for key, value in unique_seqs.items(): + model_e = modelcif.Entity(key, description='Model subunit') + for chain_idx in value: + entities_map[chain_idx] = model_e + + chain_tags = string.ascii_uppercase + asym_unit_map = {} + for chain_idx in set(chain_index): + # Define the model assembly + chain_id = chain_tags[chain_idx] + asym = modelcif.AsymUnit(entities_map[chain_idx], details='Model subunit %s' % chain_id, id=chain_id) + asym_unit_map[chain_idx] = asym + modeled_assembly = modelcif.Assembly(asym_unit_map.values(), name='Modeled assembly') + + class _LocalPLDDT(modelcif.qa_metric.Local, modelcif.qa_metric.PLDDT): + name = "pLDDT" + software = None + description = "Predicted lddt" + + class _GlobalPLDDT(modelcif.qa_metric.Global, modelcif.qa_metric.PLDDT): + name = "pLDDT" + software = None + description = "Global pLDDT, mean of per-residue pLDDTs" + + class _MyModel(modelcif.model.AbInitioModel): + def get_atoms(self): + # Add all atom sites. + for i in range(n): + for atom_name, pos, mask, b_factor in zip( + atom_types, atom_positions[i], atom_mask[i], b_factors[i] + ): + if mask < 0.5: + continue + element = atom_name[0] # Protein supports only C, N, O, S, this works. + yield modelcif.model.Atom( + asym_unit=asym_unit_map[chain_index[i]], type_symbol=element, + seq_id=residue_index[i], atom_id=atom_name, + x=pos[0], y=pos[1], z=pos[2], + het=False, biso=b_factor, occupancy=1.00) + + def add_scores(self): + # local scores + plddt_per_residue = {} + for i in range(n): + for mask, b_factor in zip(atom_mask[i], b_factors[i]): + if mask < 0.5: + continue + # add 1 per residue, not 1 per atom + if chain_index[i] not in plddt_per_residue: + # first time a chain index is seen: add the key and start the residue dict + plddt_per_residue[chain_index[i]] = {residue_index[i]: b_factor} + if residue_index[i] not in plddt_per_residue[chain_index[i]]: + plddt_per_residue[chain_index[i]][residue_index[i]] = b_factor + plddts = [] + for chain_idx in plddt_per_residue: + for residue_idx in plddt_per_residue[chain_idx]: + plddt = plddt_per_residue[chain_idx][residue_idx] + plddts.append(plddt) + self.qa_metrics.append( + _LocalPLDDT(asym_unit_map[chain_idx].residue(residue_idx), plddt)) + # global score + self.qa_metrics.append((_GlobalPLDDT(np.mean(plddts)))) + + # Add the model and modeling protocol to the file and write them out: + model = _MyModel(assembly=modeled_assembly, name='Best scoring model') + model.add_scores() + + model_group = modelcif.model.ModelGroup([model], name='All models') + system.model_groups.append(model_group) + + fh = io.StringIO() + modelcif.dumper.write(fh, [system]) + return fh.getvalue() + + +def ideal_atom_mask(prot: Protein) -> np.ndarray: + """Computes an ideal atom mask. + + `Protein.atom_mask` typically is defined according to the atoms that are + reported in the PDB. This function computes a mask according to heavy atoms + that should be present in the given sequence of amino acids. + + Args: + prot: `Protein` whose fields are `numpy.ndarray` objects. + + Returns: + An ideal atom mask. + """ + return residue_constants.STANDARD_ATOM_MASK[prot.aatype] + + +def from_prediction( + aatype: np.ndarray, + residue_index: np.ndarray, + chain_index: np.ndarray, + atom_positions: np.ndarray, + atom_mask: np.ndarray, + b_factors: Optional[np.ndarray] = None, + remove_leading_feature_dimension: bool = True, + remark: Optional[str] = None, + parents: Optional[Sequence[str]] = None, + parents_chain_index: Optional[Sequence[int]] = None +) -> Protein: + """Assembles a protein from a prediction. + + Args: + features: Dictionary holding model inputs. + result: Dictionary holding model outputs. + b_factors: (Optional) B-factors to use for the protein. + remove_leading_feature_dimension: Whether to remove the leading dimension + of the `features` values + chain_index: (Optional) Chain indices for multi-chain predictions + remark: (Optional) Remark about the prediction + parents: (Optional) List of template names + Returns: + A protein instance. + """ + def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray: + return arr[0] if remove_leading_feature_dimension else arr + + chain_index = _maybe_remove_leading_dim(chain_index) + + if b_factors is None: + b_factors = np.zeros_like(atom_mask) + + return Protein( + aatype=_maybe_remove_leading_dim(aatype), + atom_positions=atom_positions, + atom_mask=atom_mask, + residue_index=_maybe_remove_leading_dim(residue_index), + b_factors=b_factors, + chain_index=chain_index, + remark=remark, + parents=parents, + parents_chain_index=parents_chain_index, + ) diff --git a/dockformer/utils/relax.py b/dockformer/utils/relax.py new file mode 100644 index 0000000000000000000000000000000000000000..be4786e54b2ffc01b7e303eaeb69227c928c429d --- /dev/null +++ b/dockformer/utils/relax.py @@ -0,0 +1,289 @@ +import sys +import os.path +import time +from random import shuffle + +import numpy as np +import pdbfixer +import openmm as mm +import openmm.app as mm_app +import openmm.unit as mm_unit +from openmm import CustomExternalForce +from openmm.app import Modeller +from openmmforcefields.generators import SystemGenerator +from openff.toolkit import Molecule +from openff.toolkit.utils.exceptions import UndefinedStereochemistryError, RadicalsNotSupportedError +import mdtraj +from rdkit import Chem +from rdkit.Chem import AllChem +from rdkit.Geometry import Point3D +import Bio.PDB +from Bio.SVDSuperimposer import SVDSuperimposer + + +# -- Relax protein and ligand. Code adapted from: +# https://github.com/patrickbryant1/Umol/blob/f7cd2b4de09b4e7cc1b68606791dd1cc81deeebc/src/relax/openmm_relax.py +def fix_pdb(pdb_path, hydrogen_added_pdb_path): + """Add hydrogens to the PDB file + """ + fixer = pdbfixer.PDBFixer(pdb_path) + fixer.findMissingResidues() + fixer.findNonstandardResidues() + fixer.replaceNonstandardResidues() + fixer.findMissingAtoms() + fixer.addMissingAtoms() + fixer.addMissingHydrogens(7.0) + mm_app.PDBFile.writeFile(fixer.topology, fixer.positions, open(hydrogen_added_pdb_path, 'w')) + return fixer.topology, fixer.positions + + +def minimize_energy(topology, system, positions, output_pdb_path): + '''Function that minimizes energy, given topology, OpenMM system, and positions ''' + # Use a Brownian Integrator + integrator = mm.BrownianIntegrator( + 100 * mm.unit.kelvin, + 100. / mm.unit.picoseconds, + 2.0 * mm.unit.femtoseconds + ) + simulation = mm.app.Simulation(topology, system, integrator) + + # Initialize the DCDReporter + reportInterval = 100 # Adjust this value as needed + reporter = mdtraj.reporters.DCDReporter('positions.dcd', reportInterval) + + # Add the reporter to the simulation + simulation.reporters.append(reporter) + + simulation.context.setPositions(positions) + + simulation.minimizeEnergy(1, 1000) + # Save positions + minpositions = simulation.context.getState(getPositions=True).getPositions() + mm_app.PDBFile.writeFile(topology, minpositions, open(output_pdb_path, "w")) + + reporter.close() + + return topology, minpositions + + +def add_restraints(system, topology, positions, restraint_type): + # Code adapted from https://gist.github.com/peastman/ad8cda653242d731d75e18c836b2a3a5 + restraint = CustomExternalForce('k*periodicdistance(x, y, z, x0, y0, z0)^2') + system.addForce(restraint) + restraint.addGlobalParameter('k', 100.0*mm_unit.kilojoules_per_mole/mm_unit.nanometer**2) + restraint.addPerParticleParameter('x0') + restraint.addPerParticleParameter('y0') + restraint.addPerParticleParameter('z0') + + for atom in topology.atoms(): + if restraint_type == 'protein': + if 'x' not in atom.name: + restraint.addParticle(atom.index, positions[atom.index]) + elif restraint_type == 'CA+ligand': + if ('x' in atom.name) or (atom.name == "CA"): + restraint.addParticle(atom.index, positions[atom.index]) + + return system + + +def create_joined_relaxed(protein_pdb_path: str, ligand_sdf_path: str, hydorgen_added_protein_pdb_path: str, + relaxed_joined_path: str): + restraint_type = 'CA+ligand' + + start_time = time.time() + print('Reading ligand') + try: + ligand_mol = Molecule.from_file(ligand_sdf_path) + # Check for undefined stereochemistry, allow undefined stereochemistry to be loaded + except UndefinedStereochemistryError: + print('Undefined Stereochemistry Error found! Trying with undefined stereo flag True') + ligand_mol = Molecule.from_file(ligand_sdf_path, allow_undefined_stereo=True) + # Check for radicals -- break out of script if radical is encountered + except RadicalsNotSupportedError: + print('OpenFF does not currently support radicals -- use unrelaxed structure') + sys.exit() + # Assigning partial charges first because the default method (am1bcc) does not work + ligand_mol.assign_partial_charges(partial_charge_method='gasteiger') + + # Read protein PDB and add hydrogens + protein_topology, protein_positions = fix_pdb(protein_pdb_path, hydorgen_added_protein_pdb_path) + print('Added all atoms...') + + modeller = Modeller(protein_topology, protein_positions) + print('System has %d atoms' % modeller.topology.getNumAtoms()) + + print('Adding ligand...') + lig_top = ligand_mol.to_topology() + modeller.add(lig_top.to_openmm(), lig_top.get_positions().to_openmm()) + print('System has %d atoms' % modeller.topology.getNumAtoms()) + + print('Preparing system') + # Initialize a SystemGenerator using the GAFF for the ligand and implicit water. + # forcefield_kwargs = {'constraints': mm_app.HBonds, 'rigidWater': True, 'removeCMMotion': False, + # 'hydrogenMass': 4*mm_unit.amu } + system_generator = SystemGenerator( + forcefields=['amber14-all.xml', 'implicit/gbn2.xml'], + small_molecule_forcefield='gaff-2.11', + molecules=[ligand_mol], + # forcefield_kwargs=forcefield_kwargs + ) + + system = system_generator.create_system(modeller.topology, molecules=ligand_mol) + + print('Adding restraints on protein CAs and ligand atoms') + + system = add_restraints(system, modeller.topology, modeller.positions, restraint_type=restraint_type) + + minimize_energy(modeller.topology, system, modeller.positions, relaxed_joined_path) + + print(f'Time taken for relax calculation is {time.time() - start_time:.1f} seconds') + + +# -- Fix ligand changed structure. Code adapted from: +# https://github.com/patrickbryant1/Umol/blob/f7cd2b4de09b4e7cc1b68606791dd1cc81deeebc/src/relax/align_ligand_conformer.py +def generate_best_conformer(pred_coords, ligand_smiles, max_confs=100): + """Generate conformers and compare the coords with the predicted atom positions + + Generating with constraints doesn't seem to work. + cids = Chem.rdDistGeom.EmbedMultipleConfs(m,max_confs,ps) + if len([x for x in m.GetConformers()])<1: + print('Could not generate conformer with constraints') + """ + # Generate conformers + m = Chem.AddHs(Chem.MolFromSmiles(ligand_smiles)) + # Embed in 3D to get distance matrix + AllChem.EmbedMolecule(m, maxAttempts=500) + bounds = AllChem.Get3DDistanceMatrix(m) + # Get pred distance matrix + pred_dmat = np.sqrt(1e-10 + np.sum((pred_coords[:, None] - pred_coords[None, :]) ** 2 ,axis=-1)) + # Go through the atom types and add the constraints if not H + # The order here will be the same as for the pred ligand as the smiles are identical + ai, mi = 0, 0 + bounds_mapping = {} + for atom in m.GetAtoms(): + if atom.GetSymbol() != 'H': + bounds_mapping[ai] = mi + ai += 1 + mi += 1 + + # Assign available pred bound atoms + bounds_keys = [*bounds_mapping.keys()] + for i in range(len(bounds_keys)): + key_i = bounds_keys[i] + for j in range(i+1, len(bounds_keys)): + key_j = bounds_keys[j] + try: + bounds[bounds_mapping[key_i], bounds_mapping[key_j]] = pred_dmat[i, j] + bounds[bounds_mapping[key_j], bounds_mapping[key_i]] = pred_dmat[j, i] + except: + continue + # Now generate conformers using the bounds + ps = Chem.rdDistGeom.ETKDGv3() + ps.randomSeed = 0xf00d + ps.SetBoundsMat(bounds) + cids = Chem.rdDistGeom.EmbedMultipleConfs(m, max_confs) + # Get all conformer dmats + nonH_inds = [*bounds_mapping.values()] + conf_errs = [] + for conf in m.GetConformers(): + pos = conf.GetPositions() + nonH_pos = pos[nonH_inds] + conf_dmat = np.sqrt(1e-10 + np.sum((nonH_pos[:,None]-nonH_pos[None,:])**2,axis=-1)) + err = np.mean(np.sqrt(1e-10 + (conf_dmat-pred_dmat)**2)) + conf_errs.append(err) + + # Get the best + best_conf_id = np.argmin(conf_errs) + best_conf_err = conf_errs[best_conf_id] + best_conf = [x for x in m.GetConformers()][best_conf_id] + best_conf_pos = best_conf.GetPositions() + + return best_conf, best_conf_pos, best_conf_err, [atom.GetSymbol() for atom in m.GetAtoms()], nonH_inds, m, best_conf_id + + +def align_coords_transform(pred_pos, conf_pos, nonH_inds): + """Align the predicted and conformer positions + """ + sup = SVDSuperimposer() + + sup.set(pred_pos, conf_pos[nonH_inds]) # (reference_coords, coords) + sup.run() + rot, tran = sup.get_rotran() + + # Rotate coords from new chain to its new relative position/orientation + tr_coords = np.dot(conf_pos, rot) + tran + + return tr_coords + + +def write_sdf(mol, conf, aligned_conf_pos, best_conf_id, outname): + for i in range(mol.GetNumAtoms()): + x, y, z = aligned_conf_pos[i] + conf.SetAtomPosition(i, Point3D(x, y, z)) + + writer = Chem.SDWriter(outname) + writer.write(mol, confId=int(best_conf_id)) + + +# Main function +def relax_complex(protein_pdb_path: str, ligand_sdf_path: str, relaxed_protein_path: str, relaxed_ligand_path: str): + hydorgen_added_protein_pdb_path = protein_pdb_path + "_hydrogen_added.pdb" + relaxed_joined_path = protein_pdb_path + "_joined_relaxed.pdb" + + create_joined_relaxed(protein_pdb_path, ligand_sdf_path, hydorgen_added_protein_pdb_path, relaxed_joined_path) + + parser = Bio.PDB.PDBParser(QUIET=True) + joined_structure = next(iter(parser.get_structure('', relaxed_joined_path))) + + # save the relaxed protein + io = Bio.PDB.PDBIO() + io.set_structure(joined_structure["A"]) + io.save(relaxed_protein_path) + + relaxed_ligand_coords = np.array([atom.get_coord() for atom in joined_structure["B"].get_atoms() + if atom.get_id()[0] != "H"]) + original_ligand = Chem.SDMolSupplier(ligand_sdf_path)[0] + ligand_smiles = Chem.MolToSmiles(original_ligand) + + best_conf, best_conf_pos, best_conf_err, atoms, nonH_inds, mol, best_conf_id = generate_best_conformer( + relaxed_ligand_coords, ligand_smiles, max_confs=100 + ) + + aligned_conf_pos = align_coords_transform(relaxed_ligand_coords, best_conf_pos, nonH_inds) + + write_sdf(mol, best_conf, aligned_conf_pos, best_conf_id, relaxed_ligand_path) + + +def relax_folder(folder_path: str): + all_jobnames = [] + filenames = os.listdir(folder_path) + shuffle(filenames) + for filename in filenames: + if filename.endswith("_predicted_protein.pdb"): + jobname = filename.split("_predicted_protein.pdb")[0] + ligand_path = os.path.join(folder_path, jobname + "_predicted_ligand_0.sdf") + if not os.path.exists(ligand_path): + continue + all_jobnames.append(jobname) + + success = 0 + for jobname in all_jobnames: + protein_pdb_path = os.path.join(folder_path, jobname + "_predicted_protein.pdb") + ligand_sdf_path = os.path.join(folder_path, jobname + "_predicted_ligand_0.sdf") + relaxed_protein_path = os.path.join(folder_path, jobname + "_protein_relaxed.pdb") + relaxed_ligand_path = os.path.join(folder_path, jobname + "_ligand_relaxed.sdf") + if os.path.exists(relaxed_protein_path) and os.path.exists(relaxed_ligand_path): + print("Already has relaxed", jobname) + success += 1 + continue + print("Relaxing", jobname) + try: + relax_complex(protein_pdb_path, ligand_sdf_path, relaxed_protein_path, relaxed_ligand_path) + success += 1 + except Exception as e: + print("Failed to relax", jobname, e) + print(f"Relaxed {success}/{len(all_jobnames)}") + + +if __name__ == "__main__": + relax_folder(os.path.abspath(sys.argv[1])) diff --git a/dockformer/utils/residue_constants.py b/dockformer/utils/residue_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..9bc6e6ccb8eb2a099b9b21f952fe051e3b997d81 --- /dev/null +++ b/dockformer/utils/residue_constants.py @@ -0,0 +1,1430 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Constants used in AlphaFold.""" + +import collections +import functools +import os +from typing import Mapping, List, Tuple +from importlib import resources + +import numpy as np +import tree + + +# Distance from one CA to next CA [trans configuration: omega = 180]. +ca_ca = 3.80209737096 + +# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in +# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have +# chi angles so their chi angle lists are empty. +chi_angles_atoms = { + "ALA": [], + # Chi5 in arginine is always 0 +- 5 degrees, so ignore it. + "ARG": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "NE"], + ["CG", "CD", "NE", "CZ"], + ], + "ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]], + "ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]], + "CYS": [["N", "CA", "CB", "SG"]], + "GLN": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "OE1"], + ], + "GLU": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "OE1"], + ], + "GLY": [], + "HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]], + "ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]], + "LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "LYS": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "CE"], + ["CG", "CD", "CE", "NZ"], + ], + "MET": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "SD"], + ["CB", "CG", "SD", "CE"], + ], + "PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]], + "SER": [["N", "CA", "CB", "OG"]], + "THR": [["N", "CA", "CB", "OG1"]], + "TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "VAL": [["N", "CA", "CB", "CG1"]], + "XXX": [], +} + +# If chi angles given in fixed-length array, this matrix determines how to mask +# them for each AA type. The order is as per restype_order (see below). +chi_angles_mask = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [1.0, 1.0, 1.0, 1.0], # ARG + [1.0, 1.0, 0.0, 0.0], # ASN + [1.0, 1.0, 0.0, 0.0], # ASP + [1.0, 0.0, 0.0, 0.0], # CYS + [1.0, 1.0, 1.0, 0.0], # GLN + [1.0, 1.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [1.0, 1.0, 0.0, 0.0], # HIS + [1.0, 1.0, 0.0, 0.0], # ILE + [1.0, 1.0, 0.0, 0.0], # LEU + [1.0, 1.0, 1.0, 1.0], # LYS + [1.0, 1.0, 1.0, 0.0], # MET + [1.0, 1.0, 0.0, 0.0], # PHE + [1.0, 1.0, 0.0, 0.0], # PRO + [1.0, 0.0, 0.0, 0.0], # SER + [1.0, 0.0, 0.0, 0.0], # THR + [1.0, 1.0, 0.0, 0.0], # TRP + [1.0, 1.0, 0.0, 0.0], # TYR + [1.0, 0.0, 0.0, 0.0], # VAL + [0.0, 0.0, 0.0, 0.0], # XXX +] + +# The following chi angles are pi periodic: they can be rotated by a multiple +# of pi without affecting the structure. +chi_pi_periodic = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [0.0, 0.0, 0.0, 0.0], # ARG + [0.0, 0.0, 0.0, 0.0], # ASN + [0.0, 1.0, 0.0, 0.0], # ASP + [0.0, 0.0, 0.0, 0.0], # CYS + [0.0, 0.0, 0.0, 0.0], # GLN + [0.0, 0.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [0.0, 0.0, 0.0, 0.0], # HIS + [0.0, 0.0, 0.0, 0.0], # ILE + [0.0, 0.0, 0.0, 0.0], # LEU + [0.0, 0.0, 0.0, 0.0], # LYS + [0.0, 0.0, 0.0, 0.0], # MET + [0.0, 1.0, 0.0, 0.0], # PHE + [0.0, 0.0, 0.0, 0.0], # PRO + [0.0, 0.0, 0.0, 0.0], # SER + [0.0, 0.0, 0.0, 0.0], # THR + [0.0, 0.0, 0.0, 0.0], # TRP + [0.0, 1.0, 0.0, 0.0], # TYR + [0.0, 0.0, 0.0, 0.0], # VAL + [0.0, 0.0, 0.0, 0.0], # XXX + [0.0, 0.0, 0.0, 0.0], # UNK +] + +# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi, +# psi and chi angles: +# 0: 'backbone group', +# 1: 'pre-omega-group', (empty) +# 2: 'phi-group', (currently empty, because it defines only hydrogens) +# 3: 'psi-group', +# 4,5,6,7: 'chi1,2,3,4-group' +# The atom positions are relative to the axis-end-atom of the corresponding +# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis +# is defined such that the dihedral-angle-definiting atom (the last entry in +# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate). +# format: [atomname, group_idx, rel_position] +rigid_group_atom_positions = { + "ALA": [ + ["N", 0, (-0.525, 1.363, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, -0.000, -0.000)], + ["CB", 0, (-0.529, -0.774, -1.205)], + ["O", 3, (0.627, 1.062, 0.000)], + ], + "ARG": [ + ["N", 0, (-0.524, 1.362, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, -0.000, -0.000)], + ["CB", 0, (-0.524, -0.778, -1.209)], + ["O", 3, (0.626, 1.062, 0.000)], + ["CG", 4, (0.616, 1.390, -0.000)], + ["CD", 5, (0.564, 1.414, 0.000)], + ["NE", 6, (0.539, 1.357, -0.000)], + ["NH1", 7, (0.206, 2.301, 0.000)], + ["NH2", 7, (2.078, 0.978, -0.000)], + ["CZ", 7, (0.758, 1.093, -0.000)], + ], + "ASN": [ + ["N", 0, (-0.536, 1.357, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, -0.000, -0.000)], + ["CB", 0, (-0.531, -0.787, -1.200)], + ["O", 3, (0.625, 1.062, 0.000)], + ["CG", 4, (0.584, 1.399, 0.000)], + ["ND2", 5, (0.593, -1.188, 0.001)], + ["OD1", 5, (0.633, 1.059, 0.000)], + ], + "ASP": [ + ["N", 0, (-0.525, 1.362, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.527, 0.000, -0.000)], + ["CB", 0, (-0.526, -0.778, -1.208)], + ["O", 3, (0.626, 1.062, -0.000)], + ["CG", 4, (0.593, 1.398, -0.000)], + ["OD1", 5, (0.610, 1.091, 0.000)], + ["OD2", 5, (0.592, -1.101, -0.003)], + ], + "CYS": [ + ["N", 0, (-0.522, 1.362, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.524, 0.000, 0.000)], + ["CB", 0, (-0.519, -0.773, -1.212)], + ["O", 3, (0.625, 1.062, -0.000)], + ["SG", 4, (0.728, 1.653, 0.000)], + ], + "GLN": [ + ["N", 0, (-0.526, 1.361, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, 0.000, 0.000)], + ["CB", 0, (-0.525, -0.779, -1.207)], + ["O", 3, (0.626, 1.062, -0.000)], + ["CG", 4, (0.615, 1.393, 0.000)], + ["CD", 5, (0.587, 1.399, -0.000)], + ["NE2", 6, (0.593, -1.189, -0.001)], + ["OE1", 6, (0.634, 1.060, 0.000)], + ], + "GLU": [ + ["N", 0, (-0.528, 1.361, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, -0.000, -0.000)], + ["CB", 0, (-0.526, -0.781, -1.207)], + ["O", 3, (0.626, 1.062, 0.000)], + ["CG", 4, (0.615, 1.392, 0.000)], + ["CD", 5, (0.600, 1.397, 0.000)], + ["OE1", 6, (0.607, 1.095, -0.000)], + ["OE2", 6, (0.589, -1.104, -0.001)], + ], + "GLY": [ + ["N", 0, (-0.572, 1.337, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.517, -0.000, -0.000)], + ["O", 3, (0.626, 1.062, -0.000)], + ], + "HIS": [ + ["N", 0, (-0.527, 1.360, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, 0.000, 0.000)], + ["CB", 0, (-0.525, -0.778, -1.208)], + ["O", 3, (0.625, 1.063, 0.000)], + ["CG", 4, (0.600, 1.370, -0.000)], + ["CD2", 5, (0.889, -1.021, 0.003)], + ["ND1", 5, (0.744, 1.160, -0.000)], + ["CE1", 5, (2.030, 0.851, 0.002)], + ["NE2", 5, (2.145, -0.466, 0.004)], + ], + "ILE": [ + ["N", 0, (-0.493, 1.373, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.527, -0.000, -0.000)], + ["CB", 0, (-0.536, -0.793, -1.213)], + ["O", 3, (0.627, 1.062, -0.000)], + ["CG1", 4, (0.534, 1.437, -0.000)], + ["CG2", 4, (0.540, -0.785, -1.199)], + ["CD1", 5, (0.619, 1.391, 0.000)], + ], + "LEU": [ + ["N", 0, (-0.520, 1.363, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, -0.000, -0.000)], + ["CB", 0, (-0.522, -0.773, -1.214)], + ["O", 3, (0.625, 1.063, -0.000)], + ["CG", 4, (0.678, 1.371, 0.000)], + ["CD1", 5, (0.530, 1.430, -0.000)], + ["CD2", 5, (0.535, -0.774, 1.200)], + ], + "LYS": [ + ["N", 0, (-0.526, 1.362, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, 0.000, 0.000)], + ["CB", 0, (-0.524, -0.778, -1.208)], + ["O", 3, (0.626, 1.062, -0.000)], + ["CG", 4, (0.619, 1.390, 0.000)], + ["CD", 5, (0.559, 1.417, 0.000)], + ["CE", 6, (0.560, 1.416, 0.000)], + ["NZ", 7, (0.554, 1.387, 0.000)], + ], + "MET": [ + ["N", 0, (-0.521, 1.364, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, 0.000, 0.000)], + ["CB", 0, (-0.523, -0.776, -1.210)], + ["O", 3, (0.625, 1.062, -0.000)], + ["CG", 4, (0.613, 1.391, -0.000)], + ["SD", 5, (0.703, 1.695, 0.000)], + ["CE", 6, (0.320, 1.786, -0.000)], + ], + "PHE": [ + ["N", 0, (-0.518, 1.363, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.524, 0.000, -0.000)], + ["CB", 0, (-0.525, -0.776, -1.212)], + ["O", 3, (0.626, 1.062, -0.000)], + ["CG", 4, (0.607, 1.377, 0.000)], + ["CD1", 5, (0.709, 1.195, -0.000)], + ["CD2", 5, (0.706, -1.196, 0.000)], + ["CE1", 5, (2.102, 1.198, -0.000)], + ["CE2", 5, (2.098, -1.201, -0.000)], + ["CZ", 5, (2.794, -0.003, -0.001)], + ], + "PRO": [ + ["N", 0, (-0.566, 1.351, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.527, -0.000, 0.000)], + ["CB", 0, (-0.546, -0.611, -1.293)], + ["O", 3, (0.621, 1.066, 0.000)], + ["CG", 4, (0.382, 1.445, 0.0)], + # ['CD', 5, (0.427, 1.440, 0.0)], + ["CD", 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger + ], + "SER": [ + ["N", 0, (-0.529, 1.360, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, -0.000, -0.000)], + ["CB", 0, (-0.518, -0.777, -1.211)], + ["O", 3, (0.626, 1.062, -0.000)], + ["OG", 4, (0.503, 1.325, 0.000)], + ], + "THR": [ + ["N", 0, (-0.517, 1.364, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, 0.000, -0.000)], + ["CB", 0, (-0.516, -0.793, -1.215)], + ["O", 3, (0.626, 1.062, 0.000)], + ["CG2", 4, (0.550, -0.718, -1.228)], + ["OG1", 4, (0.472, 1.353, 0.000)], + ], + "TRP": [ + ["N", 0, (-0.521, 1.363, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, -0.000, 0.000)], + ["CB", 0, (-0.523, -0.776, -1.212)], + ["O", 3, (0.627, 1.062, 0.000)], + ["CG", 4, (0.609, 1.370, -0.000)], + ["CD1", 5, (0.824, 1.091, 0.000)], + ["CD2", 5, (0.854, -1.148, -0.005)], + ["CE2", 5, (2.186, -0.678, -0.007)], + ["CE3", 5, (0.622, -2.530, -0.007)], + ["NE1", 5, (2.140, 0.690, -0.004)], + ["CH2", 5, (3.028, -2.890, -0.013)], + ["CZ2", 5, (3.283, -1.543, -0.011)], + ["CZ3", 5, (1.715, -3.389, -0.011)], + ], + "TYR": [ + ["N", 0, (-0.522, 1.362, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.524, -0.000, -0.000)], + ["CB", 0, (-0.522, -0.776, -1.213)], + ["O", 3, (0.627, 1.062, -0.000)], + ["CG", 4, (0.607, 1.382, -0.000)], + ["CD1", 5, (0.716, 1.195, -0.000)], + ["CD2", 5, (0.713, -1.194, -0.001)], + ["CE1", 5, (2.107, 1.200, -0.002)], + ["CE2", 5, (2.104, -1.201, -0.003)], + ["OH", 5, (4.168, -0.002, -0.005)], + ["CZ", 5, (2.791, -0.001, -0.003)], + ], + "VAL": [ + ["N", 0, (-0.494, 1.373, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.527, -0.000, -0.000)], + ["CB", 0, (-0.533, -0.795, -1.213)], + ["O", 3, (0.627, 1.062, -0.000)], + ["CG1", 4, (0.540, 1.429, -0.000)], + ["CG2", 4, (0.533, -0.776, 1.203)], + ], + "XXX": [ + ["N", 0, (0.000, 1.000, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (0.000, 0.000, 1.000)], + ], +} + +# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention. +residue_atoms = { + "ALA": ["C", "CA", "CB", "N", "O"], + "ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"], + "ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"], + "ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"], + "CYS": ["C", "CA", "CB", "N", "O", "SG"], + "GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"], + "GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"], + "GLY": ["C", "CA", "N", "O"], + "HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"], + "ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"], + "LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"], + "LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"], + "MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"], + "PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"], + "PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"], + "SER": ["C", "CA", "CB", "N", "O", "OG"], + "THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"], + "TRP": [ + "C", + "CA", + "CB", + "CG", + "CD1", + "CD2", + "CE2", + "CE3", + "CZ2", + "CZ3", + "CH2", + "N", + "NE1", + "O", + ], + "TYR": [ + "C", + "CA", + "CB", + "CG", + "CD1", + "CD2", + "CE1", + "CE2", + "CZ", + "N", + "O", + "OH", + ], + "VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"], + "XXX": ["N", "CA", "C"], +} + +# Naming swaps for ambiguous atom names. +# Due to symmetries in the amino acids the naming of atoms is ambiguous in +# 4 of the 20 amino acids. +# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities +# in LEU, VAL and ARG can be resolved by using the 3d constellations of +# the 'ambiguous' atoms and their neighbours) +# Because for LEU, VAL and ARG, no ambiguous exist when the prediction output is chi angle instead of the location of individual atoms. +# For the rest, ASP and others, when you rotate the bond 180 degree, you get the same configuraiton due to symmetry. + +residue_atom_renaming_swaps = { + "ASP": {"OD1": "OD2"}, + "GLU": {"OE1": "OE2"}, + "PHE": {"CD1": "CD2", "CE1": "CE2"}, + "TYR": {"CD1": "CD2", "CE1": "CE2"}, +} + +# Van der Waals radii [Angstroem] of the atoms (from Wikipedia) +van_der_waals_radius = { + "C": 1.7, + "N": 1.55, + "O": 1.52, + "S": 1.8, +} + +Bond = collections.namedtuple( + "Bond", ["atom1_name", "atom2_name", "length", "stddev"] +) +BondAngle = collections.namedtuple( + "BondAngle", + ["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"], +) + + +@functools.lru_cache(maxsize=None) +def load_stereo_chemical_props() -> Tuple[ + Mapping[str, List[Bond]], + Mapping[str, List[Bond]], + Mapping[str, List[BondAngle]], +]: + """Load stereo_chemical_props.txt into a nice structure. + + Load literature values for bond lengths and bond angles and translate + bond angles into the length of the opposite edge of the triangle + ("residue_virtual_bonds"). + + Returns: + residue_bonds: Dict that maps resname -> list of Bond tuples + residue_virtual_bonds: Dict that maps resname -> list of Bond tuples + residue_bond_angles: Dict that maps resname -> list of BondAngle tuples + """ + # TODO: this file should be downloaded in a setup script + stereo_chemical_props = resources.read_text("dockformer.resources", "stereo_chemical_props.txt") + + lines_iter = iter(stereo_chemical_props.splitlines()) + # Load bond lengths. + residue_bonds = {} + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == "-": + break + bond, resname, length, stddev = line.split() + atom1, atom2 = bond.split("-") + if resname not in residue_bonds: + residue_bonds[resname] = [] + residue_bonds[resname].append( + Bond(atom1, atom2, float(length), float(stddev)) + ) + residue_bonds["UNK"] = [] + + # Load bond angles. + residue_bond_angles = {} + next(lines_iter) # Skip empty line. + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == "-": + break + bond, resname, angle_degree, stddev_degree = line.split() + atom1, atom2, atom3 = bond.split("-") + if resname not in residue_bond_angles: + residue_bond_angles[resname] = [] + residue_bond_angles[resname].append( + BondAngle( + atom1, + atom2, + atom3, + float(angle_degree) / 180.0 * np.pi, + float(stddev_degree) / 180.0 * np.pi, + ) + ) + residue_bond_angles["UNK"] = [] + + def make_bond_key(atom1_name, atom2_name): + """Unique key to lookup bonds.""" + return "-".join(sorted([atom1_name, atom2_name])) + + # Translate bond angles into distances ("virtual bonds"). + residue_virtual_bonds = {} + for resname, bond_angles in residue_bond_angles.items(): + # Create a fast lookup dict for bond lengths. + bond_cache = {} + for b in residue_bonds[resname]: + bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b + residue_virtual_bonds[resname] = [] + for ba in bond_angles: + bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)] + bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)] + + # Compute distance between atom1 and atom3 using the law of cosines + # c^2 = a^2 + b^2 - 2ab*cos(gamma). + gamma = ba.angle_rad + length = np.sqrt( + bond1.length ** 2 + + bond2.length ** 2 + - 2 * bond1.length * bond2.length * np.cos(gamma) + ) + + # Propagation of uncertainty assuming uncorrelated errors. + dl_outer = 0.5 / length + dl_dgamma = ( + 2 * bond1.length * bond2.length * np.sin(gamma) + ) * dl_outer + dl_db1 = ( + 2 * bond1.length - 2 * bond2.length * np.cos(gamma) + ) * dl_outer + dl_db2 = ( + 2 * bond2.length - 2 * bond1.length * np.cos(gamma) + ) * dl_outer + stddev = np.sqrt( + (dl_dgamma * ba.stddev) ** 2 + + (dl_db1 * bond1.stddev) ** 2 + + (dl_db2 * bond2.stddev) ** 2 + ) + residue_virtual_bonds[resname].append( + Bond(ba.atom1_name, ba.atom3name, length, stddev) + ) + + return (residue_bonds, residue_virtual_bonds, residue_bond_angles) + + +# Between-residue bond lengths for general bonds (first element) and for Proline +# (second element). +between_res_bond_length_c_n = [1.329, 1.341] +between_res_bond_length_stddev_c_n = [0.014, 0.016] + +# Between-residue cos_angles. +between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315 +between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995 + +# This mapping is used when we need to store atom data in a format that requires +# fixed atom data size for every residue (e.g. a numpy array). +atom_types = [ + "N", + "CA", + "C", + "CB", + "O", + "CG", + "CG1", + "CG2", + "OG", + "OG1", + "SG", + "CD", + "CD1", + "CD2", + "ND1", + "ND2", + "OD1", + "OD2", + "SD", + "CE", + "CE1", + "CE2", + "CE3", + "NE", + "NE1", + "NE2", + "OE1", + "OE2", + "CH2", + "NH1", + "NH2", + "OH", + "CZ", + "CZ2", + "CZ3", + "NZ", + "OXT", +] +atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} +atom_type_num = len(atom_types) # := 37. + +# A compact atom encoding with 14 columns +# pylint: disable=line-too-long +# pylint: disable=bad-whitespace +restype_name_to_atom14_names = { + "ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""], + "ARG": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD", + "NE", + "CZ", + "NH1", + "NH2", + "", + "", + "", + ], + "ASN": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "OD1", + "ND2", + "", + "", + "", + "", + "", + "", + ], + "ASP": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "OD1", + "OD2", + "", + "", + "", + "", + "", + "", + ], + "CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""], + "GLN": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD", + "OE1", + "NE2", + "", + "", + "", + "", + "", + ], + "GLU": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD", + "OE1", + "OE2", + "", + "", + "", + "", + "", + ], + "GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""], + "HIS": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "ND1", + "CD2", + "CE1", + "NE2", + "", + "", + "", + "", + ], + "ILE": [ + "N", + "CA", + "C", + "O", + "CB", + "CG1", + "CG2", + "CD1", + "", + "", + "", + "", + "", + "", + ], + "LEU": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD1", + "CD2", + "", + "", + "", + "", + "", + "", + ], + "LYS": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD", + "CE", + "NZ", + "", + "", + "", + "", + "", + ], + "MET": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "SD", + "CE", + "", + "", + "", + "", + "", + "", + ], + "PHE": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD1", + "CD2", + "CE1", + "CE2", + "CZ", + "", + "", + "", + ], + "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""], + "SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""], + "THR": [ + "N", + "CA", + "C", + "O", + "CB", + "OG1", + "CG2", + "", + "", + "", + "", + "", + "", + "", + ], + "TRP": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD1", + "CD2", + "NE1", + "CE2", + "CE3", + "CZ2", + "CZ3", + "CH2", + ], + "TYR": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD1", + "CD2", + "CE1", + "CE2", + "CZ", + "OH", + "", + "", + ], + "VAL": [ + "N", + "CA", + "C", + "O", + "CB", + "CG1", + "CG2", + "", + "", + "", + "", + "", + "", + "", + ], + "XXX": ["N", "CA", "C", "", "", "", "", "", "", "", "", "", "", ""], + "UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""], +} +# pylint: enable=line-too-long +# pylint: enable=bad-whitespace + + +# This is the standard residue order when coding AA type as a number. +# Reproduce it by taking 3-letter AA codes and sorting them alphabetically. +restypes = [ + "A", + "R", + "N", + "D", + "C", + "Q", + "E", + "G", + "H", + "I", + "L", + "K", + "M", + "F", + "P", + "S", + "T", + "W", + "Y", + "V", + "Z" +] +restype_order = {restype: i for i, restype in enumerate(restypes)} +restype_num = len(restypes) # := 20. +unk_restype_index = restype_num # Catch-all index for unknown restypes. + +restypes_with_x = restypes + ["X"] +restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)} + + +def sequence_to_onehot( + sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False +) -> np.ndarray: + """Maps the given sequence into a one-hot encoded matrix. + + Args: + sequence: An amino acid sequence. + mapping: A dictionary mapping amino acids to integers. + map_unknown_to_x: If True, any amino acid that is not in the mapping will be + mapped to the unknown amino acid 'X'. If the mapping doesn't contain + amino acid 'X', an error will be thrown. If False, any amino acid not in + the mapping will throw an error. + + Returns: + A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of + the sequence. + + Raises: + ValueError: If the mapping doesn't contain values from 0 to + num_unique_aas - 1 without any gaps. + """ + num_entries = max(mapping.values()) + 1 + + if sorted(set(mapping.values())) != list(range(num_entries)): + raise ValueError( + "The mapping must have values from 0 to num_unique_aas-1 " + "without any gaps. Got: %s" % sorted(mapping.values()) + ) + + one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32) + + for aa_index, aa_type in enumerate(sequence): + if map_unknown_to_x: + if aa_type.isalpha() and aa_type.isupper(): + aa_id = mapping.get(aa_type, mapping["X"]) + else: + raise ValueError( + f"Invalid character in the sequence: {aa_type}" + ) + else: + aa_id = mapping[aa_type] + one_hot_arr[aa_index, aa_id] = 1 + + return one_hot_arr + + +restype_1to3 = { + "A": "ALA", + "R": "ARG", + "N": "ASN", + "D": "ASP", + "C": "CYS", + "Q": "GLN", + "E": "GLU", + "G": "GLY", + "H": "HIS", + "I": "ILE", + "L": "LEU", + "K": "LYS", + "M": "MET", + "F": "PHE", + "P": "PRO", + "S": "SER", + "T": "THR", + "W": "TRP", + "Y": "TYR", + "V": "VAL", + "Z": "XXX", +} + + +# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple +# 1-to-1 mapping of 3 letter names to one letter names. The latter contains +# many more, and less common, three letter names as keys and maps many of these +# to the same one letter name (including 'X' and 'U' which we don't use here). +restype_3to1 = {v: k for k, v in restype_1to3.items()} + +# Define a restype name for all unknown residues. +unk_restype = "UNK" + +resnames = [restype_1to3[r] for r in restypes] + [unk_restype] +resname_to_idx = {resname: i for i, resname in enumerate(resnames)} + + +def _make_standard_atom_mask() -> np.ndarray: + """Returns [num_res_types, num_atom_types] mask array.""" + # +1 to account for unknown (all 0s). + mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32) + for restype, restype_letter in enumerate(restypes): + restype_name = restype_1to3[restype_letter] + atom_names = residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = atom_order[atom_name] + mask[restype, atom_type] = 1 + return mask + + +STANDARD_ATOM_MASK = _make_standard_atom_mask() + + +# A one hot representation for the first and second atoms defining the axis +# of rotation for each chi-angle in each residue. +def chi_angle_atom(atom_index: int) -> np.ndarray: + """Define chi-angle rigid groups via one-hot representations.""" + chi_angles_index = {} + one_hots = [] + + for k, v in chi_angles_atoms.items(): + indices = [atom_types.index(s[atom_index]) for s in v] + indices.extend([-1] * (4 - len(indices))) + chi_angles_index[k] = indices + + for r in restypes: + res3 = restype_1to3[r] + one_hot = np.eye(atom_type_num)[chi_angles_index[res3]] + one_hots.append(one_hot) + + one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`. + one_hot = np.stack(one_hots, axis=0) + one_hot = np.transpose(one_hot, [0, 2, 1]) + + return one_hot + + +chi_atom_1_one_hot = chi_angle_atom(1) +chi_atom_2_one_hot = chi_angle_atom(2) + +# An array like chi_angles_atoms but using indices rather than names. +chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes] +chi_angles_atom_indices = tree.map_structure( + lambda atom_name: atom_order[atom_name], chi_angles_atom_indices +) +chi_angles_atom_indices = np.array( + [ + chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms))) + for chi_atoms in chi_angles_atom_indices + ] +) + +# Mapping from (res_name, atom_name) pairs to the atom's chi group index +# and atom index within that group. +chi_groups_for_atom = collections.defaultdict(list) +for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items(): + for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res): + for atom_i, atom in enumerate(chi_group): + chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i)) +chi_groups_for_atom = dict(chi_groups_for_atom) + + +def _make_rigid_transformation_4x4(ex, ey, translation): + """Create a rigid 4x4 transformation matrix from two axes and transl.""" + # Normalize ex. + ex_normalized = ex / np.linalg.norm(ex) + + # make ey perpendicular to ex + ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized + ey_normalized /= np.linalg.norm(ey_normalized) + + # compute ez as cross product + eznorm = np.cross(ex_normalized, ey_normalized) + m = np.stack( + [ex_normalized, ey_normalized, eznorm, translation] + ).transpose() + m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0) + return m + + +# create an array with (restype, atomtype) --> rigid_group_idx +# and an array with (restype, atomtype, coord) for the atom positions +# and compute affine transformation matrices (4,4) from one rigid group to the +# previous group +restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int) +restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) +restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32) +restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int) +restype_atom14_mask = np.zeros([21, 14], dtype=np.float32) +restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32) +restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32) + + +def _make_rigid_group_constants(): + """Fill the arrays above.""" + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + for atomname, group_idx, atom_position in rigid_group_atom_positions[ + resname + ]: + atomtype = atom_order[atomname] + restype_atom37_to_rigid_group[restype, atomtype] = group_idx + restype_atom37_mask[restype, atomtype] = 1 + restype_atom37_rigid_group_positions[ + restype, atomtype, : + ] = atom_position + + atom14idx = restype_name_to_atom14_names[resname].index(atomname) + restype_atom14_to_rigid_group[restype, atom14idx] = group_idx + restype_atom14_mask[restype, atom14idx] = 1 + restype_atom14_rigid_group_positions[ + restype, atom14idx, : + ] = atom_position + + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_positions = { + name: np.array(pos) + for name, _, pos in rigid_group_atom_positions[resname] + } + + # backbone to backbone is the identity transform + restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4) + + # pre-omega-frame to backbone (currently dummy identity matrix) + restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4) + + # phi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions["N"] - atom_positions["CA"], + ey=np.array([1.0, 0.0, 0.0]), + translation=atom_positions["N"], + ) + restype_rigid_group_default_frame[restype, 2, :, :] = mat + + # psi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions["C"] - atom_positions["CA"], + ey=atom_positions["CA"] - atom_positions["N"], + translation=atom_positions["C"], + ) + restype_rigid_group_default_frame[restype, 3, :, :] = mat + + # chi1-frame to backbone + if chi_angles_mask[restype][0]: + base_atom_names = chi_angles_atoms[resname][0] + base_atom_positions = [ + atom_positions[name] for name in base_atom_names + ] + mat = _make_rigid_transformation_4x4( + ex=base_atom_positions[2] - base_atom_positions[1], + ey=base_atom_positions[0] - base_atom_positions[1], + translation=base_atom_positions[2], + ) + restype_rigid_group_default_frame[restype, 4, :, :] = mat + + # chi2-frame to chi1-frame + # chi3-frame to chi2-frame + # chi4-frame to chi3-frame + # luckily all rotation axes for the next frame start at (0,0,0) of the + # previous frame + for chi_idx in range(1, 4): + if chi_angles_mask[restype][chi_idx]: + axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2] + axis_end_atom_position = atom_positions[axis_end_atom_name] + mat = _make_rigid_transformation_4x4( + ex=axis_end_atom_position, + ey=np.array([-1.0, 0.0, 0.0]), + translation=axis_end_atom_position, + ) + restype_rigid_group_default_frame[ + restype, 4 + chi_idx, :, : + ] = mat + + +_make_rigid_group_constants() + + +def make_atom14_dists_bounds( + overlap_tolerance=1.5, bond_length_tolerance_factor=15 +): + """compute upper and lower bounds for bonds to assess violations.""" + restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32) + residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props() + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_list = restype_name_to_atom14_names[resname] + + # create lower and upper bounds for clashes + for atom1_idx, atom1_name in enumerate(atom_list): + if not atom1_name: + continue + atom1_radius = van_der_waals_radius[atom1_name[0]] + for atom2_idx, atom2_name in enumerate(atom_list): + if (not atom2_name) or atom1_idx == atom2_idx: + continue + atom2_radius = van_der_waals_radius[atom2_name[0]] + lower = atom1_radius + atom2_radius - overlap_tolerance + upper = 1e10 + restype_atom14_bond_lower_bound[ + restype, atom1_idx, atom2_idx + ] = lower + restype_atom14_bond_lower_bound[ + restype, atom2_idx, atom1_idx + ] = lower + restype_atom14_bond_upper_bound[ + restype, atom1_idx, atom2_idx + ] = upper + restype_atom14_bond_upper_bound[ + restype, atom2_idx, atom1_idx + ] = upper + + # overwrite lower and upper bounds for bonds and angles + for b in residue_bonds[resname] + residue_virtual_bonds[resname]: + atom1_idx = atom_list.index(b.atom1_name) + atom2_idx = atom_list.index(b.atom2_name) + lower = b.length - bond_length_tolerance_factor * b.stddev + upper = b.length + bond_length_tolerance_factor * b.stddev + restype_atom14_bond_lower_bound[ + restype, atom1_idx, atom2_idx + ] = lower + restype_atom14_bond_lower_bound[ + restype, atom2_idx, atom1_idx + ] = lower + restype_atom14_bond_upper_bound[ + restype, atom1_idx, atom2_idx + ] = upper + restype_atom14_bond_upper_bound[ + restype, atom2_idx, atom1_idx + ] = upper + restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev + restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev + return { + "lower_bound": restype_atom14_bond_lower_bound, # shape (21,14,14) + "upper_bound": restype_atom14_bond_upper_bound, # shape (21,14,14) + "stddev": restype_atom14_bond_stddev, # shape (21,14,14) + } + + +restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32) +restype_atom14_ambiguous_atoms_swap_idx = np.tile( + np.arange(14, dtype=int), (21, 1) +) + + +def _make_atom14_ambiguity_feats(): + for res, pairs in residue_atom_renaming_swaps.items(): + res_idx = restype_order[restype_3to1[res]] + for atom1, atom2 in pairs.items(): + atom1_idx = restype_name_to_atom14_names[res].index(atom1) + atom2_idx = restype_name_to_atom14_names[res].index(atom2) + restype_atom14_ambiguous_atoms[res_idx, atom1_idx] = 1 + restype_atom14_ambiguous_atoms[res_idx, atom2_idx] = 1 + restype_atom14_ambiguous_atoms_swap_idx[ + res_idx, atom1_idx + ] = atom2_idx + restype_atom14_ambiguous_atoms_swap_idx[ + res_idx, atom2_idx + ] = atom1_idx + + +_make_atom14_ambiguity_feats() + + +def aatype_to_str_sequence(aatype): + return ''.join([ + restypes_with_x[aatype[i]] + for i in range(len(aatype)) + ]) + + +### ALPHAFOLD MULTIMER STUFF ### +def _make_chi_atom_indices(): + """Returns atom indices needed to compute chi angles for all residue types. + + Returns: + A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are + in the order specified in residue_constants.restypes + unknown residue type + at the end. For chi angles which are not defined on the residue, the + positions indices are by default set to 0. + """ + chi_atom_indices = [] + for residue_name in restypes: + residue_name = restype_1to3[residue_name] + residue_chi_angles = chi_angles_atoms[residue_name] + atom_indices = [] + for chi_angle in residue_chi_angles: + atom_indices.append( + [atom_order[atom] for atom in chi_angle]) + for _ in range(4 - len(atom_indices)): + atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA. + chi_atom_indices.append(atom_indices) + + chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. + + return np.array(chi_atom_indices) + + +def _make_renaming_matrices(): + """Matrices to map atoms to symmetry partners in ambiguous case.""" + # As the atom naming is ambiguous for 7 of the 20 amino acids, provide + # alternative groundtruth coordinates where the naming is swapped + restype_3 = [ + restype_1to3[res] for res in restypes + ] + restype_3 += ['UNK'] + # Matrices for renaming ambiguous atoms. + all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3} + for resname, swap in residue_atom_renaming_swaps.items(): + correspondences = np.arange(14) + for source_atom_swap, target_atom_swap in swap.items(): + source_index = restype_name_to_atom14_names[ + resname].index(source_atom_swap) + target_index = restype_name_to_atom14_names[ + resname].index(target_atom_swap) + correspondences[source_index] = target_index + correspondences[target_index] = source_index + renaming_matrix = np.zeros((14, 14), dtype=np.float32) + for index, correspondence in enumerate(correspondences): + renaming_matrix[index, correspondence] = 1. + all_matrices[resname] = renaming_matrix.astype(np.float32) + renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3]) + return renaming_matrices + + +def _make_restype_atom37_mask(): + """Mask of which atoms are present for which residue type in atom37.""" + # create the corresponding mask + restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) + for restype, restype_letter in enumerate(restypes): + restype_name = restype_1to3[restype_letter] + atom_names = residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = atom_order[atom_name] + restype_atom37_mask[restype, atom_type] = 1 + return restype_atom37_mask + + +def _make_restype_atom14_mask(): + """Mask of which atoms are present for which residue type in atom14.""" + restype_atom14_mask = [] + + for rt in restypes: + atom_names = restype_name_to_atom14_names[ + restype_1to3[rt]] + restype_atom14_mask.append([(1. if name else 0.) for name in atom_names]) + + restype_atom14_mask.append([0.] * 14) + restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32) + return restype_atom14_mask + + +def _make_restype_atom37_to_atom14(): + """Map from atom37 to atom14 per residue type.""" + restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14 + for rt in restypes: + atom_names = restype_name_to_atom14_names[ + restype_1to3[rt]] + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + restype_atom37_to_atom14.append([ + (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) + for name in atom_types + ]) + + restype_atom37_to_atom14.append([0] * 37) + restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32) + return restype_atom37_to_atom14 + + +def _make_restype_atom14_to_atom37(): + """Map from atom14 to atom37 per residue type.""" + restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37 + for rt in restypes: + atom_names = restype_name_to_atom14_names[ + restype_1to3[rt]] + restype_atom14_to_atom37.append([ + (atom_order[name] if name else 0) + for name in atom_names + ]) + # Add dummy mapping for restype 'UNK' + restype_atom14_to_atom37.append([0] * 14) + restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32) + return restype_atom14_to_atom37 + + +def _make_restype_atom14_is_ambiguous(): + """Mask which atoms are ambiguous in atom14.""" + # create an ambiguous atoms mask. shape: (21, 14) + restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32) + for resname, swap in residue_atom_renaming_swaps.items(): + for atom_name1, atom_name2 in swap.items(): + restype = restype_order[ + restype_3to1[resname]] + atom_idx1 = restype_name_to_atom14_names[resname].index( + atom_name1) + atom_idx2 = restype_name_to_atom14_names[resname].index( + atom_name2) + restype_atom14_is_ambiguous[restype, atom_idx1] = 1 + restype_atom14_is_ambiguous[restype, atom_idx2] = 1 + + return restype_atom14_is_ambiguous + + +def _make_restype_rigidgroup_base_atom37_idx(): + """Create Map from rigidgroups to atom37 indices.""" + # Create an array with the atom names. + # shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3) + base_atom_names = np.full([21, 8, 3], '', dtype=object) + + # 0: backbone frame + base_atom_names[:, 0, :] = ['C', 'CA', 'N'] + + # 3: 'psi-group' + base_atom_names[:, 3, :] = ['CA', 'C', 'O'] + + # 4,5,6,7: 'chi1,2,3,4-group' + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + for chi_idx in range(4): + if chi_angles_mask[restype][chi_idx]: + atom_names = chi_angles_atoms[resname][chi_idx] + base_atom_names[restype, chi_idx + 4, :] = atom_names[1:] + + # Translate atom names into atom37 indices. + lookuptable = atom_order.copy() + lookuptable[''] = 0 + restype_rigidgroup_base_atom37_idx = np.vectorize(lambda x: lookuptable[x])( + base_atom_names) + return restype_rigidgroup_base_atom37_idx + + +CHI_ATOM_INDICES = _make_chi_atom_indices() +RENAMING_MATRICES = _make_renaming_matrices() +RESTYPE_ATOM14_TO_ATOM37 = _make_restype_atom14_to_atom37() +RESTYPE_ATOM37_TO_ATOM14 = _make_restype_atom37_to_atom14() +RESTYPE_ATOM37_MASK = _make_restype_atom37_mask() +RESTYPE_ATOM14_MASK = _make_restype_atom14_mask() +RESTYPE_ATOM14_IS_AMBIGUOUS = _make_restype_atom14_is_ambiguous() +RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX = _make_restype_rigidgroup_base_atom37_idx() + +# Create mask for existing rigid groups. +# maybe should change RESTYPE_RIGIDGROUP_MASK to [22, 8], but currently not used? +RESTYPE_RIGIDGROUP_MASK = np.zeros([21, 8], dtype=np.float32) +RESTYPE_RIGIDGROUP_MASK[:, 0] = 1 +RESTYPE_RIGIDGROUP_MASK[:, 3] = 1 +RESTYPE_RIGIDGROUP_MASK[:len(restypes), 4:] = chi_angles_mask diff --git a/dockformer/utils/rigid_utils.py b/dockformer/utils/rigid_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6220c608da7cb7d48b855255742dc9a9db6b5324 --- /dev/null +++ b/dockformer/utils/rigid_utils.py @@ -0,0 +1,1391 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from functools import lru_cache +from typing import Tuple, Any, Sequence, Callable, Optional + +import numpy as np +import torch + + +def rot_matmul( + a: torch.Tensor, + b: torch.Tensor +) -> torch.Tensor: + """ + Performs matrix multiplication of two rotation matrix tensors. Written + out by hand to avoid AMP downcasting. + + Args: + a: [*, 3, 3] left multiplicand + b: [*, 3, 3] right multiplicand + Returns: + The product ab + """ + def row_mul(i): + return torch.stack( + [ + a[..., i, 0] * b[..., 0, 0] + + a[..., i, 1] * b[..., 1, 0] + + a[..., i, 2] * b[..., 2, 0], + a[..., i, 0] * b[..., 0, 1] + + a[..., i, 1] * b[..., 1, 1] + + a[..., i, 2] * b[..., 2, 1], + a[..., i, 0] * b[..., 0, 2] + + a[..., i, 1] * b[..., 1, 2] + + a[..., i, 2] * b[..., 2, 2], + ], + dim=-1, + ) + + return torch.stack( + [ + row_mul(0), + row_mul(1), + row_mul(2), + ], + dim=-2 + ) + + +def rot_vec_mul( + r: torch.Tensor, + t: torch.Tensor +) -> torch.Tensor: + """ + Applies a rotation to a vector. Written out by hand to avoid transfer + to avoid AMP downcasting. + + Args: + r: [*, 3, 3] rotation matrices + t: [*, 3] coordinate tensors + Returns: + [*, 3] rotated coordinates + """ + x, y, z = torch.unbind(t, dim=-1) + return torch.stack( + [ + r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z, + r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z, + r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z, + ], + dim=-1, + ) + +@lru_cache(maxsize=None) +def identity_rot_mats( + batch_dims: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, +) -> torch.Tensor: + rots = torch.eye( + 3, dtype=dtype, device=device, requires_grad=requires_grad + ) + rots = rots.view(*((1,) * len(batch_dims)), 3, 3) + rots = rots.expand(*batch_dims, -1, -1) + rots = rots.contiguous() + + return rots + + +@lru_cache(maxsize=None) +def identity_trans( + batch_dims: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, +) -> torch.Tensor: + trans = torch.zeros( + (*batch_dims, 3), + dtype=dtype, + device=device, + requires_grad=requires_grad + ) + return trans + + +@lru_cache(maxsize=None) +def identity_quats( + batch_dims: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, +) -> torch.Tensor: + quat = torch.zeros( + (*batch_dims, 4), + dtype=dtype, + device=device, + requires_grad=requires_grad + ) + + with torch.no_grad(): + quat[..., 0] = 1 + + return quat + + +_quat_elements = ["a", "b", "c", "d"] +_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements] +_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)} + + +def _to_mat(pairs): + mat = np.zeros((4, 4)) + for pair in pairs: + key, value = pair + ind = _qtr_ind_dict[key] + mat[ind // 4][ind % 4] = value + + return mat + + +_QTR_MAT = np.zeros((4, 4, 3, 3)) +_QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)]) +_QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)]) +_QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)]) +_QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)]) +_QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)]) +_QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)]) +_QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)]) +_QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)]) +_QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)]) + + +def quat_to_rot(quat: torch.Tensor) -> torch.Tensor: + """ + Converts a quaternion to a rotation matrix. + + Args: + quat: [*, 4] quaternions + Returns: + [*, 3, 3] rotation matrices + """ + # [*, 4, 4] + quat = quat[..., None] * quat[..., None, :] + + # [4, 4, 3, 3] + mat = _get_quat("_QTR_MAT", dtype=quat.dtype, device=quat.device) + + # [*, 4, 4, 3, 3] + shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape) + quat = quat[..., None, None] * shaped_qtr_mat + + # [*, 3, 3] + return torch.sum(quat, dim=(-3, -4)) + + +def rot_to_quat( + rot: torch.Tensor, +): + if(rot.shape[-2:] != (3, 3)): + raise ValueError("Input rotation is incorrectly shaped") + + rot = [[rot[..., i, j] for j in range(3)] for i in range(3)] + [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot + + k = [ + [ xx + yy + zz, zy - yz, xz - zx, yx - xy,], + [ zy - yz, xx - yy - zz, xy + yx, xz + zx,], + [ xz - zx, xy + yx, yy - xx - zz, yz + zy,], + [ yx - xy, xz + zx, yz + zy, zz - xx - yy,] + ] + + k = (1./3.) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2) + + _, vectors = torch.linalg.eigh(k) + return vectors[..., -1] + + +_QUAT_MULTIPLY = np.zeros((4, 4, 4)) +_QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0], + [ 0,-1, 0, 0], + [ 0, 0,-1, 0], + [ 0, 0, 0,-1]] + +_QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0], + [ 1, 0, 0, 0], + [ 0, 0, 0, 1], + [ 0, 0,-1, 0]] + +_QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0], + [ 0, 0, 0,-1], + [ 1, 0, 0, 0], + [ 0, 1, 0, 0]] + +_QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1], + [ 0, 0, 1, 0], + [ 0,-1, 0, 0], + [ 1, 0, 0, 0]] + +_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :] + +_CACHED_QUATS = { + "_QTR_MAT": _QTR_MAT, + "_QUAT_MULTIPLY": _QUAT_MULTIPLY, + "_QUAT_MULTIPLY_BY_VEC": _QUAT_MULTIPLY_BY_VEC +} + +@lru_cache(maxsize=None) +def _get_quat(quat_key, dtype, device): + return torch.tensor(_CACHED_QUATS[quat_key], dtype=dtype, device=device) + + +def quat_multiply(quat1, quat2): + """Multiply a quaternion by another quaternion.""" + mat = _get_quat("_QUAT_MULTIPLY", dtype=quat1.dtype, device=quat1.device) + reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape) + return torch.sum( + reshaped_mat * + quat1[..., :, None, None] * + quat2[..., None, :, None], + dim=(-3, -2) + ) + + +def quat_multiply_by_vec(quat, vec): + """Multiply a quaternion by a pure-vector quaternion.""" + mat = _get_quat("_QUAT_MULTIPLY_BY_VEC", dtype=quat.dtype, device=quat.device) + reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape) + return torch.sum( + reshaped_mat * + quat[..., :, None, None] * + vec[..., None, :, None], + dim=(-3, -2) + ) + + +def invert_rot_mat(rot_mat: torch.Tensor): + return rot_mat.transpose(-1, -2) + + +def invert_quat(quat: torch.Tensor): + quat_prime = quat.clone() + quat_prime[..., 1:] *= -1 + inv = quat_prime / torch.sum(quat ** 2, dim=-1, keepdim=True) + return inv + + +class Rotation: + """ + A 3D rotation. Depending on how the object is initialized, the + rotation is represented by either a rotation matrix or a + quaternion, though both formats are made available by helper functions. + To simplify gradient computation, the underlying format of the + rotation cannot be changed in-place. Like Rigid, the class is designed + to mimic the behavior of a torch Tensor, almost as if each Rotation + object were a tensor of rotations, in one format or another. + """ + def __init__(self, + rot_mats: Optional[torch.Tensor] = None, + quats: Optional[torch.Tensor] = None, + normalize_quats: bool = True, + ): + """ + Args: + rot_mats: + A [*, 3, 3] rotation matrix tensor. Mutually exclusive with + quats + quats: + A [*, 4] quaternion. Mutually exclusive with rot_mats. If + normalize_quats is not True, must be a unit quaternion + normalize_quats: + If quats is specified, whether to normalize quats + """ + if((rot_mats is None and quats is None) or + (rot_mats is not None and quats is not None)): + raise ValueError("Exactly one input argument must be specified") + + if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or + (quats is not None and quats.shape[-1] != 4)): + raise ValueError( + "Incorrectly shaped rotation matrix or quaternion" + ) + + # Force full-precision + if(quats is not None): + quats = quats.to(dtype=torch.float32) + if(rot_mats is not None): + rot_mats = rot_mats.to(dtype=torch.float32) + + if(quats is not None and normalize_quats): + quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True) + + self._rot_mats = rot_mats + self._quats = quats + + @staticmethod + def identity( + shape, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, + fmt: str = "quat", + ) -> Rotation: + """ + Returns an identity Rotation. + + Args: + shape: + The "shape" of the resulting Rotation object. See documentation + for the shape property + dtype: + The torch dtype for the rotation + device: + The torch device for the new rotation + requires_grad: + Whether the underlying tensors in the new rotation object + should require gradient computation + fmt: + One of "quat" or "rot_mat". Determines the underlying format + of the new object's rotation + Returns: + A new identity rotation + """ + if(fmt == "rot_mat"): + rot_mats = identity_rot_mats( + shape, dtype, device, requires_grad, + ) + return Rotation(rot_mats=rot_mats, quats=None) + elif(fmt == "quat"): + quats = identity_quats(shape, dtype, device, requires_grad) + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError(f"Invalid format: f{fmt}") + + # Magic methods + + def __getitem__(self, index: Any) -> Rotation: + """ + Allows torch-style indexing over the virtual shape of the rotation + object. See documentation for the shape property. + + Args: + index: + A torch index. E.g. (1, 3, 2), or (slice(None,)) + Returns: + The indexed rotation + """ + if type(index) != tuple: + index = (index,) + + if(self._rot_mats is not None): + rot_mats = self._rot_mats[index + (slice(None), slice(None))] + return Rotation(rot_mats=rot_mats) + elif(self._quats is not None): + quats = self._quats[index + (slice(None),)] + return Rotation(quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + def __mul__(self, + right: torch.Tensor, + ) -> Rotation: + """ + Pointwise left multiplication of the rotation with a tensor. Can be + used to e.g. mask the Rotation. + + Args: + right: + The tensor multiplicand + Returns: + The product + """ + if not(isinstance(right, torch.Tensor)): + raise TypeError("The other multiplicand must be a Tensor") + + if(self._rot_mats is not None): + rot_mats = self._rot_mats * right[..., None, None] + return Rotation(rot_mats=rot_mats, quats=None) + elif(self._quats is not None): + quats = self._quats * right[..., None] + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + def __rmul__(self, + left: torch.Tensor, + ) -> Rotation: + """ + Reverse pointwise multiplication of the rotation with a tensor. + + Args: + left: + The left multiplicand + Returns: + The product + """ + return self.__mul__(left) + + # Properties + + @property + def shape(self) -> torch.Size: + """ + Returns the virtual shape of the rotation object. This shape is + defined as the batch dimensions of the underlying rotation matrix + or quaternion. If the Rotation was initialized with a [10, 3, 3] + rotation matrix tensor, for example, the resulting shape would be + [10]. + + Returns: + The virtual shape of the rotation object + """ + s = None + if(self._quats is not None): + s = self._quats.shape[:-1] + else: + s = self._rot_mats.shape[:-2] + + return s + + @property + def dtype(self) -> torch.dtype: + """ + Returns the dtype of the underlying rotation. + + Returns: + The dtype of the underlying rotation + """ + if(self._rot_mats is not None): + return self._rot_mats.dtype + elif(self._quats is not None): + return self._quats.dtype + else: + raise ValueError("Both rotations are None") + + @property + def device(self) -> torch.device: + """ + The device of the underlying rotation + + Returns: + The device of the underlying rotation + """ + if(self._rot_mats is not None): + return self._rot_mats.device + elif(self._quats is not None): + return self._quats.device + else: + raise ValueError("Both rotations are None") + + @property + def requires_grad(self) -> bool: + """ + Returns the requires_grad property of the underlying rotation + + Returns: + The requires_grad property of the underlying tensor + """ + if(self._rot_mats is not None): + return self._rot_mats.requires_grad + elif(self._quats is not None): + return self._quats.requires_grad + else: + raise ValueError("Both rotations are None") + + def get_rot_mats(self) -> torch.Tensor: + """ + Returns the underlying rotation as a rotation matrix tensor. + + Returns: + The rotation as a rotation matrix tensor + """ + rot_mats = self._rot_mats + if(rot_mats is None): + if(self._quats is None): + raise ValueError("Both rotations are None") + else: + rot_mats = quat_to_rot(self._quats) + + return rot_mats + + def get_quats(self) -> torch.Tensor: + """ + Returns the underlying rotation as a quaternion tensor. + + Depending on whether the Rotation was initialized with a + quaternion, this function may call torch.linalg.eigh. + + Returns: + The rotation as a quaternion tensor. + """ + quats = self._quats + if(quats is None): + if(self._rot_mats is None): + raise ValueError("Both rotations are None") + else: + quats = rot_to_quat(self._rot_mats) + + return quats + + def get_cur_rot(self) -> torch.Tensor: + """ + Return the underlying rotation in its current form + + Returns: + The stored rotation + """ + if(self._rot_mats is not None): + return self._rot_mats + elif(self._quats is not None): + return self._quats + else: + raise ValueError("Both rotations are None") + + # Rotation functions + + def compose_q_update_vec(self, + q_update_vec: torch.Tensor, + normalize_quats: bool = True + ) -> Rotation: + """ + Returns a new quaternion Rotation after updating the current + object's underlying rotation with a quaternion update, formatted + as a [*, 3] tensor whose final three columns represent x, y, z such + that (1, x, y, z) is the desired (not necessarily unit) quaternion + update. + + Args: + q_update_vec: + A [*, 3] quaternion update tensor + normalize_quats: + Whether to normalize the output quaternion + Returns: + An updated Rotation + """ + quats = self.get_quats() + new_quats = quats + quat_multiply_by_vec(quats, q_update_vec) + return Rotation( + rot_mats=None, + quats=new_quats, + normalize_quats=normalize_quats, + ) + + def compose_r(self, r: Rotation) -> Rotation: + """ + Compose the rotation matrices of the current Rotation object with + those of another. + + Args: + r: + An update rotation object + Returns: + An updated rotation object + """ + r1 = self.get_rot_mats() + r2 = r.get_rot_mats() + new_rot_mats = rot_matmul(r1, r2) + return Rotation(rot_mats=new_rot_mats, quats=None) + + def compose_q(self, r: Rotation, normalize_quats: bool = True) -> Rotation: + """ + Compose the quaternions of the current Rotation object with those + of another. + + Depending on whether either Rotation was initialized with + quaternions, this function may call torch.linalg.eigh. + + Args: + r: + An update rotation object + Returns: + An updated rotation object + """ + q1 = self.get_quats() + q2 = r.get_quats() + new_quats = quat_multiply(q1, q2) + return Rotation( + rot_mats=None, quats=new_quats, normalize_quats=normalize_quats + ) + + def apply(self, pts: torch.Tensor) -> torch.Tensor: + """ + Apply the current Rotation as a rotation matrix to a set of 3D + coordinates. + + Args: + pts: + A [*, 3] set of points + Returns: + [*, 3] rotated points + """ + rot_mats = self.get_rot_mats() + return rot_vec_mul(rot_mats, pts) + + def invert_apply(self, pts: torch.Tensor) -> torch.Tensor: + """ + The inverse of the apply() method. + + Args: + pts: + A [*, 3] set of points + Returns: + [*, 3] inverse-rotated points + """ + rot_mats = self.get_rot_mats() + inv_rot_mats = invert_rot_mat(rot_mats) + return rot_vec_mul(inv_rot_mats, pts) + + def invert(self) -> Rotation: + """ + Returns the inverse of the current Rotation. + + Returns: + The inverse of the current Rotation + """ + if(self._rot_mats is not None): + return Rotation( + rot_mats=invert_rot_mat(self._rot_mats), + quats=None + ) + elif(self._quats is not None): + return Rotation( + rot_mats=None, + quats=invert_quat(self._quats), + normalize_quats=False, + ) + else: + raise ValueError("Both rotations are None") + + # "Tensor" stuff + + def unsqueeze(self, + dim: int, + ) -> Rigid: + """ + Analogous to torch.unsqueeze. The dimension is relative to the + shape of the Rotation object. + + Args: + dim: A positive or negative dimension index. + Returns: + The unsqueezed Rotation. + """ + if dim >= len(self.shape): + raise ValueError("Invalid dimension") + + if(self._rot_mats is not None): + rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2) + return Rotation(rot_mats=rot_mats, quats=None) + elif(self._quats is not None): + quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1) + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + @staticmethod + def cat( + rs: Sequence[Rotation], + dim: int, + ) -> Rigid: + """ + Concatenates rotations along one of the batch dimensions. Analogous + to torch.cat(). + + Note that the output of this operation is always a rotation matrix, + regardless of the format of input rotations. + + Args: + rs: + A list of rotation objects + dim: + The dimension along which the rotations should be + concatenated + Returns: + A concatenated Rotation object in rotation matrix format + """ + rot_mats = [r.get_rot_mats() for r in rs] + rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2) + + return Rotation(rot_mats=rot_mats, quats=None) + + def map_tensor_fn(self, + fn: Callable[torch.Tensor, torch.Tensor] + ) -> Rotation: + """ + Apply a Tensor -> Tensor function to underlying rotation tensors, + mapping over the rotation dimension(s). Can be used e.g. to sum out + a one-hot batch dimension. + + Args: + fn: + A Tensor -> Tensor function to be mapped over the Rotation + Returns: + The transformed Rotation object + """ + if(self._rot_mats is not None): + rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,)) + rot_mats = torch.stack( + list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1 + ) + rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3)) + return Rotation(rot_mats=rot_mats, quats=None) + elif(self._quats is not None): + quats = torch.stack( + list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1 + ) + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + def cuda(self) -> Rotation: + """ + Analogous to the cuda() method of torch Tensors + + Returns: + A copy of the Rotation in CUDA memory + """ + if(self._rot_mats is not None): + return Rotation(rot_mats=self._rot_mats.cuda(), quats=None) + elif(self._quats is not None): + return Rotation( + rot_mats=None, + quats=self._quats.cuda(), + normalize_quats=False + ) + else: + raise ValueError("Both rotations are None") + + def to(self, + device: Optional[torch.device], + dtype: Optional[torch.dtype] + ) -> Rotation: + """ + Analogous to the to() method of torch Tensors + + Args: + device: + A torch device + dtype: + A torch dtype + Returns: + A copy of the Rotation using the new device and dtype + """ + if(self._rot_mats is not None): + return Rotation( + rot_mats=self._rot_mats.to(device=device, dtype=dtype), + quats=None, + ) + elif(self._quats is not None): + return Rotation( + rot_mats=None, + quats=self._quats.to(device=device, dtype=dtype), + normalize_quats=False, + ) + else: + raise ValueError("Both rotations are None") + + def detach(self) -> Rotation: + """ + Returns a copy of the Rotation whose underlying Tensor has been + detached from its torch graph. + + Returns: + A copy of the Rotation whose underlying Tensor has been detached + from its torch graph + """ + if(self._rot_mats is not None): + return Rotation(rot_mats=self._rot_mats.detach(), quats=None) + elif(self._quats is not None): + return Rotation( + rot_mats=None, + quats=self._quats.detach(), + normalize_quats=False, + ) + else: + raise ValueError("Both rotations are None") + + +class Rigid: + """ + A class representing a rigid transformation. Little more than a wrapper + around two objects: a Rotation object and a [*, 3] translation + Designed to behave approximately like a single torch tensor with the + shape of the shared batch dimensions of its component parts. + """ + def __init__(self, + rots: Optional[Rotation], + trans: Optional[torch.Tensor], + ): + """ + Args: + rots: A [*, 3, 3] rotation tensor + trans: A corresponding [*, 3] translation tensor + """ + # (we need device, dtype, etc. from at least one input) + + batch_dims, dtype, device, requires_grad = None, None, None, None + if(trans is not None): + batch_dims = trans.shape[:-1] + dtype = trans.dtype + device = trans.device + requires_grad = trans.requires_grad + elif(rots is not None): + batch_dims = rots.shape + dtype = rots.dtype + device = rots.device + requires_grad = rots.requires_grad + else: + raise ValueError("At least one input argument must be specified") + + if(rots is None): + rots = Rotation.identity( + batch_dims, dtype, device, requires_grad, + ) + elif(trans is None): + trans = identity_trans( + batch_dims, dtype, device, requires_grad, + ) + + if((rots.shape != trans.shape[:-1]) or + (rots.device != trans.device)): + raise ValueError("Rots and trans incompatible") + + # Force full precision. Happens to the rotations automatically. + trans = trans.to(dtype=torch.float32) + + self._rots = rots + self._trans = trans + + @staticmethod + def identity( + shape: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, + fmt: str = "quat", + ) -> Rigid: + """ + Constructs an identity transformation. + + Args: + shape: + The desired shape + dtype: + The dtype of both internal tensors + device: + The device of both internal tensors + requires_grad: + Whether grad should be enabled for the internal tensors + Returns: + The identity transformation + """ + return Rigid( + Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt), + identity_trans(shape, dtype, device, requires_grad), + ) + + def __getitem__(self, + index: Any, + ) -> Rigid: + """ + Indexes the affine transformation with PyTorch-style indices. + The index is applied to the shared dimensions of both the rotation + and the translation. + + E.g.:: + + r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None) + t = Rigid(r, torch.rand(10, 10, 3)) + indexed = t[3, 4:6] + assert(indexed.shape == (2,)) + assert(indexed.get_rots().shape == (2,)) + assert(indexed.get_trans().shape == (2, 3)) + + Args: + index: A standard torch tensor index. E.g. 8, (10, None, 3), + or (3, slice(0, 1, None)) + Returns: + The indexed tensor + """ + if type(index) != tuple: + index = (index,) + + return Rigid( + self._rots[index], + self._trans[index + (slice(None),)], + ) + + def __mul__(self, + right: torch.Tensor, + ) -> Rigid: + """ + Pointwise left multiplication of the transformation with a tensor. + Can be used to e.g. mask the Rigid. + + Args: + right: + The tensor multiplicand + Returns: + The product + """ + if not(isinstance(right, torch.Tensor)): + raise TypeError("The other multiplicand must be a Tensor") + + new_rots = self._rots * right + new_trans = self._trans * right[..., None] + + return Rigid(new_rots, new_trans) + + def __rmul__(self, + left: torch.Tensor, + ) -> Rigid: + """ + Reverse pointwise multiplication of the transformation with a + tensor. + + Args: + left: + The left multiplicand + Returns: + The product + """ + return self.__mul__(left) + + @property + def shape(self) -> torch.Size: + """ + Returns the shape of the shared dimensions of the rotation and + the translation. + + Returns: + The shape of the transformation + """ + s = self._trans.shape[:-1] + return s + + @property + def device(self) -> torch.device: + """ + Returns the device on which the Rigid's tensors are located. + + Returns: + The device on which the Rigid's tensors are located + """ + return self._trans.device + + @property + def dtype(self) -> torch.dtype: + """ + Returns the dtype of the Rigid tensors. + + Returns: + The dtype of the Rigid tensors + """ + return self._rots.dtype + + def get_rots(self) -> Rotation: + """ + Getter for the rotation. + + Returns: + The rotation object + """ + return self._rots + + def get_trans(self) -> torch.Tensor: + """ + Getter for the translation. + + Returns: + The stored translation + """ + return self._trans + + def compose_q_update_vec(self, + q_update_vec: torch.Tensor, + ) -> Rigid: + """ + Composes the transformation with a quaternion update vector of + shape [*, 6], where the final 6 columns represent the x, y, and + z values of a quaternion of form (1, x, y, z) followed by a 3D + translation. + + Args: + q_vec: The quaternion update vector. + Returns: + The composed transformation. + """ + q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:] + new_rots = self._rots.compose_q_update_vec(q_vec) + + trans_update = self._rots.apply(t_vec) + new_translation = self._trans + trans_update + + return Rigid(new_rots, new_translation) + + def compose(self, + r: Rigid, + ) -> Rigid: + """ + Composes the current rigid object with another. + + Args: + r: + Another Rigid object + Returns: + The composition of the two transformations + """ + new_rot = self._rots.compose_r(r._rots) + new_trans = self._rots.apply(r._trans) + self._trans + return Rigid(new_rot, new_trans) + + def apply(self, + pts: torch.Tensor, + ) -> torch.Tensor: + """ + Applies the transformation to a coordinate tensor. + + Args: + pts: A [*, 3] coordinate tensor. + Returns: + The transformed points. + """ + rotated = self._rots.apply(pts) + return rotated + self._trans + + def invert_apply(self, + pts: torch.Tensor + ) -> torch.Tensor: + """ + Applies the inverse of the transformation to a coordinate tensor. + + Args: + pts: A [*, 3] coordinate tensor + Returns: + The transformed points. + """ + pts = pts - self._trans + return self._rots.invert_apply(pts) + + def invert(self) -> Rigid: + """ + Inverts the transformation. + + Returns: + The inverse transformation. + """ + rot_inv = self._rots.invert() + trn_inv = rot_inv.apply(self._trans) + + return Rigid(rot_inv, -1 * trn_inv) + + def map_tensor_fn(self, + fn: Callable[torch.Tensor, torch.Tensor] + ) -> Rigid: + """ + Apply a Tensor -> Tensor function to underlying translation and + rotation tensors, mapping over the translation/rotation dimensions + respectively. + + Args: + fn: + A Tensor -> Tensor function to be mapped over the Rigid + Returns: + The transformed Rigid object + """ + new_rots = self._rots.map_tensor_fn(fn) + new_trans = torch.stack( + list(map(fn, torch.unbind(self._trans, dim=-1))), + dim=-1 + ) + + return Rigid(new_rots, new_trans) + + def to_tensor_4x4(self) -> torch.Tensor: + """ + Converts a transformation to a homogenous transformation tensor. + + Returns: + A [*, 4, 4] homogenous transformation tensor + """ + tensor = self._trans.new_zeros((*self.shape, 4, 4)) + tensor[..., :3, :3] = self._rots.get_rot_mats() + tensor[..., :3, 3] = self._trans + tensor[..., 3, 3] = 1 + return tensor + + @staticmethod + def from_tensor_4x4( + t: torch.Tensor + ) -> Rigid: + """ + Constructs a transformation from a homogenous transformation + tensor. + + Args: + t: [*, 4, 4] homogenous transformation tensor + Returns: + T object with shape [*] + """ + if(t.shape[-2:] != (4, 4)): + raise ValueError("Incorrectly shaped input tensor") + + rots = Rotation(rot_mats=t[..., :3, :3], quats=None) + trans = t[..., :3, 3] + + return Rigid(rots, trans) + + def to_tensor_7(self) -> torch.Tensor: + """ + Converts a transformation to a tensor with 7 final columns, four + for the quaternion followed by three for the translation. + + Returns: + A [*, 7] tensor representation of the transformation + """ + tensor = self._trans.new_zeros((*self.shape, 7)) + tensor[..., :4] = self._rots.get_quats() + tensor[..., 4:] = self._trans + + return tensor + + @staticmethod + def from_tensor_7( + t: torch.Tensor, + normalize_quats: bool = False, + ) -> Rigid: + if(t.shape[-1] != 7): + raise ValueError("Incorrectly shaped input tensor") + + quats, trans = t[..., :4], t[..., 4:] + + rots = Rotation( + rot_mats=None, + quats=quats, + normalize_quats=normalize_quats + ) + + return Rigid(rots, trans) + + @staticmethod + def from_3_points( + p_neg_x_axis: torch.Tensor, + origin: torch.Tensor, + p_xy_plane: torch.Tensor, + eps: float = 1e-8 + ) -> Rigid: + """ + Implements algorithm 21. Constructs transformations from sets of 3 + points using the Gram-Schmidt algorithm. + + Args: + p_neg_x_axis: [*, 3] coordinates + origin: [*, 3] coordinates used as frame origins + p_xy_plane: [*, 3] coordinates + eps: Small epsilon value + Returns: + A transformation object of shape [*] + """ + p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1) + origin = torch.unbind(origin, dim=-1) + p_xy_plane = torch.unbind(p_xy_plane, dim=-1) + + e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)] + e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)] + + denom = torch.sqrt(sum((c * c for c in e0)) + eps) + e0 = [c / denom for c in e0] + dot = sum((c1 * c2 for c1, c2 in zip(e0, e1))) + e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)] + denom = torch.sqrt(sum((c * c for c in e1)) + eps) + e1 = [c / denom for c in e1] + e2 = [ + e0[1] * e1[2] - e0[2] * e1[1], + e0[2] * e1[0] - e0[0] * e1[2], + e0[0] * e1[1] - e0[1] * e1[0], + ] + + rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1) + rots = rots.reshape(rots.shape[:-1] + (3, 3)) + + rot_obj = Rotation(rot_mats=rots, quats=None) + + return Rigid(rot_obj, torch.stack(origin, dim=-1)) + + def unsqueeze(self, + dim: int, + ) -> Rigid: + """ + Analogous to torch.unsqueeze. The dimension is relative to the + shared dimensions of the rotation/translation. + + Args: + dim: A positive or negative dimension index. + Returns: + The unsqueezed transformation. + """ + if dim >= len(self.shape): + raise ValueError("Invalid dimension") + rots = self._rots.unsqueeze(dim) + trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1) + + return Rigid(rots, trans) + + @staticmethod + def cat( + ts: Sequence[Rigid], + dim: int, + ) -> Rigid: + """ + Concatenates transformations along a new dimension. + + Args: + ts: + A list of T objects + dim: + The dimension along which the transformations should be + concatenated + Returns: + A concatenated transformation object + """ + rots = Rotation.cat([t._rots for t in ts], dim) + trans = torch.cat( + [t._trans for t in ts], dim=dim if dim >= 0 else dim - 1 + ) + + return Rigid(rots, trans) + + def apply_rot_fn(self, fn: Callable[Rotation, Rotation]) -> Rigid: + """ + Applies a Rotation -> Rotation function to the stored rotation + object. + + Args: + fn: A function of type Rotation -> Rotation + Returns: + A transformation object with a transformed rotation. + """ + return Rigid(fn(self._rots), self._trans) + + def apply_trans_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid: + """ + Applies a Tensor -> Tensor function to the stored translation. + + Args: + fn: + A function of type Tensor -> Tensor to be applied to the + translation + Returns: + A transformation object with a transformed translation. + """ + return Rigid(self._rots, fn(self._trans)) + + def scale_translation(self, trans_scale_factor: float) -> Rigid: + """ + Scales the translation by a constant factor. + + Args: + trans_scale_factor: + The constant factor + Returns: + A transformation object with a scaled translation. + """ + fn = lambda t: t * trans_scale_factor + return self.apply_trans_fn(fn) + + def stop_rot_gradient(self) -> Rigid: + """ + Detaches the underlying rotation object + + Returns: + A transformation object with detached rotations + """ + fn = lambda r: r.detach() + return self.apply_rot_fn(fn) + + @staticmethod + def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20): + """ + Returns a transformation object from reference coordinates. + + Note that this method does not take care of symmetries. If you + provide the atom positions in the non-standard way, the N atom will + end up not at [-0.527250, 1.359329, 0.0] but instead at + [-0.527250, -1.359329, 0.0]. You need to take care of such cases in + your code. + + Args: + n_xyz: A [*, 3] tensor of nitrogen xyz coordinates. + ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates. + c_xyz: A [*, 3] tensor of carbon xyz coordinates. + Returns: + A transformation object. After applying the translation and + rotation to the reference backbone, the coordinates will + approximately equal to the input coordinates. + """ + translation = -1 * ca_xyz + n_xyz = n_xyz + translation + c_xyz = c_xyz + translation + + c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)] + norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2) + sin_c1 = -c_y / norm + cos_c1 = c_x / norm + zeros = sin_c1.new_zeros(sin_c1.shape) + ones = sin_c1.new_ones(sin_c1.shape) + + c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3)) + c1_rots[..., 0, 0] = cos_c1 + c1_rots[..., 0, 1] = -1 * sin_c1 + c1_rots[..., 1, 0] = sin_c1 + c1_rots[..., 1, 1] = cos_c1 + c1_rots[..., 2, 2] = 1 + + norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2 + c_z ** 2) + sin_c2 = c_z / norm + cos_c2 = torch.sqrt(c_x ** 2 + c_y ** 2) / norm + + c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) + c2_rots[..., 0, 0] = cos_c2 + c2_rots[..., 0, 2] = sin_c2 + c2_rots[..., 1, 1] = 1 + c2_rots[..., 2, 0] = -1 * sin_c2 + c2_rots[..., 2, 2] = cos_c2 + + c_rots = rot_matmul(c2_rots, c1_rots) + n_xyz = rot_vec_mul(c_rots, n_xyz) + + _, n_y, n_z = [n_xyz[..., i] for i in range(3)] + norm = torch.sqrt(eps + n_y ** 2 + n_z ** 2) + sin_n = -n_z / norm + cos_n = n_y / norm + + n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) + n_rots[..., 0, 0] = 1 + n_rots[..., 1, 1] = cos_n + n_rots[..., 1, 2] = -1 * sin_n + n_rots[..., 2, 1] = sin_n + n_rots[..., 2, 2] = cos_n + + rots = rot_matmul(n_rots, c_rots) + + rots = rots.transpose(-1, -2) + translation = -1 * translation + + rot_obj = Rotation(rot_mats=rots, quats=None) + + return Rigid(rot_obj, translation) + + def cuda(self) -> Rigid: + """ + Moves the transformation object to GPU memory + + Returns: + A version of the transformation on GPU + """ + return Rigid(self._rots.cuda(), self._trans.cuda()) diff --git a/dockformer/utils/script_utils.py b/dockformer/utils/script_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1ffeedeb355e9c5b043410283a6f752cdda142f1 --- /dev/null +++ b/dockformer/utils/script_utils.py @@ -0,0 +1,217 @@ +import json +import logging +import os +import re +import time +from typing import List, Tuple + +import numpy +import torch +from rdkit import Chem + +from dockformer.model.model import AlphaFold +from dockformer.utils import residue_constants, protein +from dockformer.utils.consts import POSSIBLE_ATOM_TYPES, POSSIBLE_BOND_TYPES, POSSIBLE_CHARGES, POSSIBLE_CHIRALITIES + +logging.basicConfig() +logger = logging.getLogger(__file__) +logger.setLevel(level=logging.INFO) + + +def count_models_to_evaluate(model_checkpoint_path): + model_count = 0 + if model_checkpoint_path: + model_count += len(model_checkpoint_path.split(",")) + return model_count + + +def get_model_basename(model_path): + return os.path.splitext( + os.path.basename( + os.path.normpath(model_path) + ) + )[0] + + +def make_output_directory(output_dir, model_name, multiple_model_mode): + if multiple_model_mode: + prediction_dir = os.path.join(output_dir, "predictions", model_name) + else: + prediction_dir = os.path.join(output_dir, "predictions") + os.makedirs(prediction_dir, exist_ok=True) + return prediction_dir + + +# Function to get the latest checkpoint +def get_latest_checkpoint(checkpoint_dir): + if not os.path.exists(checkpoint_dir): + return None + checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.ckpt')] + if not checkpoints: + return None + latest_checkpoint = max(checkpoints, key=lambda x: os.path.getctime(os.path.join(checkpoint_dir, x))) + return os.path.join(checkpoint_dir, latest_checkpoint) + + +def load_models_from_command_line(config, model_device, model_checkpoint_path, output_dir): + # Create the output directory + + multiple_model_mode = count_models_to_evaluate(model_checkpoint_path) > 1 + if multiple_model_mode: + logger.info(f"evaluating multiple models") + + if model_checkpoint_path: + for path in model_checkpoint_path.split(","): + model = AlphaFold(config) + model = model.eval() + checkpoint_basename = get_model_basename(path) + assert os.path.isfile(path), f"Model checkpoint not found at {path}" + ckpt_path = path + + try: + d = torch.load(ckpt_path) + except RuntimeError: + print("Loading model on CPU") + d = torch.load(ckpt_path, map_location='cpu') + + if "ema" in d: + # The public weights have had this done to them already + d = d["ema"]["params"] + model.load_state_dict(d) + + + model = model.to(model_device) + logger.info( + f"Loaded Model parameters at {path}..." + ) + output_directory = make_output_directory(output_dir, checkpoint_basename, multiple_model_mode) + yield model, output_directory + + if not model_checkpoint_path: + raise ValueError("model_checkpoint_path must be specified.") + + +def parse_fasta(data): + data = re.sub('>$', '', data, flags=re.M) + lines = [ + l.replace('\n', '') + for prot in data.split('>') for l in prot.strip().split('\n', 1) + ][1:] + tags, seqs = lines[::2], lines[1::2] + + tags = [re.split('\W| \|', t)[0] for t in tags] + + return tags, seqs + + +def update_timings(timing_dict, output_file=os.path.join(os.getcwd(), "timings.json")): + """ + Write dictionary of one or more run step times to a file + """ + if os.path.exists(output_file): + with open(output_file, "r") as f: + try: + timings = json.load(f) + except json.JSONDecodeError: + logger.info(f"Overwriting non-standard JSON in {output_file}.") + timings = {} + else: + timings = {} + timings.update(timing_dict) + with open(output_file, "w") as f: + json.dump(timings, f) + return output_file + + +def run_model(model, batch, tag, output_dir): + with torch.no_grad(): + logger.info(f"Running inference for {tag}...") + t = time.perf_counter() + out = model(batch) + inference_time = time.perf_counter() - t + logger.info(f"Inference time: {inference_time}") + update_timings({tag: {"inference": inference_time}}, os.path.join(output_dir, "timings.json")) + + return out + + +def get_molecule_from_output(atoms_atype: List[int], atom_chiralities: List[int], atom_charges: List[int], + bonds: List[Tuple[int, int, int]], atom_positions: List[Tuple[float, float, float]]): + mol = Chem.RWMol() + + assert len(atoms_atype) == len(atom_chiralities) == len(atom_charges) == len(atom_positions) + for atype_idx, chirality_idx, charge_idx in zip(atoms_atype, atom_chiralities, atom_charges): + new_atom = Chem.Atom(POSSIBLE_ATOM_TYPES[atype_idx]) + new_atom.SetChiralTag(POSSIBLE_CHIRALITIES[chirality_idx]) + new_atom.SetFormalCharge(POSSIBLE_CHARGES[charge_idx]) + + mol.AddAtom(new_atom) + + # Add bonds + for bond in bonds: + atom1, atom2, bond_type_idx = bond + bond_type = POSSIBLE_BOND_TYPES[bond_type_idx] + mol.AddBond(int(atom1), int(atom2), bond_type) + + # Set atom positions + conf = Chem.Conformer(len(atoms_atype)) + for i, pos in enumerate(atom_positions.astype(float)): + conf.SetAtomPosition(i, pos) + mol.AddConformer(conf) + return mol + + +def save_output_structure(aatype, residue_index, chain_index, plddt, final_atom_protein_positions, final_atom_mask, + ligand_atype, ligand_chiralities, ligand_charges, ligand_bonds, ligand_idx, ligand_bonds_idx, + final_ligand_atom_positions, protein_output_path, ligand_output_path): + plddt_b_factors = numpy.repeat( + plddt[..., None], residue_constants.atom_type_num, axis=-1 + ) + + unrelaxed_protein = protein.from_prediction( + aatype=aatype, + residue_index=residue_index, + chain_index=chain_index, + atom_mask=final_atom_mask, + atom_positions=final_atom_protein_positions, + b_factors=plddt_b_factors, + remove_leading_feature_dimension=False, + ) + + with open(protein_output_path, 'w') as fp: + fp.write(protein.to_pdb(unrelaxed_protein)) + + # binding_site_b_factors = numpy.repeat( + # binding_site_probs[..., None], residue_constants.atom_type_num, axis=-1 + # ) + # + # protein_binding_site = protein.from_prediction( + # aatype=aatype, + # residue_index=residue_index, + # chain_index=chain_index, + # atom_mask=final_atom_mask, + # atom_positions=final_atom_protein_positions, + # b_factors=binding_site_b_factors, + # remove_leading_feature_dimension=False, + # remark=f"affinity: {affinity:.3f}", + # ) + # + # with open(protein_affinity_output_path, 'w') as fp: + # fp.write(protein.to_pdb(protein_binding_site)) + + all_ligand_idxs = numpy.unique(ligand_idx) + for cur_ligand_idx in all_ligand_idxs: + atom_mask = ligand_idx == cur_ligand_idx + cur_ligand_atype = ligand_atype[atom_mask] + cur_ligand_chiralities = ligand_chiralities[atom_mask] + cur_ligand_charges = ligand_charges[atom_mask] + + cur_ligand_bonds = ligand_bonds[ligand_bonds_idx == cur_ligand_idx] + + cur_ligand_atom_positions = final_ligand_atom_positions[atom_mask] + + ligand = get_molecule_from_output(cur_ligand_atype, cur_ligand_chiralities, cur_ligand_charges, cur_ligand_bonds, + cur_ligand_atom_positions) + with open(ligand_output_path.format(i=cur_ligand_idx), 'w') as f: + f.write(Chem.MolToMolBlock(ligand, kekulize=False)) + print("Output written to", protein_output_path, ligand_output_path) diff --git a/dockformer/utils/superimposition.py b/dockformer/utils/superimposition.py new file mode 100644 index 0000000000000000000000000000000000000000..4992609db8024b6abc0a6900092b7657bd6fd4ce --- /dev/null +++ b/dockformer/utils/superimposition.py @@ -0,0 +1,108 @@ +# Copyright 2021 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from Bio.SVDSuperimposer import SVDSuperimposer +import numpy as np +import torch + + +def _superimpose_np(reference, coords): + """ + Superimposes coordinates onto a reference by minimizing RMSD using SVD. + + Args: + reference: + [N, 3] reference array + coords: + [N, 3] array + Returns: + A tuple of [N, 3] superimposed coords and the final RMSD. + """ + sup = SVDSuperimposer() + sup.set(reference, coords) + sup.run() + rotran = sup.get_rotran() + return sup.get_transformed(), sup.get_rms(), rotran + + +def _superimpose_single(reference, coords): + reference_np = reference.detach().cpu().numpy() + coords_np = coords.detach().cpu().numpy() + superimposed, rmsd, rotran = _superimpose_np(reference_np, coords_np) + rotran = (coords.new_tensor(rotran[0]), coords.new_tensor(rotran[1])) + return coords.new_tensor(superimposed), coords.new_tensor(rmsd), rotran + + +def superimpose(reference, coords, mask): + """ + Superimposes coordinates onto a reference by minimizing RMSD using SVD. + + Args: + reference: + [*, N, 3] reference tensor + coords: + [*, N, 3] tensor + mask: + [*, N] tensor + Returns: + A tuple of [*, N, 3] superimposed coords and [*] final RMSDs. + """ + def select_unmasked_coords(coords, mask): + return torch.masked_select( + coords, + (mask > 0.)[..., None], + ).reshape(-1, 3) + + batch_dims = reference.shape[:-2] + flat_reference = reference.reshape((-1,) + reference.shape[-2:]) + flat_coords = coords.reshape((-1,) + reference.shape[-2:]) + flat_mask = mask.reshape((-1,) + mask.shape[-1:]) + superimposed_list = [] + rmsds = [] + rotrans = [] + for r, c, m in zip(flat_reference, flat_coords, flat_mask): + r_unmasked_coords = select_unmasked_coords(r, m) + c_unmasked_coords = select_unmasked_coords(c, m) + superimposed, rmsd, rotran = _superimpose_single( + r_unmasked_coords, + c_unmasked_coords + ) + + # This is very inelegant, but idk how else to invert the masking + # procedure. + count = 0 + superimposed_full_size = torch.zeros_like(r) + for i, unmasked in enumerate(m): + if(unmasked): + superimposed_full_size[i] = superimposed[count] + count += 1 + + superimposed_list.append(superimposed_full_size) + rmsds.append(rmsd) + rotrans.append(rotran) + + superimposed_stacked = torch.stack(superimposed_list, dim=0) + rmsds_stacked = torch.stack(rmsds, dim=0) + rots = [r for r, t in rotrans] + rots_stacked = torch.stack(rots, dim=0) + trans = [t for r, t in rotrans] + trans_stacked = torch.stack(trans, dim=0) + + superimposed_reshaped = superimposed_stacked.reshape( + batch_dims + coords.shape[-2:] + ) + rmsds_reshaped = rmsds_stacked.reshape( + batch_dims + ) + + return superimposed_reshaped, rmsds_reshaped, rots_stacked, trans_stacked diff --git a/dockformer/utils/tensor_utils.py b/dockformer/utils/tensor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4b0dd0b619fb9d630a0f6675e9437b99340bc0a1 --- /dev/null +++ b/dockformer/utils/tensor_utils.py @@ -0,0 +1,122 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import logging +from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional + +import torch +import torch.nn as nn + + +def add(m1, m2, inplace): + # The first operation in a checkpoint can't be in-place, but it's + # nice to have in-place addition during inference. Thus... + if(not inplace): + m1 = m1 + m2 + else: + m1 += m2 + + return m1 + + +def permute_final_dims(tensor: torch.Tensor, inds: List[int]): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + + +def flatten_final_dims(t: torch.Tensor, no_dims: int): + return t.reshape(t.shape[:-no_dims] + (-1,)) + + +def masked_mean(mask, value, dim, eps=1e-4): + mask = mask.expand(*value.shape) + return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) + + +def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): + boundaries = torch.linspace( + min_bin, max_bin, no_bins - 1, device=pts.device + ) + dists = torch.sqrt( + torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1) + ) + return torch.bucketize(dists, boundaries) + + +def dict_multimap(fn, dicts): + first = dicts[0] + new_dict = {} + for k, v in first.items(): + all_v = [d[k] for d in dicts] + if type(v) is dict: + new_dict[k] = dict_multimap(fn, all_v) + else: + new_dict[k] = fn(all_v) + + return new_dict + + +def one_hot(x, v_bins): + reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) + diffs = x[..., None] - reshaped_bins + am = torch.argmin(torch.abs(diffs), dim=-1) + return nn.functional.one_hot(am, num_classes=len(v_bins)).float() + + +def batched_gather(data, inds, dim=0, no_batch_dims=0): + ranges = [] + for i, s in enumerate(data.shape[:no_batch_dims]): + r = torch.arange(s) + r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) + ranges.append(r) + + remaining_dims = [ + slice(None) for _ in range(len(data.shape) - no_batch_dims) + ] + remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds + ranges.extend(remaining_dims) + return data[ranges] + + +# With tree_map, a poor man's JAX tree_map +def dict_map(fn, dic, leaf_type): + new_dict = {} + for k, v in dic.items(): + # print("dictttt", k,type(v), v) + if type(v) is dict: + new_dict[k] = dict_map(fn, v, leaf_type) + else: + new_dict[k] = tree_map(fn, v, leaf_type) + + return new_dict + + +def tree_map(fn, tree, leaf_type): + if isinstance(tree, dict): + return dict_map(fn, tree, leaf_type) + elif isinstance(tree, list): + return [tree_map(fn, x, leaf_type) for x in tree] + elif isinstance(tree, tuple): + return tuple([tree_map(fn, x, leaf_type) for x in tree]) + elif isinstance(tree, leaf_type): + return fn(tree) + else: + print(type(tree)) + raise ValueError("Not supported") + + +tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) diff --git a/dockformer/utils/validation_metrics.py b/dockformer/utils/validation_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..e24dce20f455019fec6e1eedfe5837eda3ff4adb --- /dev/null +++ b/dockformer/utils/validation_metrics.py @@ -0,0 +1,79 @@ +# Copyright 2021 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +def drmsd(structure_1, structure_2, mask=None): + def prep_d(structure): + d = structure[..., :, None, :] - structure[..., None, :, :] + d = d ** 2 + d = torch.sqrt(torch.sum(d, dim=-1)) + return d + + d1 = prep_d(structure_1) + d2 = prep_d(structure_2) + + drmsd = d1 - d2 + drmsd = drmsd ** 2 + if(mask is not None): + drmsd = drmsd * (mask[..., None] * mask[..., None, :]) + drmsd = torch.sum(drmsd, dim=(-1, -2)) + n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1) + drmsd = drmsd * (1 / (n * (n - 1))) if (n > 1).all() else (drmsd * 0.) + drmsd = torch.sqrt(drmsd) + + return drmsd + + +def drmsd_np(structure_1, structure_2, mask=None): + structure_1 = torch.tensor(structure_1) + structure_2 = torch.tensor(structure_2) + if(mask is not None): + mask = torch.tensor(mask) + + return drmsd(structure_1, structure_2, mask) + + +def rmsd(structure_1, structure_2, mask=None): + + squared_dists = torch.sum((structure_1 - structure_2) ** 2, dim=-1) + if mask is None: + return torch.sqrt(torch.sum(squared_dists, dim=1) / squared_dists.shape[-1]) + squared_dists = squared_dists * mask + n = torch.sum(mask, dim=1) + return torch.sqrt(torch.sum(squared_dists, dim=1) / n) + + +def gdt(p1, p2, mask, cutoffs): + n = torch.sum(mask, dim=-1) + + p1 = p1.float() + p2 = p2.float() + distances = torch.sqrt(torch.sum((p1 - p2)**2, dim=-1)) + scores = [] + for c in cutoffs: + score = torch.sum((distances <= c) * mask, dim=-1) / n + score = torch.mean(score) + scores.append(score) + + return sum(scores) / len(scores) + + +def gdt_ts(p1, p2, mask): + return gdt(p1, p2, mask, [1., 2., 4., 8.]) + + +def gdt_ha(p1, p2, mask): + return gdt(p1, p2, mask, [0.5, 1., 2., 4.]) + diff --git a/env.yml b/env.yml new file mode 100644 index 0000000000000000000000000000000000000000..99f4575b16b4b1c7eb8bee96058a5d1a82eb6137 --- /dev/null +++ b/env.yml @@ -0,0 +1,35 @@ +name: dockformer-venv +channels: + - conda-forge + - pytorch +dependencies: + - python=3.9 +# - libgcc=7.2 + - setuptools=59.5.0 + - pip + - numpy==1.21 + - scipy==1.7 +# - openmm=7.7 +# - pdbfixer +# - cudatoolkit==11.3.* + - lightning=2.* + - biopython==1.79 + - PyYAML==5.4.1 + - requests + - tqdm==4.62.2 + - typing-extensions + - wandb==0.12.21 + - modelcif==0.7 + - awscli + - ml-collections + - aria2 + - rdkit + - git + - pytorch::pytorch=2.3.0 + - pip: +# - torch==2.3.0 --index-url https://download.pytorch.org/whl/cu118 + - dm-tree==0.1.6 +# - git+https://github.com/NVIDIA/dllogger.git + - gradio + - gradio_molecule3d + - numpy==1.21 diff --git a/env_consts.py b/env_consts.py new file mode 100644 index 0000000000000000000000000000000000000000..4fb97a12e729d8b9fb5f415029654b0bcf652eb4 --- /dev/null +++ b/env_consts.py @@ -0,0 +1,10 @@ +import os + +TEST_INPUT_DIR = None +TEST_OUTPUT_DIR = None +THIS_FILE_DIR = os.path.dirname(os.path.abspath(__file__)) +CKPT_PATH = os.path.join(THIS_FILE_DIR, "resources", "only_weights_87-172000.ckpt") +RUN_CONFIG_PATH = os.path.join(THIS_FILE_DIR, "resources", "run_config.json") + +OUTPUT_PROT_PATH = os.path.join(THIS_FILE_DIR, "predicted_protein_out.pdb") +OUTPUT_LIG_PATH = os.path.join(THIS_FILE_DIR, "predicted_lig_out.sdf") diff --git a/inference_app.py b/inference_app.py index 9a9836910d73a4b851cdf5758ee53f823cd1304b..d4187e7afdf034d2345be15d31fc0c1df59d6971 100644 --- a/inference_app.py +++ b/inference_app.py @@ -1,36 +1,36 @@ - import time import gradio as gr from gradio_molecule3d import Molecule3D +from run_on_seq import run_on_sample_seqs +from env_consts import RUN_CONFIG_PATH, OUTPUT_PROT_PATH, OUTPUT_LIG_PATH - - -def predict (input_sequence, input_ligand,input_msa, input_protein): +def predict(input_sequence, input_ligand, input_msa, input_protein): start_time = time.time() # Do inference here - # return an output pdb file with the protein and ligand with resname LIG or UNK. + # return an output pdb file with the protein and ligand with resname LIG or UNK. # also return any metrics you want to log, metrics will not be used for evaluation but might be useful for users - metrics = {"mean_plddt": 80, "binding_affinity": -2} + metrics = run_on_sample_seqs(input_sequence, input_protein, input_ligand, OUTPUT_PROT_PATH, OUTPUT_LIG_PATH, + RUN_CONFIG_PATH) end_time = time.time() run_time = end_time - start_time - return ["test_out.pdb", "test_docking_pose.sdf"], metrics, run_time -with gr.Blocks() as app: + return [OUTPUT_PROT_PATH, OUTPUT_LIG_PATH], metrics, run_time + - gr.Markdown("# Template for inference") +with gr.Blocks() as app: + gr.Markdown("DockFormer") - gr.Markdown("Title, description, and other information about the model") + # gr.Markdown("Title, description, and other information about the model") with gr.Row(): input_sequence = gr.Textbox(lines=3, label="Input Protein sequence (FASTA)") input_ligand = gr.Textbox(lines=3, label="Input ligand SMILES") with gr.Row(): input_msa = gr.File(label="Input Protein MSA (A3M)") input_protein = gr.File(label="Input protein monomer") - - + # define any options here # for automated inference the default options are used @@ -50,24 +50,25 @@ with gr.Blocks() as app: ], [input_sequence, input_ligand, input_protein], ) - reps = [ - { - "model": 0, - "style": "cartoon", - "color": "whiteCarbon", - }, + reps = [ + { + "model": 0, + "style": "cartoon", + "color": "whiteCarbon", + }, { - "model": 1, - "style": "stick", - "color": "greenCarbon", - } - - ] - + "model": 1, + "style": "stick", + "color": "greenCarbon", + } + + ] + out = Molecule3D(reps=reps) metrics = gr.JSON(label="Metrics") run_time = gr.Textbox(label="Runtime") - btn.click(predict, inputs=[input_sequence, input_ligand, input_msa, input_protein], outputs=[out,metrics, run_time]) + btn.click(predict, inputs=[input_sequence, input_ligand, input_msa, input_protein], + outputs=[out, metrics, run_time]) app.launch() diff --git a/resources/only_weights_87-172000.ckpt b/resources/only_weights_87-172000.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..c5ac785e78544fb9f999c6e469aeeb8b5cc68cdc --- /dev/null +++ b/resources/only_weights_87-172000.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9024331d1e9f39084686cbae36524589dcd1ce26896da4780d462698e6ecb83c +size 53301307 diff --git a/resources/run_config.json b/resources/run_config.json new file mode 100644 index 0000000000000000000000000000000000000000..8d5d62a06f600230138f4f27338d9165f501ca79 --- /dev/null +++ b/resources/run_config.json @@ -0,0 +1,21 @@ +{ + "stage": "initial_training", + "override_conf": { + "model": { + "evoformer_stack": { + "no_blocks": 8 + } + }, + "data": {"common": {"max_recycling_iters": 3}, "data_module": {"data_loaders": {"batch_size": 1}}, "train": {"crop_size": 355, "fixed_size": false}}, + "loss": { + "positions_intra_distogram": {"weight": 0.05}, + "inter_contact": {"weight": 0.05, "pos_class_weight": 10.0}, + "binding_site": {"weight": 0.05}, + "affinity1d": {"weight": 0.03}, + "affinity2d": {"weight": 0.03}, + "affinity_cls": {"weight": 0.03}, + "fape_interface": {"weight": 1.0} + }, + "globals": {"max_lr": 0.0001} + } +} \ No newline at end of file diff --git a/run_on_seq.py b/run_on_seq.py new file mode 100644 index 0000000000000000000000000000000000000000..f57b6cd139663e81a39206c6518c81e169374bab --- /dev/null +++ b/run_on_seq.py @@ -0,0 +1,167 @@ +import json +import os +import tempfile + +import Bio.PDB +import Bio.SeqUtils +import numpy as np +from Bio import pairwise2 +from rdkit import Chem +from rdkit.Chem import AllChem, rdMolAlign + +from run_pretrained_model import run_on_folder + + +def get_seq_based_on_template(seq: str, template_path: str, output_path: str): + # get a list of all residues in template + parser = Bio.PDB.PDBParser() + template_structure = parser.get_structure("template", template_path) + chain = template_structure[0].get_chains().__next__() + template_residues = [i for i in chain.get_residues() if "CA" in i + and Bio.SeqUtils.seq1(i.get_resname()) not in ("X", "", " ")] + template_seq = "".join([Bio.SeqUtils.seq1(i.get_resname()) for i in template_residues]) + + # align the sequence to the template + alignment = pairwise2.align.globalxx(seq, template_seq, one_alignment_only=True)[0] + aligned_seq, aligned_template_seq = alignment.seqA, alignment.seqB + + # create a new pdb file with the aligned residues + new_structure = Bio.PDB.Structure.Structure("new_structure") + new_model = Bio.PDB.Model.Model(0) + new_structure.add(new_model) + new_chain = Bio.PDB.Chain.Chain("A") # Using chain ID 'A' for the output + new_model.add(new_chain) + + template_ind = -1 + seq_ind = 0 + print(aligned_seq, aligned_template_seq, len(template_residues)) + for seq_res, template_res in zip(aligned_seq, aligned_template_seq): + if template_res != "-": + template_ind += 1 + + if seq_res != "-": + seq_ind += 1 + + if seq_res == "-": + continue + + if template_res == "-": + seq_res_3_letter = Bio.SeqUtils.seq3(seq_res).upper() + residue = Bio.PDB.Residue.Residue((' ', seq_ind, ' '), seq_res_3_letter, '') + atom = Bio.PDB.Atom.Atom("C", (0.0, 0.0, 0.0), 1.0, 1.0, ' ', "CA", 0, element="C") + residue.add(atom) + new_chain.add(residue) + else: + residue = template_residues[template_ind].copy() + residue.detach_parent() + residue.id = (' ', seq_ind, ' ') + new_chain.add(residue) + io = Bio.PDB.PDBIO() + io.set_structure(new_structure) + io.save(output_path) + + +def create_conformers(smiles, output_path, num_conformers=1, multiplier_samples=1): + target_mol = Chem.MolFromSmiles(smiles) + target_mol = Chem.AddHs(target_mol) + + params = AllChem.ETKDGv3() + params.numThreads = 0 # Use all available threads + params.pruneRmsThresh = 0.1 # Pruning threshold for RMSD + conformer_ids = AllChem.EmbedMultipleConfs(target_mol, numConfs=num_conformers * multiplier_samples, params=params) + + # Optional: Optimize each conformer using MMFF94 force field + # for conf_id in conformer_ids: + # AllChem.UFFOptimizeMolecule(target_mol, confId=conf_id) + + # remove hydrogen atoms + target_mol = Chem.RemoveHs(target_mol) + + # Save aligned conformers to a file (optional) + w = Chem.SDWriter(output_path) + for i, conf_id in enumerate(conformer_ids): + if i >= num_conformers: + break + w.write(target_mol, confId=conf_id) + w.close() + + +def create_embeded_molecule(ref_mol: Chem.Mol, smiles: str): + # Convert SMILES to a molecule + target_mol = Chem.MolFromSmiles(smiles) + assert target_mol is not None, f"Failed to parse molecule from SMILES {smiles}" + + # Set up parameters for conformer generation + params = AllChem.ETKDGv3() + params.numThreads = 0 # Use all available threads + params.pruneRmsThresh = 0.1 # Pruning threshold for RMSD + + # Generate multiple conformers + num_conformers = 1000 # Define the number of conformers to generate + conformer_ids = AllChem.EmbedMultipleConfs(target_mol, numConfs=num_conformers, params=params) + + # Optional: Optimize each conformer using MMFF94 force field + # for conf_id in conformer_ids: + # AllChem.UFFOptimizeMolecule(target_mol, confId=conf_id) + + # Align each generated conformer to the initial aligned conformer of the target molecule + rmsd_list = [] + for conf_id in conformer_ids: + rmsd = rdMolAlign.AlignMol(target_mol, ref_mol, prbCid=conf_id) + rmsd_list.append(rmsd) + + best_rmsd_index = int(np.argmin(rmsd_list)) + return target_mol, conformer_ids[best_rmsd_index], rmsd_list[best_rmsd_index] + + +def run_on_sample_seqs(seq_protein: str, template_protein_path: str, smiles: str, output_prot_path: str, + output_lig_path: str, run_config_path: str): + temp_dir = tempfile.TemporaryDirectory() + temp_dir_path = temp_dir.name + metrics = {} + + get_seq_based_on_template(seq_protein, template_protein_path, f"{temp_dir_path}/prot.pdb") + create_conformers(smiles, f"{temp_dir_path}/lig.sdf") + + json_data = { + "input_structure": f"prot.pdb", + "ref_sdf": f"lig.sdf", + } + tmp_json_folder = f"{temp_dir_path}/jsons" + os.makedirs(tmp_json_folder, exist_ok=True) + json.dump(json_data, open(f"{tmp_json_folder}/input.json", "w")) + tmp_output_folder = f"{temp_dir_path}/output" + + run_on_folder(tmp_json_folder, tmp_output_folder, run_config_path, skip_relaxation=True, + long_sequence_inference=False, skip_exists=False) + predicted_protein_path = tmp_output_folder + "/predictions/input_predicted_protein.pdb" + predicted_ligand_path = tmp_output_folder + "/predictions/input_predicted_ligand_0.sdf" + predicted_affinity = json.load(open(tmp_output_folder + "/predictions/input_predicted_affinity.json")) + metrics = {**metrics, **predicted_affinity} + + try: + original_pred_ligand = Chem.MolFromMolFile(predicted_ligand_path, sanitize=False) + try: + original_pred_ligand = Chem.RemoveHs(original_pred_ligand) + except Exception as e: + print("Failed to remove hydrogens", e) + + assert original_pred_ligand is not None, f"Failed to parse ligand from {predicted_ligand_path}" + rembed_pred_ligand, conf_id, rmsd = create_embeded_molecule(original_pred_ligand, smiles) + metrics["ligand_reembed_rmsd"] = rmsd + print("reembed with rmsd", rmsd) + + # save conformation to predicted_ligand_path + w = Chem.SDWriter(predicted_ligand_path) + w.write(rembed_pred_ligand, confId=conf_id) + w.close() + except Exception as e: + print("Failed to reembed the ligand", e) + + os.rename(predicted_protein_path, output_prot_path) + os.rename(predicted_ligand_path , output_lig_path) + print("moved output to ", output_prot_path, output_lig_path) + + temp_dir.cleanup() + + return metrics diff --git a/run_pretrained_model.py b/run_pretrained_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7e2b426098fbe26051fa61514cdc01223da34e3d --- /dev/null +++ b/run_pretrained_model.py @@ -0,0 +1,197 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +from env_consts import TEST_INPUT_DIR, TEST_OUTPUT_DIR, CKPT_PATH +import json +import logging +import numpy as np +import os +import pickle + +from dockformer.data.data_modules import OpenFoldSingleDataset + +logging.basicConfig() +logger = logging.getLogger(__file__) +logger.setLevel(level=logging.INFO) + +import torch +torch_versions = torch.__version__.split(".") +torch_major_version = int(torch_versions[0]) +torch_minor_version = int(torch_versions[1]) +if ( + torch_major_version > 1 or + (torch_major_version == 1 and torch_minor_version >= 12) +): + # Gives a large speedup on Ampere-class GPUs + torch.set_float32_matmul_precision("high") + +torch.set_grad_enabled(False) + +from dockformer.config import model_config +from dockformer.utils.script_utils import (load_models_from_command_line, run_model, save_output_structure, + get_latest_checkpoint) +from dockformer.utils.tensor_utils import tensor_tree_map + + +def list_files_with_extensions(dir, extensions): + return [f for f in os.listdir(dir) if f.endswith(extensions)] + + +def override_config(base_config, overriding_config): + for k, v in overriding_config.items(): + if isinstance(v, dict): + base_config[k] = override_config(base_config[k], v) + else: + base_config[k] = v + return base_config + + +def run_on_folder(input_dir: str, output_dir: str, run_config_path: str, skip_relaxation=True, + long_sequence_inference=False, skip_exists=False): + config_preset = "initial_training" + save_outputs = False + device_name = "cuda" if torch.cuda.is_available() else "cpu" + + run_config = json.load(open(run_config_path)) + + ckpt_path = CKPT_PATH + if ckpt_path is None: + ckpt_path = get_latest_checkpoint(os.path.join(run_config["train_output_dir"], "checkpoint")) + print("Using checkpoint: ", ckpt_path) + + config = model_config(config_preset, long_sequence_inference=long_sequence_inference) + config = override_config(config, run_config.get("override_conf", {})) + + model_generator = load_models_from_command_line( + config, + model_device=device_name, + model_checkpoint_path=ckpt_path, + output_dir=output_dir) + print("Model loaded") + model, output_directory = next(model_generator) + + dataset = OpenFoldSingleDataset(data_dir=input_dir, config=config.data, mode="predict") + for i, processed_feature_dict in enumerate(dataset): + tag = dataset.get_metadata_for_idx(i)["input_name"] + print("Processing", tag) + output_name = f"{tag}_predicted" + protein_output_path = os.path.join(output_directory, f'{output_name}_protein.pdb') + if os.path.exists(protein_output_path) and skip_exists: + print("skipping exists", output_name) + continue + + # turn into a batch of size 1 + processed_feature_dict = {key: value.unsqueeze(0).to(device_name) + for key, value in processed_feature_dict.items()} + + out = run_model(model, processed_feature_dict, tag, output_dir) + + # Toss out the recycling dimensions --- we don't need them anymore + processed_feature_dict = tensor_tree_map( + lambda x: np.array(x[..., -1].cpu()), + processed_feature_dict + ) + out = tensor_tree_map(lambda x: np.array(x.cpu()), out) + + affinity_output_path = os.path.join(output_directory, f'{output_name}_affinity.json') + # affinity = torch.sum(torch.softmax(torch.tensor(out["affinity_2d_logits"]), -1) * torch.linspace(0, 15, 32), + # dim=-1).item() + affinity_2d = torch.sum(torch.softmax(torch.tensor(out["affinity_2d_logits"]), -1) * torch.linspace(0, 15, 32), + dim=-1).item() + affinity_1d = torch.sum(torch.softmax(torch.tensor(out["affinity_1d_logits"]), -1) * torch.linspace(0, 15, 32), + dim=-1).item() + affinity_cls = torch.sum(torch.softmax(torch.tensor(out["affinity_cls_logits"]), -1) * torch.linspace(0, 15, 32), + dim=-1).item() + + + affinity_2d_max = torch.linspace(0, 15, 32)[torch.argmax(torch.tensor(out["affinity_2d_logits"]))].item() + affinity_1d_max = torch.linspace(0, 15, 32)[torch.argmax(torch.tensor(out["affinity_1d_logits"]))].item() + affinity_cls_max = torch.linspace(0, 15, 32)[torch.argmax(torch.tensor(out["affinity_cls_logits"]))].item() + + print("Affinity: ", affinity_2d, affinity_cls, affinity_1d) + with open(affinity_output_path, "w") as f: + json.dump({"affinity_2d": affinity_2d, "affinity_1d": affinity_1d, "affinity_cls": affinity_cls, + "affinity_2d_max": affinity_2d_max, "affinity_1d_max": affinity_1d_max, + "affinity_cls_max": affinity_cls_max}, f) + + # binding_site = torch.sigmoid(torch.tensor(out["binding_site_logits"])) * 100 + # binding_site = binding_site[:processed_feature_dict["aatype"].shape[1]].flatten() + + # predicted_contacts = torch.sigmoid(torch.tensor(out["inter_contact_logits"])) * 100 + # binding_site = torch.max(predicted_contacts, dim=2).values.flatten() + + ligand_output_path = os.path.join(output_directory, f"{output_name}_ligand_{{i}}.sdf") + + protein_mask = processed_feature_dict["protein_mask"][0].astype(bool) + ligand_mask = processed_feature_dict["ligand_mask"][0].astype(bool) + + save_output_structure( + aatype=processed_feature_dict["aatype"][0][protein_mask], + residue_index=processed_feature_dict["in_chain_residue_index"][0], + chain_index=processed_feature_dict["chain_index"][0], + plddt=out["plddt"][0][protein_mask], + final_atom_protein_positions=out["final_atom_positions"][0][protein_mask], + final_atom_mask=out["final_atom_mask"][0][protein_mask], + ligand_atype=processed_feature_dict["ligand_atype"][0].astype(int), + ligand_chiralities=processed_feature_dict["ligand_chirality"][0].astype(int), + ligand_charges= processed_feature_dict["ligand_charge"][0].astype(int), + ligand_bonds=processed_feature_dict["ligand_bonds"][0].astype(int), + ligand_idx=processed_feature_dict["ligand_idx"][0].astype(int), + ligand_bonds_idx=processed_feature_dict["ligand_bonds_idx"][0].astype(int), + final_ligand_atom_positions=out["final_atom_positions"][0][ligand_mask][:, 1, :], # only ca index + protein_output_path=protein_output_path, + ligand_output_path=ligand_output_path, + ) + + logger.info(f"Output written to {protein_output_path}...") + + if not skip_relaxation: + # Relax the prediction. + logger.info(f"Running relaxation on {protein_output_path}...") + from dockformer.utils.relax import relax_complex + try: + relax_complex(protein_output_path, + ligand_output_path, + os.path.join(output_directory, f'{output_name}_protein_relaxed.pdb'), + os.path.join(output_directory, f'{output_name}_ligand_relaxed.sdf')) + except Exception as e: + logger.error(f"Failed to relax {protein_output_path} due to {e}...") + + if save_outputs: + output_dict_path = os.path.join( + output_directory, f'{output_name}_output_dict.pkl' + ) + with open(output_dict_path, "wb") as fp: + pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL) + + logger.info(f"Model output written to {output_dict_path}...") + + +if __name__ == "__main__": + config_path = sys.argv[1] if len(sys.argv) > 1 else os.path.join(os.path.dirname(__file__), "run_config.json") + input_dir, output_dir = TEST_INPUT_DIR, TEST_OUTPUT_DIR + options = {"skip_relaxation": True, "long_sequence_inference": False} + if len(sys.argv) > 3: + input_dir = sys.argv[2] + output_dir = sys.argv[3] + if "--relax" in sys.argv: + options["skip_relaxation"] = False + if "--long" in sys.argv: + options["long_sequence_inference"] = True + if "--allow-skip" in sys.argv: + options["skip_exists"] = True + + run_on_folder(input_dir, output_dir, config_path, **options) diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..919a04743336c353bd9a1a23c9ef1a56cd70357a --- /dev/null +++ b/train.py @@ -0,0 +1,494 @@ +import json +import sys +from typing import Optional + +# This import must be on top to set the environment variables before importing other modules +import env_consts +import time +import os + +from lightning.pytorch import seed_everything +import lightning.pytorch as pl +import torch +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.profilers import AdvancedProfiler + +from dockformer.config import model_config +from dockformer.data.data_modules import OpenFoldDataModule, DockFormerDataModule +from dockformer.model.model import AlphaFold +from dockformer.utils import residue_constants +from dockformer.utils.exponential_moving_average import ExponentialMovingAverage +from dockformer.utils.loss import AlphaFoldLoss, lddt_ca +from dockformer.utils.lr_schedulers import AlphaFoldLRScheduler +from dockformer.utils.script_utils import get_latest_checkpoint +from dockformer.utils.superimposition import superimpose +from dockformer.utils.tensor_utils import tensor_tree_map +from dockformer.utils.validation_metrics import ( + drmsd, + gdt_ts, + gdt_ha, + rmsd, +) + + +class ModelWrapper(pl.LightningModule): + def __init__(self, config): + super(ModelWrapper, self).__init__() + self.config = config + self.model = AlphaFold(config) + + self.loss = AlphaFoldLoss(config.loss) + + self.ema = ExponentialMovingAverage( + model=self.model, decay=config.ema.decay + ) + + self.cached_weights = None + self.last_lr_step = -1 + + self.aggregated_metrics = {} + self.log_agg_every_n_steps = 50 # match Trainer(log_every_n_steps=50) + + def forward(self, batch): + return self.model(batch) + + def _log(self, loss_breakdown, batch, outputs, train=True): + phase = "train" if train else "val" + for loss_name, indiv_loss in loss_breakdown.items(): + # print("logging loss", loss_name, indiv_loss, flush=True) + self.log( + f"{phase}/{loss_name}", + indiv_loss, + on_step=train, on_epoch=(not train), logger=True, sync_dist=True + ) + + if train: + agg_name = f"{phase}/{loss_name}_agg" + if agg_name not in self.aggregated_metrics: + self.aggregated_metrics[agg_name] = [] + self.aggregated_metrics[agg_name].append(float(indiv_loss)) + self.log( + f"{phase}/{loss_name}_epoch", + indiv_loss, + on_step=False, on_epoch=True, logger=True, sync_dist=True + ) + + # print("logging validation metrics", flush=True) + with torch.no_grad(): + other_metrics = self._compute_validation_metrics( + batch, + outputs, + superimposition_metrics=(not train) + ) + + for k, v in other_metrics.items(): + # print("logging metric", k, v, flush=True) + if train: + agg_name = f"{phase}/{k}_agg" + if agg_name not in self.aggregated_metrics: + self.aggregated_metrics[agg_name] = [] + self.aggregated_metrics[agg_name].append(float(torch.mean(v))) + self.log( + f"{phase}/{k}", + torch.mean(v), + on_step=False, on_epoch=True, logger=True, sync_dist=True + ) + + if train and any([len(v) >= self.log_agg_every_n_steps for v in self.aggregated_metrics.values()]): + for k, v in self.aggregated_metrics.items(): + print("logging agg", k, len(v), sum(v) / len(v), flush=True) + self.log(k, sum(v) / len(v), on_step=True, on_epoch=False, logger=True, sync_dist=True) + self.aggregated_metrics[k] = [] + + def training_step(self, batch, batch_idx): + if self.ema.device != batch["aatype"].device: + self.ema.to(batch["aatype"].device) + + # ground_truth = batch.pop('gt_features', None) + + # Run the model + # print("running model", round(time.time() % 10000, 3), flush=True) + outputs = self(batch) + + # Remove the recycling dimension + batch = tensor_tree_map(lambda t: t[..., -1], batch) + + # print("running loss", round(time.time() % 10000, 3), flush=True) + # Compute loss + loss, loss_breakdown = self.loss( + outputs, batch, _return_breakdown=True + ) + + # Log it + self._log(loss_breakdown, batch, outputs) + # print("loss done", round(time.time() % 10000, 3), flush=True) + + + return loss + + def on_before_zero_grad(self, *args, **kwargs): + self.ema.update(self.model) + + def validation_step(self, batch, batch_idx): + # At the start of validation, load the EMA weights + if self.cached_weights is None: + # model.state_dict() contains references to model weights rather + # than copies. Therefore, we need to clone them before calling + # load_state_dict(). + clone_param = lambda t: t.detach().clone() + self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict()) + self.model.load_state_dict(self.ema.state_dict()["params"]) + + # Run the model + outputs = self(batch) + batch = tensor_tree_map(lambda t: t[..., -1], batch) + + batch["use_clamped_fape"] = 0. + + # Compute loss and other metrics + _, loss_breakdown = self.loss( + outputs, batch, _return_breakdown=True + ) + + self._log(loss_breakdown, batch, outputs, train=False) + + def on_validation_epoch_end(self): + # Restore the model weights to normal + self.model.load_state_dict(self.cached_weights) + self.cached_weights = None + + def _compute_validation_metrics(self, + batch, + outputs, + superimposition_metrics=False + ): + metrics = {} + + all_gt_coords = batch["atom37_gt_positions"] + all_pred_coords = outputs["final_atom_positions"] + all_atom_mask = batch["atom37_atom_exists_in_gt"] + + rough_protein_atom_mask = torch.repeat_interleave(batch["protein_mask"], 37, dim=-1).view(*all_atom_mask.shape) + protein_gt_coords = all_gt_coords * rough_protein_atom_mask[..., None] + protein_pred_coords = all_pred_coords * rough_protein_atom_mask[..., None] + protein_all_atom_mask = all_atom_mask * rough_protein_atom_mask + + rough_ligand_atom_mask = torch.repeat_interleave(batch["ligand_mask"], 37, dim=-1).view(*all_atom_mask.shape) + ligand_gt_coords = all_gt_coords * rough_ligand_atom_mask[..., None] + ligand_pred_coords = all_pred_coords * rough_ligand_atom_mask[..., None] + ligand_all_atom_mask = all_atom_mask * rough_ligand_atom_mask + + # This is super janky for superimposition. Fix later + protein_gt_coords_masked = protein_gt_coords * protein_all_atom_mask[..., None] + protein_pred_coords_masked = protein_pred_coords * protein_all_atom_mask[..., None] + ca_pos = residue_constants.atom_order["CA"] + protein_gt_coords_masked_ca = protein_gt_coords_masked[..., ca_pos, :] + protein_pred_coords_masked_ca = protein_pred_coords_masked[..., ca_pos, :] + protein_atom_mask_ca = protein_all_atom_mask[..., ca_pos] + + ligand_gt_coords_single_atom = ligand_gt_coords[..., ca_pos, :] + ligand_pred_coords_single_atom = ligand_pred_coords[..., ca_pos, :] + ligand_gt_mask_single_atom = ligand_all_atom_mask[..., ca_pos] + + lddt_ca_score = lddt_ca( + protein_pred_coords, + protein_gt_coords, + protein_all_atom_mask, + eps=self.config.globals.eps, + per_residue=False, + ) + + metrics["lddt_ca"] = lddt_ca_score + + drmsd_ca_score = drmsd( + protein_pred_coords_masked_ca, + protein_gt_coords_masked_ca, + mask=protein_atom_mask_ca, # still required here to compute n + ) + + metrics["drmsd_ca"] = drmsd_ca_score + + drmsd_intra_ligand_score = drmsd( + ligand_pred_coords_single_atom, + ligand_gt_coords_single_atom, + mask=ligand_gt_mask_single_atom, + ) + + metrics["drmsd_intra_ligand"] = drmsd_intra_ligand_score + + # --- inter contacts + gt_contacts = batch["gt_inter_contacts"] + pred_contacts = torch.sigmoid(outputs["inter_contact_logits"].clone().detach()).squeeze(-1) + pred_contacts = (pred_contacts > 0.5).float() + pred_contacts = pred_contacts * batch["inter_pair_mask"] + + + # Calculate True Positives, False Positives, and False Negatives + tp = torch.sum((gt_contacts == 1) & (pred_contacts == 1)) + fp = torch.sum((gt_contacts == 0) & (pred_contacts == 1)) + fn = torch.sum((gt_contacts == 1) & (pred_contacts == 0)) + + # Calculate Recall and Precision + recall = tp / (tp + fn) if (tp + fn) > 0 else tp.float() + precision = tp / (tp + fp) if (tp + fp) > 0 else tp.float() + + metrics["inter_contacts_recall"] = recall.clone().detach() + metrics["inter_contacts_precision"] = precision.clone().detach() + + # print("inter_contacts recall", recall, "precision", precision, tp, fp, fn, torch.ones_like(gt_contacts).sum()) + + # --- Affinity + if True or batch["affinity_loss_factor"].sum() > 0.1: + # print("affinity loss factor", batch["affinity_loss_factor"].sum()) + gt_affinity = batch["affinity"].squeeze(-1) + affinity_linspace = torch.linspace(0, 15, 32, device=batch["affinity"].device) + pred_affinity_1d = torch.sum( + torch.softmax(outputs["affinity_1d_logits"].clone().detach(), -1) * affinity_linspace, dim=-1) + + pred_affinity_2d = torch.sum( + torch.softmax(outputs["affinity_2d_logits"].clone().detach(), -1) * affinity_linspace, dim=-1) + + pred_affinity_cls = torch.sum( + torch.softmax(outputs["affinity_cls_logits"].clone().detach(), -1) * affinity_linspace, dim=-1) + + aff_loss_factor = batch["affinity_loss_factor"].squeeze() + + metrics["affinity_dist_1d"] = (torch.abs(gt_affinity - pred_affinity_1d) * aff_loss_factor).sum() / aff_loss_factor.sum() + metrics["affinity_dist_2d"] = (torch.abs(gt_affinity - pred_affinity_2d) * aff_loss_factor).sum() / aff_loss_factor.sum() + metrics["affinity_dist_cls"] = (torch.abs(gt_affinity - pred_affinity_cls) * aff_loss_factor).sum() / aff_loss_factor.sum() + metrics["affinity_dist_avg"] = (torch.abs(gt_affinity - (pred_affinity_cls + pred_affinity_1d + pred_affinity_2d) / 3) * aff_loss_factor).sum() / aff_loss_factor.sum() + # print("affinity metrics", gt_affinity, pred_affinity_2d, aff_loss_factor, metrics["affinity_dist_1d"], + # metrics["affinity_dist_2d"], metrics["affinity_dist_cls"], metrics["affinity_dist_avg"]) + else: + # print("skipping affinity metrics") + pass + if superimposition_metrics: + superimposed_pred, alignment_rmsd, rots, transs = superimpose( + protein_gt_coords_masked_ca, protein_pred_coords_masked_ca, protein_atom_mask_ca, + ) + gdt_ts_score = gdt_ts( + superimposed_pred, protein_gt_coords_masked_ca, protein_atom_mask_ca + ) + gdt_ha_score = gdt_ha( + superimposed_pred, protein_gt_coords_masked_ca, protein_atom_mask_ca + ) + + metrics["protein_alignment_rmsd"] = alignment_rmsd + metrics["gdt_ts"] = gdt_ts_score + metrics["gdt_ha"] = gdt_ha_score + + superimposed_ligand_coords = ligand_pred_coords_single_atom @ rots + transs[:, None, :] + ligand_alignment_rmsds = rmsd(ligand_gt_coords_single_atom, superimposed_ligand_coords, + mask=ligand_gt_mask_single_atom) + metrics["ligand_alignment_rmsd"] = ligand_alignment_rmsds.mean() + metrics["ligand_alignment_rmsd_under_2"] = torch.mean((ligand_alignment_rmsds < 2).float()) + metrics["ligand_alignment_rmsd_under_5"] = torch.mean((ligand_alignment_rmsds < 5).float()) + + print("ligand rmsd:", ligand_alignment_rmsds) + + return metrics + + def configure_optimizers(self, + learning_rate: Optional[float] = None, + eps: float = 1e-5, + ) -> torch.optim.Adam: + if learning_rate is None: + learning_rate = self.config.globals.max_lr + + optimizer = torch.optim.Adam( + self.model.parameters(), + lr=learning_rate, + eps=eps + ) + + if self.last_lr_step != -1: + for group in optimizer.param_groups: + if 'initial_lr' not in group: + group['initial_lr'] = learning_rate + + lr_scheduler = AlphaFoldLRScheduler( + optimizer, + last_epoch=self.last_lr_step, + max_lr=self.config.globals.max_lr, + start_decay_after_n_steps=10000, + decay_every_n_steps=10000, + ) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": lr_scheduler, + "interval": "step", + "name": "AlphaFoldLRScheduler", + } + } + + def on_load_checkpoint(self, checkpoint): + ema = checkpoint["ema"] + self.ema.load_state_dict(ema) + + def on_save_checkpoint(self, checkpoint): + checkpoint["ema"] = self.ema.state_dict() + + def resume_last_lr_step(self, lr_step): + self.last_lr_step = lr_step + + +def override_config(base_config, overriding_config): + for k, v in overriding_config.items(): + if isinstance(v, dict): + base_config[k] = override_config(base_config[k], v) + else: + base_config[k] = v + return base_config + + +def train(override_config_path: str): + run_config = json.load(open(override_config_path, "r")) + seed = 42 + seed_everything(seed, workers=True) + output_dir = run_config["train_output_dir"] + os.makedirs(output_dir, exist_ok=True) + + print("Starting train", time.time()) + config = model_config( + run_config.get("stage", "initial_training"), + train=True, + low_prec=True + ) + config = override_config(config, run_config.get("override_conf", {})) + accumulate_grad_batches = run_config.get("accumulate_grad_batches", 1) + print("config loaded", time.time()) + + # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device_name = "cuda" if torch.cuda.is_available() else "cpu" + # device_name = "mps" if device_name == "cpu" and torch.backends.mps.is_available() else device_name + model_module = ModelWrapper(config) + print("model loaded", time.time()) + + # device_name = "cpu" + + # for debugging memory: + # torch.cuda.memory._record_memory_history() + + if "train_input_dir" in run_config: + data_module = OpenFoldDataModule( + config=config.data, + batch_seed=seed, + train_data_dir=run_config["train_input_dir"], + val_data_dir=run_config["val_input_dir"], + train_epoch_len=run_config.get("train_epoch_len", 1000), + ) + else: + data_module = DockFormerDataModule( + config=config.data, + batch_seed=seed, + train_data_file=run_config["train_input_file"], + val_data_file=run_config["val_input_file"], + ) + print("data module loaded", time.time()) + + checkpoint_dir = os.path.join(output_dir, "checkpoint") + ckpt_path = run_config.get("ckpt_path", get_latest_checkpoint(checkpoint_dir)) + + if ckpt_path: + print(f"Resuming from checkpoint: {ckpt_path}") + sd = torch.load(ckpt_path) + last_global_step = int(sd['global_step']) + model_module.resume_last_lr_step(last_global_step) + + # Do we need this? + data_module.prepare_data() + data_module.setup("fit") + + callbacks = [] + + mc = ModelCheckpoint( + dirpath=checkpoint_dir, + # every_n_epochs=1, + every_n_train_steps=250, + auto_insert_metric_name=False, + save_top_k=1, + save_on_train_epoch_end=True, # before validation + ) + + mc2 = ModelCheckpoint( + dirpath=checkpoint_dir, # Directory to save checkpoints + filename="step{step}_lig_rmsd{val/ligand_alignment_rmsd:.2f}", # Filename format for best + monitor="val/ligand_alignment_rmsd", # Metric to monitor + mode="min", # We want the lowest `ligand_rmsd` + save_top_k=1, # Save only the best model based on `ligand_rmsd` + every_n_epochs=1, # Save a checkpoint every epoch + auto_insert_metric_name=False + ) + callbacks.append(mc) + callbacks.append(mc2) + + lr_monitor = LearningRateMonitor(logging_interval="step") + callbacks.append(lr_monitor) + + loggers = [] + + wandb_project_name = "EvoDocker3" + wandb_run_id_path = os.path.join(output_dir, "wandb_run_id.txt") + + # Initialize WandbLogger and save run_id + local_rank = int(os.getenv('LOCAL_RANK', os.getenv("SLURM_PROCID", '0'))) + global_rank = int(os.getenv('GLOBAL_RANK', os.getenv("SLURM_LOCALID", '0'))) + print("ranks", os.getenv('LOCAL_RANK', 'd0'), os.getenv('local_rank', 'd0'), os.getenv('GLOBAL_RANK', 'd0'), + os.getenv('global_rank', 'd0'), os.getenv("SLURM_PROCID", 'd0'), os.getenv('SLURM_LOCALID', 'd0'), flush=True) + if local_rank == 0 and global_rank == 0 and not os.path.exists(wandb_run_id_path): + wandb_logger = WandbLogger(project=wandb_project_name, save_dir=output_dir) + with open(wandb_run_id_path, 'w') as f: + f.write(wandb_logger.experiment.id) + wandb_logger.experiment.config.update(run_config, allow_val_change=True) + else: + # Necessary for multi-node training https://github.com/rstrudel/segmenter/issues/22 + while not os.path.exists(wandb_run_id_path): + print(f"Waiting for run_id file to be created ({local_rank})", flush=True) + time.sleep(1) + with open(wandb_run_id_path, 'r') as f: + run_id = f.read().strip() + wandb_logger = WandbLogger(project=wandb_project_name, save_dir=output_dir, resume='must', id=run_id) + loggers.append(wandb_logger) + + strategy_params = {"strategy": "auto"} + if run_config.get("multi_node", False): + strategy_params["strategy"] = "ddp" + # strategy_params["strategy"] = "ddp_find_unused_parameters_true" # this causes issues with checkpointing... + strategy_params["num_nodes"] = run_config["multi_node"]["num_nodes"] + strategy_params["devices"] = run_config["multi_node"]["devices"] + + trainer = pl.Trainer( + accelerator=device_name, + default_root_dir=output_dir, + **strategy_params, + reload_dataloaders_every_n_epochs=1, + accumulate_grad_batches=accumulate_grad_batches, + check_val_every_n_epoch=run_config.get("check_val_every_n_epoch", 10), + callbacks=callbacks, + logger=loggers, + # profiler=AdvancedProfiler(), + ) + + print("Starting fit", time.time()) + trainer.fit( + model_module, + datamodule=data_module, + ckpt_path=ckpt_path, + ) + + # profiler_results = trainer.profiler.summary() + # print(profiler_results) + + # torch.cuda.memory._dump_snapshot("my_train_snapshot.pickle") + # view on https://pytorch.org/memory_viz + + +if __name__ == "__main__": + if len(sys.argv) > 1: + train(sys.argv[1]) + else: + train(os.path.join(os.path.dirname(__file__), "run_config.json")) +