Using Jean Zay

[!] Setup python env depot in $WORK

On Jean Zay, the space on the $HOME folder is extremely limited, so we need to store python and conda environments in the $WORK directory. To achieve that, you should move your ~/.local , ~/.conda and ~/.cache directories to $WORK and create a symbolic link from the previous location to the new.

To do that: copy-paste the following command on the terminal the first time you do this

if [ -d "$HOME/.local" ]; then
    if [ ! -d "$WORK/.local" ]; then
        mv $HOME/.local $WORK
    fi
else
    mkdir $WORK/.local
fi
ln -s $WORK/.local $HOME

then do the same for .conda:

if [ -d "$HOME/.conda" ]; then
    if [ ! -d "$WORK/.conda" ]; then
        mv $HOME/.conda $WORK
    fi
else
    mkdir $WORK/.conda
fi
ln -s $WORK/.conda $HOME

and for .cache

if [ -d "$HOME/.cache" ]; then
    if [ ! -d "$WORK/.cache" ]; then
        mv $HOME/.cache $WORK
    fi
else
    mkdir $WORK/.cache
fi
 
ln -s $WORK/.cache $HOME

[!] Different CPU types & Segmentation Faults

The computational nodes on Jean Zay have two different types of incompatible CPUs.

  • Most nodes have Intel CPUs (notably, all V100 and the A100 40GB nodes —partition=p5) but

  • The nodes with A100 80 GB SXM GPUs (obtained by running with —-constraint=a100 ) have AMD CPUs

When running on —-constraint=a100 on A100 80GB GPUs you must module load arch/a100 BEFORE loading ANY other module.

MPI4PY AND MPI4JAX must be compiled for AMD architecture to work on both. If they are compiled for the standard intel architecture they will only work on it.

To compile them, you must load on the login node the arch/a100 module so do something like

module purge
module load arch/a100
module load gcc/12.2.0 anaconda-py3 cuda/12.2.0 cudnn/9.2.0.82-cuda openmpi/4.1.5-cuda
conda activate ENV

# clear the installation cache otherwise mpi4py will be installed 
# from a cache which you do not know how was compiled
pip cache remove mpi4py
pip cache remove mpi4jax
pip uninstall -y mpi4py mpi4jax

pip install mpi4py
pip install --no-build-isolation mpi4jax

If you are not running with —-constraint=a100 you must not load the arch/a100 module!

mpi4py and mpi4jax compiled for AMD will work on Intel, though other packages might not.

You might to recompile them…

