44
55//! Tests against peeling
66
7- use std:: { collections:: HashSet , f64:: consts:: TAU } ;
7+ use std:: { collections:: HashSet , f64:: consts:: TAU , num :: NonZeroUsize , path :: PathBuf } ;
88
99use approx:: assert_abs_diff_eq;
10+ use crossbeam_utils:: atomic:: AtomicCell ;
1011use hifitime:: { Duration , Epoch } ;
1112use indexmap:: indexmap;
12- use indicatif:: { MultiProgress , ProgressDrawTarget } ;
13+ use indicatif:: { MultiProgress , ProgressBar , ProgressDrawTarget } ;
1314use itertools:: { izip, Itertools } ;
1415use marlu:: {
1516 constants:: VEL_C ,
@@ -27,6 +28,7 @@ use crate::{
2728 beam:: { Delays , FEEBeam } ,
2829 context:: { ObsContext , Polarisations } ,
2930 io:: read:: VisInputType ,
31+ io:: write:: VisOutputType ,
3032 model:: { new_sky_modeller, SkyModellerCpu } ,
3133 srclist:: { ComponentType , FluxDensity , FluxDensityType , Source , SourceComponent , SourceList } ,
3234} ;
@@ -2924,3 +2926,212 @@ mod gpu_tests {
29242926 }
29252927 }
29262928}
2929+
2930+ #[ test]
2931+ fn test_peel_weight_preservation ( ) {
2932+ // Test that the original weights are preserved and not modified by tapering
2933+ let apply_precession = true ;
2934+
2935+ let obs_context = get_phase1_obs_context ( CPU_TILE_LIMIT ) ;
2936+
2937+ let array_pos = obs_context. array_position ;
2938+ let num_tiles = obs_context. get_total_num_tiles ( ) ;
2939+ let num_times = obs_context. timestamps . len ( ) ;
2940+ let num_baselines = ( num_tiles * ( num_tiles - 1 ) ) / 2 ;
2941+ let num_chans = obs_context. fine_chan_freqs . len ( ) ;
2942+
2943+ let chanblocks = obs_context
2944+ . fine_chan_freqs
2945+ . iter ( )
2946+ . enumerate ( )
2947+ . map ( |( i, f) | Chanblock {
2948+ chanblock_index : i as u16 ,
2949+ unflagged_index : i as u16 ,
2950+ freq : * f as f64 ,
2951+ } )
2952+ . collect_vec ( ) ;
2953+
2954+ let fine_chan_freqs_hz = obs_context
2955+ . fine_chan_freqs
2956+ . iter ( )
2957+ . map ( |& f| f as f64 )
2958+ . collect_vec ( ) ;
2959+ let _lambdas_m = fine_chan_freqs_hz. iter ( ) . map ( |& f| VEL_C / f) . collect_vec ( ) ;
2960+ let avg_freq = 4 ;
2961+ let low_res_lambdas_m = obs_context
2962+ . fine_chan_freqs
2963+ . as_slice ( )
2964+ . chunks ( avg_freq)
2965+ . map ( |chunk| {
2966+ let f = chunk. iter ( ) . sum :: < u64 > ( ) as f64 / chunk. len ( ) as f64 ;
2967+ VEL_C / f
2968+ } )
2969+ . collect_vec ( ) ;
2970+
2971+ let lst_0h_rad = get_lmst (
2972+ array_pos. longitude_rad ,
2973+ obs_context. timestamps [ 0 ] ,
2974+ obs_context. dut1 . unwrap_or_default ( ) ,
2975+ ) ;
2976+ let source_radec =
2977+ RADec :: from_hadec ( HADec :: from_radians ( 0.2 , array_pos. latitude_rad ) , lst_0h_rad) ;
2978+ let source_fd = 1. ;
2979+ let source_list = SourceList :: from ( [ (
2980+ "One" . into ( ) ,
2981+ point_src_i ! ( source_radec, 0. , fine_chan_freqs_hz[ 0 ] , source_fd) ,
2982+ ) ] ) ;
2983+
2984+ let beam = get_beam ( num_tiles) ;
2985+
2986+ // Create original weights with some non-uniform values to make changes detectable
2987+ let original_weights = {
2988+ let mut weights = Array3 :: < f32 > :: ones ( ( num_times, num_chans, num_baselines) ) ;
2989+ // Set some weights to different values to make changes detectable
2990+ for i in 0 ..num_times {
2991+ for j in 0 ..num_chans {
2992+ for k in 0 ..num_baselines {
2993+ weights[ [ i, j, k] ] = 1.0 + ( i + j + k) as f32 * 0.1 ;
2994+ }
2995+ }
2996+ }
2997+ weights
2998+ } ;
2999+
3000+ let timeblock = Timeblock {
3001+ index : 0 ,
3002+ range : 0 ..2 ,
3003+ timestamps : obs_context. timestamps . clone ( ) ,
3004+ timesteps : vec1 ! [ 0 , 1 ] ,
3005+ median : obs_context. timestamps [ 0 ] ,
3006+ } ;
3007+
3008+ let source_weighted_positions = [ source_radec] ;
3009+
3010+ let multi_progress = MultiProgress :: with_draw_target ( ProgressDrawTarget :: hidden ( ) ) ;
3011+
3012+ let peel_weight_params = PeelWeightParams {
3013+ uvw_min_metres : 0.0 ,
3014+ uvw_max_metres : f64:: MAX ,
3015+ short_baseline_sigma : SHORT_BASELINE_SIGMA ,
3016+ } ;
3017+
3018+ let flagged_tiles: HashSet < _ > = obs_context. flagged_tiles . iter ( ) . cloned ( ) . collect ( ) ;
3019+ let tile_baseline_flags = TileBaselineFlags :: new ( num_tiles, flagged_tiles) ;
3020+
3021+ let peel_loop_params = PeelLoopParams {
3022+ num_passes : NonZeroUsize :: try_from ( NUM_PASSES ) . expect ( "NUM_PASSES > 0" ) ,
3023+ num_loops : NonZeroUsize :: try_from ( NUM_LOOPS ) . expect ( "NUM_LOOPS > 0" ) ,
3024+ convergence : CONVERGENCE ,
3025+ } ;
3026+
3027+ // Create a copy of weights for testing
3028+ let mut test_weights = original_weights. clone ( ) ;
3029+
3030+ // Apply tapering to the test weights
3031+ peel_weight_params. apply_tfb (
3032+ test_weights. view_mut ( ) ,
3033+ & obs_context,
3034+ & timeblock,
3035+ apply_precession,
3036+ & chanblocks,
3037+ & tile_baseline_flags,
3038+ ) ;
3039+
3040+ // Verify that the original weights are different from the tapered weights
3041+ // (this proves that apply_tfb actually modifies the weights)
3042+ assert ! ( !original_weights. abs_diff_eq( & test_weights, 1e-6 ) ) ;
3043+
3044+ // Now test the actual peel_thread function
3045+ // We need to set up the channels for peel_thread
3046+ let ( tx_full_residual, rx_full_residual) = crossbeam_channel:: bounded ( 1 ) ;
3047+ let ( tx_write, rx_write) = crossbeam_channel:: bounded ( 10 ) ; // Buffer for written data
3048+ let ( tx_iono_consts, _rx_iono_consts) = crossbeam_channel:: bounded ( 1 ) ;
3049+ let error = AtomicCell :: new ( false ) ;
3050+
3051+ // Clone original_weights before moving into thread
3052+ let original_weights_clone = original_weights. clone ( ) ;
3053+
3054+ // Send test data to the channel
3055+ let test_vis_residual: Array3 < Jones < f32 > > =
3056+ Array3 :: zeros ( ( num_times, num_chans, num_baselines) ) ;
3057+ let test_weights_for_thread = original_weights. clone ( ) ;
3058+ let timeblock_owned = timeblock. clone ( ) ;
3059+ tx_full_residual
3060+ . send ( ( test_vis_residual, test_weights_for_thread, timeblock_owned) )
3061+ . unwrap ( ) ;
3062+ drop ( tx_full_residual) ; // Close the sender
3063+
3064+ // Spawn peel_thread in a separate thread so we can read from rx_write
3065+ let output_vis_params = OutputVisParams {
3066+ output_files : vec1 ! [ (
3067+ PathBuf :: from( "/tmp/test_output.uvfits" ) ,
3068+ VisOutputType :: Uvfits
3069+ ) ] ,
3070+ output_time_average_factor : NonZeroUsize :: new ( 1 ) . unwrap ( ) ,
3071+ output_freq_average_factor : NonZeroUsize :: new ( 1 ) . unwrap ( ) ,
3072+ output_timeblocks : vec1 ! [ timeblock. clone( ) ] ,
3073+ write_smallest_contiguous_band : false ,
3074+ } ;
3075+ let output_vis_params = output_vis_params;
3076+ let peel_handle = std:: thread:: spawn ( move || {
3077+ // Wrap the receiver to convert owned Timeblock to reference
3078+ let rx_full_residual_ref = rx_full_residual
3079+ . into_iter ( )
3080+ . map ( |( a, b, c) | ( a, b, Box :: new ( c) ) )
3081+ . map ( |( a, b, c) | ( a, b, Box :: leak ( c) as & Timeblock ) ) ;
3082+ // Create a new channel to pass the reference tuple to peel_thread
3083+ let ( tx_ref, rx_ref) = crossbeam_channel:: bounded ( 1 ) ;
3084+ for ( a, b, c_ref) in rx_full_residual_ref {
3085+ tx_ref. send ( ( a, b, c_ref) ) . unwrap ( ) ;
3086+ }
3087+ drop ( tx_ref) ;
3088+ peel_thread (
3089+ & beam,
3090+ & source_list,
3091+ & source_weighted_positions,
3092+ 1 , // num_sources_to_iono_subtract
3093+ & peel_loop_params,
3094+ & obs_context,
3095+ & obs_context. tile_xyzs ,
3096+ & peel_weight_params,
3097+ & tile_baseline_flags,
3098+ & chanblocks,
3099+ & low_res_lambdas_m,
3100+ apply_precession,
3101+ Some ( & output_vis_params) , // pass output_vis_params
3102+ rx_ref,
3103+ tx_write,
3104+ tx_iono_consts,
3105+ & error,
3106+ & multi_progress,
3107+ & multi_progress. add ( ProgressBar :: new ( 1 ) ) ,
3108+ )
3109+ } ) ;
3110+
3111+ // Collect all written weights
3112+ let mut written_weights = Vec :: new ( ) ;
3113+ while let Ok ( vis_timestep) = rx_write. recv ( ) {
3114+ written_weights. push ( vis_timestep. cross_weights_fb ) ;
3115+ }
3116+
3117+ // Wait for peel_thread to complete
3118+ let result = peel_handle. join ( ) . unwrap ( ) ;
3119+
3120+ // Verify peel_thread completed successfully
3121+ assert ! ( result. is_ok( ) , "peel_thread should complete successfully" ) ;
3122+
3123+ // Verify that we received written weights
3124+ assert ! (
3125+ !written_weights. is_empty( ) ,
3126+ "Should have received written weights"
3127+ ) ;
3128+
3129+ // Verify that the written weights match the original weights
3130+ // The written weights should be the original weights, not the tapered ones
3131+ for ( i, written_weight_fb) in written_weights. iter ( ) . enumerate ( ) {
3132+ let original_weight_fb = original_weights_clone. slice ( ndarray:: s![ i, .., ..] ) ;
3133+ let written_slice = written_weight_fb. as_slice ( ) . unwrap ( ) ;
3134+ let original_slice = original_weight_fb. as_slice ( ) . unwrap ( ) ;
3135+ assert_abs_diff_eq ! ( written_slice, original_slice, epsilon = 1e-6 ) ;
3136+ }
3137+ }
0 commit comments