#!/usr/bin/env python # coding: utf-8 import os import hostlist # get SLURM variables # rank = int(os.environ["SLURM_PROCID"]) local_rank = int(os.environ["SLURM_LOCALID"]) size = int(os.environ["SLURM_NTASKS"]) cpus_per_task = int(os.environ["SLURM_CPUS_PER_TASK"]) # get node list from slurm hostnames = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"]) # get IDs of reserved GPU gpu_ids = os.environ["SLURM_STEP_GPUS"].split(",") # define MASTER_ADD & MASTER_PORT os.environ["MASTER_ADDR"] = hostnames[0] os.environ["MASTER_PORT"] = str( 12345 + int(min(gpu_ids)) ) # to avoid port conflict on the same node