hertz-dev / utils /dist.py
calculating
committing...
4f90f1b
import os
import torch as T
import re
from tqdm import tqdm
from datetime import timedelta
import requests
import hashlib
from io import BytesIO
from huggingface_hub import hf_hub_download
def rank0():
rank = os.environ.get('RANK')
if rank is None or rank == '0':
return True
else:
return False
def local0():
local_rank = os.environ.get('LOCAL_RANK')
if local_rank is None or local_rank == '0':
return True
else:
return False
class tqdm0(tqdm):
def __init__(self, *args, **kwargs):
total = kwargs.get('total', None)
if total is None and len(args) > 0:
try:
total = len(args[0])
except TypeError:
pass
if total is not None:
kwargs['miniters'] = max(1, total // 20)
super().__init__(*args, **kwargs, disable=not rank0(), bar_format='{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]')
def print0(*args, **kwargs):
if rank0():
print(*args, **kwargs)
_PRINTED_IDS = set()
def printonce(*args, id=None, **kwargs):
if id is None:
id = ' '.join(map(str, args))
if id not in _PRINTED_IDS:
print(*args, **kwargs)
_PRINTED_IDS.add(id)
def print0once(*args, **kwargs):
if rank0():
printonce(*args, **kwargs)
def init_dist():
if T.distributed.is_initialized():
print0('Distributed already initialized')
rank = T.distributed.get_rank()
local_rank = int(os.environ.get('LOCAL_RANK', 0))
world_size = T.distributed.get_world_size()
else:
try:
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
device = f'cuda:{local_rank}'
T.cuda.set_device(device)
T.distributed.init_process_group(backend='nccl', timeout=timedelta(minutes=30), rank=rank, world_size=world_size, device_id=T.device(device))
print(f'Rank {rank} of {world_size}.')
except Exception as e:
print0once(f'Not initializing distributed env: {e}')
rank = 0
local_rank = 0
world_size = 1
return rank, local_rank, world_size
def load_ckpt(load_from_location, expected_hash=None):
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' #Disable this to speed up debugging errors with downloading from the hub
if local0():
repo_id = "si-pbc/hertz-dev"
print0(f'Loading checkpoint from repo_id {repo_id} and filename {load_from_location}.pt. This may take a while...')
save_path = hf_hub_download(repo_id=repo_id, filename=f"{load_from_location}.pt")
print0(f'Downloaded checkpoint to {save_path}')
if expected_hash is not None:
with open(save_path, 'rb') as f:
file_hash = hashlib.md5(f.read()).hexdigest()
if file_hash != expected_hash:
print(f'Hash mismatch for {save_path}. Expected {expected_hash} but got {file_hash}. Deleting checkpoint and trying again.')
os.remove(save_path)
return load_ckpt(load_from_location, expected_hash)
if T.distributed.is_initialized():
save_path = [save_path]
T.distributed.broadcast_object_list(save_path, src=0)
save_path = save_path[0]
loaded = T.load(save_path, weights_only=False, map_location='cpu')
print0(f'Loaded checkpoint from {save_path}')
return loaded