Spaces:
Running
Running
File size: 937 Bytes
6d1366a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
# Copyright (c) OpenMMLab. All rights reserved.
def nlc_to_nchw(x, hw_shape):
"""Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
Args:
x (Tensor): The input tensor of shape [N, L, C] before convertion.
hw_shape (Sequence[int]): The height and width of output feature map.
Returns:
Tensor: The output tensor of shape [N, C, H, W] after convertion.
"""
H, W = hw_shape
assert len(x.shape) == 3
B, L, C = x.shape
assert L == H * W, 'The seq_len doesn\'t match H, W'
return x.transpose(1, 2).reshape(B, C, H, W)
def nchw_to_nlc(x):
"""Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
Args:
x (Tensor): The input tensor of shape [N, C, H, W] before convertion.
Returns:
Tensor: The output tensor of shape [N, L, C] after convertion.
"""
assert len(x.shape) == 4
return x.flatten(2).transpose(1, 2).contiguous()
|