diff --git a/src/core/array.rs b/src/core/array.rs index 2d1119213..19519c9fc 100644 --- a/src/core/array.rs +++ b/src/core/array.rs @@ -189,16 +189,20 @@ where /// An example of creating an Array from half::f16 array /// /// ```rust - /// use arrayfire::{Array, Dim4, print}; + /// use arrayfire::{Array, Dim4, is_half_available, print}; /// use half::f16; /// /// let values: [f32; 3] = [1.0, 2.0, 3.0]; /// - /// let half_values = values.iter().map(|&x| f16::from_f32(x)).collect::>(); + /// if is_half_available(0) { // Default device is 0, hence the argument + /// let half_values = values.iter().map(|&x| f16::from_f32(x)).collect::>(); /// - /// let hvals = Array::new(&half_values, Dim4::new(&[3, 1, 1, 1])); + /// let hvals = Array::new(&half_values, Dim4::new(&[3, 1, 1, 1])); /// - /// print(&hvals); + /// print(&hvals); + /// } else { + /// println!("Half support isn't available on this device"); + /// } /// ``` /// pub fn new(slice: &[T], dims: Dim4) -> Self { diff --git a/src/core/data.rs b/src/core/data.rs index 22dfc0425..56ea07195 100644 --- a/src/core/data.rs +++ b/src/core/data.rs @@ -948,9 +948,9 @@ pub fn pad( let err_val = af_pad( &mut temp as *mut af_array, input.get(), - begin.ndims() as c_uint, + 4, begin.get().as_ptr() as *const dim_t, - end.ndims() as c_uint, + 4, end.get().as_ptr() as *const dim_t, fill_type as c_uint, ); @@ -963,7 +963,9 @@ pub fn pad( mod tests { use super::reorder_v2; + use super::super::defines::BorderType; use super::super::random::randu; + use super::pad; use crate::dim4; @@ -976,4 +978,12 @@ mod tests { let _swap_1_2 = reorder_v2(&a, 0, 2, Some(vec![1])); let _swap_0_3 = reorder_v2(&a, 3, 1, Some(vec![2, 0])); } + + #[test] + fn check_pad_api() { + let a = randu::(dim4![3, 3]); + let begin_dims = dim4!(0, 0, 0, 0); + let end_dims = dim4!(2, 2, 0, 0); + let _padded = pad(&a, begin_dims, end_dims, BorderType::ZERO); + } } diff --git a/src/core/device.rs b/src/core/device.rs index 3d313e421..61f856e23 100644 --- a/src/core/device.rs +++ b/src/core/device.rs @@ -36,6 +36,7 @@ extern "C" { fn af_alloc_pinned(non_pagable_ptr: *mut void_ptr, bytes: dim_t) -> c_int; fn af_free_pinned(non_pagable_ptr: void_ptr) -> c_int; + fn af_get_half_support(available: *mut c_int, device: c_int) -> c_int; } /// Get ArrayFire Version Number @@ -331,3 +332,21 @@ pub unsafe fn free_pinned(ptr: void_ptr) { let err_val = af_free_pinned(ptr); HANDLE_ERROR(AfError::from(err_val)); } + +/// Check if a device has half support +/// +/// # Parameters +/// +/// - `device` is the device for which half precision support is checked for +/// +/// # Return Values +/// +/// `True` if `device` device has half support, `False` otherwise. +pub fn is_half_available(device: i32) -> bool { + unsafe { + let mut temp: i32 = 0; + let err_val = af_get_half_support(&mut temp as *mut c_int, device as c_int); + HANDLE_ERROR(AfError::from(err_val)); + temp > 0 + } +} diff --git a/src/core/index.rs b/src/core/index.rs index dc2b711ae..6552d76bd 100644 --- a/src/core/index.rs +++ b/src/core/index.rs @@ -298,10 +298,10 @@ pub fn set_row(inout: &mut Array, new_row: &Array, row_num: u64) where T: HasAfEnum, { - let seqs = [ - Seq::new(row_num as f64, row_num as f64, 1.0), - Seq::default(), - ]; + let mut seqs = vec![Seq::new(row_num as f64, row_num as f64, 1.0)]; + if inout.dims().ndims() > 1 { + seqs.push(Seq::default()); + } assign_seq(inout, &seqs, new_row) }