File size: 812 Bytes
aa5ee46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import numpy as np

def run_func_in_parts(func, vid_emb, aud_emb, part_len, dim, device):
    """
    Run given function in parts, spliting the inputs on dimension dim 
    This is used to save memory when inputs too large to compute on gpu 
    """
    dist_chunk = []
    for v_spl, a_spl in list(
            zip(vid_emb.split(part_len, dim=dim),
                aud_emb.split(part_len, dim=dim))):
        dist_chunk.append(func(v_spl.to(device), a_spl.to(device)))
    dist = torch.cat(dist_chunk, dim - 1)
    return dist

def logsoftmax_2d(logits):
    # Log softmax on last 2 dims because torch won't allow multiple dims
    orig_shape = logits.shape
    logprobs = torch.nn.LogSoftmax(dim=-1)(
        logits.reshape(list(logits.shape[:-2]) + [-1])).reshape(orig_shape)
    return logprobs