File size: 629 Bytes
c3c908f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#!/usr/bin/env python
# coding: utf-8

import os
import hostlist

# get SLURM variables
# rank = int(os.environ["SLURM_PROCID"])
local_rank = int(os.environ["SLURM_LOCALID"])
size = int(os.environ["SLURM_NTASKS"])
cpus_per_task = int(os.environ["SLURM_CPUS_PER_TASK"])

# get node list from slurm
hostnames = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])

# get IDs of reserved GPU
gpu_ids = os.environ["SLURM_STEP_GPUS"].split(",")

# define MASTER_ADD & MASTER_PORT
os.environ["MASTER_ADDR"] = hostnames[0]
os.environ["MASTER_PORT"] = str(
    12345 + int(min(gpu_ids))
)  # to avoid port conflict on the same node