|
|
|
|
|
|
|
|
import os
|
|
|
import subprocess
|
|
|
from time import sleep
|
|
|
|
|
|
import fairscale.nn.model_parallel.initialize as fs_init
|
|
|
import torch
|
|
|
import torch.distributed as dist
|
|
|
from datetime import timedelta
|
|
|
|
|
|
|
|
|
def _setup_dist_env_from_slurm(args):
|
|
|
while not os.environ.get("MASTER_ADDR", ""):
|
|
|
os.environ["MASTER_ADDR"] = (
|
|
|
subprocess.check_output(
|
|
|
"sinfo -Nh -n %s | head -n 1 | awk '{print $1}'" % os.environ["SLURM_NODELIST"],
|
|
|
shell=True,
|
|
|
)
|
|
|
.decode()
|
|
|
.strip()
|
|
|
)
|
|
|
sleep(1)
|
|
|
if not os.environ.get("MASTER_PORT"):
|
|
|
os.environ["MASTER_PORT"] = str(args.master_port)
|
|
|
if not os.environ.get("WORLD_SIZE"):
|
|
|
os.environ["WORLD_SIZE"] = os.environ["SLURM_NPROCS"]
|
|
|
if not os.environ.get("RANK"):
|
|
|
os.environ["RANK"] = os.environ["SLURM_PROCID"]
|
|
|
if not os.environ.get("LOCAL_RANK"):
|
|
|
os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"]
|
|
|
if not os.environ.get("LOCAL_WORLD_SIZE"):
|
|
|
os.environ["LOCAL_WORLD_SIZE"] = os.environ["SLURM_NTASKS_PER_NODE"]
|
|
|
|
|
|
|
|
|
_INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP = None, None
|
|
|
_LOCAL_RANK, _LOCAL_WORLD_SIZE = -1, -1
|
|
|
|
|
|
|
|
|
def get_local_rank() -> int:
|
|
|
return _LOCAL_RANK
|
|
|
|
|
|
|
|
|
def get_local_world_size() -> int:
|
|
|
return _LOCAL_WORLD_SIZE
|
|
|
|
|
|
|
|
|
def distributed_init(args):
|
|
|
if any([x not in os.environ for x in ["RANK", "WORLD_SIZE", "MASTER_PORT", "MASTER_ADDR"]]):
|
|
|
_setup_dist_env_from_slurm(args)
|
|
|
|
|
|
dist.init_process_group("nccl", timeout=timedelta(hours=5))
|
|
|
fs_init.initialize_model_parallel(args.model_parallel_size)
|
|
|
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
|
|
|
|
|
|
global _LOCAL_RANK, _LOCAL_WORLD_SIZE
|
|
|
_LOCAL_RANK = int(os.environ["LOCAL_RANK"])
|
|
|
_LOCAL_WORLD_SIZE = int(os.environ["LOCAL_WORLD_SIZE"])
|
|
|
|
|
|
global _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP
|
|
|
local_ranks, local_world_sizes = [
|
|
|
torch.empty([dist.get_world_size()], dtype=torch.long, device="cuda") for _ in (0, 1)
|
|
|
]
|
|
|
dist.all_gather_into_tensor(local_ranks, torch.tensor(get_local_rank(), device="cuda"))
|
|
|
dist.all_gather_into_tensor(local_world_sizes, torch.tensor(get_local_world_size(), device="cuda"))
|
|
|
local_ranks, local_world_sizes = local_ranks.tolist(), local_world_sizes.tolist()
|
|
|
node_ranks = [[0]]
|
|
|
for i in range(1, dist.get_world_size()):
|
|
|
if len(node_ranks[-1]) == local_world_sizes[i - 1]:
|
|
|
node_ranks.append([])
|
|
|
else:
|
|
|
assert local_world_sizes[i] == local_world_sizes[i - 1]
|
|
|
node_ranks[-1].append(i)
|
|
|
for ranks in node_ranks:
|
|
|
group = dist.new_group(ranks)
|
|
|
if dist.get_rank() in ranks:
|
|
|
assert _INTRA_NODE_PROCESS_GROUP is None
|
|
|
_INTRA_NODE_PROCESS_GROUP = group
|
|
|
assert _INTRA_NODE_PROCESS_GROUP is not None
|
|
|
|
|
|
if min(local_world_sizes) == max(local_world_sizes):
|
|
|
for i in range(get_local_world_size()):
|
|
|
group = dist.new_group(list(range(i, dist.get_world_size(), get_local_world_size())))
|
|
|
if i == get_local_rank():
|
|
|
assert _INTER_NODE_PROCESS_GROUP is None
|
|
|
_INTER_NODE_PROCESS_GROUP = group
|
|
|
assert _INTER_NODE_PROCESS_GROUP is not None
|
|
|
|
|
|
|
|
|
def get_intra_node_process_group():
|
|
|
assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra-node process group is not initialized."
|
|
|
return _INTRA_NODE_PROCESS_GROUP
|
|
|
|
|
|
|
|
|
def get_inter_node_process_group():
|
|
|
assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra- and inter-node process groups are not initialized."
|
|
|
return _INTER_NODE_PROCESS_GROUP
|
|
|
|