import torch import numpy as np mask_lst=[] mask_total = np.zeros((32, 32)) for layer_idx in range(32): attn_weights_lst=[] for i in range(4): attn_weights_lst.append(torch.load("weights"+str(i)+"/"+str(layer_idx)+".pth")) attn_weights = torch.cat(attn_weights_lst,dim=0) for i in range(100): res_v = torch.sum(attn_weights[i,:,:,35:35+576],dim=[1,2]) res_t = torch.sum(attn_weights[i,:,:,35+576:],dim=[1,2]) res_s = torch.sum(attn_weights[i,:,:,:35],dim=[1,2]) res = res_v/(res_t+res_s) #mask = res>0.35 # headcut5 if layer_idx>=2: mask = res>0.38 # headcut8: 0.38 headcut16: 0.11 headcut24: 0.044 else: mask = res>=0 # mask = res>0.06 # headcut20 # mask = res>0.03 # headcut30 mask = mask.int() # torch.save(mask,"temp/"+str(layer_idx)+".pth") mask_total[layer_idx]+=mask.detach().cpu().numpy() #import pdb; pdb.set_trace() #mask_lst.append(mask) #mask_lst.append(mask.sum().detach().cpu().numpy()) temp = mask_total>40 #headcut5 10 30 import pdb; pdb.set_trace() print(mask_lst) #print(np.mean(mask_lst)) ''' attn_weights_lst=[] for i in range(4): attn_weights_lst for j in range(40): attn_weights_lst.append(torch.load("weights"+)) attn_weights = torch.cat([attnweight1,attnweight2],dim=0) res_v = torch.sum(attn_weights[:,:,:,35:35+576],dim=[0,2,3]) res_t = torch.sum( attn_weights[:,:,:,35+576:],dim=[0,2,3]) res_s = torch.sum( attn_weights[:,:,:,:35],dim=[0,2,3]) res = res_v/(res_t+res_s) mask = res>0.12 import pdb; pdb.set_trace() '''