diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0a9fc14..8c1b6c9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,11 +15,6 @@ repos: hooks: - id: end-of-file-fixer - id: trailing-whitespace - - repo: https://github.com/psf/black-pre-commit-mirror - rev: 25.11.0 - hooks: - - id: black - language_version: python3.11 - repo: https://github.com/asottile/pyupgrade rev: v3.21.2 hooks: @@ -32,11 +27,8 @@ repos: - id: isort name: isort (python) - repo: https://github.com/astral-sh/ruff-pre-commit - # Ruff version. rev: v0.14.7 hooks: - # Run the linter. - id: ruff args: [ --fix ] - # Run the formatter. - id: ruff-format diff --git a/deep_field_metadetect/jaxify/jax_metacal.py b/deep_field_metadetect/jaxify/jax_metacal.py new file mode 100644 index 0000000..44f9748 --- /dev/null +++ b/deep_field_metadetect/jaxify/jax_metacal.py @@ -0,0 +1,888 @@ +from functools import partial + +import galsim as galsim +import jax +import jax.numpy as jnp +import jax_galsim +import numpy as np + +from deep_field_metadetect.jaxify.jax_utils import compute_stepk +from deep_field_metadetect.jaxify.observation import ( + DFMdetObservation, + dfmd_obs_to_ngmix_obs, +) +from deep_field_metadetect.metacal import DEFAULT_SHEARS, DEFAULT_STEP + +DEFAULT_FFT_SIZE = 256 + + +def get_shear_tuple(shear, step): + """Convert shear string identifier to (g1, g2) tuple. + + Parameters + ---------- + shear : str + Shear identifier. Valid values are: + - "noshear": No shear applied + - "1p": Positive shear in g1 direction + - "1m": Negative shear in g1 direction + - "2p": Positive shear in g2 direction + - "2m": Negative shear in g2 direction + step : float + Magnitude of the shear step to apply. + Defaults to DEFAULT_STEP. + + Returns + ------- + tuple of float + Two-element tuple (g1, g2) representing the shear components. + """ + if shear == "noshear": + return (0, 0) + elif shear == "1p": + return (step, 0) + elif shear == "1m": + return (-step, 0) + elif shear == "2p": + return (0, step) + elif shear == "2m": + return (0, -step) + else: + raise RuntimeError("Shear value '%s' not recognized!" % shear) + + +@partial(jax.jit, static_argnames=["dk", "nxy_psf", "kim_size"]) +def jax_get_gauss_reconv_psf_galsim( + psf, dk, nxy_psf=53, step=DEFAULT_STEP, flux=1.0, kim_size=None +): + """Gets the target reconvolution PSF for an input PSF object. + + This is taken from galsim/tests/test_metacal.py and assumes the psf is + centered. + Note: Order of parameters differs from the corresponding non-jax versions + + Parameters + ---------- + psf : galsim.GSObject + The input point spread function (PSF) object. + dk : float + The Fourier-space pixel scale. + nxy_psf : int, optional + The size of the PSF image in pixels (default is 53). + step : float, optional + Factor by which to expand the PSF to suppress noise from high-k + fourier modes introduced due to shearing of pre-PSF images. + Defaults to deep_field_metadetect.metacal.DEFAULT_STEP. + flux : float, optional + The total flux of the output PSF (default is 1). + kim_size : int + k image size. + Defaults to None, which sets size as 4*nxy_psf + + Returns + ------- + reconv_psf : JaxGalsim object + The reconvolution PSF. + """ + small_kval = 1.0e-2 # Find the k where the given psf hits this kvalue + smaller_kval = 3.0e-3 # Target PSF will have this kvalue at the same k + + """ + The dk and kim_size are set for jitting purposes. + This will lead to a difference in reconv PSF size between GS and JGS + if similar settings are not used.""" + if kim_size is None: + kim = psf.drawKImage(nx=4 * nxy_psf, ny=4 * nxy_psf, scale=dk) + else: + kim = psf.drawKImage(nx=kim_size, ny=kim_size, scale=dk) + + karr_r = kim.real.array + # Find the smallest r where the kval < small_kval + nk = karr_r.shape[0] + kx, ky = jnp.meshgrid(jnp.arange(-nk / 2, nk / 2), jnp.arange(-nk / 2, nk / 2)) + ksq = (kx**2 + ky**2) * dk**2 + ksq_max = jnp.min(jnp.where(karr_r < small_kval * psf.flux, ksq, jnp.inf)) + + # We take our target PSF to be the (round) Gaussian that is even smaller at + # this ksq + # exp(-0.5 * ksq_max * sigma_sq) = smaller_kval + sigma_sq = -2.0 * jnp.log(smaller_kval) / ksq_max + + dilation = 1.0 + 2.0 * step + return jax_galsim.Gaussian(sigma=jnp.sqrt(sigma_sq) * dilation).withFlux(flux) + + +@partial(jax.jit, static_argnames=["dk", "nxy_psf"]) +def jax_get_gauss_reconv_psf(dfmd_obs, nxy_psf, dk, step=DEFAULT_STEP): + """Get the Gaussian reconv PSF for a DFMdetObs. + + Parameters + ---------- + dfmd_obs : DFMdetObservation + The observation containing the PSF to process. + nxy_psf : int + Size of the PSF image in pixels. + dk : float + Fourier-space pixel scale. + step : float, optional + Factor by which to expand the PSF to suppress noise from high-k + fourier modes introduced due to shearing of pre-PSF images. + Defaults to DEFAULT_STEP. + + Returns + ------- + jax_galsim.Gaussian + The Gaussian reconvolution PSF object. + """ + psf = get_jax_galsim_object_from_dfmd_obs_nopix(dfmd_obs.psf, kind="image") + return jax_get_gauss_reconv_psf_galsim(psf, nxy_psf=nxy_psf, dk=dk, step=step) + + +@partial(jax.jit, static_argnames=["nxy_psf", "scale"]) +def jax_get_max_gauss_reconv_psf_galsim( + psf_w, psf_d, nxy_psf, scale=0.2, step=DEFAULT_STEP +): + """Get the larger of two Gaussian reconvolution PSFs for two galsim objects.""" + dk = compute_stepk(pixel_scale=scale, image_size=nxy_psf) + mc_psf_w = jax_get_gauss_reconv_psf_galsim(psf_w, dk, nxy_psf, step=step) + mc_psf_d = jax_get_gauss_reconv_psf_galsim(psf_d, dk, nxy_psf, step=step) + + return jax.lax.cond( + mc_psf_w.fwhm > mc_psf_d.fwhm, lambda: mc_psf_w, lambda: mc_psf_d + ) + + +@partial(jax.jit, static_argnames=["scale", "nxy_psf"]) +def jax_get_max_gauss_reconv_psf(obs_w, obs_d, nxy_psf, scale=0.2, step=DEFAULT_STEP): + """Get the larger of two reconv PSFs for two DFMdetObservations.""" + psf_w = get_jax_galsim_object_from_dfmd_obs_nopix(obs_w.psf, kind="image") + psf_d = get_jax_galsim_object_from_dfmd_obs_nopix(obs_d.psf, kind="image") + return jax_get_max_gauss_reconv_psf_galsim( + psf_w, psf_d, nxy_psf, scale=scale, step=step + ) + + +@partial(jax.jit, static_argnames=["nxy_psf", "fft_size"]) +def _jax_render_psf_and_build_obs( + image, dfmd_obs, reconv_psf, nxy_psf, weight_fac=1, fft_size=DEFAULT_FFT_SIZE +): + reconv_psf = reconv_psf.withGSParams( + minimum_fft_size=fft_size, + maximum_fft_size=fft_size, + ) + + pim = reconv_psf.drawImage( + nx=nxy_psf, + ny=nxy_psf, + wcs=dfmd_obs.psf.wcs.local(), + offset=jax_galsim.PositionD( + x=dfmd_obs.psf.wcs.origin.x - (nxy_psf + 1) / 2, + y=dfmd_obs.psf.wcs.origin.y - (nxy_psf + 1) / 2, + ), + ).array + + obs_psf = dfmd_obs.psf.replace(image=pim) + return dfmd_obs.replace( + image=jnp.array(image), psf=obs_psf, weight=dfmd_obs.weight * weight_fac + ) + + +@partial(jax.jit, static_argnames=["dims", "fft_size"]) +def _jax_metacal_op_g1g2_impl( + *, wcs, image, noise, psf_inv, dims, reconv_psf, g1, g2, fft_size=DEFAULT_FFT_SIZE +): + """Run metacal on an dfmd observation. + + Note that the noise image should already be rotated by 90 degrees here. + """ + + ims = jax_galsim.Convolve( + [ + jax_galsim.Convolve([image, psf_inv]).shear(g1=g1, g2=g2), + reconv_psf, + ] + ) + + ns = jax_galsim.Convolve( + [ + jax_galsim.Convolve([noise, psf_inv]).shear(g1=g1, g2=g2), + reconv_psf, + ] + ) + + ims = ims.withGSParams( + minimum_fft_size=fft_size, + maximum_fft_size=fft_size, + ) + ims = ims.drawImage(nx=dims[1], ny=dims[0], wcs=wcs).array + + ns = ns.withGSParams( + minimum_fft_size=fft_size, + maximum_fft_size=fft_size, + ) + ns = jnp.rot90( + ns.drawImage(nx=dims[1], ny=dims[0], wcs=wcs).array, + k=-1, + ) + return ims + ns + + +def jax_metacal_op_g1g2( + dfmd_obs, reconv_psf, g1, g2, nxy_psf, fft_size=DEFAULT_FFT_SIZE +): + """Run metacal on an dfmd observation with specified shear. + + Parameters + ---------- + dfmd_obs : DFMdetObservation + The observation to process. + reconv_psf : jax_galsim.GSObject + The reconvolution PSF object. + g1 : float + g1 shear components to apply. + g2 : float + g2 shear components to apply. + nxy_psf : int + Size of the PSF image in pixels. + fft_size : int, optional + FFT size for convolution operations (default is DEFAULT_FFT_SIZE). + + Returns + ------- + DFMdetObservation + New observation with metacal applied. + """ + mcal_image = _jax_metacal_op_g1g2_impl( + wcs=dfmd_obs.wcs.local(), + image=get_jax_galsim_object_from_dfmd_obs(dfmd_obs, kind="image"), + # we rotate by 90 degrees on the way in and then _metacal_op_g1g2_impl + # rotates back after deconv and shearing + noise=get_jax_galsim_object_from_dfmd_obs(dfmd_obs, kind="noise", rot90=1), + psf_inv=jax_galsim.Deconvolve( + get_jax_galsim_object_from_dfmd_obs(dfmd_obs.psf, kind="image") + ), + dims=dfmd_obs.image.shape, + reconv_psf=reconv_psf, + g1=g1, + g2=g2, + fft_size=fft_size, + ) + + return _jax_render_psf_and_build_obs( + mcal_image, + dfmd_obs, + reconv_psf, + nxy_psf=nxy_psf, + weight_fac=0.5, + fft_size=fft_size, + ) + + +@partial(jax.jit, static_argnames=["nxy_psf", "scale", "shears", "fft_size"]) +def jax_metacal_op_shears( + dfmd_obs, + nxy_psf=53, + reconv_psf=jax_galsim.Gaussian(sigma=0.0).withFlux(1.0), + shears=DEFAULT_SHEARS, + step=DEFAULT_STEP, + scale=0.2, + fft_size=DEFAULT_FFT_SIZE, +): + """Run metacal on an dfmd observation with multiple shear values. + + Parameters + ---------- + dfmd_obs : DFMdetObservation + The observation to process. + nxy_psf : int, optional + Size of the PSF image in pixels (default is 53). + reconv_psf : jax_galsim.GSObject, optional + The reconvolution PSF. + Default: a proper reconvolution PSF will be computed automatically. + using jax_get_gauss_reconv_psf function. + shears : tuple of str, optional + Shear identifiers to process (default is DEFAULT_SHEARS). + step : float, optional + Shear step magnitude (default is DEFAULT_STEP). + scale : float, optional + Pixel scale in arcseconds (default is 0.2). + fft_size : int, optional + FFT size for convolution operations (default is DEFAULT_FFT_SIZE). + + Returns + ------- + dict + Dictionary mapping shear identifiers to processed DFMdetObservation objects. + """ + dk = compute_stepk(pixel_scale=scale, image_size=nxy_psf) + + def compute_reconv(): + return jax_get_gauss_reconv_psf(dfmd_obs, dk=dk, nxy_psf=nxy_psf, step=step) + + def use_provided_reconv(): + return reconv_psf + + reconv_psf = jax.lax.cond( + reconv_psf.sigma == 0, compute_reconv, use_provided_reconv + ) + wcs = dfmd_obs.wcs.local() + image = get_jax_galsim_object_from_dfmd_obs(dfmd_obs, kind="image") + # we rotate by 90 degrees on the way in and then _metacal_op_g1g2_impl + # rotates back after deconv and shearing + noise = get_jax_galsim_object_from_dfmd_obs(dfmd_obs, kind="noise", rot90=1) + psf = get_jax_galsim_object_from_dfmd_obs(dfmd_obs.psf, kind="image") + psf_inv = jax_galsim.Deconvolve(psf) + + shear_tuples = jnp.array([get_shear_tuple(shear, step) for shear in shears]) + g1_vals = shear_tuples[:, 0] + g2_vals = shear_tuples[:, 1] + + # Vectorized metacal operation across all shears + def single_shear_op(g1, g2): + mcal_image = _jax_metacal_op_g1g2_impl( + wcs=wcs, + image=image, + noise=noise, + psf_inv=psf_inv, + dims=dfmd_obs.image.shape, + reconv_psf=reconv_psf, + g1=g1, + g2=g2, + fft_size=fft_size, + ) + return _jax_render_psf_and_build_obs( + mcal_image, + dfmd_obs, + reconv_psf, + nxy_psf=nxy_psf, + weight_fac=0.5, + fft_size=fft_size, + ) + + # Use vmap to parallelize across shears + vectorized_shear_op = jax.vmap(single_shear_op) + mcal_obs_list = vectorized_shear_op(g1_vals, g2_vals) + + # Convert back to dictionary format + mcal_res = {} + for i, shear in enumerate(shears): + mcal_res[shear] = jax.tree.map(lambda x: x[i], mcal_obs_list) + + return mcal_res + + +@partial( + jax.jit, + static_argnames=[ + "nxy", + "nxy_psf", + "return_k_info", + "force_stepk_field", + "force_maxk_field", + "force_stepk_psf", + "force_maxk_psf", + "fft_size", + ], +) +def jax_match_psf( + dfmd_obs, + reconv_psf, + nxy, + nxy_psf, + return_k_info=False, + force_stepk_field=0.0, + force_maxk_field=0.0, + force_stepk_psf=0.0, + force_maxk_psf=0.0, + fft_size=DEFAULT_FFT_SIZE, +): + """Match the PSF on an dfmd observation to a new PSF.""" + wcs = dfmd_obs.wcs.local() + image = get_jax_galsim_object_from_dfmd_obs( + dfmd_obs, + kind="image", + force_stepk=force_stepk_field, + force_maxk=force_maxk_field, + ) + psf = get_jax_galsim_object_from_dfmd_obs( + dfmd_obs.psf, + kind="image", + force_stepk=force_stepk_psf, + force_maxk=force_maxk_psf, + ) + + ims = jax_galsim.Convolve( + [image, jax_galsim.Deconvolve(psf), reconv_psf], + gsparams=jax_galsim.GSParams( + minimum_fft_size=fft_size, + maximum_fft_size=fft_size, + ), + ) + + ims = ims.withGSParams( + minimum_fft_size=fft_size, + maximum_fft_size=fft_size, + ) + ims = ims.drawImage(nx=nxy, ny=nxy, wcs=wcs).array + + if return_k_info: + return _jax_render_psf_and_build_obs( + ims, dfmd_obs, reconv_psf, nxy_psf, weight_fac=1 + ), (image.stepk, image.maxk, psf.stepk, psf.maxk) + else: + return _jax_render_psf_and_build_obs( + ims, dfmd_obs, reconv_psf, nxy_psf, weight_fac=1 + ), (np.nan, np.nan, np.nan, np.nan) + + +def _extract_attr(obs, attr, dtype=np.float32): + if getattr(obs, "has_" + attr)(): + return getattr(obs, attr) + else: + return np.zeros_like(obs.image, dtype=dtype) + + +def jax_add_dfmd_psf(psf1, psf2): + """Add two DFMdetPSF objects""" + from deep_field_metadetect.jaxify.observation import DFMdetPSF + + added_image = psf1.image + psf2.image + + new_wgt = jnp.where( + (psf1.weight > 0) & (psf2.weight > 0), + 1 / (1 / psf1.weight + 1 / psf2.weight), + 0, + ) + + return DFMdetPSF( + image=added_image, + weight=new_wgt, + wcs=psf1.wcs, # Assume same WCS + meta={**psf1.meta, **psf2.meta}, + store_pixels=psf1.store_pixels, + ignore_zero_weight=psf1.ignore_zero_weight, + ) + + +def jax_add_dfmd_obs( + dfmd_obs1, dfmd_obs2, ignore_psf=False, skip_mfrac_for_second=False +) -> DFMdetObservation: + """Add two DFMD observations. + + Parameters + ---------- + dfmd_obs1 : DFMdetObservation + The first observation to add. + dfmd_obs2 : DFMdetObservation + The second observation to add. + ignore_psf : bool, optional + If True, the output PSF will be set to zero instead of combining + the input PSFs. Default is False. + skip_mfrac_for_second : bool, optional + If True, only use the mfrac from the first observation instead of + averaging both. Default is False. + + Returns + ------- + DFMdetObservation + A new observation containing the combined data from both inputs. + The image is the sum of input images, weights are combined using + inverse variance weighting, and masks are combined using bitwise OR, + and noise is summed. + """ + + if repr(dfmd_obs1.wcs) != repr(dfmd_obs2.wcs): + # This if statement will not perform any action at runtime + raise RuntimeError( + "AffineTransforms must be equal to add dfmd observations! %s != %s" + % (repr(dfmd_obs1.wcs), repr(dfmd_obs2.wcs)), + ) + + if dfmd_obs1.image.shape != dfmd_obs2.image.shape: + raise RuntimeError( + "Image shapes must be equal to add dfmd observations! %s != %s" + % ( + dfmd_obs1.image.shape, + dfmd_obs2.image.shape, + ), + ) + + # Handle PSF addition using dedicated function + def add_psfs(): + return jax_add_dfmd_psf(dfmd_obs1.psf, dfmd_obs2.psf) + + def no_psf(): + from deep_field_metadetect.jaxify.observation import DFMdetPSF + + return DFMdetPSF( + image=jnp.zeros_like(dfmd_obs1.psf.image, dtype=jnp.float32), + wcs=dfmd_obs1.psf.wcs, + meta=dfmd_obs1.psf.meta, + store_pixels=dfmd_obs1.psf.store_pixels, + ignore_zero_weight=dfmd_obs1.psf.ignore_zero_weight, + ) + + # Add PSFs if both exist and we're not ignoring PSF + has_psf1 = dfmd_obs1.has_psf() + has_psf2 = dfmd_obs2.has_psf() + should_add_psf = (not ignore_psf) & has_psf1 & has_psf2 + + new_psf = jax.lax.cond(should_add_psf, add_psfs, no_psf) + + new_wgt = jnp.where( + (dfmd_obs1.weight > 0) & (dfmd_obs2.weight > 0), + 1 / (1 / dfmd_obs1.weight + 1 / dfmd_obs2.weight), + 0, + ) + + new_meta_data = {} + + # Handle bmask, ormask, noise, and mfrac + # Unlike the non-jax version we do not need to test conditions here + # because now the default values are zeros instead of None + new_bmask = dfmd_obs1.bmask | dfmd_obs2.bmask + new_ormask = dfmd_obs1.ormask | dfmd_obs2.ormask + new_noise = dfmd_obs1.noise + dfmd_obs2.noise + + def mfrac_skip_second(): + return dfmd_obs1.mfrac + + def mfrac_use_both(): + return (dfmd_obs1.mfrac + dfmd_obs2.mfrac) / 2 + + new_mfrac = jax.lax.cond(skip_mfrac_for_second, mfrac_skip_second, mfrac_use_both) + + new_meta_data.update(dfmd_obs1.meta) + new_meta_data.update(dfmd_obs2.meta) + + obs = DFMdetObservation( + image=dfmd_obs1.image + dfmd_obs2.image, + weight=new_wgt, + bmask=new_bmask, + ormask=new_ormask, + noise=new_noise, + wcs=dfmd_obs1.wcs, + psf=new_psf, + meta=new_meta_data, + mfrac=new_mfrac, + store_pixels=getattr(dfmd_obs1, "store_pixels", True), + ignore_zero_weight=getattr(dfmd_obs1, "ignore_zero_weight", True), + ) + + return obs + + +def get_jax_galsim_object_from_dfmd_obs( + dfmd_obs, + kind="image", + rot90=0, + force_stepk=0.0, + force_maxk=0.0, +): + """Make an interpolated image from an dfmd obs.""" + return jax_galsim.InterpolatedImage( + jax_galsim.ImageD( + jnp.rot90(getattr(dfmd_obs, kind).copy(), k=rot90), + wcs=dfmd_obs.wcs.local(), + ), + x_interpolant="lanczos15", + wcs=dfmd_obs.wcs.local(), + _force_stepk=force_stepk, + _force_maxk=force_maxk, + ) + + +def get_jax_galsim_object_from_dfmd_obs_nopix(dfmd_obs, kind="image"): + """Make an interpolated image from an DFMdet obs w/o a pixel.""" + wcs = dfmd_obs.wcs.local() + return jax_galsim.Convolve( + [ + get_jax_galsim_object_from_dfmd_obs(dfmd_obs, kind=kind), + jax_galsim.Deconvolve(wcs.toWorld(jax_galsim.Pixel(scale=1))), + ] + ) + + +@partial( + jax.jit, + static_argnames=[ + "nxy", + "nxy_psf", + "shears", + "skip_obs_wide_corrections", + "skip_obs_deep_corrections", + "return_noshear_deep", + "scale", + "return_k_info", + "force_stepk_field", + "force_maxk_field", + "force_stepk_psf", + "force_maxk_psf", + "fft_size", + ], +) +def _jax_helper_metacal_wide_and_deep_psf_matched( + obs_wide, + obs_deep, + obs_deep_noise, + reconv_psf, + nxy, + nxy_psf, + shears=None, + step=DEFAULT_STEP, + skip_obs_wide_corrections=False, + skip_obs_deep_corrections=False, + return_noshear_deep=False, + scale=0.2, + return_k_info=False, + force_stepk_field=0.0, + force_maxk_field=0.0, + force_stepk_psf=0.0, + force_maxk_psf=0.0, + fft_size=DEFAULT_FFT_SIZE, +): + """Do metacalibration for a combination of wide+deep datasets. + + Parameters + ---------- + obs_wide : DFMdetObservation + The wide-field observation. + obs_deep : DFMdetObservation + The deep-field observation. + obs_deep_noise : DFMdetObservation + The deep-field noise observation. + reconv_psf : JaxGalsim object + The reconvolution PSF. + shears : tuple of strings, optional + The shears to use for the metacalibration, by default DEFAULT_SHEARS + if set to None. + step : float, optional + The step size for the metacalibration, by default DEFAULT_STEP. + skip_obs_wide_corrections : bool, optional + Skip the observation corrections for the wide-field observations, + by default False. + skip_obs_deep_corrections : bool, optional + Skip the observation corrections for the deep-field observations, + by default False. + return_noshear_deep : bool, optional + adds deep field no shear results to the output. Default - False. + This is a static variable so changing it would trigger recompilation. + scale : float, optional + pixel scale. default to 0.2. + Note this parameter is not present in non-jax version. + This is later used for compute_stepk to compute the pixel scale in + fourier space and this is a static variable so changing it would + trigger recompilation. + return_k_info : bool, optional + return _force_stepk and _force_maxk values in the following order + _force_stepk_field, _force_maxk_field, _force_stepk_psf, _force_maxk_psf. + Used mainly for testing. + force_stepk_field : float, optional + Force stepk for drawing field images. + Defaults to 0.0, which lets JaxGalsim choose the value. + Used mainly for testing. + force_maxk_field: float, optional + Force maxk for drawing field images. + Defaults to 0.0, which lets Galsim choose the value. + Used mainly for testing. + force_stepk_psf: float, optional + Force stepk for drawing PSF images. + Defaults to 0.0, which lets Galsim choose the value. + Used mainly for testing. + force_maxk_psf: float, optional + Force stepk for drawing PSF images + Defaults to 0.0, which lets Galsim choose the value. + Used mainly for testing. + fft_size: int, optional + To fix max and min values of FFT size. + Defaults to None which lets Galsim determine the values. + Used mainly to test against JaxGalsim. + + Returns + ------- + mcal_res : dict + Output from metacal_op_shears for shear cases listed by the shears input, + optionally no shear deep field case if return_noshear_deep is True + and kinfo for debugging if return_k_info is set to True. + kinfo is returned in the following order: + _force_stepk_field, _force_maxk_field, _force_stepk_psf, _force_maxk_psf. + """ + # make the wide obs + + mcal_obs_wide, kinfo = jax_match_psf( + obs_wide, + reconv_psf, + nxy, + nxy_psf, + return_k_info=return_k_info, + force_stepk_field=force_stepk_field, + force_maxk_field=force_maxk_field, + force_stepk_psf=force_stepk_psf, + force_maxk_psf=force_maxk_psf, + fft_size=fft_size, + ) + if not skip_obs_wide_corrections: + mcal_obs_wide = jax_add_dfmd_obs( + mcal_obs_wide, + jax_metacal_op_g1g2(obs_deep_noise, reconv_psf, 0, 0, nxy_psf=nxy_psf), + skip_mfrac_for_second=True, + ) + + # get PSF matched noise + obs_wide_noise = obs_wide.replace(image=obs_wide.noise) + wide_noise_corr, _ = jax_match_psf( + obs_wide_noise, + reconv_psf, + nxy, + nxy_psf, + force_stepk_field=force_stepk_field, + force_maxk_field=force_maxk_field, + force_stepk_psf=force_stepk_psf, + force_maxk_psf=force_maxk_psf, + fft_size=fft_size, + ) + + # now run mcal on deep + mcal_res = jax_metacal_op_shears( + obs_deep, + reconv_psf=reconv_psf, + shears=shears, + step=step, + nxy_psf=nxy_psf, + scale=scale, + fft_size=fft_size, + ) + + # now add in noise corr to make it match the wide noise + if not skip_obs_deep_corrections: + for k in shears: + mcal_res[k] = jax_add_dfmd_obs( + mcal_res[k], + wide_noise_corr, + skip_mfrac_for_second=True, + ) + + # we report the wide obs as noshear for later measurements + noshear_res = mcal_res.pop("noshear") + mcal_res["noshear"] = mcal_obs_wide + if return_noshear_deep: + mcal_res["noshear_deep"] = noshear_res + + if return_k_info: + mcal_res["kinfo"] = kinfo + + return mcal_res + + +def jax_metacal_wide_and_deep_psf_matched( + obs_wide, + obs_deep, + obs_deep_noise, + nxy, + nxy_psf, + shears=DEFAULT_SHEARS, + step=DEFAULT_STEP, + skip_obs_wide_corrections=False, + skip_obs_deep_corrections=False, + return_noshear_deep=False, + scale=0.2, + return_k_info=False, + force_stepk_field=0.0, + force_maxk_field=0.0, + force_stepk_psf=0.0, + force_maxk_psf=0.0, + fft_size=DEFAULT_FFT_SIZE, +): + """Do metacalibration for a combination of wide+deep datasets. + + Parameters + ---------- + obs_wide : DFMdetObservation + The wide-field observation. + obs_deep : DFMdetObservation + The deep-field observation. + obs_deep_noise : DFMdetObservation + The deep-field noise observation. + shears : tuple of strings, optional + The shears to use for the metacalibration, by default DEFAULT_SHEARS + if set to None. + step : float, optional + The step size for the metacalibration, by default DEFAULT_STEP. + skip_obs_wide_corrections : bool, optional + Skip the observation corrections for the wide-field observations, + by default False. + skip_obs_deep_corrections : bool, optional + Skip the observation corrections for the deep-field observations, + by default False. + return_noshear_deep : bool, optional + adds deep field no shear results to the output. Default - False. + This is a static variable so changing it would trigger recompilation. + scale : float, optional + pixel scale. default to 0.2. + Note this parameter is not present in non-jax version. + This is later used for compute_stepk to compute the pixel scale in + fourier space and this is a static variable so changing it would + trigger recompilation. + return_k_info : bool, optional + return _force_stepk and _force_maxk values in the following order + _force_stepk_field, _force_maxk_field, _force_stepk_psf, _force_maxk_psf. + Used mainly for testing. + force_stepk_field : float, optional + Force stepk for drawing field images. + Defaults to 0.0, which lets JaxGalsim choose the value. + Used mainly for testing. + force_maxk_field: float, optional + Force maxk for drawing field images. + Defaults to 0.0, which lets Galsim choose the value. + Used mainly for testing. + force_stepk_psf: float, optional + Force stepk for drawing PSF images. + Defaults to 0.0, which lets Galsim choose the value. + Used mainly for testing. + force_maxk_psf: float, optional + Force stepk for drawing PSF images + Defaults to 0.0, which lets Galsim choose the value. + Used mainly for testing. + fft_size: int, optional + To fix max and min values of FFT size. + Defaults to None which lets Galsim determine the values. + Used mainly to test against JaxGalsim. + + Returns + ------- + mcal_res : dict + Output from metacal_op_shears for shear cases listed by the shears input, + optionally no shear deep field case if return_noshear_deep is True + and kinfo for debugging if return_k_info is set to True. + kinfo is returned in the following order: + _force_stepk_field, _force_maxk_field, _force_stepk_psf, _force_maxk_psf. + """ + # first get the biggest reconv PSF of the two + reconv_psf = jax_get_max_gauss_reconv_psf(obs_wide, obs_deep, nxy_psf, scale) + + mcal_res = _jax_helper_metacal_wide_and_deep_psf_matched( + obs_wide=obs_wide, + obs_deep=obs_deep, + obs_deep_noise=obs_deep_noise, + reconv_psf=reconv_psf, + nxy=nxy, + nxy_psf=nxy_psf, + shears=shears, + step=step, + skip_obs_wide_corrections=skip_obs_wide_corrections, + skip_obs_deep_corrections=skip_obs_deep_corrections, + return_noshear_deep=return_noshear_deep, + scale=scale, + return_k_info=return_k_info, + force_stepk_field=force_stepk_field, + force_maxk_field=force_maxk_field, + force_stepk_psf=force_stepk_psf, + force_maxk_psf=force_maxk_psf, + fft_size=fft_size, + ) + + for k in shears: + mcal_res[k] = dfmd_obs_to_ngmix_obs(mcal_res[k]) + mcal_res[k].psf.galsim_obj = reconv_psf + + return mcal_res diff --git a/deep_field_metadetect/jaxify/jax_metadetect.py b/deep_field_metadetect/jaxify/jax_metadetect.py new file mode 100644 index 0000000..d75d6f4 --- /dev/null +++ b/deep_field_metadetect/jaxify/jax_metadetect.py @@ -0,0 +1,176 @@ +import ngmix +import numpy as np + +from deep_field_metadetect.detect import ( + generate_mbobs_for_detections, + run_detection_sep, +) +from deep_field_metadetect.jaxify.jax_metacal import ( + DEFAULT_FFT_SIZE, + DEFAULT_SHEARS, + DEFAULT_STEP, + jax_metacal_wide_and_deep_psf_matched, +) +from deep_field_metadetect.mfrac import compute_mfrac_interp_image +from deep_field_metadetect.utils import fit_gauss_mom_obs, fit_gauss_mom_obs_and_psf + + +def jax_single_band_deep_field_metadetect( + obs_wide, + obs_deep, + obs_deep_noise, + nxy, + nxy_psf, + step=DEFAULT_STEP, + shears=None, + skip_obs_wide_corrections=False, + skip_obs_deep_corrections=False, + nodet_flags=0, + scale=0.2, + return_k_info=False, + force_stepk_field=0.0, + force_maxk_field=0.0, + force_stepk_psf=0.0, + force_maxk_psf=0.0, + fft_size=DEFAULT_FFT_SIZE, +): + """Run deep-field metadetection for a simple scenario of a single band + with a single image per band using only post-PSF Gaussian weighted moments. + + Parameters + ---------- + obs_wide : DFMdetObservation + The wide-field observation. + obs_deep : DFMdetObservation + The deep-field observation. + obs_deep_noise : DFMdetObservation + The deep-field noise observation. + nxy: int + Image size + nxy_psf: int + PSF size + step : float, optional + The step size for the metacalibration, by default DEFAULT_STEP. + shears : list, optional + The shears to use for the metacalibration, by default DEFAULT_SHEARS + if set to None. + skip_obs_wide_corrections : bool, optional + Skip the observation corrections for the wide-field observations, + by default False. + skip_obs_deep_corrections : bool, optional + Skip the observation corrections for the deep-field observations, + by default False. + nodet_flags : int, optional + The bmask flags marking area in the image to skip, by default 0. + scale: float + pixel scale + scale : float, optional + pixel scale. default to 0.2. + Note this parameter is not present in non-jax version. + This is later used for compute_stepk to compute the pixel scale in + fourier space and this is a static variable so changing it would + trigger recompilation. + return_k_info : bool, optional + return _force stepk and maxk values in the following order + _force_stepk_field, _force_maxk_field, _force_stepk_psf, _force_maxk_psf. + Used mainly for testing. + force_stepk_field : float, optional + Force stepk for drawing field images. + Defaults to 0.0, which lets JaxGalsim choose the value. + Used mainly for testing. + force_maxk_field: float, optional + Force maxk for drawing field images. + Defaults to 0.0, which lets Galsim choose the value. + Used mainly for testing. + force_stepk_psf: float, optional + Force stepk for drawing PSF images. + Defaults to 0.0, which lets Galsim choose the value. + Used mainly for testing. + force_maxk_psf: float, optional + Force stepk for drawing PSF images + Defaults to 0.0, which lets Galsim choose the value. + Used mainly for testing. + fft_size: int, optional + To fix max and min values of FFT size. + Defaults to None which lets Galsim determine the values. + Used mainly to test against JaxGalsim. + + Returns + ------- + dfmdet_res : numpy.ndarray + The deep-field metadetection results as a structured array containing + detection and measurement results for all shears. + + Note: If return_k_info is set to True for debugging, + this function returns a tuple containing (dfmdet_res, kinfo). kinfo being: + (_force_stepk_field, _force_maxk_field, _force_stepk_psf, _force_maxk_psf) + """ + if shears is None: + shears = DEFAULT_SHEARS + + mcal_res = jax_metacal_wide_and_deep_psf_matched( + obs_wide=obs_wide, + obs_deep=obs_deep, + obs_deep_noise=obs_deep_noise, + nxy=nxy, + nxy_psf=nxy_psf, + step=step, + shears=shears, + skip_obs_wide_corrections=skip_obs_wide_corrections, + skip_obs_deep_corrections=skip_obs_deep_corrections, + scale=scale, + return_k_info=return_k_info, + force_stepk_field=force_stepk_field, + force_maxk_field=force_maxk_field, + force_stepk_psf=force_stepk_psf, + force_maxk_psf=force_maxk_psf, + fft_size=fft_size, + ) # This returns ngmix Obs for now + + psf_res = fit_gauss_mom_obs(mcal_res["noshear"].psf) + dfmdet_res = [] + for shear in shears: + obs = mcal_res[shear] + detres = run_detection_sep(obs, nodet_flags=nodet_flags) + + ixc = (detres["catalog"]["x"] + 0.5).astype(int) + iyc = (detres["catalog"]["y"] + 0.5).astype(int) + bmask_flags = obs.bmask[iyc, ixc] + + mfrac_vals = np.zeros_like(bmask_flags, dtype="f4") + if np.any(obs.mfrac > 0): + _interp_mfrac = compute_mfrac_interp_image( + obs.mfrac, + obs.jacobian.get_galsim_wcs(), + ) + for i, (x, y) in enumerate( + zip(detres["catalog"]["x"], detres["catalog"]["y"]) + ): + mfrac_vals[i] = _interp_mfrac.xValue(x, y) + + for ind, (obj, mbobs) in enumerate( + generate_mbobs_for_detections( + ngmix.observation.get_mb_obs(obs), + xs=detres["catalog"]["x"], + ys=detres["catalog"]["y"], + ) + ): + fres = fit_gauss_mom_obs_and_psf(mbobs[0][0], psf_res=psf_res) + dfmdet_res.append( + (ind + 1, obj["x"], obj["y"], shear, bmask_flags[ind], mfrac_vals[ind]) + + tuple(fres[0]) + ) + + total_dtype = [ + ("id", "i8"), + ("x", "f8"), + ("y", "f8"), + ("mdet_step", "U7"), + ("bmask_flags", "i4"), + ("mfrac", "f4"), + ] + fres.dtype.descr + + if return_k_info: + return (np.array(dfmdet_res, dtype=total_dtype), mcal_res.get("kinfo")) + + return np.array(dfmdet_res, dtype=total_dtype) diff --git a/deep_field_metadetect/jaxify/jax_utils.py b/deep_field_metadetect/jaxify/jax_utils.py new file mode 100644 index 0000000..c455ad6 --- /dev/null +++ b/deep_field_metadetect/jaxify/jax_utils.py @@ -0,0 +1,21 @@ +import jax.numpy as jnp + + +def compute_stepk(pixel_scale, image_size): + """Compute psf fourier scale based on pixel scale and PSF image dimension. + The size is obtained from galsim.GSObject.getGoodImageSize. + The factor 1/4 from deep_field_metadetect.metacal.get_gauss_reconv_psf_galsim. + + Parameters: + ----------- + pixel_scale : float + The scale of a single pixel in the image. + image_size : int + The dimension of the PSF image (typically a square size). + + Returns: + -------- + float + The computed stepk value, which represents the Fourier-space sampling frequency. + """ + return 2 * jnp.pi / (image_size * pixel_scale) / 4 diff --git a/deep_field_metadetect/jaxify/observation.py b/deep_field_metadetect/jaxify/observation.py new file mode 100644 index 0000000..e4e783e --- /dev/null +++ b/deep_field_metadetect/jaxify/observation.py @@ -0,0 +1,352 @@ +import jax +import jax.numpy as jnp +import jax_galsim +import ngmix +import numpy as np +from ngmix.observation import Observation + + +@jax.tree_util.register_pytree_node_class +class DFMdetPSF: + def __init__( + self, + image, + weight=None, + wcs=None, + meta=None, + store_pixels=True, + ignore_zero_weight=True, + ): + if meta is None: + meta = {} + + if wcs is None: + wcs = jax_galsim.wcs.AffineTransform( + dudx=1.0, + dudy=0.0, + dvdx=0.0, + dvdy=1.0, + origin=jax_galsim.PositionD( + y=(image.shape[0] + 1) / 2, + x=(image.shape[1] + 1) / 2, + ), + ) + + self.image = image + if weight is None: + weight = jnp.ones_like(self.image, dtype=jnp.float32) + self.weight = weight + self.wcs = wcs + self.meta = meta + self.store_pixels = store_pixels + self.ignore_zero_weight = ignore_zero_weight + + def tree_flatten(self): + children = (self.image, self.weight) + aux_data = (self.wcs, self.meta, self.store_pixels, self.ignore_zero_weight) + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls( + image=children[0], + weight=children[1], + wcs=aux_data[0], + meta=aux_data[1], + store_pixels=aux_data[2], + ignore_zero_weight=aux_data[3], + ) + + def has_bmask(self): + return False + + def has_mfrac(self): + return False + + def has_noise(self): + return False + + def has_ormask(self): + return False + + def has_psf(self): + return False + + @jax.jit + def replace(self, **kwargs): + """Create a new instance similar to NamedTuple._replace""" + new_kwargs = { + "image": self.image, + "wcs": self.wcs, + "meta": self.meta, + "store_pixels": self.store_pixels, + "ignore_zero_weight": self.ignore_zero_weight, + } + new_kwargs.update(kwargs) + return DFMdetPSF(**new_kwargs) + + +@jax.tree_util.register_pytree_node_class +class DFMdetObservation: + def __init__( + self, + image, + weight=None, + bmask=None, + ormask=None, + noise=None, + wcs=None, + psf=None, + mfrac=None, + meta=None, + store_pixels=True, + ignore_zero_weight=True, + ): + image = image + if weight is None: + weight = jnp.ones_like(image, dtype=jnp.float32) + if bmask is None: + bmask = jnp.zeros_like(image, dtype=jnp.int32) + if ormask is None: + ormask = jnp.zeros_like(image, dtype=jnp.int32) + if noise is None: + noise = jnp.zeros_like(image, dtype=jnp.float32) + if mfrac is None: + mfrac = jnp.zeros_like(image, dtype=jnp.float32) + if meta is None: + meta = {} + + if psf is None: + psf = DFMdetPSF(image=jnp.zeros_like(image, dtype=jnp.float32)) + + if wcs is None: + wcs = jax_galsim.wcs.AffineTransform( + dudx=1.0, + dudy=0.0, + dvdx=0.0, + dvdy=1.0, + origin=jax_galsim.PositionD( + y=(image.shape[0] + 1) / 2, + x=(image.shape[1] + 1) / 2, + ), + ) + self.image = image + self.weight = weight + self.bmask = bmask + self.ormask = ormask + self.noise = noise + self.wcs = wcs + self.psf = psf + self.mfrac = mfrac + self.meta = meta + self.store_pixels = store_pixels + self.ignore_zero_weight = ignore_zero_weight + + def tree_flatten(self): + children = ( + self.image, + self.weight, + self.bmask, + self.ormask, + self.noise, + self.wcs, + self.psf, + self.mfrac, + ) + + aux_data = (self.meta, self.store_pixels, self.ignore_zero_weight) + + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + # Reconstruct the object from flattened data + return cls( + image=children[0], + weight=children[1], + bmask=children[2], + ormask=children[3], + noise=children[4], + wcs=children[5], + psf=children[6], + mfrac=children[7], + meta=aux_data[0], + store_pixels=aux_data[1], + ignore_zero_weight=aux_data[2], + ) + + def has_bmask(self): + return True + + def has_mfrac(self): + return True + + def has_noise(self): + return True + + def has_ormask(self): + return True + + def has_psf(self): + return jnp.any(self.psf.image != 0) + + @jax.jit + def replace(self, **kwargs): + """Create a new instance similar to NamedTuple._replace""" + new_kwargs = { + "image": self.image, + "weight": self.weight, + "bmask": self.bmask, + "ormask": self.ormask, + "noise": self.noise, + "wcs": self.wcs, + "psf": self.psf, + "mfrac": self.mfrac, + "meta": self.meta, + "store_pixels": self.store_pixels, + "ignore_zero_weight": self.ignore_zero_weight, + } + new_kwargs.update(kwargs) + return DFMdetObservation(**new_kwargs) + + +def ngmix_obs_to_dfmd_obs(obs: ngmix.observation.Observation) -> DFMdetObservation: + """Convert an ngmix Observation to a DFMdetObservation. + Note that unlike the non-jax version, PSF is no longer an instance of + observation and default values of bmask, ormask, mfrac are arrays of zeros. + + Parameters + ---------- + obs: ngmix.observation.Observation + The ngmix observation object to convert. + + Returns + ------- + DFMdetObservation + The converted DFMdetObservation with JAX arrays. + """ + jacobian = obs.get_jacobian() + + psf = None + if obs.has_psf(): + psf_obs = obs.get_psf() + psf_jacobian = psf_obs.get_jacobian() + psf = DFMdetPSF( + image=psf_obs.image, + wcs=jax_galsim.wcs.AffineTransform( + dudx=psf_jacobian.dudcol, + dudy=psf_jacobian.dudrow, + dvdx=psf_jacobian.dvdcol, + dvdy=psf_jacobian.dvdrow, + origin=jax_galsim.PositionD( + y=psf_jacobian.row0 + 1, + x=psf_jacobian.col0 + 1, + ), + ), + meta=psf_obs.meta, + store_pixels=getattr(psf_obs, "store_pixels", True), + ignore_zero_weight=getattr(psf_obs, "ignore_zero_weight", True), + ) + + return DFMdetObservation( + image=obs.image, + weight=obs.weight, + bmask=obs.bmask if obs.has_bmask() else None, + ormask=obs.ormask if obs.has_ormask() else None, + noise=obs.noise if obs.has_noise() else None, + wcs=jax_galsim.wcs.AffineTransform( + dudx=jacobian.dudcol, + dudy=jacobian.dudrow, + dvdx=jacobian.dvdcol, + dvdy=jacobian.dvdrow, + origin=jax_galsim.PositionD( + y=jacobian.row0 + 1, + x=jacobian.col0 + 1, + ), + ), + psf=psf, + meta=obs.meta, + mfrac=obs.mfrac if obs.has_mfrac() else None, + store_pixels=getattr(obs, "store_pixels", True), + ignore_zero_weight=getattr(obs, "ignore_zero_weight", True), + ) + + +def dfmd_psf_to_ngmix_obs(dfmd_psf: DFMdetPSF) -> Observation: + """Convert a DFMdetPSF to an ngmix Observation. + + Parameters + ---------- + dfmd_psf: DFMdetPSF + The Deep Field Metadetect PSF object to convert. + + Returns + ------- + ngmix.observation.Observation + The converted ngmix observation representing the PSF. + """ + psf = Observation( + image=np.array(dfmd_psf.image), + jacobian=ngmix.jacobian.Jacobian( + row=dfmd_psf.wcs.origin.y - 1, + col=dfmd_psf.wcs.origin.x - 1, + dudcol=dfmd_psf.wcs.dudx, + dudrow=dfmd_psf.wcs.dudy, + dvdcol=dfmd_psf.wcs.dvdx, + dvdrow=dfmd_psf.wcs.dvdy, + ), + meta=dfmd_psf.meta, + store_pixels=np.array(dfmd_psf.store_pixels, dtype=np.bool_), + ignore_zero_weight=np.array(dfmd_psf.ignore_zero_weight, dtype=np.bool_), + ) + return psf + + +def dfmd_obs_to_ngmix_obs(dfmd_obs: DFMdetObservation) -> Observation: + """Convert a DFMdetObservation to an ngmix Observation. + + This function transforms a JAX-compatible DFMdetObservation object into + a standard ngmix Observation object, converting all JAX arrays to numpy + arrays and transforming the JAX-galsim WCS to an ngmix Jacobian. + Note: This function never passes None values for the following: + bmask, ormask, mfrac, instead sets default arrays of zeros. + + Parameters + ---------- + dfmd_obs: DFMdetObservation + The Deep Field Metadetect observation object to convert. + + Returns + ------- + ngmix.observation.Observation + The converted ngmix observation with numpy arrays and ngmix Jacobian. + """ + psf = None + if dfmd_obs.has_psf(): + psf = dfmd_psf_to_ngmix_obs(dfmd_obs.psf) + + bmask = np.array(dfmd_obs.bmask) + ormask = np.array(dfmd_obs.ormask) + noise = np.array(dfmd_obs.noise) if dfmd_obs.has_noise() else None + mfrac = np.array(dfmd_obs.mfrac) + + return Observation( + image=np.array(dfmd_obs.image), + weight=np.array(dfmd_obs.weight), + bmask=bmask, + ormask=ormask, + noise=noise, + jacobian=ngmix.jacobian.Jacobian( + row=dfmd_obs.wcs.origin.y - 1, + col=dfmd_obs.wcs.origin.x - 1, + dudcol=dfmd_obs.wcs.dudx, + dudrow=dfmd_obs.wcs.dudy, + dvdcol=dfmd_obs.wcs.dvdx, + dvdrow=dfmd_obs.wcs.dvdy, + ), + psf=psf, + mfrac=mfrac, + meta=dfmd_obs.meta, + store_pixels=np.array(dfmd_obs.store_pixels, dtype=np.bool_), + ignore_zero_weight=np.array(dfmd_obs.ignore_zero_weight, dtype=np.bool_), + ) diff --git a/deep_field_metadetect/jaxify/tests/test_jax_deep_metacal.py b/deep_field_metadetect/jaxify/tests/test_jax_deep_metacal.py new file mode 100644 index 0000000..ab1df67 --- /dev/null +++ b/deep_field_metadetect/jaxify/tests/test_jax_deep_metacal.py @@ -0,0 +1,494 @@ +import multiprocessing + +import numpy as np +import pytest + +from deep_field_metadetect.jaxify.jax_metacal import ( + jax_metacal_op_shears, + jax_metacal_wide_and_deep_psf_matched, +) +from deep_field_metadetect.jaxify.observation import ( + dfmd_obs_to_ngmix_obs, + ngmix_obs_to_dfmd_obs, +) +from deep_field_metadetect.metacal import metacal_wide_and_deep_psf_matched +from deep_field_metadetect.utils import ( + MAX_ABS_C, + MAX_ABS_M, + assert_m_c_ok, + estimate_m_and_c, + fit_gauss_mom_mcal_res, + make_simple_sim, + measure_mcal_shear_quants, + print_m_c, +) + + +def _run_single_sim( + seed, + s2n, + g1, + g2, + deep_noise_fac, + deep_psf_fac, + skip_wide, + skip_deep, +): + nxy = 53 + nxy_psf = 53 + scale = 0.2 + + obs_w, obs_d, obs_dn = make_simple_sim( + seed=seed, + g1=g1, + g2=g2, + s2n=s2n, + dim=nxy, + dim_psf=nxy_psf, + scale=scale, + deep_noise_fac=deep_noise_fac, + deep_psf_fac=deep_psf_fac, + return_dfmd_obs=True, + ) + mcal_res = jax_metacal_wide_and_deep_psf_matched( + obs_w, + obs_d, + obs_dn, + nxy=53, + nxy_psf=53, + skip_obs_wide_corrections=skip_wide, + skip_obs_deep_corrections=skip_deep, + scale=scale, + ) + res = fit_gauss_mom_mcal_res(mcal_res) + return measure_mcal_shear_quants(res) + + +def _run_single_sim_jax_and_ngmix( + seed, + s2n, + g1, + g2, + deep_noise_fac, + deep_psf_fac, + skip_wide, + skip_deep, +): + nxy = 53 + nxy_psf = 53 + scale = 0.2 + + obs_w_ngmix, obs_d_ngmix, obs_dn_ngmix = make_simple_sim( + seed=seed, + g1=g1, + g2=g2, + s2n=s2n, + dim=nxy, + dim_psf=nxy_psf, + scale=scale, + deep_noise_fac=deep_noise_fac, + deep_psf_fac=deep_psf_fac, + return_dfmd_obs=False, + ) + mcal_res_ngmix = metacal_wide_and_deep_psf_matched( + obs_w_ngmix, + obs_d_ngmix, + obs_dn_ngmix, + skip_obs_wide_corrections=skip_wide, + skip_obs_deep_corrections=skip_deep, + ) + res_ngmix = fit_gauss_mom_mcal_res(mcal_res_ngmix) + + obs_w = ngmix_obs_to_dfmd_obs(obs_w_ngmix) + obs_d = ngmix_obs_to_dfmd_obs(obs_d_ngmix) + obs_dn = ngmix_obs_to_dfmd_obs(obs_dn_ngmix) + + mcal_res = jax_metacal_wide_and_deep_psf_matched( + obs_w, + obs_d, + obs_dn, + nxy=53, + nxy_psf=53, + skip_obs_wide_corrections=skip_wide, + skip_obs_deep_corrections=skip_deep, + scale=scale, + ) + res = fit_gauss_mom_mcal_res(mcal_res) + return measure_mcal_shear_quants(res), measure_mcal_shear_quants(res_ngmix) + + +def _run_sim_pair(seed, s2n, deep_noise_fac, deep_psf_fac, skip_wide, skip_deep): + res_p = _run_single_sim( + seed, + s2n, + 0.02, + 0.0, + deep_noise_fac, + deep_psf_fac, + skip_wide, + skip_deep, + ) + + res_m = _run_single_sim( + seed, + s2n, + -0.02, + 0.0, + deep_noise_fac, + deep_psf_fac, + skip_wide, + skip_deep, + ) + + return res_p, res_m + + +def _run_sim_pair_jax_and_ngmix( + seed, s2n, deep_noise_fac, deep_psf_fac, skip_wide, skip_deep +): + res_p, res_p_ngmix = _run_single_sim_jax_and_ngmix( + seed, + s2n, + 0.02, + 0.0, + deep_noise_fac, + deep_psf_fac, + skip_wide, + skip_deep, + ) + + res_m, res_m_ngmix = _run_single_sim_jax_and_ngmix( + seed, + s2n, + -0.02, + 0.0, + deep_noise_fac, + deep_psf_fac, + skip_wide, + skip_deep, + ) + + return (res_p, res_m), (res_p_ngmix, res_m_ngmix) + + +def test_deep_metacal_smoke(): + res_p, res_m = _run_sim_pair(1234, 1e8, 1.0 / np.sqrt(10), 1, False, False) + for col in res_p.dtype.names: + assert np.isfinite(res_p[col]).all() + assert np.isfinite(res_m[col]).all() + + +@pytest.mark.parametrize("deep_psf_ratio", [0.8, 1.2]) +def test_jax_vs_ngmix_comparison(deep_psf_ratio): + nsims = 5 + noise_fac = 1 / np.sqrt(10) + + rng = np.random.RandomState(seed=34132) + seeds = rng.randint(size=nsims, low=1, high=2**29) + + res_p = [] + res_m = [] + res_p_ngmix = [] + res_m_ngmix = [] + for seed in seeds: + res, res_ngmix = _run_sim_pair_jax_and_ngmix( + seed, 1e8, noise_fac, deep_psf_ratio, False, False + ) + if res is not None: + res_p.append(res[0]) + res_m.append(res[1]) + res_p_ngmix.append(res_ngmix[0]) + res_m_ngmix.append(res_ngmix[1]) + + assert np.allclose( + res[0].tolist(), + res_ngmix[0].tolist(), + atol=5e-4, + rtol=0.025, + equal_nan=True, + ) + assert np.allclose( + res[1].tolist(), + res_ngmix[1].tolist(), + atol=1e-5, + rtol=0.025, + equal_nan=True, + ) + + m, merr, c1, c1err, c2, c2err = estimate_m_and_c( + np.concatenate(res_p), + np.concatenate(res_m), + 0.02, + jackknife=len(res_p), + ) + + m_ng, merr_ng, c1_ng, c1err_ng, c2_ng, c2err_ng = estimate_m_and_c( + np.concatenate(res_p_ngmix), + np.concatenate(res_m_ngmix), + 0.02, + jackknife=len(res_p_ngmix), + ) + + assert np.allclose(m, m_ng, atol=1e-4) + assert np.allclose(merr, merr_ng, atol=1e-5) + assert np.allclose(c1, c1_ng, atol=1e-4) + assert np.allclose(c1err, c1err_ng, atol=1e-5) + assert np.allclose(c2, c2_ng, atol=1e-4) + assert np.allclose(c2err, c2err_ng, atol=1e-5) + + print("JAX results:") + print_m_c(m, merr, c1, c1err, c2, c2err) + print("ngmix results:") + print_m_c(m_ng, merr_ng, c1_ng, c1err_ng, c2_ng, c2err_ng) + assert_m_c_ok(m, merr, c1, c1err, c2, c2err) + + +@pytest.mark.parametrize("deep_psf_ratio", [0.8, 1, 1.2]) +def test_deep_metacal(deep_psf_ratio): + nsims = 50 + noise_fac = 1 / np.sqrt(10) + + rng = np.random.RandomState(seed=34132) + seeds = rng.randint(size=nsims, low=1, high=2**29) + + res_p = [] + res_m = [] + for seed in seeds: + res = _run_sim_pair(seed, 1e8, noise_fac, deep_psf_ratio, False, False) + if res is not None: + res_p.append(res[0]) + res_m.append(res[1]) + + m, merr, c1, c1err, c2, c2err = estimate_m_and_c( + np.concatenate(res_p), + np.concatenate(res_m), + 0.02, + jackknife=len(res_p), + ) + + print_m_c(m, merr, c1, c1err, c2, c2err) + assert_m_c_ok(m, merr, c1, c1err, c2, c2err) + + +@pytest.mark.slow +def test_deep_metacal_widelows2n(): + nsims = 500 + noise_fac = 1 / np.sqrt(1000) + + rng = np.random.RandomState(seed=34132) + seeds = rng.randint(size=nsims, low=1, high=2**29) + + res_p = [] + res_m = [] + for seed in seeds: + res = _run_sim_pair(seed, 20, noise_fac, 1, False, False) + if res is not None: + res_p.append(res[0]) + res_m.append(res[1]) + + m, merr, c1, c1err, c2, c2err = estimate_m_and_c( + np.concatenate(res_p), + np.concatenate(res_m), + 0.02, + jackknife=len(res_p), + ) + + print_m_c(m, merr, c1, c1err, c2, c2err) + assert_m_c_ok(m, merr, c1, c1err, c2, c2err) + + +@pytest.mark.slow +@pytest.mark.parametrize( + "skip_wide,skip_deep", [(True, True), (True, False), (False, True), (False, False)] +) +def test_deep_metacal_slow(skip_wide, skip_deep): # pragma: no cover + if not skip_wide and not skip_deep: + nsims = 100_000 + s2n = 20 + else: + nsims = 100_000 + s2n = 10 + chunk_size = multiprocessing.cpu_count() * 100 + nchunks = nsims // chunk_size + 1 + noise_fac = 1 / np.sqrt(10) + nsims = nchunks * chunk_size + + rng = np.random.RandomState(seed=34132) + seeds = rng.randint(size=nsims, low=1, high=2**29) + res_p = [] + res_m = [] + loc = 0 + for chunk in range(nchunks): + _seeds = seeds[loc : loc + chunk_size] + for seed in _seeds: + res = _run_sim_pair(seed, s2n, noise_fac, 0.8, skip_wide, skip_deep) + if res is not None: + res_p.append(res[0]) + res_m.append(res[1]) + + if len(res_p) < 500: + njack = len(res_p) + else: + njack = 100 + + m, merr, c1, c1err, c2, c2err = estimate_m_and_c( + np.concatenate(res_p), + np.concatenate(res_m), + 0.02, + jackknife=njack, + ) + + print("# of sims:", len(res_p), flush=True) + print_m_c(m, merr, c1, c1err, c2, c2err) + + if not skip_wide and not skip_deep: + assert np.abs(m) < max(MAX_ABS_M, 3 * merr), (m, merr) + elif 3 * merr < 5e-3: + assert np.abs(m) >= max(MAX_ABS_M, 3 * merr), (m, merr) + # if we are more than 10 sigma biased, then the test + # has passed for sure + if np.abs(m) / max(MAX_ABS_M / 3, merr) >= 10: + break + assert np.abs(c1) < max(4.0 * c1err, MAX_ABS_C), (c1, c1err) + assert np.abs(c2) < max(4.0 * c2err, MAX_ABS_C), (c2, c2err) + + loc += chunk_size + + print_m_c(m, merr, c1, c1err, c2, c2err) + if not skip_wide and not skip_deep: + assert np.abs(m) < max(MAX_ABS_M, 3 * merr), (m, merr) + else: + assert np.abs(m) >= max(MAX_ABS_M, 3 * merr), (m, merr) + assert np.abs(c1) < max(4.0 * c1err, MAX_ABS_C), (c1, c1err) + assert np.abs(c2) < max(4.0 * c2err, MAX_ABS_C), (c2, c2err) + + +def _run_single_sim_maybe_mcal( + seed, + s2n, + g1, + g2, + deep_noise_fac, + deep_psf_fac, + use_mcal, + zero_flux, +): + nxy = 53 + nxy_psf = 53 + scale = 0.2 + obs_w, obs_d, obs_dn = make_simple_sim( + seed=seed, + g1=g1, + g2=g2, + s2n=s2n, + dim=nxy, + dim_psf=nxy_psf, + scale=scale, + deep_noise_fac=deep_noise_fac, + deep_psf_fac=deep_psf_fac, + obj_flux_factor=0.0 if zero_flux else 1.0, + return_dfmd_obs=True, + ) + if use_mcal: + mcal_res = jax_metacal_op_shears( + obs_w, + scale=scale, + ) + for key, value in mcal_res.items(): + mcal_res[key] = dfmd_obs_to_ngmix_obs(value) + else: + mcal_res, _ = jax_metacal_wide_and_deep_psf_matched( + obs_w, + obs_d, + obs_dn, + nxy=nxy, + nxy_psf=nxy_psf, + scale=scale, + ) + return fit_gauss_mom_mcal_res(mcal_res), mcal_res + + +@pytest.mark.slow +def test_deep_metacal_noise_object_s2n(): + nsims = 100 + noise_fac = 1 / np.sqrt(10) + s2n = 10 + + rng = np.random.RandomState(seed=34132) + seeds = rng.randint(size=nsims, low=1, high=2**29) + + dmcal_res = [] + mcal_res = [] + for seed in seeds: + dmcal_res.append( + _run_single_sim_maybe_mcal( + seed, + s2n, + 0.02, + 0.0, + noise_fac, + 1.0, + False, + False, + ) + ) + mcal_res.append( + _run_single_sim_maybe_mcal( + seed, + s2n, + 0.02, + 0.0, + noise_fac, + 1.0, + True, + False, + ) + ) + + dmcal_res = np.concatenate([d[0] for d in dmcal_res if d is not None], axis=0) + mcal_res = np.concatenate([d[0] for d in mcal_res if d is not None], axis=0) + dmcal_res = dmcal_res[dmcal_res["mdet_step"] == "noshear"] + mcal_res = mcal_res[mcal_res["mdet_step"] == "noshear"] + + ratio = (np.median(dmcal_res["wmom_s2n"]) / np.median(mcal_res["wmom_s2n"])) ** 2 + print("s2n ratio squared:", ratio) + assert np.allclose(ratio, 2, atol=0, rtol=0.2), ratio + + dmcal_res = [] + mcal_res = [] + for seed in seeds: + dmcal_res.append( + _run_single_sim_maybe_mcal( + seed, + s2n, + 0.02, + 0.0, + noise_fac, + 1.0, + False, + True, + ) + ) + mcal_res.append( + _run_single_sim_maybe_mcal( + seed, + s2n, + 0.02, + 0.0, + noise_fac, + 1.0, + True, + True, + ) + ) + + dmcal_res = np.array( + [np.std(d[1]["noshear"].image) for d in dmcal_res if d is not None] + ) + mcal_res = np.array( + [np.std(d[1]["noshear"].image) for d in mcal_res if d is not None] + ) + + ratio = (np.median(dmcal_res) / np.median(mcal_res)) ** 2 + print("noise ratio squared:", ratio) + assert np.allclose(ratio, 0.5, atol=0, rtol=0.2), ratio diff --git a/deep_field_metadetect/jaxify/tests/test_jax_metacal.py b/deep_field_metadetect/jaxify/tests/test_jax_metacal.py new file mode 100644 index 0000000..2f8cbbb --- /dev/null +++ b/deep_field_metadetect/jaxify/tests/test_jax_metacal.py @@ -0,0 +1,615 @@ +import multiprocessing + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from deep_field_metadetect.jaxify.jax_metacal import ( + jax_add_dfmd_obs, + jax_metacal_op_shears, +) +from deep_field_metadetect.jaxify.observation import ( + DFMdetObservation, + DFMdetPSF, + dfmd_obs_to_ngmix_obs, + ngmix_obs_to_dfmd_obs, +) +from deep_field_metadetect.metacal import add_ngmix_obs, metacal_op_shears +from deep_field_metadetect.utils import ( + assert_m_c_ok, + estimate_m_and_c, + fit_gauss_mom_mcal_res, + make_simple_sim, + measure_mcal_shear_quants, + print_m_c, +) + + +def _run_single_sim_pair(seed, s2n): + nxy = 53 + nxy_psf = 53 + scale = 0.2 + obs_plus, *_ = make_simple_sim( + seed=seed, + g1=0.02, + g2=0.0, + s2n=s2n, + dim=nxy, + dim_psf=nxy_psf, + scale=scale, + deep_noise_fac=1.0 / np.sqrt(10), + deep_psf_fac=1.0, + return_dfmd_obs=True, + ) + mcal_res = jax_metacal_op_shears( + obs_plus, + nxy_psf=nxy_psf, + scale=scale, + ) + res_p = fit_gauss_mom_mcal_res(mcal_res) + res_p = measure_mcal_shear_quants(res_p) + + obs_minus, *_ = make_simple_sim( + seed=seed, + g1=-0.02, + g2=0.0, + s2n=s2n, + dim=nxy, + dim_psf=nxy_psf, + scale=scale, + deep_noise_fac=1.0 / np.sqrt(10), + deep_psf_fac=1.0, + return_dfmd_obs=True, + ) + mcal_res = jax_metacal_op_shears( + obs_minus, + nxy_psf=nxy_psf, + scale=scale, + ) + res_m = fit_gauss_mom_mcal_res(mcal_res) + res_m = measure_mcal_shear_quants(res_m) + + return res_p, res_m + + +def _run_single_sim_pair_jax_and_ngmix(seed, s2n): + nxy = 53 + nxy_psf = 53 + scale = 0.2 + obs_plus, *_ = make_simple_sim( + seed=seed, + g1=0.02, + g2=0.0, + s2n=s2n, + dim=nxy, + dim_psf=nxy_psf, + scale=scale, + deep_noise_fac=1.0 / np.sqrt(10), + deep_psf_fac=1.0, + return_dfmd_obs=False, + ) + + mcal_res_ngmix = metacal_op_shears(obs_plus) + + res_p_ngmix = fit_gauss_mom_mcal_res(mcal_res_ngmix) + res_p_ngmix = measure_mcal_shear_quants(res_p_ngmix) + + obs_plus = ngmix_obs_to_dfmd_obs(obs_plus) + + mcal_res = jax_metacal_op_shears( + obs_plus, + nxy_psf=nxy_psf, + scale=scale, + ) + res_p = fit_gauss_mom_mcal_res(mcal_res) + res_p = measure_mcal_shear_quants(res_p) + + obs_minus, *_ = make_simple_sim( + seed=seed, + g1=-0.02, + g2=0.0, + s2n=s2n, + dim=nxy, + dim_psf=nxy_psf, + scale=scale, + deep_noise_fac=1.0 / np.sqrt(10), + deep_psf_fac=1.0, + return_dfmd_obs=False, + ) + + mcal_res_ngmix = metacal_op_shears(obs_minus) + res_m_ngmix = fit_gauss_mom_mcal_res(mcal_res_ngmix) + res_m_ngmix = measure_mcal_shear_quants(res_m_ngmix) + + obs_minus = ngmix_obs_to_dfmd_obs(obs_minus) + mcal_res = jax_metacal_op_shears( + obs_minus, + nxy_psf=nxy_psf, + scale=scale, + ) + res_m = fit_gauss_mom_mcal_res(mcal_res) + res_m = measure_mcal_shear_quants(res_m) + + return (res_p, res_m), (res_p_ngmix, res_m_ngmix) + + +def test_metacal_smoke(): + res_p, res_m = _run_single_sim_pair(1234, 1e8) + for col in res_p.dtype.names: + assert np.isfinite(res_p[col]).all() + assert np.isfinite(res_m[col]).all() + + +def test_metacal_jax_vs_ngmix(): + nsims = 5 + + rng = np.random.RandomState(seed=34132) + seeds = rng.randint(size=nsims, low=1, high=2**29) + res_p = [] + res_m = [] + res_p_ngmix = [] + res_m_ngmix = [] + for seed in seeds: + res, res_ngmix = _run_single_sim_pair_jax_and_ngmix(seed, 1e8) + if res is not None: + res_p.append(res[0]) + res_m.append(res[1]) + + res_p_ngmix.append(res_ngmix[0]) + res_m_ngmix.append(res_ngmix[1]) + + assert np.allclose( + res[0].tolist(), + res_ngmix[0].tolist(), + atol=1e-3, + rtol=0.01, + equal_nan=True, + ) + assert np.allclose( + res[1].tolist(), + res_ngmix[1].tolist(), + atol=1e-3, + rtol=0.01, + equal_nan=True, + ) + + m, merr, c1, c1err, c2, c2err = estimate_m_and_c( + np.concatenate(res_p), + np.concatenate(res_m), + 0.02, + jackknife=len(res_p), + ) + + m_ng, merr_ng, c1_ng, c1err_ng, c2_ng, c2err_ng = estimate_m_and_c( + np.concatenate(res_p_ngmix), + np.concatenate(res_m_ngmix), + 0.02, + jackknife=len(res_p_ngmix), + ) + + print("JAX results:") + print_m_c(m, merr, c1, c1err, c2, c2err) + print("ngmix results:") + print_m_c(m_ng, merr_ng, c1_ng, c1err_ng, c2_ng, c2err_ng) + assert_m_c_ok(m, merr, c1, c1err, c2, c2err) + + assert np.allclose(m, m_ng, atol=1e-4) + assert np.allclose(merr, merr_ng, atol=1e-4) + assert np.allclose(c1err, c1err_ng, atol=1e-6) + assert np.allclose(c1, c1_ng, atol=1e-6) + assert np.allclose(c2err, c2err_ng, atol=1e-6) + assert np.allclose(c2, c2_ng, atol=1e-6) + + +def test_metacal(): + nsims = 50 + + rng = np.random.RandomState(seed=34132) + seeds = rng.randint(size=nsims, low=1, high=2**29) + res_p = [] + res_m = [] + for seed in seeds: + res = _run_single_sim_pair(seed, 1e8) + if res is not None: + res_p.append(res[0]) + res_m.append(res[1]) + + m, merr, c1, c1err, c2, c2err = estimate_m_and_c( + np.concatenate(res_p), + np.concatenate(res_m), + 0.02, + jackknife=len(res_p), + ) + + print_m_c(m, merr, c1, c1err, c2, c2err) + assert_m_c_ok(m, merr, c1, c1err, c2, c2err) + + +@pytest.mark.slow +def test_metacal_slow(): # pragma: no cover + nsims = 100_000 + chunk_size = multiprocessing.cpu_count() * 100 + nchunks = nsims // chunk_size + 1 + nsims = nchunks * chunk_size + + rng = np.random.RandomState(seed=34132) + seeds = rng.randint(size=nsims, low=1, high=2**29) + res_p = [] + res_m = [] + loc = 0 + for chunk in range(nchunks): + _seeds = seeds[loc : loc + chunk_size] + for seed in _seeds: + res = _run_single_sim_pair(seed, 20) + if res is not None: + res_p.append(res[0]) + res_m.append(res[1]) + + if len(res_p) < 500: + njack = len(res_p) + else: + njack = 100 + + m, merr, c1, c1err, c2, c2err = estimate_m_and_c( + np.concatenate(res_p), + np.concatenate(res_m), + 0.02, + jackknife=njack, + ) + + print("# of sims:", len(res_p), flush=True) + print_m_c(m, merr, c1, c1err, c2, c2err) + + loc += chunk_size + + print_m_c(m, merr, c1, c1err, c2, c2err) + assert_m_c_ok(m, merr, c1, c1err, c2, c2err) + + +def test_jax_vs_ngmix_render_psf_and_build_obs(): + """Test _jax_render_psf_and_build_obs vs render_psf_and_build_obs""" + import galsim + import jax_galsim + + from deep_field_metadetect.jaxify.jax_metacal import ( + _jax_render_psf_and_build_obs, + jax_get_gauss_reconv_psf_galsim, + ) + from deep_field_metadetect.jaxify.jax_utils import compute_stepk + from deep_field_metadetect.jaxify.observation import ngmix_obs_to_dfmd_obs + from deep_field_metadetect.metacal import ( + _render_psf_and_build_obs, + get_gauss_reconv_psf_galsim, + ) + from deep_field_metadetect.utils import make_simple_sim + + # Create test observations + nxy = 53 + nxy_psf = 21 + scale = 0.2 + + ngmix_obs, _, _ = make_simple_sim( + seed=12345, + g1=0.0, + g2=0.0, + s2n=1e8, + dim=nxy, + dim_psf=nxy_psf, + scale=scale, + return_dfmd_obs=False, + ) + + # Convert to dfmd observation + dfmd_obs = ngmix_obs_to_dfmd_obs(ngmix_obs) + + # Create test reconv PSFs + test_image = jnp.ones((nxy, nxy)) + + # JAX version + jax_psf = jax_galsim.Gaussian(sigma=1.0).withFlux(1.0) + dk = compute_stepk(pixel_scale=scale, image_size=nxy_psf) + jax_reconv_psf = jax_get_gauss_reconv_psf_galsim( + jax_psf, dk=dk, nxy_psf=nxy_psf, kim_size=256 + ) + jax_result = _jax_render_psf_and_build_obs( + test_image, dfmd_obs, jax_reconv_psf, nxy_psf=nxy_psf, weight_fac=1 + ) + + # ngmix version + ngmix_psf = galsim.Gaussian(sigma=1.0).withFlux(1.0) + ngmix_reconv_psf = get_gauss_reconv_psf_galsim(ngmix_psf, dk=dk, kim_size=256) + + ngmix_result = _render_psf_and_build_obs( + test_image, ngmix_obs, ngmix_reconv_psf, weight_fac=1 + ) + + assert jnp.isclose(ngmix_reconv_psf.sigma, jax_reconv_psf.sigma), ( + "reconv psf sigmas are different" + ) + # Check if shapes match + assert jax_result.psf.image.shape == ngmix_result.psf.image.shape, ( + f"PSF shapes don't match: JAX {jax_result.psf.image.shape} " + f"vs ngmix {ngmix_result.psf.image.shape}" + ) + + # Compare PSF images with some tolerance + diff = jnp.abs(jax_result.psf.image - ngmix_result.psf.image) + max_diff = jnp.max(diff) + rel_diff = max_diff / jnp.max(jax_result.psf.image) + + print(f"Max absolute difference: {max_diff}") + print(f"Max relative difference: {rel_diff}") + + # Test that PSF images are reasonably close + assert jnp.allclose( + jax_result.psf.image, ngmix_result.psf.image, atol=1e-10, rtol=1e-6 + ), f"PSF images differ significantly. Max diff: {max_diff}, Rel diff: {rel_diff}" + + +def _create_test_dfmd_obs( + has_bmask=False, + has_ormask=False, + has_noise=False, + has_mfrac=False, + has_psf=True, + has_wcs=True, + seed=42, +): + """Create a test DFMdetObservation with specified attributes. + If has_* is set to False, the defaults are used.""" + import jax_galsim + + key = jax.random.PRNGKey(seed) + + image = jnp.ones((10, 10)) + weight = jnp.ones((10, 10)) + + wcs = jax_galsim.wcs.AffineTransform( + dudx=0.2, dudy=0.01, dvdx=0.02, dvdy=0.3, origin=jax_galsim.PositionD(5.5, 4.5) + ) + + obs = DFMdetObservation( + image=image, + weight=weight, + bmask=jax.random.randint(key, (10, 10), 0, 2, dtype=jnp.int32) + if has_bmask + else None, + ormask=jax.random.randint( + jax.random.split(key)[0], (10, 10), 0, 2, dtype=jnp.int32 + ) + if has_ormask + else None, + noise=jax.random.uniform( + jax.random.split(key)[1], (10, 10), minval=0.05, maxval=0.15 + ) + if has_noise + else None, + mfrac=jax.random.uniform( + jax.random.split(jax.random.split(key)[0])[1], + (10, 10), + minval=0.01, + maxval=0.5, + ) + if has_mfrac + else None, + wcs=wcs if has_wcs else None, + psf=DFMdetPSF( + image=jax.random.uniform( + jax.random.split(jax.random.split(jax.random.split(key)[0])[0])[1], + (5, 5), + minval=0.8, + maxval=1.2, + ).astype(jnp.float32), + wcs=wcs, + ) + if has_psf + else None, + ) + return obs + + +def test_jax_add_dfmd_obs_vs_add_dfmd_obs(): + """Test that jax_add_dfmd_obs and add_dfmd_obs + They should return the same values for image, psf, weight, wcs.""" + + obs1 = _create_test_dfmd_obs( + has_bmask=True, + has_ormask=True, + has_noise=True, + has_mfrac=True, + has_psf=True, + has_wcs=False, + seed=11, + ) + obs2 = _create_test_dfmd_obs( + has_bmask=True, + has_ormask=True, + has_noise=True, + has_mfrac=True, + has_psf=True, + has_wcs=False, + seed=13, + ) + ngmix_obs1 = dfmd_obs_to_ngmix_obs(obs1) + ngmix_obs2 = dfmd_obs_to_ngmix_obs(obs2) + + jax_result = jax_add_dfmd_obs( + obs1, obs2, ignore_psf=False, skip_mfrac_for_second=False + ) + non_jax_result = add_ngmix_obs( + ngmix_obs1, ngmix_obs2, ignore_psf=False, skip_mfrac_for_second=False + ) + + assert jnp.allclose(jax_result.image, non_jax_result.image, atol=1e-10), ( + "Images do not match" + ) + assert jnp.allclose(jax_result.weight, non_jax_result.weight, atol=1e-10), ( + "Weights do not match" + ) + + assert jax_result.wcs.dudx == non_jax_result.jacobian.dudcol, ( + "wcs dudx does not match" + ) + assert jax_result.wcs.dudy == non_jax_result.jacobian.dudrow, ( + "wcs dudy does not match" + ) + assert jax_result.wcs.dvdx == non_jax_result.jacobian.dvdcol, ( + "wcs dvdx does not match" + ) + assert jax_result.wcs.dvdy == non_jax_result.jacobian.dvdrow, ( + "wcs dvdy does not match" + ) + assert jax_result.wcs.origin.x == non_jax_result.jacobian.col0 + 1, ( + "wcs does not match" + ) + assert jax_result.wcs.origin.y == non_jax_result.jacobian.row0 + 1, ( + "wcs does not match" + ) + + # Compare PSF if both have PSF + if jax_result.has_psf() and non_jax_result.has_psf(): + assert jnp.allclose( + jax_result.psf.image, non_jax_result.psf.image, atol=1e-10 + ), "PSF images do not match" + assert jnp.allclose( + jax_result.psf.weight, non_jax_result.psf.weight, atol=1e-10 + ), "PSF weights do not match" + assert jax_result.psf.wcs.dudx == non_jax_result.psf.jacobian.dudcol, ( + "PSF wcs does not match" + ) + assert jax_result.psf.wcs.dudy == non_jax_result.psf.jacobian.dudrow, ( + "PSF wcs does not match" + ) + assert jax_result.psf.wcs.dvdx == non_jax_result.psf.jacobian.dvdcol, ( + "PSF wcs does not match" + ) + assert jax_result.psf.wcs.dvdy == non_jax_result.psf.jacobian.dvdrow, ( + "PSF wcs does not match" + ) + + assert jax_result.psf.wcs.origin.x == non_jax_result.psf.jacobian.col0 + 1, ( + "PSF wcs does not match" + ) + assert jax_result.psf.wcs.origin.y == non_jax_result.psf.jacobian.row0 + 1, ( + "PSF wcs does not match" + ) + + if jax_result.has_bmask(): + assert jnp.allclose(jax_result.bmask, non_jax_result.bmask), ( + "bmasks do not match" + ) + + if jax_result.has_ormask(): + assert jnp.allclose(jax_result.ormask, non_jax_result.ormask), ( + "ormasks do not match" + ) + + if jax_result.has_noise(): + assert jnp.allclose(jax_result.noise, non_jax_result.noise, atol=1e-10), ( + "noise do not match" + ) + + if jax_result.has_mfrac(): + assert jnp.allclose(jax_result.mfrac, non_jax_result.mfrac, atol=1e-10), ( + "mfrac do not match" + ) + + +def test_jax_add_dfmd_obs_vs_add_dfmd_obs_ignore_psf(): + """Test that both functions work correctly when ignoring PSF.""" + obs1 = _create_test_dfmd_obs( + has_bmask=True, + has_ormask=True, + has_noise=True, + has_mfrac=True, + has_psf=True, + seed=17, + ) + obs2 = _create_test_dfmd_obs( + has_bmask=True, + has_ormask=True, + has_noise=True, + has_mfrac=True, + has_psf=True, + seed=19, + ) + ngmix_obs1 = dfmd_obs_to_ngmix_obs(obs1) + ngmix_obs2 = dfmd_obs_to_ngmix_obs(obs2) + + jax_result = jax_add_dfmd_obs( + obs1, obs2, ignore_psf=True, skip_mfrac_for_second=False + ) + non_jax_result = add_ngmix_obs( + ngmix_obs1, ngmix_obs2, ignore_psf=True, skip_mfrac_for_second=False + ) + + assert jnp.allclose(jax_result.image, non_jax_result.image, atol=1e-10), ( + "Images do not match with ignore_psf=True" + ) + assert jnp.allclose(jax_result.weight, non_jax_result.weight, atol=1e-10), ( + "Weights do not match with ignore_psf=True" + ) + + assert jax_result.wcs.dudx == non_jax_result.jacobian.dudcol, ( + "wcs dudx does not match with ignore_psf=True" + ) + assert jax_result.wcs.dudy == non_jax_result.jacobian.dudrow, ( + "wcs dudy does not match with ignore_psf=True" + ) + assert jax_result.wcs.dvdx == non_jax_result.jacobian.dvdcol, ( + "wcs dvdx does not match with ignore_psf=True" + ) + assert jax_result.wcs.dvdy == non_jax_result.jacobian.dvdrow, ( + "wcs dvdy does not match with ignore_psf=True" + ) + assert jax_result.wcs.origin.x == non_jax_result.jacobian.col0 + 1, ( + "wcs origin.x does not match with ignore_psf=True" + ) + assert jax_result.wcs.origin.y == non_jax_result.jacobian.row0 + 1, ( + "wcs origin.y does not match with ignore_psf=True" + ) + + assert not jax_result.has_psf(), ( + "JAX result should not have PSF when ignore_psf=True" + ) + assert not non_jax_result.has_psf(), ( + "Non-JAX result should not have PSF when ignore_psf=True" + ) + + +def test_jax_add_dfmd_obs_vs_add_dfmd_obs_skip_mfrac(): + """Test that both functions handle skip_mfrac_for_second correctly.""" + obs1 = _create_test_dfmd_obs(has_mfrac=True, has_psf=True, seed=16) + obs2 = _create_test_dfmd_obs(has_mfrac=True, has_psf=True, seed=12) + ngmix_obs1 = dfmd_obs_to_ngmix_obs(obs1) + ngmix_obs2 = dfmd_obs_to_ngmix_obs(obs2) + + jax_result = jax_add_dfmd_obs( + obs1, obs2, ignore_psf=True, skip_mfrac_for_second=True + ) + non_jax_result = add_ngmix_obs( + ngmix_obs1, ngmix_obs2, ignore_psf=True, skip_mfrac_for_second=True + ) + + assert jnp.allclose(jax_result.mfrac, non_jax_result.mfrac, atol=1e-10), ( + "mfrac do not match with skip_mfrac_for_second=True" + ) + assert jnp.allclose(jax_result.mfrac, obs1.mfrac, atol=1e-10), ( + "mfrac should equal obs1.mfrac when skip_mfrac_for_second=True" + ) + + jax_result = jax_add_dfmd_obs( + obs1, obs2, ignore_psf=True, skip_mfrac_for_second=False + ) + non_jax_result = add_ngmix_obs( + ngmix_obs1, ngmix_obs2, ignore_psf=True, skip_mfrac_for_second=False + ) + + assert jnp.allclose(jax_result.mfrac, non_jax_result.mfrac, atol=1e-10), ( + "mfrac do not match with skip_mfrac_for_second=False" + ) + expected_mfrac = (obs1.mfrac + obs2.mfrac) / 2 + assert jnp.allclose(jax_result.mfrac, expected_mfrac, atol=1e-10), ( + "mfrac should be average when skip_mfrac_for_second=False" + ) diff --git a/deep_field_metadetect/jaxify/tests/test_jax_metadetect.py b/deep_field_metadetect/jaxify/tests/test_jax_metadetect.py new file mode 100644 index 0000000..6c9c5e7 --- /dev/null +++ b/deep_field_metadetect/jaxify/tests/test_jax_metadetect.py @@ -0,0 +1,499 @@ +import multiprocessing + +import numpy as np +import pytest + +from deep_field_metadetect.jaxify.jax_metacal import DEFAULT_FFT_SIZE +from deep_field_metadetect.jaxify.jax_metadetect import ( + jax_single_band_deep_field_metadetect, +) +from deep_field_metadetect.jaxify.observation import ngmix_obs_to_dfmd_obs +from deep_field_metadetect.metadetect import single_band_deep_field_metadetect +from deep_field_metadetect.utils import ( + MAX_ABS_C, + MAX_ABS_M, + assert_m_c_ok, + estimate_m_and_c, + make_simple_sim, + measure_mcal_shear_quants, + print_m_c, +) + + +def _run_single_sim( + seed, + s2n, + g1, + g2, + deep_noise_fac, + deep_psf_fac, + skip_wide, + skip_deep, +): + nxy = 201 + nxy_psf = 53 + scale = 0.2 + obs_w, obs_d, obs_dn = make_simple_sim( + seed=seed, + g1=g1, + g2=g2, + s2n=s2n, + deep_noise_fac=deep_noise_fac, + deep_psf_fac=deep_psf_fac, + dim=nxy, + dim_psf=nxy_psf, + scale=scale, + buff=25, + n_objs=10, + return_dfmd_obs=True, + ) + + res = jax_single_band_deep_field_metadetect( + obs_w, + obs_d, + obs_dn, + nxy=nxy, + nxy_psf=nxy_psf, + skip_obs_wide_corrections=skip_wide, + skip_obs_deep_corrections=skip_deep, + scale=scale, + ) + return measure_mcal_shear_quants(res) + + +def _run_sim_pair(seed, s2n, deep_noise_fac, deep_psf_fac, skip_wide, skip_deep): + res_p = _run_single_sim( + seed, + s2n, + 0.02, + 0.0, + deep_noise_fac, + deep_psf_fac, + skip_wide, + skip_deep, + ) + + res_m = _run_single_sim( + seed, + s2n, + -0.02, + 0.0, + deep_noise_fac, + deep_psf_fac, + skip_wide, + skip_deep, + ) + + return res_p, res_m + + +def _run_single_sim_jax_and_ngmix( + seed, + s2n, + g1, + g2, + deep_noise_fac, + deep_psf_fac, + skip_wide, + skip_deep, +): + nxy = 201 + nxy_psf = 53 + scale = 0.2 + + # Creating ngmix and dfmdet observations + obs_w_ngmix, obs_d_ngmix, obs_dn_ngmix = make_simple_sim( + seed=seed, + g1=g1, + g2=g2, + s2n=s2n, + deep_noise_fac=deep_noise_fac, + deep_psf_fac=deep_psf_fac, + dim=nxy, + dim_psf=nxy_psf, + scale=scale, + buff=25, + n_objs=10, + return_dfmd_obs=False, + ) + + obs_w = ngmix_obs_to_dfmd_obs(obs_w_ngmix) + obs_d = ngmix_obs_to_dfmd_obs(obs_d_ngmix) + obs_dn = ngmix_obs_to_dfmd_obs(obs_dn_ngmix) + + non_jax_results = single_band_deep_field_metadetect( + obs_w_ngmix, + obs_d_ngmix, + obs_dn_ngmix, + skip_obs_wide_corrections=skip_wide, + skip_obs_deep_corrections=skip_deep, + return_k_info=True, + fft_size=DEFAULT_FFT_SIZE, + ) + + res_ngmix = non_jax_results[0] + (force_stepk_field, force_maxk_field, force_stepk_psf, force_maxk_psf) = ( + non_jax_results[1] + ) + + results = jax_single_band_deep_field_metadetect( + obs_w, + obs_d, + obs_dn, + nxy=nxy, + nxy_psf=nxy_psf, + skip_obs_wide_corrections=skip_wide, + skip_obs_deep_corrections=skip_deep, + scale=scale, + return_k_info=True, + force_stepk_field=force_stepk_field, + force_maxk_field=force_maxk_field, + force_stepk_psf=force_stepk_psf, + force_maxk_psf=force_maxk_psf, + fft_size=DEFAULT_FFT_SIZE, + ) + + res = results[0] + kinfo = results[1] + + assert kinfo[0] == force_stepk_field + assert kinfo[1] == force_maxk_field + assert kinfo[2] == force_stepk_psf + assert kinfo[3] == force_maxk_psf + + return measure_mcal_shear_quants(res), measure_mcal_shear_quants(res_ngmix) + + +def _run_sim_pair_jax_and_ngmix( + seed, s2n, deep_noise_fac, deep_psf_fac, skip_wide, skip_deep +): + res_p, res_p_ngmix = _run_single_sim_jax_and_ngmix( + seed, + s2n, + 0.02, + 0.0, + deep_noise_fac, + deep_psf_fac, + skip_wide, + skip_deep, + ) + + res_m, res_m_ngmix = _run_single_sim_jax_and_ngmix( + seed, + s2n, + -0.02, + 0.0, + deep_noise_fac, + deep_psf_fac, + skip_wide, + skip_deep, + ) + + return (res_p, res_m), (res_p_ngmix, res_m_ngmix) + + +def test_metadetect_single_band_deep_field_metadetect_smoke(): + res_p, res_m = _run_sim_pair(1234, 1e4, 1.0 / np.sqrt(10), 1, False, False) + for col in res_p.dtype.names: + assert np.isfinite(res_p[col]).all() + assert np.isfinite(res_m[col]).all() + + +@pytest.mark.parametrize("deep_psf_ratio", [0.8, 1, 1.1]) +def test_metadetect_single_band_deep_field_metadetect_jax_vs_ngmix(deep_psf_ratio): + nsims = 5 + noise_fac = 1 / np.sqrt(30) + + rng = np.random.RandomState(seed=34132) + seeds = rng.randint(size=nsims, low=1, high=2**29) + res_p = [] + res_m = [] + res_p_ngmix = [] + res_m_ngmix = [] + for seed in seeds: + res, res_ngmix = _run_sim_pair_jax_and_ngmix( + seed, 1e4, noise_fac, deep_psf_ratio, False, False + ) + if res is not None: + res_p.append(res[0]) + res_m.append(res[1]) + res_p_ngmix.append(res_ngmix[0]) + res_m_ngmix.append(res_ngmix[1]) + + assert np.allclose( + res[0].tolist(), + res_ngmix[0].tolist(), + atol=1e-5, + rtol=0.025, + equal_nan=True, + ) + assert np.allclose( + res[1].tolist(), + res_ngmix[1].tolist(), + atol=1e-5, + rtol=0.025, + equal_nan=True, + ) + + m, merr, c1, c1err, c2, c2err = estimate_m_and_c( + np.concatenate(res_p), + np.concatenate(res_m), + 0.02, + jackknife=len(res_p), + ) + + m_ng, merr_ng, c1_ng, c1err_ng, c2_ng, c2err_ng = estimate_m_and_c( + np.concatenate(res_p_ngmix), + np.concatenate(res_m_ngmix), + 0.02, + jackknife=len(res_p_ngmix), + ) + + print("JAX results:") + print_m_c(m, merr, c1, c1err, c2, c2err) + print("ngmix results:") + print_m_c(m_ng, merr_ng, c1_ng, c1err_ng, c2_ng, c2err_ng) + assert_m_c_ok(m, merr, c1, c1err, c2, c2err) + + assert np.allclose(m, m_ng, atol=1e-4) + assert np.allclose(merr, merr_ng, atol=1e-4) + assert np.allclose(c1err, c1err_ng, atol=1e-6) + assert np.allclose(c1, c1_ng, atol=1e-6) + assert np.allclose(c2err, c2err_ng, atol=1e-6) + assert np.allclose(c2, c2_ng, atol=1e-6) + + +def test_metadetect_single_band_deep_field_metadetect_bmask(): + nxy = 201 + nxy_psf = 53 + scale = 0.2 + + rng = np.random.RandomState(seed=1234) + obs_w, obs_d, obs_dn = make_simple_sim( + seed=1234, + g1=0.02, + g2=0.00, + s2n=1000, + deep_noise_fac=1.0 / np.sqrt(10), + deep_psf_fac=1, + dim=nxy, + dim_psf=nxy_psf, + scale=scale, + buff=25, + n_objs=10, + return_dfmd_obs=True, + ) + obs_w = obs_w.replace( + bmask=rng.choice([0, 1, 3], p=[0.5, 0.25, 0.25], size=obs_w.image.shape) + ) + + res = jax_single_band_deep_field_metadetect( + obs_w, + obs_d, + obs_dn, + nxy=nxy, + nxy_psf=nxy_psf, + skip_obs_wide_corrections=False, + skip_obs_deep_corrections=False, + scale=scale, + ) + + xc = (res["x"] + 0.5).astype(int) + yc = (res["y"] + 0.5).astype(int) + msk = res["mdet_step"] == "noshear" + assert np.array_equal(obs_w.bmask[yc[msk], xc[msk]], res["bmask_flags"][msk]) + assert np.any(res["bmask_flags"][msk] != 0) + + for step in ["1p", "1m", "2p", "2m"]: + msk = res["mdet_step"] == step + assert not np.array_equal( + obs_d.bmask[yc[msk], xc[msk]] | obs_dn.bmask[yc[msk], xc[msk]], + res["bmask_flags"][msk], + ) + + +def test_metadetect_single_band_deep_field_metadetect_mfrac_wide(): + nxy = 201 + nxy_psf = 53 + scale = 0.2 + rng = np.random.RandomState(seed=1234) + obs_w, obs_d, obs_dn = make_simple_sim( + seed=1234, + g1=0.02, + g2=0.00, + s2n=1000, + deep_noise_fac=1.0 / np.sqrt(10), + deep_psf_fac=1, + dim=nxy, + dim_psf=nxy_psf, + scale=scale, + buff=25, + n_objs=10, + return_dfmd_obs=True, + ) + obs_w = obs_w.replace( + mfrac=np.float32(rng.uniform(0.5, 0.7, size=obs_w.image.shape)) + ) + + res = jax_single_band_deep_field_metadetect( + obs_w, + obs_d, + obs_dn, + nxy=nxy, + nxy_psf=nxy_psf, + skip_obs_wide_corrections=False, + skip_obs_deep_corrections=False, + scale=scale, + ) + + msk = (res["wmom_flags"] == 0) & (res["mdet_step"] == "noshear") + assert np.all(res["mfrac"][msk] >= 0.5) + assert np.all(res["mfrac"][msk] <= 0.7) + + msk = (res["wmom_flags"] == 0) & (res["mdet_step"] != "noshear") + assert np.all(res["mfrac"][msk] == 0) + + +def test_metadetect_single_band_deep_field_metadetect_mfrac_deep(): + nxy = 201 + nxy_psf = 53 + scale = 0.2 + rng = np.random.RandomState(seed=1234) + obs_w, obs_d, obs_dn = make_simple_sim( + seed=1234, + g1=0.02, + g2=0.00, + s2n=1000, + deep_noise_fac=1.0 / np.sqrt(10), + deep_psf_fac=1, + dim=nxy, + dim_psf=nxy_psf, + scale=scale, + buff=25, + n_objs=10, + return_dfmd_obs=True, + ) + obs_d = obs_d.replace( + mfrac=np.float32(rng.uniform(0.5, 0.7, size=obs_w.image.shape)) + ) + + res = jax_single_band_deep_field_metadetect( + obs_w, + obs_d, + obs_dn, + nxy=nxy, + nxy_psf=nxy_psf, + skip_obs_wide_corrections=False, + skip_obs_deep_corrections=False, + scale=scale, + ) + + msk = (res["wmom_flags"] == 0) & (res["mdet_step"] != "noshear") + assert np.all(res["mfrac"][msk] >= 0.5) + assert np.all(res["mfrac"][msk] <= 0.7) + + msk = (res["wmom_flags"] == 0) & (res["mdet_step"] == "noshear") + assert np.all(res["mfrac"][msk] == 0) + + +@pytest.mark.parametrize("deep_psf_ratio", [0.8, 1, 1.1]) +def test_metadetect_single_band_deep_field_metadetect(deep_psf_ratio): + nsims = 100 + noise_fac = 1 / np.sqrt(30) + + rng = np.random.RandomState(seed=34132) + seeds = rng.randint(size=nsims, low=1, high=2**29) + res_p = [] + res_m = [] + for seed in seeds: + res = _run_sim_pair(seed, 1e4, noise_fac, deep_psf_ratio, False, False) + if res is not None: + res_p.append(res[0]) + res_m.append(res[1]) + + m, merr, c1, c1err, c2, c2err = estimate_m_and_c( + np.concatenate(res_p), + np.concatenate(res_m), + 0.02, + jackknife=len(res_p), + ) + + assert np.isfinite(m) + assert np.isfinite(merr) + print_m_c(m, merr, c1, c1err, c2, c2err) + assert_m_c_ok(m, merr, c1, c1err, c2, c2err) + + +@pytest.mark.slow +@pytest.mark.parametrize( + "skip_wide,skip_deep", [(True, True), (True, False), (False, True), (False, False)] +) +def test_metadetect_single_band_deep_field_metadetect_slow( + skip_wide, skip_deep +): # pragma: no cover + if not skip_wide and not skip_deep: + nsims = 1_000_000 + s2n = 20 + else: + nsims = 100_000 + s2n = 10 + chunk_size = multiprocessing.cpu_count() * 100 + nchunks = nsims // chunk_size + 1 + noise_fac = 1 / np.sqrt(10) + nsims = nchunks * chunk_size + + rng = np.random.RandomState(seed=34132) + seeds = rng.randint(size=nsims, low=1, high=2**29) + res_p = [] + res_m = [] + loc = 0 + for chunk in range(nchunks): + _seeds = seeds[loc : loc + chunk_size] + # jobs = [ + # joblib.delayed(_run_sim_pair)( + # seed, s2n, noise_fac, 0.8, skip_wide, skip_deep + # ) + # for seed in _seeds + # ] + # outputs = joblib.Parallel(n_jobs=-1, verbose=10)(jobs) + for seed in _seeds: + res = _run_sim_pair(seed, s2n, noise_fac, 0.8, skip_wide, skip_deep) + if res is not None: + res_p.append(res[0]) + res_m.append(res[1]) + + if len(res_p) < 500: + njack = len(res_p) + else: + njack = 100 + + m, merr, c1, c1err, c2, c2err = estimate_m_and_c( + np.concatenate(res_p), + np.concatenate(res_m), + 0.02, + jackknife=njack, + ) + + print("# of sims:", len(res_p), flush=True) + print_m_c(m, merr, c1, c1err, c2, c2err) + + if not skip_wide and not skip_deep: + assert np.abs(m) < max(MAX_ABS_M, 3 * merr), (m, merr) + elif 3 * merr < 5e-3: + assert np.abs(m) >= max(MAX_ABS_M, 3 * merr), (m, merr) + # if we are more than 10 sigma biased, then the test + # has passed for sure + if np.abs(m) / max(MAX_ABS_M / 3, merr) >= 10: + break + assert np.abs(c1) < max(4.0 * c1err, MAX_ABS_C), (c1, c1err) + assert np.abs(c2) < max(4.0 * c2err, MAX_ABS_C), (c2, c2err) + + loc += chunk_size + + print_m_c(m, merr, c1, c1err, c2, c2err) + if not skip_wide and not skip_deep: + assert np.abs(m) < max(MAX_ABS_M, 3 * merr), (m, merr) + else: + assert np.abs(m) >= max(MAX_ABS_M, 3 * merr), (m, merr) + assert np.abs(c1) < max(4.0 * c1err, MAX_ABS_C), (c1, c1err) + assert np.abs(c2) < max(4.0 * c2err, MAX_ABS_C), (c2, c2err) diff --git a/deep_field_metadetect/jaxify/tests/test_jax_ngmix_intermediates.py b/deep_field_metadetect/jaxify/tests/test_jax_ngmix_intermediates.py new file mode 100644 index 0000000..e8731b7 --- /dev/null +++ b/deep_field_metadetect/jaxify/tests/test_jax_ngmix_intermediates.py @@ -0,0 +1,210 @@ +import numpy as np +import pytest + +from deep_field_metadetect.jaxify.jax_metacal import ( + get_jax_galsim_object_from_dfmd_obs_nopix, + jax_get_gauss_reconv_psf_galsim, + jax_get_max_gauss_reconv_psf_galsim, + jax_metacal_op_g1g2, + jax_metacal_op_shears, +) +from deep_field_metadetect.jaxify.jax_utils import compute_stepk +from deep_field_metadetect.jaxify.observation import ( + ngmix_obs_to_dfmd_obs, +) +from deep_field_metadetect.metacal import ( + get_galsim_object_from_ngmix_obs_nopix, + get_gauss_reconv_psf_galsim, + get_max_gauss_reconv_psf_galsim, + metacal_op_g1g2, + metacal_op_shears, +) +from deep_field_metadetect.utils import make_simple_sim + + +class TestJaxNgmixIntermediates: + """Test is the versions produce the same intermediate values.""" + + @pytest.fixture(scope="class") + def simple_obs_pair(self): + """Create a simple observation pair for testing.""" + nxy = 53 + nxy_psf = 53 + scale = 0.2 + + obs_w_ngmix, obs_d_ngmix, obs_dn_ngmix = make_simple_sim( + seed=12345, + g1=0.02, + g2=0.0, + s2n=1e8, + dim=nxy, + dim_psf=nxy_psf, + scale=scale, + deep_noise_fac=1.0 / np.sqrt(10), + deep_psf_fac=1.0, + return_dfmd_obs=False, + ) + + # Convert to JAX observations + obs_w_jax = ngmix_obs_to_dfmd_obs(obs_w_ngmix) + obs_d_jax = ngmix_obs_to_dfmd_obs(obs_d_ngmix) + obs_dn_jax = ngmix_obs_to_dfmd_obs(obs_dn_ngmix) + + return { + "ngmix": (obs_w_ngmix, obs_d_ngmix, obs_dn_ngmix), + "jax": (obs_w_jax, obs_d_jax, obs_dn_jax), + "params": {"nxy": nxy, "nxy_psf": nxy_psf, "scale": scale}, + } + + def test_gauss_reconv_psf_consistency(self, simple_obs_pair): + """Test Gaussian reconvolution PSF.""" + obs_w_ngmix, obs_d_ngmix, _ = simple_obs_pair["ngmix"] + obs_w_jax, obs_d_jax, _ = simple_obs_pair["jax"] + nxy_psf = simple_obs_pair["params"]["nxy_psf"] + scale = simple_obs_pair["params"]["scale"] + + # Test single PSF + psf_ngmix = get_galsim_object_from_ngmix_obs_nopix( + obs_w_ngmix.psf, kind="image" + ) + psf_jax = get_jax_galsim_object_from_dfmd_obs_nopix(obs_w_jax.psf, kind="image") + + dk = compute_stepk(pixel_scale=scale, image_size=nxy_psf) + + kim_size = 173 + reconv_psf_jax = jax_get_gauss_reconv_psf_galsim( + psf_jax, dk=dk, kim_size=kim_size + ) + reconv_psf_ngmix = get_gauss_reconv_psf_galsim( + psf_ngmix, dk=dk, kim_size=kim_size + ) + + # Test PSF properties - relax tolerance for small numerical differences + assert np.allclose( + reconv_psf_ngmix.fwhm, reconv_psf_jax.fwhm, rtol=1e-6, atol=1e-10 + ), f"FWHM mismatch: {reconv_psf_ngmix.fwhm} vs {reconv_psf_jax.fwhm}" + + assert np.allclose( + reconv_psf_ngmix.flux, reconv_psf_jax.flux, rtol=1e-10, atol=1e-12 + ), f"Flux mismatch: {reconv_psf_ngmix.flux} vs {reconv_psf_jax.flux}" + + def test_max_gauss_reconv_psf_consistency(self, simple_obs_pair): + """Test max Gaussian reconvolution PSF. + kim_size and dk are not get for Galsim in this case.""" + obs_w_ngmix, obs_d_ngmix, _ = simple_obs_pair["ngmix"] + obs_w_jax, obs_d_jax, _ = simple_obs_pair["jax"] + nxy_psf = simple_obs_pair["params"]["nxy_psf"] + scale = simple_obs_pair["params"]["scale"] + + # Get PSFs + psf_w_ngmix = get_galsim_object_from_ngmix_obs_nopix( + obs_w_ngmix.psf, kind="image" + ) + psf_d_ngmix = get_galsim_object_from_ngmix_obs_nopix( + obs_d_ngmix.psf, kind="image" + ) + psf_w_jax = get_jax_galsim_object_from_dfmd_obs_nopix( + obs_w_jax.psf, kind="image" + ) + psf_d_jax = get_jax_galsim_object_from_dfmd_obs_nopix( + obs_d_jax.psf, kind="image" + ) + + # Compare maximum reconvolution PSFs + max_reconv_psf_ngmix = get_max_gauss_reconv_psf_galsim(psf_w_ngmix, psf_d_ngmix) + max_reconv_psf_jax = jax_get_max_gauss_reconv_psf_galsim( + psf_w_jax, psf_d_jax, nxy_psf, scale + ) + + # Test PSF properties + assert np.allclose( + max_reconv_psf_ngmix.fwhm, max_reconv_psf_jax.fwhm, rtol=0.02, atol=1e-6 + ), f"Max FWHM: {max_reconv_psf_ngmix.fwhm} vs {max_reconv_psf_jax.fwhm}" + + def test_metacal_single_shear_consistency(self, simple_obs_pair): + """Test single shear operations.""" + obs_w_ngmix, _, _ = simple_obs_pair["ngmix"] + obs_w_jax, _, _ = simple_obs_pair["jax"] + nxy_psf = simple_obs_pair["params"]["nxy_psf"] + scale = simple_obs_pair["params"]["scale"] + + # Test single shear transformation + g1, g2 = 0.01, 0.0 + + # Get reconvolution PSFs for both versions + dk = compute_stepk(pixel_scale=scale, image_size=nxy_psf) + psf_jax = get_jax_galsim_object_from_dfmd_obs_nopix(obs_w_jax.psf, kind="image") + psf_ngmix = get_galsim_object_from_ngmix_obs_nopix( + obs_w_ngmix.psf, kind="image" + ) + + kim_size = 173 + reconv_psf_jax = jax_get_gauss_reconv_psf_galsim(psf_jax, dk, kim_size=kim_size) + reconv_psf_ngmix = get_gauss_reconv_psf_galsim( + psf_ngmix, dk=dk, kim_size=kim_size + ) + + # Run metacal operations + mcal_obs_ngmix = metacal_op_g1g2(obs_w_ngmix, reconv_psf_ngmix, g1, g2) + mcal_obs_jax = jax_metacal_op_g1g2(obs_w_jax, reconv_psf_jax, g1, g2, nxy_psf) + + # Convert JAX result to ngmix for comparison + from deep_field_metadetect.jaxify.observation import dfmd_obs_to_ngmix_obs + + mcal_obs_jax_ngmix = dfmd_obs_to_ngmix_obs(mcal_obs_jax) + + # Compare image statistics + assert np.allclose( + np.mean(mcal_obs_ngmix.image), + np.mean(mcal_obs_jax_ngmix.image), + rtol=1e-5, + atol=1e-9, + ), "Image mean mismatch" + + assert np.allclose( + np.std(mcal_obs_ngmix.image), + np.std(mcal_obs_jax_ngmix.image), + rtol=1e-5, + atol=1e-9, + ), "Image std mismatch" + + def test_metacal_shears_intermediate_values(self, simple_obs_pair): + """Test intermediate values in metacal shears operations.""" + obs_w_ngmix, _, _ = simple_obs_pair["ngmix"] + obs_w_jax, _, _ = simple_obs_pair["jax"] + scale = simple_obs_pair["params"]["scale"] + + test_shears = ("noshear", "1p", "1m") + + # Run metacal operations + mcal_res_ngmix = metacal_op_shears(obs_w_ngmix, shears=test_shears) + mcal_res_jax = jax_metacal_op_shears(obs_w_jax, shears=test_shears, scale=scale) + + # Convert JAX results to ngmix for comparison + from deep_field_metadetect.jaxify.observation import dfmd_obs_to_ngmix_obs + + mcal_res_jax_ngmix = {} + for shear in test_shears: + mcal_res_jax_ngmix[shear] = dfmd_obs_to_ngmix_obs(mcal_res_jax[shear]) + + # Compare results for each shear + for shear in test_shears: + obs_ngmix = mcal_res_ngmix[shear] + obs_jax_ngmix = mcal_res_jax_ngmix[shear] + + # Compare image statistics + img_mean_diff = abs(np.mean(obs_ngmix.image) - np.mean(obs_jax_ngmix.image)) + img_std_ratio = np.std(obs_ngmix.image) / np.std(obs_jax_ngmix.image) + + assert img_mean_diff < 1e-4, ( + f"Shear {shear}: Image mean difference too large: {img_mean_diff}" + ) + assert 0.99 < img_std_ratio < 1.01, ( + f"Shear {shear}: Image std ratio out of range: {img_std_ratio}" + ) + + # Compare weight statistics + weight_ratio = np.mean(obs_ngmix.weight) / np.mean(obs_jax_ngmix.weight) + assert 0.99 < weight_ratio < 1.01, ( + f"Shear {shear}: Weight ratio out of range: {weight_ratio}" + ) diff --git a/deep_field_metadetect/metacal.py b/deep_field_metadetect/metacal.py index 9774e80..3e9381c 100644 --- a/deep_field_metadetect/metacal.py +++ b/deep_field_metadetect/metacal.py @@ -2,7 +2,7 @@ import ngmix import numpy as np -DEFAULT_SHEARS = ["noshear", "1p", "1m", "2p", "2m"] +DEFAULT_SHEARS = ("noshear", "1p", "1m", "2p", "2m") DEFAULT_STEP = 0.01 @@ -21,7 +21,7 @@ def get_shear_tuple(shear, step): raise RuntimeError("Shear value '%s' not regonized!" % shear) -def get_gauss_reconv_psf_galsim(psf, step=DEFAULT_STEP, flux=1): +def get_gauss_reconv_psf_galsim(psf, step=DEFAULT_STEP, flux=1, dk=None, kim_size=None): """Gets the target reconvolution PSF for an input PSF object. This is taken from galsim/tests/test_metacal.py and assumes the psf is @@ -31,22 +31,30 @@ def get_gauss_reconv_psf_galsim(psf, step=DEFAULT_STEP, flux=1): ---------- psf : galsim object The PSF. + step : float, optional + Factor by which to expand the PSF to supress noise from high-k + fourirer modes introduced due to shearing of pre-PSF images. + Defaults to deep_field_metadetect.metacal.DEFAULT_STEP. flux : float The output flux of the PSF. Defaults to 1. + dk : float + The Fourier-space pixel scale. + kim_size : int + k image size. + Defaults to None, which lets galsim set the size Returns ------- reconv_psf : galsim object The reconvolution PSF. - sigma : float - The width of the reconv PSF befor dilation. """ - dk = psf.stepk / 4.0 + if dk is None: + dk = psf.stepk / 4.0 small_kval = 1.0e-2 # Find the k where the given psf hits this kvalue smaller_kval = 3.0e-3 # Target PSF will have this kvalue at the same k - kim = psf.drawKImage(scale=dk) + kim = psf.drawKImage(nx=kim_size, ny=kim_size, scale=dk) karr_r = kim.real.array # Find the smallest r where the kval < small_kval nk = karr_r.shape[0] @@ -86,7 +94,13 @@ def get_max_gauss_reconv_psf(obs_w, obs_d, step=DEFAULT_STEP): return get_max_gauss_reconv_psf_galsim(psf_w, psf_d, step=step) -def _render_psf_and_build_obs(image, obs, reconv_psf, weight_fac=1): +def _render_psf_and_build_obs(image, obs, reconv_psf, weight_fac=1, fft_size=None): + if fft_size is not None: + reconv_psf = reconv_psf.withGSParams( + minimum_fft_size=fft_size, + maximum_fft_size=fft_size, + ) + pim = reconv_psf.drawImage( nx=obs.psf.image.shape[1], ny=obs.psf.image.shape[0], @@ -105,7 +119,9 @@ def _render_psf_and_build_obs(image, obs, reconv_psf, weight_fac=1): return obs -def _metacal_op_g1g2_impl(*, wcs, image, noise, psf_inv, dims, reconv_psf, g1, g2): +def _metacal_op_g1g2_impl( + *, wcs, image, noise, psf_inv, dims, reconv_psf, g1, g2, fft_size=None +): """Run metacal on an ngmix observation. Note that the noise image should already be rotated by 90 degrees here. @@ -125,6 +141,17 @@ def _metacal_op_g1g2_impl(*, wcs, image, noise, psf_inv, dims, reconv_psf, g1, g ] ) + if fft_size is not None: + ims = ims.withGSParams( + minimum_fft_size=fft_size, + maximum_fft_size=fft_size, + ) + + ns = ns.withGSParams( + minimum_fft_size=fft_size, + maximum_fft_size=fft_size, + ) + ims = ims.drawImage(nx=dims[1], ny=dims[0], wcs=wcs).array ns = np.rot90( ns.drawImage(nx=dims[1], ny=dims[0], wcs=wcs).array, @@ -133,7 +160,7 @@ def _metacal_op_g1g2_impl(*, wcs, image, noise, psf_inv, dims, reconv_psf, g1, g return ims + ns -def metacal_op_g1g2(obs, reconv_psf, g1, g2): +def metacal_op_g1g2(obs, reconv_psf, g1, g2, fft_size=None): """Run metacal on an ngmix observation.""" mcal_image = _metacal_op_g1g2_impl( wcs=obs.jacobian.get_galsim_wcs(), @@ -148,11 +175,16 @@ def metacal_op_g1g2(obs, reconv_psf, g1, g2): reconv_psf=reconv_psf, g1=g1, g2=g2, + fft_size=fft_size, + ) + return _render_psf_and_build_obs( + mcal_image, obs, reconv_psf, weight_fac=0.5, fft_size=fft_size ) - return _render_psf_and_build_obs(mcal_image, obs, reconv_psf, weight_fac=0.5) -def metacal_op_shears(obs, reconv_psf=None, shears=None, step=DEFAULT_STEP): +def metacal_op_shears( + obs, reconv_psf=None, shears=None, step=DEFAULT_STEP, fft_size=None +): """Run metacal on an ngmix observation.""" if shears is None: shears = DEFAULT_SHEARS @@ -182,21 +214,65 @@ def metacal_op_shears(obs, reconv_psf=None, shears=None, step=DEFAULT_STEP): g2=g2, ) mcal_res[shear] = _render_psf_and_build_obs( - mcal_image, obs, reconv_psf, weight_fac=0.5 + mcal_image, + obs, + reconv_psf, + weight_fac=0.5, + fft_size=fft_size, ) return mcal_res -def match_psf(obs, reconv_psf): +def match_psf( + obs, + reconv_psf, + return_k_info=False, + force_stepk_field=0.0, + force_maxk_field=0.0, + force_stepk_psf=0.0, + force_maxk_psf=0.0, + fft_size=None, +): """Match the PSF on an ngmix observation to a new PSF.""" wcs = obs.jacobian.get_galsim_wcs() - image = get_galsim_object_from_ngmix_obs(obs, kind="image") - psf = get_galsim_object_from_ngmix_obs(obs.psf, kind="image") + image = get_galsim_object_from_ngmix_obs( + obs, + kind="image", + _force_stepk=force_stepk_field, + _force_maxk=force_maxk_field, + ) + + psf = get_galsim_object_from_ngmix_obs( + obs.psf, + kind="image", + _force_stepk=force_stepk_psf, + _force_maxk=force_maxk_psf, + ) + + ims = galsim.Convolve( + [image, galsim.Deconvolve(psf), reconv_psf], + ) + + if fft_size is not None: + ims = ims.withGSParams( + minimum_fft_size=fft_size, + maximum_fft_size=fft_size, + ) - ims = galsim.Convolve([image, galsim.Deconvolve(psf), reconv_psf]) ims = ims.drawImage(nx=obs.image.shape[1], ny=obs.image.shape[0], wcs=wcs).array + if return_k_info: + return _render_psf_and_build_obs( + ims, obs, reconv_psf, weight_fac=1, fft_size=fft_size + ), ( + image._stepk, + image._maxk, + psf._stepk, + psf._maxk, + ) - return _render_psf_and_build_obs(ims, obs, reconv_psf, weight_fac=1) + return _render_psf_and_build_obs( + ims, obs, reconv_psf, weight_fac=1, fft_size=fft_size + ) def _extract_attr(obs, attr, dtype): @@ -281,7 +357,9 @@ def add_ngmix_obs(obs1, obs2, ignore_psf=False, skip_mfrac_for_second=False): return obs -def get_galsim_object_from_ngmix_obs(obs, kind="image", rot90=0): +def get_galsim_object_from_ngmix_obs( + obs, kind="image", rot90=0, _force_stepk=0.0, _force_maxk=0.0 +): """Make an interpolated image from an ngmix obs.""" return galsim.InterpolatedImage( galsim.ImageD( @@ -289,6 +367,8 @@ def get_galsim_object_from_ngmix_obs(obs, kind="image", rot90=0): wcs=obs.jacobian.get_galsim_wcs(), ), x_interpolant="lanczos15", + _force_stepk=_force_stepk, + _force_maxk=_force_maxk, ) @@ -312,26 +392,108 @@ def metacal_wide_and_deep_psf_matched( skip_obs_wide_corrections=False, skip_obs_deep_corrections=False, return_noshear_deep=False, + return_k_info=False, + force_stepk_field=0.0, + force_maxk_field=0.0, + force_stepk_psf=0.0, + force_maxk_psf=0.0, + fft_size=None, ): - """Do metacalibration for a combination of wide+deep datasets.""" + """Do metacalibration for a combination of wide+deep datasets. + + Parameters + ---------- + obs_wide : ngmix.Observation + The wide-field observation. + obs_deep : ngmix.Observation + The deep-field observation. + obs_deep_noise : ngmix.Observation + The deep-field noise observation. + shears : list, optional + The shears to use for the metacalibration, by default DEFAULT_SHEARS + if set to None. + step : float, optional + The step size for the metacalibration, by default DEFAULT_STEP. + skip_obs_wide_corrections : bool, optional + Skip the observation corrections for the wide-field observations, + by default False. + skip_obs_deep_corrections : bool, optional + Skip the observation corrections for the deep-field observations, + by default False. + return_noshear_deep : bool, optional + adds deep field no shear results to the output. Default - False. + return_k_info : bool, optional + return _force stepk and maxk values in the following order + _force_stepk_field, _force_maxk_field, _force_stepk_psf, _force_maxk_psf. + Used mainly for testing. + force_stepk_field : float, optional + Force stepk for drawing field images. + Defaults to 0.0, which lets Galsim choose the value. + Used mainly for testing. + force_maxk_field: float, optional + Force maxk for drawing field images. + Defaults to 0.0, which lets Galsim choose the value. + Used mainly for testing. + force_stepk_psf: float, optional + Force stepk for drawing PSF images. + Defaults to 0.0, which lets Galsim choose the value. + Used mainly for testing. + force_maxk_psf: float, optional + Force stepk for drawing PSF images + Defaults to 0.0, which lets Galsim choose the value. + Used mainly for testing. + fft_size: int, optional + To fix max and min values of FFT size. + Defaults to None which lets Galsim determine the values. + Used mainly to test against Galsim. + + Returns + ------- + mcal_res : dict + Output from metacal_op_shears for shear cases listed by the shears input, + optionaly no shear deep field case if return_noshear_deep is True + and kinfo for debugging if return_k_info is set to True. + kinfo is returned in the following order: + _force_stepk_field, _force_maxk_field, _force_stepk_psf, _force_maxk_psf. + """ # first get the biggest reconv PSF of the two reconv_psf = get_max_gauss_reconv_psf(obs_wide, obs_deep) + mcal_obs_wide = match_psf( + obs_wide, + reconv_psf, + return_k_info=return_k_info, + force_stepk_field=force_stepk_field, + force_maxk_field=force_maxk_field, + force_stepk_psf=force_stepk_psf, + force_maxk_psf=force_maxk_psf, + fft_size=fft_size, + ) - # make the wide obs - if skip_obs_wide_corrections: - mcal_obs_wide = match_psf(obs_wide, reconv_psf) - else: + if return_k_info: + mcal_obs_wide, kinfo = mcal_obs_wide + force_stepk_field, force_maxk_field, force_stepk_psf, force_maxk_psf = kinfo + + if not skip_obs_wide_corrections: mcal_obs_wide = add_ngmix_obs( - match_psf(obs_wide, reconv_psf), - metacal_op_g1g2(obs_deep_noise, reconv_psf, 0, 0), + mcal_obs_wide, + metacal_op_g1g2(obs_deep_noise, reconv_psf, 0, 0, fft_size=fft_size), skip_mfrac_for_second=True, ) # get PSF matched noise obs_wide_noise = obs_wide.copy() obs_wide_noise.image = obs_wide.noise - wide_noise_corr = match_psf(obs_wide_noise, reconv_psf) + wide_noise_corr = match_psf( + obs_wide_noise, + reconv_psf, + force_stepk_field=force_stepk_field, + force_maxk_field=force_maxk_field, + force_stepk_psf=force_stepk_psf, + force_maxk_psf=force_maxk_psf, + return_k_info=False, + fft_size=fft_size, + ) # now run mcal on deep mcal_res = metacal_op_shears( @@ -339,6 +501,7 @@ def metacal_wide_and_deep_psf_matched( reconv_psf=reconv_psf, shears=shears, step=step, + fft_size=fft_size, ) # now add in noise corr to make it match the wide noise @@ -359,4 +522,7 @@ def metacal_wide_and_deep_psf_matched( for k in mcal_res: mcal_res[k].psf.galsim_obj = reconv_psf + if return_k_info: + mcal_res["kinfo"] = kinfo + return mcal_res diff --git a/deep_field_metadetect/metadetect.py b/deep_field_metadetect/metadetect.py index f1cd361..f7857f9 100644 --- a/deep_field_metadetect/metadetect.py +++ b/deep_field_metadetect/metadetect.py @@ -23,6 +23,12 @@ def single_band_deep_field_metadetect( skip_obs_wide_corrections=False, skip_obs_deep_corrections=False, nodet_flags=0, + return_k_info=False, + force_stepk_field=0.0, + force_maxk_field=0.0, + force_stepk_psf=0.0, + force_maxk_psf=0.0, + fft_size=None, ): """Run deep-field metadetection for a simple scenario of a single band with a single image per band using only post-PSF Gaussian weighted moments. @@ -48,13 +54,40 @@ def single_band_deep_field_metadetect( by default False. nodet_flags : int, optional The bmask flags marking area in the image to skip, by default 0. + return_k_info : bool, optional + return _force stepk and maxk values in the following order + _force_stepk_field, _force_maxk_field, _force_stepk_psf, _force_maxk_psf. + Used mainly for testing. + force_stepk_field : float, optional + Force stepk for drawing field images. + Defaults to 0.0, which lets Galsim choose the value. + Used mainly for testing. + force_maxk_field: float, optional + Force maxk for drawing field images. + Defaults to 0.0, which lets Galsim choose the value. + Used mainly for testing. + force_stepk_psf: float, optional + Force stepk for drawing PSF images. + Defaults to 0.0, which lets Galsim choose the value. + Used mainly for testing. + force_maxk_psf: float, optional + Force stepk for drawing PSF images + Defaults to 0.0, which lets Galsim choose the value. + Used mainly for testing. + fft_size: int, optional + To fix max and min values of FFT size. + Defaults to None which lets Galsim determine the values. + Used mainly to test against Galsim. Returns ------- - dfmdet_res : dict - The deep-field metadetection results, a dictionary with keys from `shears` - and values containing the detection+measurement results for the corresponding - shear. + dfmdet_res : numpy.ndarray + The deep-field metadetection results as a structured array containing + detection and measurement results for all shears. + + Note: If return_k_info is set to True for debugging, + this function returns a tuple containing (dfmdet_res, kinfo). kinfo being: + (_force_stepk_field, _force_maxk_field, _force_stepk_psf, _force_maxk_psf) """ if shears is None: shears = DEFAULT_SHEARS @@ -67,10 +100,18 @@ def single_band_deep_field_metadetect( shears=shears, skip_obs_wide_corrections=skip_obs_wide_corrections, skip_obs_deep_corrections=skip_obs_deep_corrections, + return_k_info=return_k_info, + force_stepk_field=force_stepk_field, + force_maxk_field=force_maxk_field, + force_stepk_psf=force_stepk_psf, + force_maxk_psf=force_maxk_psf, + fft_size=fft_size, ) + psf_res = fit_gauss_mom_obs(mcal_res["noshear"].psf) dfmdet_res = [] - for shear, obs in mcal_res.items(): + for shear in shears: + obs = mcal_res[shear] detres = run_detection_sep(obs, nodet_flags=nodet_flags) ixc = (detres["catalog"]["x"] + 0.5).astype(int) @@ -110,4 +151,7 @@ def single_band_deep_field_metadetect( ("mfrac", "f4"), ] + fres.dtype.descr + if return_k_info: + return (np.array(dfmdet_res, dtype=total_dtype), mcal_res.get("kinfo")) + return np.array(dfmdet_res, dtype=total_dtype) diff --git a/deep_field_metadetect/tests/test_metadetect.py b/deep_field_metadetect/tests/test_metadetect.py index 9112ec9..157d980 100644 --- a/deep_field_metadetect/tests/test_metadetect.py +++ b/deep_field_metadetect/tests/test_metadetect.py @@ -171,14 +171,27 @@ def test_metadetect_single_band_deep_field_metadetect_mfrac_deep(): ) obs_d.mfrac = rng.uniform(0.5, 0.7, size=obs_w.image.shape) - res = single_band_deep_field_metadetect( + results = single_band_deep_field_metadetect( obs_w, obs_d, obs_dn, skip_obs_wide_corrections=False, skip_obs_deep_corrections=False, + return_k_info=True, + force_stepk_field=0.12403490725241548, + force_maxk_field=8.160777791551611, + force_stepk_psf=0.6815071326229606, + force_maxk_psf=12.640001692177682, ) + res = results[0] + kinfo = results[1] + + assert kinfo[0] == 0.12403490725241548 + assert kinfo[1] == 8.160777791551611 + assert kinfo[2] == 0.6815071326229606 + assert kinfo[3] == 12.640001692177682 + msk = (res["wmom_flags"] == 0) & (res["mdet_step"] != "noshear") assert np.all(res["mfrac"][msk] >= 0.5) assert np.all(res["mfrac"][msk] <= 0.7) diff --git a/deep_field_metadetect/utils.py b/deep_field_metadetect/utils.py index 4c1634a..8c669e5 100644 --- a/deep_field_metadetect/utils.py +++ b/deep_field_metadetect/utils.py @@ -7,6 +7,13 @@ import numpy as np from ngmix.gaussmom import GaussMom +from deep_field_metadetect.jaxify.observation import ( + DFMdetObservation, + DFMdetPSF, + dfmd_obs_to_ngmix_obs, + dfmd_psf_to_ngmix_obs, + ngmix_obs_to_dfmd_obs, +) from deep_field_metadetect.metacal import DEFAULT_SHEARS GLOBAL_START_TIME = time.time() @@ -297,7 +304,12 @@ def fit_gauss_mom_mcal_res(mcal_res, fwhm=1.2): vals = np.zeros(len(mcal_res), dtype=dt) fitter = GaussMom(fwhm) - psf_res = fitter.go(mcal_res["noshear"].psf) + + psf = mcal_res["noshear"].psf + if isinstance(psf, DFMdetPSF): + psf = dfmd_psf_to_ngmix_obs(mcal_res["noshear"].psf) + + psf_res = fitter.go(psf) for i, (shear, obs) in enumerate(mcal_res.items()): vals["mdet_step"][i] = shear @@ -309,7 +321,10 @@ def fit_gauss_mom_mcal_res(mcal_res, fwhm=1.2): vals["wmom_psf_T"][i] = psf_res["T"] + if isinstance(obs, DFMdetObservation): + obs = dfmd_obs_to_ngmix_obs(obs) res = fitter.go(obs) + vals["wmom_flags"][i] = res["flags"] if res["flags"] != 0: @@ -541,14 +556,14 @@ def _gen_hex_grid(*, rng, dim, buff, pixel_scale, n_tot): return shifts -def _make_single_sim(*, dither=None, rng, psf, obj, nse, scale, dim): +def _make_single_sim(*, dither=None, rng, psf, obj, nse, scale, dim, dim_psf=53): cen = (dim - 1) / 2 im = obj.drawImage(nx=dim, ny=dim, scale=scale).array im += rng.normal(size=im.shape, scale=nse) - cen_psf = (53 - 1) / 2 - psf_im = psf.drawImage(nx=53, ny=53, scale=scale).array + cen_psf = (dim_psf - 1) / 2 + psf_im = psf.drawImage(nx=dim_psf, ny=dim_psf, scale=scale).array if dither is not None: jac = ngmix.DiagonalJacobian( @@ -558,7 +573,7 @@ def _make_single_sim(*, dither=None, rng, psf, obj, nse, scale, dim): jac = ngmix.DiagonalJacobian(scale=scale, row=cen, col=cen) psf_jac = ngmix.DiagonalJacobian(scale=scale, row=cen_psf, col=cen_psf) - obs = ngmix.Observation( + obs = ngmix.observation.Observation( image=im, weight=np.ones_like(im) / nse**2, jacobian=jac, @@ -570,6 +585,7 @@ def _make_single_sim(*, dither=None, rng, psf, obj, nse, scale, dim): bmask=np.zeros_like(im, dtype=np.int32), mfrac=np.zeros_like(im), ) + return obs @@ -584,8 +600,10 @@ def make_simple_sim( n_objs=1, scale=0.2, dim=53, + dim_psf=53, buff=26, obj_flux_factor=1, + return_dfmd_obs=False, ): """Make a simple simulation for testing deep-field metadetection. @@ -670,6 +688,7 @@ def make_simple_sim( dither=shifts[0] / scale if n_objs == 1 else None, scale=scale, dim=dim, + dim_psf=dim_psf, ) obs_deep = _make_single_sim( @@ -680,6 +699,7 @@ def make_simple_sim( dither=shifts[0] / scale if n_objs == 1 else None, scale=scale, dim=dim, + dim_psf=dim_psf, ) obs_deep_noise = _make_single_sim( @@ -690,6 +710,14 @@ def make_simple_sim( dither=shifts[0] / scale if n_objs == 1 else None, scale=scale, dim=dim, + dim_psf=dim_psf, ) + if return_dfmd_obs: + return ( + ngmix_obs_to_dfmd_obs(obs_wide), + ngmix_obs_to_dfmd_obs(obs_deep), + ngmix_obs_to_dfmd_obs(obs_deep_noise), + ) + return obs_wide, obs_deep, obs_deep_noise diff --git a/environment.yml b/environment.yml index 9a74dd2..f362fba 100644 --- a/environment.yml +++ b/environment.yml @@ -16,6 +16,7 @@ dependencies: - ngmix - numba - numpy + - jax<0.7.0 - pip: - git+https://github.com/GalSim-developers/JAX-GalSim.git@main