File size: 4,611 Bytes
58627fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import time
import torch
import random

import torch.multiprocessing as mp
import numpy as np

try:
    mp.set_start_method('spawn', force=True)
except RuntimeError:
    pass

import colbert.utils.distributed as distributed

from colbert.infra.run import Run
from colbert.infra.config import BaseConfig, RunConfig, RunSettings

from colbert.utils.utils import print_message


class Launcher:
    def __init__(self, callee, run_config=None, return_all=False):
        self.callee = callee
        self.return_all = return_all

        self.run_config = RunConfig.from_existing(Run().config, run_config)
        self.nranks = self.run_config.nranks

    def launch(self, custom_config, *args):
        return_value_queue = mp.Queue()

        rng = random.Random(time.time())
        port = str(12355 + rng.randint(0, 1000))  # randomize the port to avoid collision on launching several jobs.

        all_procs = []
        for new_rank in range(0, self.nranks):
            assert isinstance(custom_config, BaseConfig)
            assert isinstance(custom_config, RunSettings)

            new_config = type(custom_config).from_existing(custom_config, self.run_config, RunConfig(rank=new_rank))

            args_ = (self.callee, port, return_value_queue, new_config, *args)
            all_procs.append(mp.Process(target=setup_new_process, args=args_))

        # Clear GPU space (e.g., after a `Searcher` on GPU-0 is deleted)
        # TODO: Generalize this from GPU-0 only!
        # TODO: Move this to a function. And call that function from __del__ in a class that's inherited by Searcher, Indexer, etc.

        # t = torch.cuda.get_device_properties(0).total_memory
        # r = torch.cuda.memory_reserved(0)
        # a = torch.cuda.memory_allocated(0)
        # f = r-a

        # print_message(f"[Pre-Emptying] GPU memory check: r={r}, a={a}, f={f}")

        torch.cuda.empty_cache()

        # t = torch.cuda.get_device_properties(0).total_memory
        # r = torch.cuda.memory_reserved(0)
        # a = torch.cuda.memory_allocated(0)
        # f = r-a

        # print_message(f"[Post-Emptying] GPU memory check: r={r}, a={a}, f={f}")

        print_memory_stats('MAIN')

        for proc in all_procs:
            print("#> Starting...")
            proc.start()

        print_memory_stats('MAIN')

        # TODO: If the processes crash upon join, raise an exception and don't block on .get() below!

        return_values = sorted([return_value_queue.get() for _ in all_procs])
        return_values = [val for rank, val in return_values]

        if not self.return_all:
            return_values = return_values[0]
        
        for proc in all_procs:
            proc.join()
            print("#> Joined...")

        print_memory_stats('MAIN')
        
        return return_values


def setup_new_process(callee, port, return_value_queue, config, *args):
    print_memory_stats()

    random.seed(12345)
    np.random.seed(12345)
    torch.manual_seed(12345)
    torch.cuda.manual_seed_all(12345)

    rank, nranks = config.rank, config.nranks

    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = port
    os.environ["WORLD_SIZE"] = str(config.nranks)
    os.environ["RANK"] = str(config.rank)

    # TODO: Ideally the gpus "getter" handles this max-nranks thing!
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, config.gpus_[:nranks]))

    nranks_, distributed_ = distributed.init(rank)
    assert nranks_ == nranks

    # Run.init(args.rank, args.root, args.experiment, args.run)

    with Run().context(config, inherit_config=False):
        return_val = callee(config, *args)

    return_value_queue.put((rank, return_val))


def print_memory_stats(message=''):
    return  # FIXME: Add this back before release.

    import psutil  # Remove before releases? Or at least make optional with try/except.

    global_info = psutil.virtual_memory()
    total, available, used, free = global_info.total, global_info.available, global_info.used, global_info.free

    info = psutil.Process().memory_info()
    rss, vms, shared = info.rss, info.vms, info.shared
    uss = psutil.Process().memory_full_info().uss

    gib = 1024 ** 3

    summary = f"""
    "[PID: {os.getpid()}]
    [{message}]
    Available: {available / gib:,.1f} / {total / gib:,.1f}
    Free: {free / gib:,.1f} / {total / gib:,.1f}
    Usage: {used / gib:,.1f} / {total / gib:,.1f}

    RSS: {rss  / gib:,.1f}
    VMS: {vms  / gib:,.1f}
    USS: {uss  / gib:,.1f}
    SHARED: {shared  / gib:,.1f}
    """.strip().replace('\n', '\t')

    print_message(summary, pad=True)