Skip to content

Commit e858202

Browse files
committed
v0.6.1 add peel weight preservation test
1 parent ca3c840 commit e858202

File tree

1 file changed

+213
-2
lines changed

1 file changed

+213
-2
lines changed

src/params/peel/tests.rs

Lines changed: 213 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44

55
//! Tests against peeling
66
7-
use std::{collections::HashSet, f64::consts::TAU};
7+
use std::{collections::HashSet, f64::consts::TAU, num::NonZeroUsize, path::PathBuf};
88

99
use approx::assert_abs_diff_eq;
10+
use crossbeam_utils::atomic::AtomicCell;
1011
use hifitime::{Duration, Epoch};
1112
use indexmap::indexmap;
12-
use indicatif::{MultiProgress, ProgressDrawTarget};
13+
use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget};
1314
use itertools::{izip, Itertools};
1415
use marlu::{
1516
constants::VEL_C,
@@ -27,6 +28,7 @@ use crate::{
2728
beam::{Delays, FEEBeam},
2829
context::{ObsContext, Polarisations},
2930
io::read::VisInputType,
31+
io::write::VisOutputType,
3032
model::{new_sky_modeller, SkyModellerCpu},
3133
srclist::{ComponentType, FluxDensity, FluxDensityType, Source, SourceComponent, SourceList},
3234
};
@@ -2924,3 +2926,212 @@ mod gpu_tests {
29242926
}
29252927
}
29262928
}
2929+
2930+
#[test]
2931+
fn test_peel_weight_preservation() {
2932+
// Test that the original weights are preserved and not modified by tapering
2933+
let apply_precession = true;
2934+
2935+
let obs_context = get_phase1_obs_context(CPU_TILE_LIMIT);
2936+
2937+
let array_pos = obs_context.array_position;
2938+
let num_tiles = obs_context.get_total_num_tiles();
2939+
let num_times = obs_context.timestamps.len();
2940+
let num_baselines = (num_tiles * (num_tiles - 1)) / 2;
2941+
let num_chans = obs_context.fine_chan_freqs.len();
2942+
2943+
let chanblocks = obs_context
2944+
.fine_chan_freqs
2945+
.iter()
2946+
.enumerate()
2947+
.map(|(i, f)| Chanblock {
2948+
chanblock_index: i as u16,
2949+
unflagged_index: i as u16,
2950+
freq: *f as f64,
2951+
})
2952+
.collect_vec();
2953+
2954+
let fine_chan_freqs_hz = obs_context
2955+
.fine_chan_freqs
2956+
.iter()
2957+
.map(|&f| f as f64)
2958+
.collect_vec();
2959+
let _lambdas_m = fine_chan_freqs_hz.iter().map(|&f| VEL_C / f).collect_vec();
2960+
let avg_freq = 4;
2961+
let low_res_lambdas_m = obs_context
2962+
.fine_chan_freqs
2963+
.as_slice()
2964+
.chunks(avg_freq)
2965+
.map(|chunk| {
2966+
let f = chunk.iter().sum::<u64>() as f64 / chunk.len() as f64;
2967+
VEL_C / f
2968+
})
2969+
.collect_vec();
2970+
2971+
let lst_0h_rad = get_lmst(
2972+
array_pos.longitude_rad,
2973+
obs_context.timestamps[0],
2974+
obs_context.dut1.unwrap_or_default(),
2975+
);
2976+
let source_radec =
2977+
RADec::from_hadec(HADec::from_radians(0.2, array_pos.latitude_rad), lst_0h_rad);
2978+
let source_fd = 1.;
2979+
let source_list = SourceList::from([(
2980+
"One".into(),
2981+
point_src_i!(source_radec, 0., fine_chan_freqs_hz[0], source_fd),
2982+
)]);
2983+
2984+
let beam = get_beam(num_tiles);
2985+
2986+
// Create original weights with some non-uniform values to make changes detectable
2987+
let original_weights = {
2988+
let mut weights = Array3::<f32>::ones((num_times, num_chans, num_baselines));
2989+
// Set some weights to different values to make changes detectable
2990+
for i in 0..num_times {
2991+
for j in 0..num_chans {
2992+
for k in 0..num_baselines {
2993+
weights[[i, j, k]] = 1.0 + (i + j + k) as f32 * 0.1;
2994+
}
2995+
}
2996+
}
2997+
weights
2998+
};
2999+
3000+
let timeblock = Timeblock {
3001+
index: 0,
3002+
range: 0..2,
3003+
timestamps: obs_context.timestamps.clone(),
3004+
timesteps: vec1![0, 1],
3005+
median: obs_context.timestamps[0],
3006+
};
3007+
3008+
let source_weighted_positions = [source_radec];
3009+
3010+
let multi_progress = MultiProgress::with_draw_target(ProgressDrawTarget::hidden());
3011+
3012+
let peel_weight_params = PeelWeightParams {
3013+
uvw_min_metres: 0.0,
3014+
uvw_max_metres: f64::MAX,
3015+
short_baseline_sigma: SHORT_BASELINE_SIGMA,
3016+
};
3017+
3018+
let flagged_tiles: HashSet<_> = obs_context.flagged_tiles.iter().cloned().collect();
3019+
let tile_baseline_flags = TileBaselineFlags::new(num_tiles, flagged_tiles);
3020+
3021+
let peel_loop_params = PeelLoopParams {
3022+
num_passes: NonZeroUsize::try_from(NUM_PASSES).expect("NUM_PASSES > 0"),
3023+
num_loops: NonZeroUsize::try_from(NUM_LOOPS).expect("NUM_LOOPS > 0"),
3024+
convergence: CONVERGENCE,
3025+
};
3026+
3027+
// Create a copy of weights for testing
3028+
let mut test_weights = original_weights.clone();
3029+
3030+
// Apply tapering to the test weights
3031+
peel_weight_params.apply_tfb(
3032+
test_weights.view_mut(),
3033+
&obs_context,
3034+
&timeblock,
3035+
apply_precession,
3036+
&chanblocks,
3037+
&tile_baseline_flags,
3038+
);
3039+
3040+
// Verify that the original weights are different from the tapered weights
3041+
// (this proves that apply_tfb actually modifies the weights)
3042+
assert!(!original_weights.abs_diff_eq(&test_weights, 1e-6));
3043+
3044+
// Now test the actual peel_thread function
3045+
// We need to set up the channels for peel_thread
3046+
let (tx_full_residual, rx_full_residual) = crossbeam_channel::bounded(1);
3047+
let (tx_write, rx_write) = crossbeam_channel::bounded(10); // Buffer for written data
3048+
let (tx_iono_consts, _rx_iono_consts) = crossbeam_channel::bounded(1);
3049+
let error = AtomicCell::new(false);
3050+
3051+
// Clone original_weights before moving into thread
3052+
let original_weights_clone = original_weights.clone();
3053+
3054+
// Send test data to the channel
3055+
let test_vis_residual: Array3<Jones<f32>> =
3056+
Array3::zeros((num_times, num_chans, num_baselines));
3057+
let test_weights_for_thread = original_weights.clone();
3058+
let timeblock_owned = timeblock.clone();
3059+
tx_full_residual
3060+
.send((test_vis_residual, test_weights_for_thread, timeblock_owned))
3061+
.unwrap();
3062+
drop(tx_full_residual); // Close the sender
3063+
3064+
// Spawn peel_thread in a separate thread so we can read from rx_write
3065+
let output_vis_params = OutputVisParams {
3066+
output_files: vec1![(
3067+
PathBuf::from("/tmp/test_output.uvfits"),
3068+
VisOutputType::Uvfits
3069+
)],
3070+
output_time_average_factor: NonZeroUsize::new(1).unwrap(),
3071+
output_freq_average_factor: NonZeroUsize::new(1).unwrap(),
3072+
output_timeblocks: vec1![timeblock.clone()],
3073+
write_smallest_contiguous_band: false,
3074+
};
3075+
let output_vis_params = output_vis_params;
3076+
let peel_handle = std::thread::spawn(move || {
3077+
// Wrap the receiver to convert owned Timeblock to reference
3078+
let rx_full_residual_ref = rx_full_residual
3079+
.into_iter()
3080+
.map(|(a, b, c)| (a, b, Box::new(c)))
3081+
.map(|(a, b, c)| (a, b, Box::leak(c) as &Timeblock));
3082+
// Create a new channel to pass the reference tuple to peel_thread
3083+
let (tx_ref, rx_ref) = crossbeam_channel::bounded(1);
3084+
for (a, b, c_ref) in rx_full_residual_ref {
3085+
tx_ref.send((a, b, c_ref)).unwrap();
3086+
}
3087+
drop(tx_ref);
3088+
peel_thread(
3089+
&beam,
3090+
&source_list,
3091+
&source_weighted_positions,
3092+
1, // num_sources_to_iono_subtract
3093+
&peel_loop_params,
3094+
&obs_context,
3095+
&obs_context.tile_xyzs,
3096+
&peel_weight_params,
3097+
&tile_baseline_flags,
3098+
&chanblocks,
3099+
&low_res_lambdas_m,
3100+
apply_precession,
3101+
Some(&output_vis_params), // pass output_vis_params
3102+
rx_ref,
3103+
tx_write,
3104+
tx_iono_consts,
3105+
&error,
3106+
&multi_progress,
3107+
&multi_progress.add(ProgressBar::new(1)),
3108+
)
3109+
});
3110+
3111+
// Collect all written weights
3112+
let mut written_weights = Vec::new();
3113+
while let Ok(vis_timestep) = rx_write.recv() {
3114+
written_weights.push(vis_timestep.cross_weights_fb);
3115+
}
3116+
3117+
// Wait for peel_thread to complete
3118+
let result = peel_handle.join().unwrap();
3119+
3120+
// Verify peel_thread completed successfully
3121+
assert!(result.is_ok(), "peel_thread should complete successfully");
3122+
3123+
// Verify that we received written weights
3124+
assert!(
3125+
!written_weights.is_empty(),
3126+
"Should have received written weights"
3127+
);
3128+
3129+
// Verify that the written weights match the original weights
3130+
// The written weights should be the original weights, not the tapered ones
3131+
for (i, written_weight_fb) in written_weights.iter().enumerate() {
3132+
let original_weight_fb = original_weights_clone.slice(ndarray::s![i, .., ..]);
3133+
let written_slice = written_weight_fb.as_slice().unwrap();
3134+
let original_slice = original_weight_fb.as_slice().unwrap();
3135+
assert_abs_diff_eq!(written_slice, original_slice, epsilon = 1e-6);
3136+
}
3137+
}

0 commit comments

Comments
 (0)