Setting up IMPORTANT software

  1. To install gh (command line GitHub, useful to add your credentials for GitHub private repositories

  • curl -sS https://webi.sh/gh | sh then he will probably tell you to add a line like source ~/.config/envman/PATH.env to your .bash_profile file

  • You should then run the following lines:

  • gh auth login
    gh auth setup-git
  1. This will install uv, a tool to manage python environments

  • curl -LsSf https://astral.sh/uv/install.sh | sh to install uv

  • Once you have installed uv you should install some associated python versions, such as 3.12

uv python install 3.12
  1. You need to install a small custom-made wrapper of uv to work around the limitations of number of files in Jean Zay

mkdir -p ~/.local/bin 
wget -O ~/.local/bin/fuv "https://gist.githubusercontent.com/PhilipVinc/ede9711b752c83dd908226d448d56a73/raw/1f0cbd1e07e0d8690fe73f4d70bfa41ebde2f7ac/fuv"
chmod +x ~/.local/bin/fuv
  1. You need to declare where to store your Python environments:

    • Add to your ~/.bashrc and ~/.bash_profile files the following lines (to add them, you can simply edit those files with nano)

    export XDG_CACHE_HOME=$SCRATCH/.cache
    export UV_DEPOT=$SCRATCH/uv-venvs

To install Python, you should use UV. Look at the guide in Python environments: how to use UV

PIP Cache management

Pip caches very aggressively all packages it builds and installs, in order to speed up installation when you run pip install XYZ .

However this is at odds with the fact that at times we want to recompile a package (mpi4py) for different versions of MPI or CPU architectures.

To solve those issues, you can selectively remove some packages from pip cache by running

pip cache remove XYZ

We suggest you always remove mpi4jax and mpi4py from the cache before installing them.

Running jobs over Multi-GPU

As in Using Cholesky, one can either run over multiple gpus with sharding or mpi. Sharding is much easier to set up and use, so we recommend this approach.

Sharding (Recommended)

Setting up the environment

One should probably setup two environments one for amd cpus and one for intel, although I have not checked if this is strictly necessary

module purge
module load cpuarch/amd #only for a100 environment
module load gcc/12.2.0 anaconda-py3 openmpi/4.1.5
mamba create -y --name ENV_NAME python=3.11 
conda activate ENV_NAME

pip install --upgrade pip
pip install -U 'jax[cuda]' 'nvidia-cudnn-cu12<9.4' #at the moment there is a bug in the latest version of cudnn
... your other packages ...

Running the script

With sharding one can run either with a single task and multiple gpus (which will work only on a single node) or with multiple tasks and multiple gpus (which can work on multiple nodes). See example scripts below:

Single task - multiple gpus (single node)

#!/bin/bash
#SBATCH --job-name=test
#SBATCH --output=%j.out
#SBATCH --ntasks=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=64 #see gpu partition information for how many cpus per node
#SBATCH --gres=gpu:8 #number of gpus on a single node
#SBATCH --time=00:05:00
#SBATCH --constraint=a100 #or v100,... can also specify partition=gpu_p2 instead
#SBATCH --account=iqu@a100 #or iqu@v100
#SBATCH --qos=qos_gpu-dev #for debugging, otherwise use qos_gpu-t3 or qos_gpu-t4, see gpu partition info
#SBATCH --mail-type=ALL
#SBATCH --mail-user=your_email

module purge
module load cpuarch/amd #only if on a100
module load gcc/12.2.0 anaconda-py3
conda activate ENV_NAME

export NETKET_EXPERIMENTAL_SHARDING=1
export JAX_PLATFORM_NAME=gpu

srun yourscript.py

Multiple task - multiple gpus (multiple nodes)

#!/bin/bash
#SBATCH --job-name=test
#SBATCH --output=%j.out
#SBATCH --ntasks=16
#SBATCH --ntasks-per-node=8   #together with --ntasks this specifies 2 nodes with 8 tasks each
#SBATCH --cpus-per-task=8
#SBATCH --gres=gpu:8 #number of gpus on a single node
#SBATCH --time=00:05:00
#SBATCH --constraint=a100 #or v100,... can also specify partition=gpu_p2 instead
#SBATCH --account=iqu@a100 #or iqu@v100
#SBATCH --qos=qos_gpu-dev #for debugging, otherwise use qos_gpu-t3 or qos_gpu-t4, see gpu partition info
#SBATCH --mail-type=ALL
#SBATCH --mail-user=your_email

module purge
module load cpuarch/amd
module load gcc/12.2.0 anaconda-py3 openmpi/4.1.5
conda activate ENV_NAME

export NETKET_EXPERIMENTAL_SHARDING=1
export JAX_PLATFORM_NAME=gpu

srun python yourscript.py

In the python script

If using a single task with multiple gpus (i.e running on a single node) the only necessary modification to your python script is to use to_jax_operator() when initializing your driver, e.g

gs = netket.VMC(hamiltonian.to_jax_operator(), optimizer)
....

If using multiple tasks with multiple gpus (i.e running on multiple nodes) one also needs to add the following at the beginning of your script

import jax
jax.distributed.initialize()
import netket as nk
...

To check everything works correctly, you can print the following in your python script (after jax.distributed.initalize())

print(f"{jax.process_index()}/{jax.process_count()} : global", jax.devices())
print(f"{jax.process_index()}/{jax.process_count()} : local", jax.local_devices())

If everythings works correctly for single task/multi-gpus, output will look like:

0/1: global, {n devices...} #for single task with n gpus
0/1: local, {n devices...}

Or for multiple tasks/multiple gpus:

0/m: global, {m*n (all) devices...} #for m tasks each with n gpus
0/m: local, {n devices...}
1/m: global, {m*n (all) devices...} 
1/m: local. {n devices...}
...
m-1/m: ...

MPI (tricky to setup)

Setting up the environment

This compiles mpi4py/mpi4jax for AMD architecture, which will also work on Intel CPUs. You could also compile natively for Intel by removing the module load cpuarch/amd but that is not needed.

This script will link mpi4jax to CUDA-enabled OpenMPI distribution of Jean Zay. It will also work if you replace cuda12_local with cuda12_pipif you prefer to use the Jax-preferred cuda versions, though I don’t think you should be doing that.

# Pick your environment name
ENV_NAME=jax_gpu_mpi_amd

module load cpuarch/amd # load this package only if you are running on A100
module load gcc/12.2.0 cuda/12.2.0 cudnn/9.2.0.82-cuda openmpi/4.1.5-cuda anaconda-py3
  
conda create -y --name $ENV_NAME python=3.11 
conda activate $ENV_NAME

pip install --upgrade pip

# Remove mpi4py and mpi4jax from build cache
pip cache remove mpi4py
pip cache remove mpi4jax

pip install --upgrade "jax[cuda12_local]"==0.4.30
pip install --upgrade mpi4py 
pip install --upgrade mpi4jax
pip install --upgrade netket 

Running with MPI (Stable)

The script is run through MPI, but every task will see all local GPUs. If you import netket before jax, it will automatically limit visibility to only one GPU per task.

#!/bin/sh
#SBATCH --job-name=test_mpi
#SBATCH --output=test_mpi_%j.txt
#SBATCH --hint=nomultithread  # Disable Hyperthreading

#SBATCH --ntasks=4
#SBATCH --cpus-per-task=5
#SBATCH --gres=gpu:4          # here you should insert the total number of gpus per node
#SBATCH --time=01:00:00

#SBATCH --constraint=a100
#SBATCH --account=iqu@a100
#SBATCH --qos=qos_gpu-dev

ENV_NAME=jax_gpu_mpi_amd

module purge
module load cpuarch/amd # load this package only if you are running on A100
module load gcc/12.2.0 anaconda-py3 
module load cuda/12.2.0 cudnn/9.2.0.82-cuda openmpi/4.1.5-cuda
conda activate $ENV_NAME

# This is to use fast direct gpu-to-gpu communication
export MPI4JAX_USE_CUDA_MPI=1
# This is not needed
export NETKET_MPI=1
# This automatically assigns only 1 GPU per rank (MPI cannoot use more than 1)
export NETKET_MPI_AUTODETECT_LOCAL_GPU=1

srun python yourscript.py

Jean Zay GPU Partition Info

http://www.idris.fr/eng/jean-zay/gpu/jean-zay-gpu-exec_partition_slurm-eng.html

http://www.idris.fr/eng/jean-zay/gpu/jean-zay-gpu-exec_interactif-eng.html

# SLURM OPTIONS EXPLAINED
# Possible options are
# (none)                         # for 4xV100
# --constraint v100-16g          # for 4xV100 + RAM GPU 16 Go
# --constraint v100-32g          # for 4xV100 + RAM GPU 32 Go
# --partition=gpu_p2             # for 8xV100
# --partition=gpu_p2s            # for 8xV100 + RAM CPU 384 Go
# --partition=gpu_p2l            # for 8xV100 + RAM CPU 768 Go
# --partition=gpu_p4             # for 8xA100 PCIe 40 Go
# --constraint=a100              # for 8xA100 SXM4 80 Go
# 
# Quality of services:
# --qos=qos_gpu-t3 (default)     # timelimit  20:00:00 ; max 512 GPUs/job and 1024 GPUs/person
# --qos=qos_gpu-t4 (only V100)   # timelimit 100:00:00 ; max  16 GPUs/job and  128 GPUs/person
# --qos=qos_gpu-dev              # timelimit  02:00:00 ; max  32 GPUs/job and   32 GPUs/person
#
# Account (must be specified): 
# --account=iqu@a100             # If you use A100 GPUs
# --account=iqu@v100             # If you use V100 GPUs
# --account=iqu@cpu              # If you use CPUs
#
# END SLURM OPTIONS EXPLAINED