Faster mask_human_targets implementation

#13
by Godricly - opened

I tried to implement a faster version of mask_human_targets . Could anyone help me to verify it?

        inds_2 = (input_ids == 2).nonzero()
        target_copy = input_ids.clone()
        for ind in inds_2:
            target_copy[ind[0], ind[1] + 1:] = -100
        inds_92542 = (input_ids == 92542)
        end_count = torch.zeros(input_ids.shape[0], 1).to(input_ids.device)
        end_count = torch.cat([end_count, inds_92542.cumsum(dim=1)[:,:-1]], dim=1)
        end_count_mask = end_count % 2 == 0
        prefix = torch.zeros((input_ids.shape[0], 1), device=input_ids.device, dtype=input_ids.dtype)
        last_eoa = torch.cat((prefix, inds_92542[:, :-1]), dim=1)
        last_eoa = (last_eoa * end_count_mask).cumsum(dim=1)

        for inds in inds_92542.nonzero():
            if end_count_mask[inds[0], inds[1]]:
                target_copy[inds[0], last_eoa[inds[0], inds[1]]:inds[1] + 6] = -100
        for ind in inds_2:
            target_copy[ind[0], ind[1] + 1:] = -100

Sign up or log in to comment