KGEditor / src /models /one_shot_learner.py
ChancesYuan's picture
Upload 12 files
06a8327
raw
history blame
5.05 kB
import torch
from allennlp.modules.feedforward import FeedForward
from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper
from higher.patch import monkeypatch as make_functional
class ConditionedParameter(torch.nn.Module):
def __init__(self, parameter, condition_dim=1024, hidden_dim=128, max_scale=1):
super().__init__()
self.parameter_shape = parameter.shape
if len(self.parameter_shape) == 2: # condition_dim是从lstm中得到的tensor,然后用linear学习返回到768作为更新的parm_dict
self.conditioners = torch.nn.Sequential(
torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)),
torch.nn.Tanh(),
torch.nn.utils.weight_norm(
torch.nn.Linear(
hidden_dim, 2 * (parameter.shape[0] + parameter.shape[1]) + 1
)
),
)
elif len(self.parameter_shape) == 1:
self.conditioners = torch.nn.Sequential(
torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)),
torch.nn.Tanh(),
torch.nn.utils.weight_norm(
torch.nn.Linear(hidden_dim, 2 * parameter.shape[0] + 1)
),
)
else:
raise RuntimeError()
self.max_scale = max_scale
def forward(self, inputs, grad):
if len(self.parameter_shape) == 2:
(
conditioner_cola,
conditioner_rowa,
conditioner_colb,
conditioner_rowb,
conditioner_norm,
) = self.conditioners(inputs).split(
[
self.parameter_shape[1],
self.parameter_shape[0],
self.parameter_shape[1],
self.parameter_shape[0],
1,
],
dim=-1,
)
a = conditioner_rowa.softmax(-1).T @ conditioner_cola
b = conditioner_rowb.softmax(-1).T @ conditioner_colb
elif len(self.parameter_shape) == 1:
a, b, conditioner_norm = self.conditioners(inputs).split(
[self.parameter_shape[0], self.parameter_shape[0], 1], dim=-1
)
else:
raise RuntimeError()
return (
self.max_scale
* torch.mean(conditioner_norm.sigmoid(), dim=0).squeeze() # 多条我们直接取mean
* (grad * a.squeeze() + b.squeeze())
)
class LSTMConditioner(torch.nn.Module):
def __init__(
self,
vocab_dim=30522,
embedding_dim=768,
hidden_dim=256,
output_dim=1024,
embedding_init=None,
):
super().__init__()
self.embedding = torch.nn.Embedding(
num_embeddings=vocab_dim,
embedding_dim=embedding_dim,
padding_idx=0,
_weight=embedding_init,
)
self.lstm = PytorchSeq2VecWrapper(
torch.nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_dim,
num_layers=1,
bidirectional=True,
batch_first=True,
)
)
self.linear = FeedForward(
input_dim=hidden_dim * 2,
num_layers=1,
hidden_dims=[output_dim],
activations=[torch.nn.Tanh()],
)
def forward(self, inputs, masks):
return self.linear(self.lstm(self.embedding(inputs), masks)) # 1, 64
class OneShotLearner(torch.nn.Module):
def __init__(
self,
model,
vocab_dim=30522,
embedding_dim=768,
hidden_dim=128,
condition_dim=1024,
include_set={},
max_scale=1e-3,
embedding_init=None,
):
super().__init__()
self.param2conditioner_map = {
n: "{}_conditioner".format(n).replace(".", "_")
for n, p in model.named_parameters()
if n in include_set
}
self.conditioners = torch.nn.ModuleDict(
{
self.param2conditioner_map[n]: ConditionedParameter(
p,
condition_dim,
hidden_dim,
max_scale=max_scale,
)
for n, p in model.named_parameters()
if n in include_set
}
)
self.condition = LSTMConditioner(
vocab_dim,
embedding_dim,
hidden_dim,
condition_dim,
embedding_init=embedding_init,
)
def forward(self, inputs, masks, grads=None):
condition = self.condition(inputs, masks) # LSTM输出condition
return {
p: self.conditioners[self.param2conditioner_map[p]](
condition,
grad=grads[p] if grads else None,
)
for p, c in self.param2conditioner_map.items()
}