Faster mask_human_targets implementation
#13
by
Godricly
- opened
I tried to implement a faster version of mask_human_targets . Could anyone help me to verify it?
inds_2 = (input_ids == 2).nonzero()
target_copy = input_ids.clone()
for ind in inds_2:
target_copy[ind[0], ind[1] + 1:] = -100
inds_92542 = (input_ids == 92542)
end_count = torch.zeros(input_ids.shape[0], 1).to(input_ids.device)
end_count = torch.cat([end_count, inds_92542.cumsum(dim=1)[:,:-1]], dim=1)
end_count_mask = end_count % 2 == 0
prefix = torch.zeros((input_ids.shape[0], 1), device=input_ids.device, dtype=input_ids.dtype)
last_eoa = torch.cat((prefix, inds_92542[:, :-1]), dim=1)
last_eoa = (last_eoa * end_count_mask).cumsum(dim=1)
for inds in inds_92542.nonzero():
if end_count_mask[inds[0], inds[1]]:
target_copy[inds[0], last_eoa[inds[0], inds[1]]:inds[1] + 6] = -100
for ind in inds_2:
target_copy[ind[0], ind[1] + 1:] = -100