zhaoyian01's picture
Add application file
6d1366a
raw
history blame
937 Bytes
# Copyright (c) OpenMMLab. All rights reserved.
def nlc_to_nchw(x, hw_shape):
"""Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
Args:
x (Tensor): The input tensor of shape [N, L, C] before convertion.
hw_shape (Sequence[int]): The height and width of output feature map.
Returns:
Tensor: The output tensor of shape [N, C, H, W] after convertion.
"""
H, W = hw_shape
assert len(x.shape) == 3
B, L, C = x.shape
assert L == H * W, 'The seq_len doesn\'t match H, W'
return x.transpose(1, 2).reshape(B, C, H, W)
def nchw_to_nlc(x):
"""Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
Args:
x (Tensor): The input tensor of shape [N, C, H, W] before convertion.
Returns:
Tensor: The output tensor of shape [N, L, C] after convertion.
"""
assert len(x.shape) == 4
return x.flatten(2).transpose(1, 2).contiguous()