Spaces:
Sleeping
Sleeping
""" | |
https://github.com/ProteinDesignLab/protpardelle | |
License: MIT | |
Author: Alex Chu | |
Dataloader from PDB files. | |
""" | |
import copy | |
import pickle | |
import json | |
import numpy as np | |
import torch | |
from torch.utils import data | |
from core import utils | |
from core import protein | |
from core import residue_constants | |
FEATURES_1D = ( | |
"coords_in", | |
"torsions_in", | |
"b_factors", | |
"atom_positions", | |
"aatype", | |
"atom_mask", | |
"residue_index", | |
"chain_index", | |
) | |
FEATURES_FLOAT = ( | |
"coords_in", | |
"torsions_in", | |
"b_factors", | |
"atom_positions", | |
"atom_mask", | |
"seq_mask", | |
) | |
FEATURES_LONG = ("aatype", "residue_index", "chain_index", "orig_size") | |
def make_fixed_size_1d(data, fixed_size=128): | |
data_len = data.shape[0] | |
if data_len >= fixed_size: | |
extra_len = data_len - fixed_size | |
start_idx = np.random.choice(np.arange(extra_len + 1)) | |
new_data = data[start_idx : (start_idx + fixed_size)] | |
mask = torch.ones(fixed_size) | |
if data_len < fixed_size: | |
pad_size = fixed_size - data_len | |
extra_shape = data.shape[1:] | |
new_data = torch.cat([data, torch.zeros(pad_size, *extra_shape)], 0) | |
mask = torch.cat([torch.ones(data_len), torch.zeros(pad_size)], 0) | |
return new_data, mask | |
def apply_random_se3(coords_in, atom_mask=None, translation_scale=1.0): | |
# unbatched. center on the mean of CA coords | |
coords_mean = coords_in[:, 1:2].mean(-3, keepdim=True) | |
coords_in -= coords_mean | |
random_rot, _ = torch.linalg.qr(torch.randn(3, 3)) | |
coords_in = coords_in @ random_rot | |
random_trans = torch.randn_like(coords_mean) * translation_scale | |
coords_in += random_trans | |
if atom_mask is not None: | |
coords_in = coords_in * atom_mask[..., None] | |
return coords_in | |
def get_masked_coords_array(coords, atom_mask): | |
ma_mask = repeat(1 - atom_mask[..., None].cpu().numpy(), "... 1 -> ... 3") | |
return np.ma.array(coords.cpu().numpy(), mask=ma_mask) | |
def make_crop_cond_mask_and_recenter_coords( | |
atom_mask, | |
atom_coords, | |
contiguous_prob=0.05, | |
discontiguous_prob=0.9, | |
sidechain_only_prob=0.8, | |
max_span_len=10, | |
max_discontiguous_res=8, | |
dist_threshold=8.0, | |
recenter_coords=True, | |
): | |
b, n, a = atom_mask.shape | |
device = atom_mask.device | |
seq_mask = atom_mask[..., 1] | |
n_res = seq_mask.sum(-1) | |
masks = [] | |
for i, nr in enumerate(n_res): | |
nr = nr.int().item() | |
mask = torch.zeros((n, a), device=device) | |
conditioning_type = torch.distributions.Categorical( | |
torch.tensor( | |
[ | |
contiguous_prob, | |
discontiguous_prob, | |
1.0 - contiguous_prob - discontiguous_prob, | |
] | |
) | |
).sample() | |
conditioning_type = ["contiguous", "discontiguous", "none"][conditioning_type] | |
if conditioning_type == "contiguous": | |
span_len = torch.randint( | |
1, min(max_span_len, nr), (1,), device=device | |
).item() | |
span_start = torch.randint(0, nr - span_len, (1,), device=device) | |
mask[span_start : span_start + span_len, :] = 1 | |
elif conditioning_type == "discontiguous": | |
# Extract CB atoms coordinates for the i-th example | |
cb_atoms = atom_coords[i, :, 3] | |
# Pairwise distances between CB atoms | |
cb_distances = torch.cdist(cb_atoms, cb_atoms) | |
close_mask = ( | |
cb_distances <= dist_threshold | |
) # Mask for selecting close CB atoms | |
random_residue = torch.randint(0, nr, (1,), device=device).squeeze() | |
cb_dist_i = cb_distances[random_residue] + 1e3 * (1 - seq_mask[i]) | |
close_mask = cb_dist_i <= dist_threshold | |
n_neighbors = close_mask.sum().int() | |
# pick how many neighbors (up to 10) | |
n_sele = torch.randint( | |
2, | |
n_neighbors.clamp(min=3, max=max_discontiguous_res + 1), | |
(1,), | |
device=device, | |
) | |
# Select the indices of CB atoms that are close together | |
idxs = torch.arange(n, device=device)[close_mask.bool()] | |
idxs = idxs[torch.randperm(len(idxs))[:n_sele]] | |
if len(idxs) > 0: | |
mask[idxs] = 1 | |
if np.random.uniform() < sidechain_only_prob: | |
mask[:, :5] = 0 | |
masks.append(mask) | |
crop_cond_mask = torch.stack(masks) | |
crop_cond_mask = crop_cond_mask * atom_mask | |
if recenter_coords: | |
motif_masked_array = get_masked_coords_array(atom_coords, crop_cond_mask) | |
cond_coords_center = motif_masked_array.mean((1, 2)) | |
motif_mask = torch.Tensor(1 - cond_coords_center.mask).to(crop_cond_mask) | |
means = torch.Tensor(cond_coords_center.data).to(atom_coords) * motif_mask | |
coords_out = atom_coords - rearrange(means, "b c -> b 1 1 c") | |
else: | |
coords_out = atom_coords | |
return coords_out, crop_cond_mask | |
class Dataset(data.Dataset): | |
"""Loads and processes PDBs into tensors.""" | |
def __init__( | |
self, | |
pdb_path, | |
fixed_size, | |
mode="train", | |
overfit=-1, | |
short_epoch=False, | |
se3_data_augment=True, | |
): | |
self.pdb_path = pdb_path | |
self.fixed_size = fixed_size | |
self.mode = mode | |
self.overfit = overfit | |
self.short_epoch = short_epoch | |
self.se3_data_augment = se3_data_augment | |
with open(f"{self.pdb_path}/{mode}_pdb_keys.list") as f: | |
self.pdb_keys = np.array(f.read().split("\n")[:-1]) | |
if overfit > 0: | |
n_data = len(self.pdb_keys) | |
self.pdb_keys = np.random.choice( | |
self.pdb_keys, min(n_data, overfit), replace=False | |
).repeat(n_data // overfit) | |
def __len__(self): | |
if self.short_epoch: | |
return min(len(self.pdb_keys), 256) | |
else: | |
return len(self.pdb_keys) | |
def __getitem__(self, idx): | |
pdb_key = self.pdb_keys[idx] | |
data = self.get_item(pdb_key) | |
# For now, replace dataloading errors with a random pdb. 10 tries | |
for _ in range(10): | |
if data is not None: | |
return data | |
pdb_key = self.pdb_keys[np.random.randint(len(self.pdb_keys))] | |
data = self.get_item(pdb_key) | |
raise Exception("Failed to load data example after 10 tries.") | |
def get_item(self, pdb_key): | |
example = {} | |
if self.pdb_path.endswith("cath_s40_dataset"): # CATH pdbs | |
data_file = f"{self.pdb_path}/dompdb/{pdb_key}" | |
elif self.pdb_path.endswith("ingraham_cath_dataset"): # ingraham splits | |
data_file = f"{self.pdb_path}/pdb_store/{pdb_key}" | |
else: | |
raise Exception("Invalid pdb path.") | |
try: | |
example = utils.load_feats_from_pdb(data_file) | |
coords_in = example["atom_positions"] | |
except FileNotFoundError: | |
raise Exception(f"File {pdb_key} not found. Check if dataset is corrupted?") | |
except RuntimeError: | |
return None | |
# Apply data augmentation | |
if self.se3_data_augment: | |
coords_in = apply_random_se3(coords_in, atom_mask=example["atom_mask"]) | |
orig_size = coords_in.shape[0] | |
example["coords_in"] = coords_in | |
example["orig_size"] = torch.ones(1) * orig_size | |
fixed_size_example = {} | |
seq_mask = None | |
for k, v in example.items(): | |
if k in FEATURES_1D: | |
fixed_size_example[k], seq_mask = make_fixed_size_1d( | |
v, fixed_size=self.fixed_size | |
) | |
else: | |
fixed_size_example[k] = v | |
if seq_mask is not None: | |
fixed_size_example["seq_mask"] = seq_mask | |
example_out = {} | |
for k, v in fixed_size_example.items(): | |
if k in FEATURES_FLOAT: | |
example_out[k] = v.float() | |
elif k in FEATURES_LONG: | |
example_out[k] = v.long() | |
return example_out | |
def collate(self, example_list): | |
out = {} | |
for ex in example_list: | |
for k, v in ex.items(): | |
out.setdefault(k, []).append(v) | |
return {k: torch.stack(v) for k, v in out.items()} | |
def sample(self, n=1, return_data=True, return_keys=False): | |
keys = self.pdb_keys[torch.randperm(self.__len__())[:n].long()] | |
if return_keys and not return_data: | |
return keys | |
if n == 1: | |
data = self.collate([self.get_item(keys)]) | |
else: | |
data = self.collate([self.get_item(key) for key in keys]) | |
if return_data and return_keys: | |
return data, keys | |
if return_data and not return_keys: | |
return data | |