|
from transformers import PreTrainedModel |
|
|
|
|
|
import torch |
|
|
|
import os |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn import Parameter, Sequential, ModuleList, Linear |
|
|
|
from rdkit import Chem |
|
from rdkit.Chem import AllChem |
|
|
|
from transformers import PretrainedConfig |
|
from transformers import PreTrainedModel |
|
from transformers import AutoModel |
|
|
|
from torch_geometric.data import Data |
|
from torch_geometric.loader import DataLoader |
|
from torch_geometric.utils import remove_self_loops, add_self_loops, sort_edge_index |
|
from torch_scatter import scatter |
|
from torch_geometric.nn import global_add_pool, radius |
|
from torch_sparse import SparseTensor |
|
|
|
from transmxm_model.configuration_transmxm import TransmxmConfig |
|
|
|
from tqdm import tqdm |
|
import numpy as np |
|
import pandas as pd |
|
from typing import List |
|
import math |
|
import inspect |
|
from operator import itemgetter |
|
from collections import OrderedDict |
|
from math import sqrt, pi as PI |
|
from scipy.optimize import brentq |
|
from scipy import special as sp |
|
|
|
try: |
|
import sympy as sym |
|
except ImportError: |
|
sym = None |
|
|
|
|
|
|
|
class SmilesDataset(torch.utils.data.Dataset): |
|
def __init__(self, smiles): |
|
self.smiles_list = smiles |
|
self.data_list = [] |
|
|
|
|
|
def __len__(self): |
|
return len(self.data_list) |
|
|
|
def __getitem__(self, idx): |
|
return self.data_list[idx] |
|
|
|
def get_data(self, smiles): |
|
self.smiles_list = smiles |
|
|
|
|
|
types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'S': 4} |
|
|
|
for i in range(len(self.smiles_list)): |
|
|
|
|
|
mol = Chem.MolFromSmiles(self.smiles_list[i]) |
|
if mol is None: |
|
print("无法创建Mol对象", self.smiles_list[i]) |
|
else: |
|
|
|
mol3d = Chem.AddHs( |
|
mol) |
|
if mol3d is None: |
|
print("无法创建mol3d对象", self.smiles_list[i]) |
|
else: |
|
AllChem.EmbedMolecule(mol3d, randomSeed=1) |
|
|
|
N = mol3d.GetNumAtoms() |
|
|
|
if mol3d.GetNumConformers() > 0: |
|
conformer = mol3d.GetConformer() |
|
pos = conformer.GetPositions() |
|
pos = torch.tensor(pos, dtype=torch.float) |
|
|
|
type_idx = [] |
|
|
|
|
|
|
|
|
|
|
|
for atom in mol3d.GetAtoms(): |
|
type_idx.append(types[atom.GetSymbol()]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
row, col, edge_type = [], [], [] |
|
for bond in mol3d.GetBonds(): |
|
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() |
|
row += [start, end] |
|
col += [end, start] |
|
|
|
|
|
edge_index = torch.tensor([row, col], dtype=torch.long) |
|
|
|
|
|
|
|
perm = (edge_index[0] * N + edge_index[1]).argsort() |
|
edge_index = edge_index[:, perm] |
|
|
|
|
|
|
|
|
|
|
|
|
|
x = torch.tensor(type_idx).to(torch.float) |
|
|
|
|
|
|
|
data = Data(x=x, pos=pos, edge_index=edge_index, smiles=self.smiles_list[i]) |
|
|
|
self.data_list.append(data) |
|
else: |
|
print("无法创建comfor", self.smiles_list[i]) |
|
return self.data_list |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import logging |
|
from typing import List, Optional, Tuple |
|
|
|
import numpy as np |
|
from torch.nn import LayerNorm |
|
import copy |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn, Tensor |
|
|
|
|
|
class PositionEmbeddingSine(nn.Module): |
|
""" |
|
This is a more standard version of the position embedding, very similar to the one |
|
used by the Attention is all you need paper, generalized to work on images. (To 1D sequences) |
|
""" |
|
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): |
|
super().__init__() |
|
self.num_pos_feats = num_pos_feats |
|
self.temperature = temperature |
|
self.normalize = normalize |
|
if scale is not None and normalize is False: |
|
raise ValueError("normalize should be True if scale is passed") |
|
if scale is None: |
|
scale = 2 * math.pi |
|
self.scale = scale |
|
|
|
def forward(self, x, mask): |
|
""" |
|
Args: |
|
x: torch.tensor, (batch_size, L, d) |
|
mask: torch.tensor, (batch_size, L), with 1 as valid |
|
|
|
Returns: |
|
|
|
""" |
|
assert mask is not None |
|
x_embed = mask.cumsum(1, dtype=torch.float32) |
|
if self.normalize: |
|
eps = 1e-6 |
|
x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale |
|
|
|
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) |
|
|
|
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='trunc') / self.num_pos_feats) |
|
pos_x = x_embed[:, :, None] / dim_t |
|
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) |
|
|
|
return pos_x |
|
|
|
def build_position_encoding(x): |
|
N_steps = x |
|
pos_embed = PositionEmbeddingSine(N_steps, normalize=True) |
|
|
|
return pos_embed |
|
|
|
|
|
class Transformer(nn.Module): |
|
|
|
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, |
|
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, |
|
activation="relu", normalize_before=False): |
|
super().__init__() |
|
|
|
|
|
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, |
|
dropout, activation, normalize_before) |
|
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None |
|
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) |
|
|
|
self._reset_parameters() |
|
|
|
self.d_model = d_model |
|
self.nhead = nhead |
|
|
|
def _reset_parameters(self): |
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|
|
def forward(self, src, mask, att_mask, pos_embed): |
|
""" |
|
Args: |
|
src: (batch_size, L, d) |
|
mask: (batch_size, L) |
|
query_embed: (#queries, d) |
|
pos_embed: (batch_size, L, d) the same as src |
|
|
|
Returns: |
|
|
|
""" |
|
src = src.permute(1, 0, 2) |
|
pos_embed = pos_embed.permute(1, 0, 2) |
|
|
|
memory = self.encoder( |
|
src, |
|
mask=att_mask, |
|
src_key_padding_mask=mask, |
|
pos=pos_embed |
|
) |
|
|
|
memory = memory.transpose(0, 1) |
|
return memory |
|
|
|
|
|
class TransformerEncoder(nn.Module): |
|
|
|
def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False): |
|
super().__init__() |
|
self.layers = _get_clones(encoder_layer, num_layers) |
|
self.num_layers = num_layers |
|
self.norm = norm |
|
self.return_intermediate = return_intermediate |
|
|
|
def forward(self, src, |
|
mask: Optional[Tensor] = None, |
|
src_key_padding_mask: Optional[Tensor] = None, |
|
pos: Optional[Tensor] = None): |
|
output = src |
|
|
|
intermediate = [] |
|
|
|
for layer in self.layers: |
|
output = layer(output, src_mask=mask, |
|
src_key_padding_mask=src_key_padding_mask, pos=pos) |
|
if self.return_intermediate: |
|
intermediate.append(output) |
|
|
|
if self.norm is not None: |
|
output = self.norm(output) |
|
|
|
if self.return_intermediate: |
|
return torch.stack(intermediate) |
|
|
|
return output |
|
|
|
|
|
class TransformerEncoderLayer(nn.Module): |
|
|
|
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, |
|
activation="relu", normalize_before=False): |
|
super().__init__() |
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
self.dropout = nn.Dropout(dropout) |
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
|
self.norm1 = nn.LayerNorm(d_model) |
|
self.norm2 = nn.LayerNorm(d_model) |
|
self.dropout1 = nn.Dropout(dropout) |
|
self.dropout2 = nn.Dropout(dropout) |
|
|
|
self.activation = _get_activation_fn(activation) |
|
self.normalize_before = normalize_before |
|
|
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
|
return tensor if pos is None else tensor + pos |
|
|
|
def forward_post(self, |
|
src, |
|
src_mask: Optional[Tensor] = None, |
|
src_key_padding_mask: Optional[Tensor] = None, |
|
pos: Optional[Tensor] = None): |
|
q = k = self.with_pos_embed(src, pos) |
|
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, |
|
key_padding_mask=src_key_padding_mask)[0] |
|
src = src + self.dropout1(src2) |
|
src = self.norm1(src) |
|
src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) |
|
src = src + self.dropout2(src2) |
|
src = self.norm2(src) |
|
return src |
|
|
|
def forward_pre(self, src, |
|
src_mask: Optional[Tensor] = None, |
|
src_key_padding_mask: Optional[Tensor] = None, |
|
pos: Optional[Tensor] = None): |
|
src2 = self.norm1(src) |
|
q = k = self.with_pos_embed(src2, pos) |
|
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, |
|
key_padding_mask=src_key_padding_mask)[0] |
|
src = src + self.dropout1(src2) |
|
src2 = self.norm2(src) |
|
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) |
|
src = src + self.dropout2(src2) |
|
return src |
|
|
|
def forward(self, src, |
|
src_mask: Optional[Tensor] = None, |
|
src_key_padding_mask: Optional[Tensor] = None, |
|
pos: Optional[Tensor] = None): |
|
if self.normalize_before: |
|
return self.forward_pre(src, src_mask, src_key_padding_mask, pos) |
|
return self.forward_post(src, src_mask, src_key_padding_mask, pos) |
|
|
|
|
|
def _get_clones(module, N): |
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) |
|
|
|
|
|
def build_transformer(x): |
|
return Transformer( |
|
d_model=x, |
|
dropout=0.5, |
|
nhead=8, |
|
dim_feedforward=1024, |
|
num_encoder_layers=2, |
|
normalize_before=True, |
|
) |
|
|
|
|
|
def _get_activation_fn(activation): |
|
"""Return an activation function given a string""" |
|
if activation == "relu": |
|
return F.relu |
|
if activation == "gelu": |
|
return F.gelu |
|
if activation == "glu": |
|
return F.glu |
|
raise RuntimeError(F"activation should be relu/gelu, not {activation}.") |
|
|
|
|
|
|
|
class EMA: |
|
def __init__(self, model, decay): |
|
self.decay = decay |
|
self.shadow = {} |
|
self.original = {} |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
if param.requires_grad: |
|
self.shadow[name] = param.data.clone() |
|
|
|
def __call__(self, model, num_updates=99999): |
|
decay = min(self.decay, (1.0 + num_updates) / (10.0 + num_updates)) |
|
for name, param in model.named_parameters(): |
|
if param.requires_grad: |
|
assert name in self.shadow |
|
new_average = \ |
|
(1.0 - decay) * param.data + decay * self.shadow[name] |
|
self.shadow[name] = new_average.clone() |
|
|
|
def assign(self, model): |
|
for name, param in model.named_parameters(): |
|
if param.requires_grad: |
|
assert name in self.shadow |
|
self.original[name] = param.data.clone() |
|
param.data = self.shadow[name] |
|
|
|
def resume(self, model): |
|
for name, param in model.named_parameters(): |
|
if param.requires_grad: |
|
assert name in self.shadow |
|
param.data = self.original[name] |
|
|
|
|
|
def MLP(channels): |
|
return Sequential(*[ |
|
Sequential(Linear(channels[i - 1], channels[i]), SiLU()) |
|
for i in range(1, len(channels))]) |
|
|
|
|
|
class Res(nn.Module): |
|
def __init__(self, dim): |
|
super(Res, self).__init__() |
|
|
|
self.mlp = MLP([dim, dim, dim]) |
|
|
|
def forward(self, m): |
|
m1 = self.mlp(m) |
|
m_out = m1 + m |
|
return m_out |
|
|
|
|
|
def compute_idx(pos, edge_index): |
|
|
|
pos_i = pos[edge_index[0]] |
|
pos_j = pos[edge_index[1]] |
|
|
|
d_ij = torch.norm(abs(pos_j - pos_i), dim=-1, keepdim=False).unsqueeze(-1) + 1e-5 |
|
v_ji = (pos_i - pos_j) / d_ij |
|
|
|
unique, counts = torch.unique(edge_index[0], sorted=True, return_counts=True) |
|
full_index = torch.arange(0, edge_index[0].size()[0]).cuda().int() |
|
|
|
|
|
|
|
repeat = torch.repeat_interleave(counts, counts) |
|
counts_repeat1 = torch.repeat_interleave(full_index, repeat) |
|
|
|
|
|
split = torch.split(full_index, counts.tolist()) |
|
index2 = list(edge_index[0].data.cpu().numpy()) |
|
counts_repeat2 = torch.cat(itemgetter(*index2)(split), dim=0) |
|
|
|
|
|
v1 = v_ji[counts_repeat1.long()] |
|
v2 = v_ji[counts_repeat2.long()] |
|
|
|
angle = (v1*v2).sum(-1).unsqueeze(-1) |
|
angle = torch.clamp(angle, min=-1.0, max=1.0) + 1e-6 + 1.0 |
|
|
|
return counts_repeat1.long(), counts_repeat2.long(), angle |
|
|
|
|
|
def Jn(r, n): |
|
return np.sqrt(np.pi / (2 * r)) * sp.jv(n + 0.5, r) |
|
|
|
|
|
def Jn_zeros(n, k): |
|
zerosj = np.zeros((n, k), dtype='float32') |
|
zerosj[0] = np.arange(1, k + 1) * np.pi |
|
points = np.arange(1, k + n) * np.pi |
|
racines = np.zeros(k + n - 1, dtype='float32') |
|
for i in range(1, n): |
|
for j in range(k + n - 1 - i): |
|
foo = brentq(Jn, points[j], points[j + 1], (i, )) |
|
racines[j] = foo |
|
points = racines |
|
zerosj[i][:k] = racines[:k] |
|
|
|
return zerosj |
|
|
|
|
|
def spherical_bessel_formulas(n): |
|
x = sym.symbols('x') |
|
|
|
f = [sym.sin(x) / x] |
|
a = sym.sin(x) / x |
|
for i in range(1, n): |
|
b = sym.diff(a, x) / x |
|
f += [sym.simplify(b * (-x)**i)] |
|
a = sym.simplify(b) |
|
return f |
|
|
|
|
|
def bessel_basis(n, k): |
|
zeros = Jn_zeros(n, k) |
|
normalizer = [] |
|
for order in range(n): |
|
normalizer_tmp = [] |
|
for i in range(k): |
|
normalizer_tmp += [0.5 * Jn(zeros[order, i], order + 1)**2] |
|
normalizer_tmp = 1 / np.array(normalizer_tmp)**0.5 |
|
normalizer += [normalizer_tmp] |
|
|
|
f = spherical_bessel_formulas(n) |
|
x = sym.symbols('x') |
|
bess_basis = [] |
|
for order in range(n): |
|
bess_basis_tmp = [] |
|
for i in range(k): |
|
bess_basis_tmp += [ |
|
sym.simplify(normalizer[order][i] * |
|
f[order].subs(x, zeros[order, i] * x)) |
|
] |
|
bess_basis += [bess_basis_tmp] |
|
return bess_basis |
|
|
|
|
|
def sph_harm_prefactor(k, m): |
|
return ((2 * k + 1) * np.math.factorial(k - abs(m)) / |
|
(4 * np.pi * np.math.factorial(k + abs(m))))**0.5 |
|
|
|
|
|
def associated_legendre_polynomials(k, zero_m_only=True): |
|
z = sym.symbols('z') |
|
P_l_m = [[0] * (j + 1) for j in range(k)] |
|
|
|
P_l_m[0][0] = 1 |
|
if k > 0: |
|
P_l_m[1][0] = z |
|
|
|
for j in range(2, k): |
|
P_l_m[j][0] = sym.simplify(((2 * j - 1) * z * P_l_m[j - 1][0] - |
|
(j - 1) * P_l_m[j - 2][0]) / j) |
|
if not zero_m_only: |
|
for i in range(1, k): |
|
P_l_m[i][i] = sym.simplify((1 - 2 * i) * P_l_m[i - 1][i - 1]) |
|
if i + 1 < k: |
|
P_l_m[i + 1][i] = sym.simplify( |
|
(2 * i + 1) * z * P_l_m[i][i]) |
|
for j in range(i + 2, k): |
|
P_l_m[j][i] = sym.simplify( |
|
((2 * j - 1) * z * P_l_m[j - 1][i] - |
|
(i + j - 1) * P_l_m[j - 2][i]) / (j - i)) |
|
|
|
return P_l_m |
|
|
|
|
|
def real_sph_harm(k, zero_m_only=True, spherical_coordinates=True): |
|
if not zero_m_only: |
|
S_m = [0] |
|
C_m = [1] |
|
for i in range(1, k): |
|
x = sym.symbols('x') |
|
y = sym.symbols('y') |
|
S_m += [x * S_m[i - 1] + y * C_m[i - 1]] |
|
C_m += [x * C_m[i - 1] - y * S_m[i - 1]] |
|
|
|
P_l_m = associated_legendre_polynomials(k, zero_m_only) |
|
if spherical_coordinates: |
|
theta = sym.symbols('theta') |
|
z = sym.symbols('z') |
|
for i in range(len(P_l_m)): |
|
for j in range(len(P_l_m[i])): |
|
if type(P_l_m[i][j]) != int: |
|
P_l_m[i][j] = P_l_m[i][j].subs(z, sym.cos(theta)) |
|
if not zero_m_only: |
|
phi = sym.symbols('phi') |
|
for i in range(len(S_m)): |
|
S_m[i] = S_m[i].subs(x, |
|
sym.sin(theta) * sym.cos(phi)).subs( |
|
y, |
|
sym.sin(theta) * sym.sin(phi)) |
|
for i in range(len(C_m)): |
|
C_m[i] = C_m[i].subs(x, |
|
sym.sin(theta) * sym.cos(phi)).subs( |
|
y, |
|
sym.sin(theta) * sym.sin(phi)) |
|
|
|
Y_func_l_m = [['0'] * (2 * j + 1) for j in range(k)] |
|
for i in range(k): |
|
Y_func_l_m[i][0] = sym.simplify(sph_harm_prefactor(i, 0) * P_l_m[i][0]) |
|
|
|
if not zero_m_only: |
|
for i in range(1, k): |
|
for j in range(1, i + 1): |
|
Y_func_l_m[i][j] = sym.simplify( |
|
2**0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j]) |
|
for i in range(1, k): |
|
for j in range(1, i + 1): |
|
Y_func_l_m[i][-j] = sym.simplify( |
|
2**0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j]) |
|
|
|
return Y_func_l_m |
|
|
|
|
|
class BesselBasisLayer(torch.nn.Module): |
|
def __init__(self, num_radial, cutoff, envelope_exponent=6): |
|
super(BesselBasisLayer, self).__init__() |
|
self.cutoff = cutoff |
|
self.envelope = Envelope(envelope_exponent) |
|
|
|
self.freq = torch.nn.Parameter(torch.Tensor(num_radial)) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
|
|
|
|
|
|
|
|
|
|
tmp_tensor = torch.arange(1, self.freq.numel() + 1, dtype=self.freq.dtype, device=self.freq.device) |
|
|
|
|
|
self.freq.data = torch.mul(tmp_tensor, PI) |
|
|
|
def forward(self, dist): |
|
dist = dist.unsqueeze(-1) / self.cutoff |
|
return self.envelope(dist) * (self.freq * dist).sin() |
|
|
|
|
|
class SiLU(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, input): |
|
return silu(input) |
|
|
|
|
|
def silu(input): |
|
return input * torch.sigmoid(input) |
|
|
|
|
|
class Envelope(torch.nn.Module): |
|
def __init__(self, exponent): |
|
super(Envelope, self).__init__() |
|
self.p = exponent |
|
self.a = -(self.p + 1) * (self.p + 2) / 2 |
|
self.b = self.p * (self.p + 2) |
|
self.c = -self.p * (self.p + 1) / 2 |
|
|
|
def forward(self, x): |
|
p, a, b, c = self.p, self.a, self.b, self.c |
|
x_pow_p0 = x.pow(p) |
|
x_pow_p1 = x_pow_p0 * x |
|
env_val = 1. / x + a * x_pow_p0 + b * x_pow_p1 + c * x_pow_p1 * x |
|
|
|
zero = torch.zeros_like(x) |
|
return torch.where(x < 1, env_val, zero) |
|
|
|
|
|
class SphericalBasisLayer(torch.nn.Module): |
|
def __init__(self, num_spherical, num_radial, cutoff=5.0, |
|
envelope_exponent=5): |
|
super(SphericalBasisLayer, self).__init__() |
|
assert num_radial <= 64 |
|
self.num_spherical = num_spherical |
|
self.num_radial = num_radial |
|
self.cutoff = cutoff |
|
self.envelope = Envelope(envelope_exponent) |
|
|
|
bessel_forms = bessel_basis(num_spherical, num_radial) |
|
sph_harm_forms = real_sph_harm(num_spherical) |
|
self.sph_funcs = [] |
|
self.bessel_funcs = [] |
|
|
|
x, theta = sym.symbols('x theta') |
|
modules = {'sin': torch.sin, 'cos': torch.cos} |
|
for i in range(num_spherical): |
|
if i == 0: |
|
sph1 = sym.lambdify([theta], sph_harm_forms[i][0], modules)(0) |
|
self.sph_funcs.append(lambda x: torch.zeros_like(x) + sph1) |
|
else: |
|
sph = sym.lambdify([theta], sph_harm_forms[i][0], modules) |
|
self.sph_funcs.append(sph) |
|
for j in range(num_radial): |
|
bessel = sym.lambdify([x], bessel_forms[i][j], modules) |
|
self.bessel_funcs.append(bessel) |
|
|
|
def forward(self, dist, angle, idx_kj): |
|
dist = dist / self.cutoff |
|
rbf = torch.stack([f(dist) for f in self.bessel_funcs], dim=1) |
|
rbf = self.envelope(dist).unsqueeze(-1) * rbf |
|
|
|
cbf = torch.stack([f(angle) for f in self.sph_funcs], dim=1) |
|
|
|
n, k = self.num_spherical, self.num_radial |
|
out = (rbf[idx_kj].view(-1, n, k) * cbf.view(-1, n, 1)).view(-1, n * k) |
|
return out |
|
|
|
|
|
|
|
msg_special_args = set([ |
|
'edge_index', |
|
'edge_index_i', |
|
'edge_index_j', |
|
'size', |
|
'size_i', |
|
'size_j', |
|
]) |
|
|
|
aggr_special_args = set([ |
|
'index', |
|
'dim_size', |
|
]) |
|
|
|
update_special_args = set([]) |
|
|
|
|
|
class MessagePassing(torch.nn.Module): |
|
r"""Base class for creating message passing layers |
|
|
|
.. math:: |
|
\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, |
|
\square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} |
|
\left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{i,j}\right) \right), |
|
|
|
where :math:`\square` denotes a differentiable, permutation invariant |
|
function, *e.g.*, sum, mean or max, and :math:`\gamma_{\mathbf{\Theta}}` |
|
and :math:`\phi_{\mathbf{\Theta}}` denote differentiable functions such as |
|
MLPs. |
|
See `here <https://pytorch-geometric.readthedocs.io/en/latest/notes/ |
|
create_gnn.html>`__ for the accompanying tutorial. |
|
|
|
Args: |
|
aggr (string, optional): The aggregation scheme to use |
|
(:obj:`"add"`, :obj:`"mean"` or :obj:`"max"`). |
|
(default: :obj:`"add"`) |
|
flow (string, optional): The flow direction of message passing |
|
(:obj:`"source_to_target"` or :obj:`"target_to_source"`). |
|
(default: :obj:`"source_to_target"`) |
|
node_dim (int, optional): The axis along which to propagate. |
|
(default: :obj:`0`) |
|
""" |
|
def __init__(self, aggr='add', flow='target_to_source', node_dim=0): |
|
super(MessagePassing, self).__init__() |
|
|
|
self.aggr = aggr |
|
assert self.aggr in ['add', 'mean', 'max'] |
|
|
|
self.flow = flow |
|
assert self.flow in ['source_to_target', 'target_to_source'] |
|
|
|
self.node_dim = node_dim |
|
assert self.node_dim >= 0 |
|
|
|
self.__msg_params__ = inspect.signature(self.message).parameters |
|
self.__msg_params__ = OrderedDict(self.__msg_params__) |
|
|
|
self.__aggr_params__ = inspect.signature(self.aggregate).parameters |
|
self.__aggr_params__ = OrderedDict(self.__aggr_params__) |
|
self.__aggr_params__.popitem(last=False) |
|
|
|
self.__update_params__ = inspect.signature(self.update).parameters |
|
self.__update_params__ = OrderedDict(self.__update_params__) |
|
self.__update_params__.popitem(last=False) |
|
|
|
msg_args = set(self.__msg_params__.keys()) - msg_special_args |
|
aggr_args = set(self.__aggr_params__.keys()) - aggr_special_args |
|
update_args = set(self.__update_params__.keys()) - update_special_args |
|
|
|
self.__args__ = set().union(msg_args, aggr_args, update_args) |
|
|
|
def __set_size__(self, size, index, tensor): |
|
if not torch.is_tensor(tensor): |
|
pass |
|
elif size[index] is None: |
|
size[index] = tensor.size(self.node_dim) |
|
elif size[index] != tensor.size(self.node_dim): |
|
raise ValueError( |
|
(f'Encountered node tensor with size ' |
|
f'{tensor.size(self.node_dim)} in dimension {self.node_dim}, ' |
|
f'but expected size {size[index]}.')) |
|
|
|
def __collect__(self, edge_index, size, kwargs): |
|
i, j = (0, 1) if self.flow == "target_to_source" else (1, 0) |
|
ij = {"_i": i, "_j": j} |
|
|
|
out = {} |
|
for arg in self.__args__: |
|
if arg[-2:] not in ij.keys(): |
|
out[arg] = kwargs.get(arg, inspect.Parameter.empty) |
|
else: |
|
idx = ij[arg[-2:]] |
|
data = kwargs.get(arg[:-2], inspect.Parameter.empty) |
|
|
|
if data is inspect.Parameter.empty: |
|
out[arg] = data |
|
continue |
|
|
|
if isinstance(data, tuple) or isinstance(data, list): |
|
assert len(data) == 2 |
|
self.__set_size__(size, 1 - idx, data[1 - idx]) |
|
data = data[idx] |
|
|
|
if not torch.is_tensor(data): |
|
out[arg] = data |
|
continue |
|
|
|
self.__set_size__(size, idx, data) |
|
out[arg] = data.index_select(self.node_dim, edge_index[idx]) |
|
|
|
size[0] = size[1] if size[0] is None else size[0] |
|
size[1] = size[0] if size[1] is None else size[1] |
|
|
|
|
|
out['edge_index'] = edge_index |
|
out['edge_index_i'] = edge_index[i] |
|
out['edge_index_j'] = edge_index[j] |
|
out['size'] = size |
|
out['size_i'] = size[i] |
|
out['size_j'] = size[j] |
|
|
|
|
|
out['index'] = out['edge_index_i'] |
|
out['dim_size'] = out['size_i'] |
|
|
|
return out |
|
|
|
def __distribute__(self, params, kwargs): |
|
out = {} |
|
for key, param in params.items(): |
|
data = kwargs[key] |
|
if data is inspect.Parameter.empty: |
|
if param.default is inspect.Parameter.empty: |
|
raise TypeError(f'Required parameter {key} is empty.') |
|
data = param.default |
|
out[key] = data |
|
return out |
|
|
|
def propagate(self, edge_index, size=None, **kwargs): |
|
r"""The initial call to start propagating messages. |
|
|
|
Args: |
|
edge_index (Tensor): The indices of a general (sparse) assignment |
|
matrix with shape :obj:`[N, M]` (can be directed or |
|
undirected). |
|
size (list or tuple, optional): The size :obj:`[N, M]` of the |
|
assignment matrix. If set to :obj:`None`, the size will be |
|
automatically inferred and assumed to be quadratic. |
|
(default: :obj:`None`) |
|
**kwargs: Any additional data which is needed to construct and |
|
aggregate messages, and to update node embeddings. |
|
""" |
|
|
|
size = [None, None] if size is None else size |
|
size = [size, size] if isinstance(size, int) else size |
|
size = size.tolist() if torch.is_tensor(size) else size |
|
size = list(size) if isinstance(size, tuple) else size |
|
assert isinstance(size, list) |
|
assert len(size) == 2 |
|
|
|
kwargs = self.__collect__(edge_index, size, kwargs) |
|
|
|
msg_kwargs = self.__distribute__(self.__msg_params__, kwargs) |
|
|
|
m = self.message(**msg_kwargs) |
|
aggr_kwargs = self.__distribute__(self.__aggr_params__, kwargs) |
|
m = self.aggregate(m, **aggr_kwargs) |
|
|
|
update_kwargs = self.__distribute__(self.__update_params__, kwargs) |
|
m = self.update(m, **update_kwargs) |
|
|
|
return m |
|
|
|
def message(self, x_j): |
|
r"""Constructs messages to node :math:`i` in analogy to |
|
:math:`\phi_{\mathbf{\Theta}}` for each edge in |
|
:math:`(j,i) \in \mathcal{E}` if :obj:`flow="source_to_target"` and |
|
:math:`(i,j) \in \mathcal{E}` if :obj:`flow="target_to_source"`. |
|
Can take any argument which was initially passed to :meth:`propagate`. |
|
In addition, tensors passed to :meth:`propagate` can be mapped to the |
|
respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or |
|
:obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`. |
|
""" |
|
|
|
return x_j |
|
|
|
def aggregate(self, inputs, index, dim_size): |
|
r"""Aggregates messages from neighbors as |
|
:math:`\square_{j \in \mathcal{N}(i)}`. |
|
|
|
By default, delegates call to scatter functions that support |
|
"add", "mean" and "max" operations specified in :meth:`__init__` by |
|
the :obj:`aggr` argument. |
|
""" |
|
|
|
return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr) |
|
|
|
def update(self, inputs): |
|
r"""Updates node embeddings in analogy to |
|
:math:`\gamma_{\mathbf{\Theta}}` for each node |
|
:math:`i \in \mathcal{V}`. |
|
Takes in the output of aggregation as first argument and any argument |
|
which was initially passed to :meth:`propagate`. |
|
""" |
|
|
|
return inputs |
|
|
|
class TransMXMNet(nn.Module): |
|
def __init__(self, dim=128, n_layer=6, cutoff=5.0, num_spherical=7, num_radial=6, envelope_exponent=5): |
|
super(TransMXMNet, self).__init__() |
|
|
|
self.dim = dim |
|
self.n_layer = n_layer |
|
self.cutoff = cutoff |
|
|
|
self.embeddings = nn.Parameter(torch.ones((5, self.dim))) |
|
|
|
self.rbf_l = BesselBasisLayer(16, 5, envelope_exponent) |
|
self.rbf_g = BesselBasisLayer(16, self.cutoff, envelope_exponent) |
|
self.sbf = SphericalBasisLayer(num_spherical, num_radial, 5, envelope_exponent) |
|
|
|
self.rbf_g_mlp = MLP([16, self.dim]) |
|
self.rbf_l_mlp = MLP([16, self.dim]) |
|
|
|
self.sbf_1_mlp = MLP([num_spherical * num_radial, self.dim]) |
|
self.sbf_2_mlp = MLP([num_spherical * num_radial, self.dim]) |
|
|
|
self.global_layers = torch.nn.ModuleList() |
|
for layer in range(self.n_layer): |
|
self.global_layers.append(Global_MP(self.dim)) |
|
|
|
self.local_layers = torch.nn.ModuleList() |
|
for layer in range(self.n_layer): |
|
self.local_layers.append(Local_MP(self.dim)) |
|
|
|
self.pos_embed = build_position_encoding(self.dim) |
|
self.transformer = build_transformer(self.dim) |
|
|
|
self.init() |
|
|
|
def init(self): |
|
stdv = math.sqrt(3) |
|
self.embeddings.data.uniform_(-stdv, stdv) |
|
|
|
def indices(self, edge_index, num_nodes): |
|
row, col = edge_index |
|
|
|
value = torch.arange(row.size(0), device=row.device) |
|
adj_t = SparseTensor(row=col, col=row, value=value, |
|
sparse_sizes=(num_nodes, num_nodes)) |
|
|
|
|
|
adj_t_row = adj_t[row] |
|
num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long) |
|
|
|
idx_i = col.repeat_interleave(num_triplets) |
|
idx_j = row.repeat_interleave(num_triplets) |
|
idx_k = adj_t_row.storage.col() |
|
mask = idx_i != idx_k |
|
idx_i_1, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask] |
|
|
|
idx_kj = adj_t_row.storage.value()[mask] |
|
idx_ji_1 = adj_t_row.storage.row()[mask] |
|
|
|
|
|
adj_t_col = adj_t[col] |
|
|
|
num_pairs = adj_t_col.set_value(None).sum(dim=1).to(torch.long) |
|
idx_i_2 = row.repeat_interleave(num_pairs) |
|
idx_j1 = col.repeat_interleave(num_pairs) |
|
idx_j2 = adj_t_col.storage.col() |
|
|
|
idx_ji_2 = adj_t_col.storage.row() |
|
idx_jj = adj_t_col.storage.value() |
|
|
|
return idx_i_1, idx_j, idx_k, idx_kj, idx_ji_1, idx_i_2, idx_j1, idx_j2, idx_jj, idx_ji_2 |
|
|
|
|
|
def forward_features(self, data): |
|
x = data.x |
|
edge_index = data.edge_index |
|
pos = data.pos |
|
batch = data.batch |
|
|
|
h = torch.index_select(self.embeddings, 0, x.long()).unsqueeze(0) |
|
data_len = torch.bincount(batch) |
|
|
|
diff_tensor = torch.diff(data_len) |
|
indices = torch.nonzero(diff_tensor) + 1 |
|
indices[0] = 0 |
|
|
|
att_mask = torch.zeros(len(batch), len(batch)).cuda() |
|
|
|
att_mask[indices[0]:, indices[0]:] = 1 |
|
i = 0 |
|
for i in range(0, h.size(0) - 1): |
|
att_mask[indices[i]:indices[i + 1], indices[i]:indices[i + 1]] = 1 |
|
att_mask[indices[i]:indices[-1], indices[i]:indices[-1]] = 1 |
|
|
|
mask = torch.ones(1, len(batch)).bool().cuda() |
|
|
|
pos_h = self.pos_embed(h, mask).cuda() |
|
memory = self.transformer(h, ~mask, att_mask, pos_h) |
|
h = memory.squeeze(0) |
|
|
|
'''局部层-------------------------------------------------------------------------- |
|
''' |
|
|
|
edge_index_l, _ = remove_self_loops(edge_index) |
|
j_l, i_l = edge_index_l |
|
dist_l = (pos[i_l] - pos[j_l]).pow(2).sum(dim=-1).sqrt() |
|
|
|
'''全局层-------------------------------------------------------------------------- |
|
''' |
|
|
|
|
|
row, col = radius(pos, pos, self.cutoff, batch, batch, max_num_neighbors=500) |
|
edge_index_g = torch.stack([row, col], dim=0) |
|
edge_index_g, _ = remove_self_loops(edge_index_g) |
|
j_g, i_g = edge_index_g |
|
dist_g = (pos[i_g] - pos[j_g]).pow(2).sum(dim=-1).sqrt() |
|
|
|
|
|
idx_i_1, idx_j, idx_k, idx_kj, idx_ji, idx_i_2, idx_j1, idx_j2, idx_jj, idx_ji_2 = self.indices(edge_index_l, num_nodes=h.size(0)) |
|
|
|
|
|
pos_ji_1, pos_kj = pos[idx_j] - pos[idx_i_1], pos[idx_k] - pos[idx_j] |
|
a = (pos_ji_1 * pos_kj).sum(dim=-1) |
|
b = torch.cross(pos_ji_1, pos_kj).norm(dim=-1) |
|
angle_1 = torch.atan2(b, a) |
|
|
|
|
|
pos_ji_2, pos_jj = pos[idx_j1] - pos[idx_i_2], pos[idx_j2] - pos[idx_j1] |
|
a = (pos_ji_2 * pos_jj).sum(dim=-1) |
|
b = torch.cross(pos_ji_2, pos_jj).norm(dim=-1) |
|
angle_2 = torch.atan2(b, a) |
|
|
|
|
|
rbf_g = self.rbf_g(dist_g) |
|
rbf_l = self.rbf_l(dist_l) |
|
sbf_1 = self.sbf(dist_l, angle_1, idx_kj) |
|
sbf_2 = self.sbf(dist_l, angle_2, idx_jj) |
|
|
|
rbf_g = self.rbf_g_mlp(rbf_g) |
|
rbf_l = self.rbf_l_mlp(rbf_l) |
|
sbf_1 = self.sbf_1_mlp(sbf_1) |
|
sbf_2 = self.sbf_2_mlp(sbf_2) |
|
|
|
|
|
node_sum = 0 |
|
|
|
for layer in range(self.n_layer): |
|
h = self.global_layers[layer](h, rbf_g, edge_index_g) |
|
h, t = self.local_layers[layer](h, rbf_l, sbf_1, sbf_2, idx_kj, idx_ji, idx_jj, idx_ji_2, edge_index_l) |
|
node_sum += t |
|
|
|
|
|
output = global_add_pool(node_sum, batch) |
|
return output.view(-1) |
|
|
|
def loss(self, pred, label): |
|
pred, label = pred.reshape(-1), label.reshape(-1) |
|
return F.mse_loss(pred, label) |
|
|
|
|
|
class Global_MP(MessagePassing): |
|
|
|
def __init__(self, dim): |
|
super(Global_MP, self).__init__() |
|
self.dim = dim |
|
|
|
self.h_mlp = MLP([self.dim, self.dim]) |
|
|
|
self.res1 = Res(self.dim) |
|
self.res2 = Res(self.dim) |
|
self.res3 = Res(self.dim) |
|
self.mlp = MLP([self.dim, self.dim]) |
|
|
|
self.x_edge_mlp = MLP([self.dim * 3, self.dim]) |
|
self.linear = nn.Linear(self.dim, self.dim, bias=False) |
|
|
|
def forward(self, h, edge_attr, edge_index): |
|
edge_index, _ = add_self_loops(edge_index, num_nodes=h.size(0)) |
|
|
|
res_h = h |
|
|
|
|
|
h = self.h_mlp(h) |
|
|
|
|
|
h = self.propagate(edge_index, x=h, num_nodes=h.size(0), edge_attr=edge_attr) |
|
|
|
|
|
h = self.res1(h) |
|
h = self.mlp(h) + res_h |
|
h = self.res2(h) |
|
h = self.res3(h) |
|
|
|
|
|
h = self.propagate(edge_index, x=h, num_nodes=h.size(0), edge_attr=edge_attr) |
|
|
|
return h |
|
|
|
def message(self, x_i, x_j, edge_attr, edge_index, num_nodes): |
|
num_edge = edge_attr.size()[0] |
|
|
|
x_edge = torch.cat((x_i[:num_edge], x_j[:num_edge], edge_attr), -1) |
|
x_edge = self.x_edge_mlp(x_edge) |
|
|
|
x_j = torch.cat((self.linear(edge_attr) * x_edge, x_j[num_edge:]), dim=0) |
|
|
|
return x_j |
|
|
|
def update(self, aggr_out): |
|
return aggr_out |
|
|
|
|
|
class Local_MP(torch.nn.Module): |
|
def __init__(self, dim): |
|
super(Local_MP, self).__init__() |
|
self.dim = dim |
|
|
|
self.h_mlp = MLP([self.dim, self.dim]) |
|
|
|
self.mlp_kj = MLP([3 * self.dim, self.dim]) |
|
self.mlp_ji_1 = MLP([3 * self.dim, self.dim]) |
|
self.mlp_ji_2 = MLP([self.dim, self.dim]) |
|
self.mlp_jj = MLP([self.dim, self.dim]) |
|
|
|
self.mlp_sbf1 = MLP([self.dim, self.dim, self.dim]) |
|
self.mlp_sbf2 = MLP([self.dim, self.dim, self.dim]) |
|
self.lin_rbf1 = nn.Linear(self.dim, self.dim, bias=False) |
|
self.lin_rbf2 = nn.Linear(self.dim, self.dim, bias=False) |
|
|
|
self.res1 = Res(self.dim) |
|
self.res2 = Res(self.dim) |
|
self.res3 = Res(self.dim) |
|
|
|
self.lin_rbf_out = nn.Linear(self.dim, self.dim, bias=False) |
|
|
|
self.h_mlp = MLP([self.dim, self.dim]) |
|
|
|
self.y_mlp = MLP([self.dim, self.dim, self.dim, self.dim]) |
|
self.y_W = nn.Linear(self.dim, 1) |
|
|
|
def forward(self, h, rbf, sbf1, sbf2, idx_kj, idx_ji_1, idx_jj, idx_ji_2, edge_index, num_nodes=None): |
|
res_h = h |
|
|
|
|
|
h = self.h_mlp(h) |
|
|
|
|
|
j, i = edge_index |
|
m = torch.cat([h[i], h[j], rbf], dim=-1) |
|
|
|
m_kj = self.mlp_kj(m) |
|
m_kj = m_kj * self.lin_rbf1(rbf) |
|
m_kj = m_kj[idx_kj] * self.mlp_sbf1(sbf1) |
|
m_kj = scatter(m_kj, idx_ji_1, dim=0, dim_size=m.size(0), reduce='add') |
|
|
|
m_ji_1 = self.mlp_ji_1(m) |
|
|
|
m = m_ji_1 + m_kj |
|
|
|
|
|
m_jj = self.mlp_jj(m) |
|
m_jj = m_jj * self.lin_rbf2(rbf) |
|
m_jj = m_jj[idx_jj] * self.mlp_sbf2(sbf2) |
|
m_jj = scatter(m_jj, idx_ji_2, dim=0, dim_size=m.size(0), reduce='add') |
|
|
|
m_ji_2 = self.mlp_ji_2(m) |
|
|
|
m = m_ji_2 + m_jj |
|
|
|
|
|
m = self.lin_rbf_out(rbf) * m |
|
h = scatter(m, i, dim=0, dim_size=h.size(0), reduce='add') |
|
|
|
|
|
h = self.res1(h) |
|
h = self.h_mlp(h) + res_h |
|
h = self.res2(h) |
|
h = self.res3(h) |
|
|
|
|
|
y = self.y_mlp(h) |
|
y = self.y_W(y) |
|
|
|
return h, y |
|
|
|
|
|
class TransmxmConfig(PretrainedConfig): |
|
model_type = "transmxm" |
|
|
|
def __init__( |
|
self, |
|
dim: int=128, |
|
n_layer: int=6, |
|
cutoff: float=5.0, |
|
num_spherical: int=7, |
|
num_radial: int=6, |
|
envelope_exponent: int=5, |
|
|
|
smiles: List[str] = None, |
|
processor_class: str = "SmilesProcessor", |
|
**kwargs, |
|
): |
|
|
|
self.dim = dim |
|
self.n_layer = n_layer |
|
self.cutoff = cutoff |
|
self.num_spherical = num_spherical |
|
self.num_radial = num_radial |
|
self.envelope_exponent = envelope_exponent |
|
|
|
self.smiles = smiles |
|
self.processor_class = processor_class |
|
|
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
class TransmxmModel(PreTrainedModel): |
|
config_class = TransmxmConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.backbone = TransMXMNet( |
|
dim=config.dim, |
|
n_layer=config.n_layer, |
|
cutoff=config.cutoff, |
|
num_spherical=config.num_spherical, |
|
num_radial=config.num_radial, |
|
envelope_exponent=config.envelope_exponent, |
|
) |
|
self.process = SmilesDataset( |
|
smiles=config.smiles, |
|
) |
|
|
|
self.model = None |
|
self.dataset = None |
|
self.output = None |
|
self.data_loader = None |
|
self.pred_data = None |
|
|
|
def forward(self, tensor): |
|
return self.backbone.forward_features(tensor) |
|
|
|
def SmilesProcessor(self, smiles): |
|
return self.process.get_data(smiles) |
|
|
|
|
|
def predict_smiles(self, smiles, device: str='cpu', result_dir: str='./', **kwargs): |
|
|
|
|
|
batch_size = kwargs.pop('batch_size', 1) |
|
shuffle = kwargs.pop('shuffle', False) |
|
drop_last = kwargs.pop('drop_last', False) |
|
num_workers = kwargs.pop('num_workers', 0) |
|
|
|
self.model = AutoModel.from_pretrained("Huhujingjing/custom-transmxm", trust_remote_code=True).to(device) |
|
self.model.eval() |
|
|
|
self.dataset = self.process.get_data(smiles) |
|
self.output = "" |
|
self.output += ("predicted samples num: {}\n".format(len(self.dataset))) |
|
self.output +=("predicted samples:{}\n".format(self.dataset[0])) |
|
self.data_loader = DataLoader(self.dataset, |
|
batch_size=batch_size, |
|
shuffle=shuffle, |
|
drop_last=drop_last, |
|
num_workers=num_workers |
|
) |
|
self.pred_data = { |
|
'smiles': [], |
|
'pred': [] |
|
} |
|
|
|
for batch in tqdm(self.data_loader): |
|
batch = batch.to(device) |
|
with torch.no_grad(): |
|
self.pred_data['smiles'] += batch['smiles'] |
|
self.pred_data['pred'] += self.model(batch).cpu().tolist() |
|
|
|
pred = torch.tensor(self.pred_data['pred']).reshape(-1) |
|
if device == 'cuda': |
|
pred = pred.cpu().tolist() |
|
self.pred_data['pred'] = pred |
|
pred_df = pd.DataFrame(self.pred_data) |
|
pred_df['pred'] = pred_df['pred'].apply(lambda x: round(x, 2)) |
|
self.output +=('-' * 40 + '\n'+'predicted result: \n'+'{}\n'.format(pred_df)) |
|
self.output +=('-' * 40) |
|
|
|
pred_df.to_csv(os.path.join(result_dir, 'prediction.csv'), index=False) |
|
self.output +=('\nsave predicted result to {}\n'.format(os.path.join(result_dir, 'prediction.csv'))) |
|
|
|
return self.output |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
transmxm_config = TransmxmConfig.from_pretrained("custom-transmxm") |
|
|
|
transmxmd = TransmxmModel(transmxm_config) |
|
transmxmd.model.load_state_dict(torch.load(r'G:\Trans_MXM\runs\model.pt')) |
|
transmxmd.save_pretrained("custom-transmxm") |
|
|
|
|