Spaces:
Runtime error
Runtime error
File size: 4,611 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import os
import time
import torch
import random
import torch.multiprocessing as mp
import numpy as np
try:
mp.set_start_method('spawn', force=True)
except RuntimeError:
pass
import colbert.utils.distributed as distributed
from colbert.infra.run import Run
from colbert.infra.config import BaseConfig, RunConfig, RunSettings
from colbert.utils.utils import print_message
class Launcher:
def __init__(self, callee, run_config=None, return_all=False):
self.callee = callee
self.return_all = return_all
self.run_config = RunConfig.from_existing(Run().config, run_config)
self.nranks = self.run_config.nranks
def launch(self, custom_config, *args):
return_value_queue = mp.Queue()
rng = random.Random(time.time())
port = str(12355 + rng.randint(0, 1000)) # randomize the port to avoid collision on launching several jobs.
all_procs = []
for new_rank in range(0, self.nranks):
assert isinstance(custom_config, BaseConfig)
assert isinstance(custom_config, RunSettings)
new_config = type(custom_config).from_existing(custom_config, self.run_config, RunConfig(rank=new_rank))
args_ = (self.callee, port, return_value_queue, new_config, *args)
all_procs.append(mp.Process(target=setup_new_process, args=args_))
# Clear GPU space (e.g., after a `Searcher` on GPU-0 is deleted)
# TODO: Generalize this from GPU-0 only!
# TODO: Move this to a function. And call that function from __del__ in a class that's inherited by Searcher, Indexer, etc.
# t = torch.cuda.get_device_properties(0).total_memory
# r = torch.cuda.memory_reserved(0)
# a = torch.cuda.memory_allocated(0)
# f = r-a
# print_message(f"[Pre-Emptying] GPU memory check: r={r}, a={a}, f={f}")
torch.cuda.empty_cache()
# t = torch.cuda.get_device_properties(0).total_memory
# r = torch.cuda.memory_reserved(0)
# a = torch.cuda.memory_allocated(0)
# f = r-a
# print_message(f"[Post-Emptying] GPU memory check: r={r}, a={a}, f={f}")
print_memory_stats('MAIN')
for proc in all_procs:
print("#> Starting...")
proc.start()
print_memory_stats('MAIN')
# TODO: If the processes crash upon join, raise an exception and don't block on .get() below!
return_values = sorted([return_value_queue.get() for _ in all_procs])
return_values = [val for rank, val in return_values]
if not self.return_all:
return_values = return_values[0]
for proc in all_procs:
proc.join()
print("#> Joined...")
print_memory_stats('MAIN')
return return_values
def setup_new_process(callee, port, return_value_queue, config, *args):
print_memory_stats()
random.seed(12345)
np.random.seed(12345)
torch.manual_seed(12345)
torch.cuda.manual_seed_all(12345)
rank, nranks = config.rank, config.nranks
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = port
os.environ["WORLD_SIZE"] = str(config.nranks)
os.environ["RANK"] = str(config.rank)
# TODO: Ideally the gpus "getter" handles this max-nranks thing!
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, config.gpus_[:nranks]))
nranks_, distributed_ = distributed.init(rank)
assert nranks_ == nranks
# Run.init(args.rank, args.root, args.experiment, args.run)
with Run().context(config, inherit_config=False):
return_val = callee(config, *args)
return_value_queue.put((rank, return_val))
def print_memory_stats(message=''):
return # FIXME: Add this back before release.
import psutil # Remove before releases? Or at least make optional with try/except.
global_info = psutil.virtual_memory()
total, available, used, free = global_info.total, global_info.available, global_info.used, global_info.free
info = psutil.Process().memory_info()
rss, vms, shared = info.rss, info.vms, info.shared
uss = psutil.Process().memory_full_info().uss
gib = 1024 ** 3
summary = f"""
"[PID: {os.getpid()}]
[{message}]
Available: {available / gib:,.1f} / {total / gib:,.1f}
Free: {free / gib:,.1f} / {total / gib:,.1f}
Usage: {used / gib:,.1f} / {total / gib:,.1f}
RSS: {rss / gib:,.1f}
VMS: {vms / gib:,.1f}
USS: {uss / gib:,.1f}
SHARED: {shared / gib:,.1f}
""".strip().replace('\n', '\t')
print_message(summary, pad=True)
|