File size: 3,155 Bytes
07c6a04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from typing import Optional

import torch
import torch.distributed as dist
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
from torch.distributed import ProcessGroup

from videosys.utils.logging import init_dist_logger, logger
from videosys.utils.utils import set_seed

PARALLEL_MANAGER = None


class ParallelManager(ProcessGroupMesh):
    def __init__(self, dp_size, cp_size, sp_size):
        super().__init__(dp_size, cp_size, sp_size)
        dp_axis, cp_axis, sp_axis = 0, 1, 2

        self.dp_size = dp_size
        self.dp_group: ProcessGroup = self.get_group_along_axis(dp_axis)
        self.dp_rank = dist.get_rank(self.dp_group)

        self.cp_size = cp_size
        self.cp_group: ProcessGroup = self.get_group_along_axis(cp_axis)
        self.cp_rank = dist.get_rank(self.cp_group)

        self.sp_size = sp_size
        self.sp_group: ProcessGroup = self.get_group_along_axis(sp_axis)
        self.sp_rank = dist.get_rank(self.sp_group)
        self.enable_sp = sp_size > 1

        logger.info(f"Init parallel manager with dp_size: {dp_size}, cp_size: {cp_size}, sp_size: {sp_size}")


def set_parallel_manager(dp_size, cp_size, sp_size):
    global PARALLEL_MANAGER
    PARALLEL_MANAGER = ParallelManager(dp_size, cp_size, sp_size)


def get_data_parallel_group():
    return PARALLEL_MANAGER.dp_group


def get_data_parallel_size():
    return PARALLEL_MANAGER.dp_size


def get_data_parallel_rank():
    return PARALLEL_MANAGER.dp_rank


def get_sequence_parallel_group():
    return PARALLEL_MANAGER.sp_group


def get_sequence_parallel_size():
    return PARALLEL_MANAGER.sp_size


def get_sequence_parallel_rank():
    return PARALLEL_MANAGER.sp_rank


def get_cfg_parallel_group():
    return PARALLEL_MANAGER.cp_group


def get_cfg_parallel_size():
    return PARALLEL_MANAGER.cp_size


def enable_sequence_parallel():
    if PARALLEL_MANAGER is None:
        return False
    return PARALLEL_MANAGER.enable_sp


def get_parallel_manager():
    return PARALLEL_MANAGER


def initialize(
    rank=0,
    world_size=1,
    init_method=None,
    seed: Optional[int] = None,
    sp_size: Optional[int] = None,
    enable_cp: bool = True,
):
    if not dist.is_initialized():
        try:
            dist.destroy_process_group()
        except Exception:
            pass
        dist.init_process_group(backend="nccl", init_method=init_method, world_size=world_size, rank=rank)
        torch.cuda.set_device(rank)
        init_dist_logger()
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    # init sequence parallel
    if sp_size is None:
        sp_size = dist.get_world_size()
        dp_size = 1
    else:
        assert dist.get_world_size() % sp_size == 0, f"world_size {dist.get_world_size()} must be divisible by sp_size"
        dp_size = dist.get_world_size() // sp_size

    # update cfg parallel
    if enable_cp and sp_size % 2 == 0:
        sp_size = sp_size // 2
        cp_size = 2
    else:
        cp_size = 1

    set_parallel_manager(dp_size, cp_size, sp_size)

    if seed is not None:
        set_seed(seed + get_data_parallel_rank())