Spaces:
Runtime error
Runtime error
import torch | |
class MomentumOptim: | |
def __init__(self, step_size=0.01, momentum=0.9): | |
self.step_size = step_size | |
self.momentum = momentum | |
self.m = None # velocity | |
def init(self): | |
self.m = None | |
def upd_m(self, old_m, g): | |
return g + self.momentum * old_m | |
def upd(self, old_x, m): | |
return old_x + self.step_size * m | |
def __call__(self, old_xs, new_xs): | |
pesudo_gs = [new_x - old_x for old_x, new_x in zip(old_xs, new_xs)] | |
if not self.m: | |
self.m = pesudo_gs | |
else: | |
self.m = [self.upd_m(old_m, g) for old_m, g in zip(self.m, pesudo_gs)] | |
updated_kv = [self.upd(old_x, m) for old_x, m in zip(old_xs, self.m)] | |
return updated_kv | |
class AttnOptimWrapper: | |
def __init__(self, llm, model_type, optimizer="momentum", **optimizer_args): | |
self.model = llm | |
self.kv = None | |
self.model_type = model_type | |
if optimizer == "momentum": | |
self.optim_k = MomentumOptim(**optimizer_args) | |
self.optim_v = MomentumOptim(**optimizer_args) | |
else: | |
raise ValueError() | |
def init(self): | |
self.optim_k.init() | |
self.optim_v.init() | |
def step(self, ctx_ids): | |
L = len(ctx_ids) | |
ctx_ids = ctx_ids.unsqueeze(0) # [1, L] | |
mask = torch.ones_like(ctx_ids) | |
if self.kv is not None: | |
mask = mask.repeat(1, 2) # [1, 2*L] | |
next_kv = self.model( | |
input_ids=ctx_ids, | |
attention_mask=mask, | |
past_key_values=self.kv, | |
use_cache=True, | |
).past_key_values # kv @ (old_ctx + new_ctx) | |
cur_kv = [] | |
for layer_k, layer_v in next_kv: | |
# [B, num_head, 2*L, head_hidden] | |
cur_kv.append([layer_k[:, :, -L:, :], layer_v[:, :, -L:, :]]) # kv @ (new_ctx) | |
if not self.kv: | |
self.kv = cur_kv | |
else: | |
old_ks, old_vs = zip(*self.kv) | |
cur_ks, cur_vs = zip(*cur_kv) | |
upd_ks = self.optim_k(old_ks, cur_ks) | |
upd_vs = self.optim_v(old_vs, cur_vs) | |
self.kv = list(zip(upd_ks, upd_vs)) | |
return self.kv | |