@@ -5,7 +5,7 @@ ARG MPI=openmpi
55ARG PLATFORM=cupy
66
77# Select CUDA version
8- ARG CUDAVERSION=12.4
8+ ARG CUDAVERSION=12.8
99
1010# Pull from mambaforge and install XML and ssh
1111FROM condaforge/mambaforge AS base
@@ -17,57 +17,79 @@ FROM base AS mpi
1717ARG MPI
1818RUN mamba install -n base -c conda-forge ${MPI}
1919
20- # Pull from MPI build install core dependencies
20+ # Pull from MPI build and install core dependencies
2121FROM base AS core
22- COPY ./dependencies_core.yml ./dependencies.yml
23- RUN mamba env update -n base -f dependencies.yml
22+ RUN mamba install -n base -y -c conda-forge \
23+ python numpy scipy h5py pip
2424
2525# Pull from MPI build and install full dependencies
2626FROM mpi AS full
27- COPY ./dependencies_full.yml ./dependencies.yml
28- RUN mamba env update -n base -f dependencies.yml
27+ ARG MPI
28+ RUN mamba install -n base -y -c conda-forge \
29+ python numpy scipy matplotlib h5py \
30+ pyzmq mpi4py[build=*${MPI}*] packaging \
31+ pillow pyfftw pyyaml pip
2932
3033# Pull from MPI build and install accelerate/pycuda dependencies
3134FROM mpi AS pycuda
32- ARG CUDAVERSION
33- COPY ./ptypy/accelerate/cuda_pycuda/dependencies.yml ./dependencies.yml
34- COPY ./cufft/dependencies.yml ./dependencies_cufft.yml
35- RUN mamba env update -n base -f dependencies.yml && \
36- mamba env update -n base -f dependencies_cufft.yml && \
37- mamba install cuda-version=${CUDAVERSION}
35+ ARG CUDAVERSION MPI
36+ RUN mamba install -n base -y -c conda-forge -c nvidia \
37+ python numpy scipy matplotlib h5py pyzmq mpi4py[build=*${MPI}*] \
38+ pillow pyfftw pyyaml compilers pip \
39+ reikna pycuda cuda-nvcc cuda-cudart-dev cuda-version=${CUDAVERSION}
3840
3941# Pull from MPI build and install accelerate/cupy dependencies
4042FROM mpi AS cupy
41- ARG CUDAVERSION
42- COPY ./ptypy/accelerate/cuda_cupy/dependencies.yml ./dependencies.yml
43- COPY ./cufft/dependencies.yml ./dependencies_cufft.yml
44- RUN mamba env update -n base -f dependencies.yml && \
45- mamba env update -n base -f dependencies_cufft.yml && \
46- mamba install cuda-version=${CUDAVERSION}
43+ ARG CUDAVERSION MPI
44+ RUN mamba install -n base -y -c conda-forge \
45+ python numpy scipy matplotlib h5py pyzmq mpi4py[build=*${MPI}*] \
46+ pillow pyfftw pyyaml compilers pip \
47+ cupy cuda-version=${CUDAVERSION}
48+ RUN mamba clean -y -a
4749
4850# Pull from platform specific image and install ptypy
4951FROM ${PLATFORM} AS build
5052COPY pyproject.toml ./
5153COPY ./templates ./templates
5254COPY ./benchmark ./benchmark
53- COPY ./cufft ./cufft
5455COPY ./ptypy ./ptypy
5556RUN pip install .
5657
57- # For core/full build, no post processing needed
58+ # For core build, clean up conda env
5859FROM build AS core-post
60+ RUN mamba clean -y -a
61+
62+ # For full build, clean up conda env
5963FROM build AS full-post
64+ RUN mamba clean -y -a
6065
6166# For pycuda build, install filtered cufft
6267FROM build AS pycuda-post
68+ ARG CUDAVERSION
69+ RUN mamba install -n base -y -c conda-forge -c nvidia \
70+ python cmake>=3.8.0 pybind11 compilers \
71+ cuda-nvcc cuda-cudart-dev libcufft-dev libcufft-static cuda-version=${CUDAVERSION}
72+ COPY ./cufft ./cufft
6373RUN pip install ./cufft
74+ RUN mamba remove -n base -y \
75+ cmake pybind11 cuda-nvcc cuda-cudart-dev libcufft-dev libcufft-static
76+ RUN mamba clean -y -a
6477
65- # For pycuda build, install filtered cufft
78+ # For cupy build, install filtered cufft
6679FROM build AS cupy-post
80+ ARG CUDAVERSION
81+ RUN mamba install -n base -y -c conda-forge -c nvidia \
82+ python cmake>=3.8.0 pybind11 compilers \
83+ cuda-nvcc cuda-cudart-dev libcufft-dev libcufft-static cuda-version=${CUDAVERSION}
84+ COPY ./cufft ./cufft
6785RUN pip install ./cufft
86+ RUN mamba remove -n base -y \
87+ cmake pybind11 libcufft-dev libcufft-static
88+ RUN mamba clean -y -a
6889
6990# Platform specific runtime container
7091FROM ${PLATFORM}-post AS runtime
92+ USER ptypy-user
7193
7294# Run PtyPy run script as entrypoint
7395ENTRYPOINT ["ptypy.cli" ]
0 commit comments