-
Notifications
You must be signed in to change notification settings - Fork 25
Closed
Description
using Reactant
using Enzyme
function double_gyre_model()
v = Reactant.to_rarray(ones(78, 78, 31))
pv02 = Reactant.to_rarray(ones(78, 78, 31))
return (v, pv02)
end
function wind_stress_init()
res = ones(63, 63)
res = Reactant.to_rarray(res)
return res
end
function estimate_tracer_error(v, pv02, wind_stress)
u = similar(v, 63, 63, 16)
fill!(u, 0)
copyto!(@view(v[8:end-8, 8:end-8, 15]), wind_stress)
v0 = copy(v)
pv2 = similar(u, 78, 78, 31)
fill!(pv2, 0)
@trace track_numbers=false for _ = 1:3
copyto!(@view(v[8:end-8, 8:end-8, 8:end-8]), Reactant.Ops.add(v[8:end-8, 8:end-8, 8:end-8], u))
copyto!(u, v[9:end-7, 7:end-9, 8:end-8])
copyto!(@view(u[:, :, 2]), Reactant.Ops.add(u[:, :, 2], u[:, :, 8]))
sVp = Reactant.TracedUtils.broadcast_to_size(v[8:end-8, 8:end-8, 9], size(v[8:end-8, 8:end-8, 8:end-8]))
copyto!(@view(v[8:end-8, 8:end-8, 8:end-8]), sVp)
copyto!(@view(pv2[8:end-8, 8:end-8, 8:end-8]), sVp)
end
copyto!(pv02, pv2)
# adding this fixes
# copyto!(v, v0)
mean_sq_surface_u = sum(u)
return mean_sq_surface_u
end
function estimate_tracer_error(model, wind_stress)
estimate_tracer_error(model[1], model[2], wind_stress)
end
function differentiate_tracer_error(model, J, dJ)
v = model[1]
pv02 = model[2]
dv = zero(v)
dpv02 = zero(pv02)
dedν = autodiff(set_strong_zero(Enzyme.Reverse),
estimate_tracer_error, Active,
Duplicated(v, dv),
Duplicated(pv02, dpv02),
Duplicated(J, dJ))
return dedν, dJ
end
rmodel = double_gyre_model()
rwind_stress = wind_stress_init()
@info "Compiling..."
dJ = make_zero(rwind_stress) # Field{Face, Center, Nothing}(rmodel.grid)
pre_pipeline = "mark-func-memory-effects,inline{default-pipeline=canonicalize max-iterations=4},propagate-constant-bounds,sroa-wrappers{instcombine=false instsimplify=true },canonicalize,sroa-wrappers{instcombine=false instsimplify=true },libdevice-funcs-raise,canonicalize,remove-duplicate-func-def,canonicalize,cse,canonicalize,enzyme-hlo-generate-td{patterns=compare_op_canon<16>;transpose_transpose<16>;broadcast_in_dim_op_canon<16>;convert_op_canon<16>;dynamic_broadcast_in_dim_op_not_actually_dynamic<16>;chained_dynamic_broadcast_in_dim_canonicalization<16>;dynamic_broadcast_in_dim_all_dims_non_expanding<16>;noop_reduce_op_canon<16>;empty_reduce_op_canon<16>;dynamic_reshape_op_canon<16>;get_tuple_element_op_canon<16>;real_op_canon<16>;imag_op_canon<16>;conj_complex_negate<16>;get_dimension_size_op_canon<16>;gather_op_canon<16>;reshape_op_canon<16>;merge_consecutive_reshapes<16>;transpose_is_reshape<16>;zero_extent_tensor_canon<16>;cse_broadcast_in_dim<16>;cse_slice<16>;cse_transpose<16>;cse_convert<16>;cse_pad<16>;cse_dot_general<16>;cse_reshape<16>;cse_mul<16>;cse_div<16>;cse_add<16>;cse_subtract<16>;cse_min<16>;cse_max<16>;cse_neg<16>;cse_abs<16>;cse_concatenate<16>;concatenate_op_canon<16>(1024);select_op_canon<16>(1024);add_simplify<16>;sub_simplify<16>;and_simplify<16>;max_simplify<16>;min_simplify<16>;or_simplify<16>;xor_simplify<16>;mul_simplify<16>;div_simplify<16>;rem_simplify<16>;pow_simplify<16>;simplify_extend<16>;simplify_wrap<16>;simplify_rotate<16>;noop_slice<16>;noop_reverse<16>;slice_slice<16>;shift_right_logical_simplify<16>;pad_simplify<16>(1024);select_pad_to_dus<1>;and_pad_pad<1>;negative_pad_to_slice<16>;slice_simplify<16>;convert_simplify<16>;dynamic_slice_to_static<16>;dynamic_update_slice_elim<16>;concat_to_broadcast<16>;reduce_to_reshape<16>;broadcast_to_reshape<16>;slice_internal;iota_simplify<16>(1024);broadcast_in_dim_simplify<16>(1024);convert_concat<1>;dynamic_update_to_concat<1>;slice_of_dynamic_update<1>;slice_elementwise<1>;slice_pad<1>;dot_reshape_dot<1>;concat_fuse<1>;pad_reshape_pad<1>;pad_pad<1>;concat_push_binop_add<1>;concat_push_binop_mul<1>;scatter_to_dynamic_update_slice<1>;reduce_concat<1>;slice_concat<1>;concat_slice<1>;select_op_used_within_if<1>;bin_broadcast_splat_add<1>;bin_broadcast_splat_subtract<1>;bin_broadcast_splat_div<1>;bin_broadcast_splat_mul<1>;dot_general_simplify<16>;transpose_simplify<16>;reshape_empty_broadcast<1>;add_pad_pad_to_concat<1>;broadcast_reshape<1>;concat_pad<1>;reduce_pad<1>;broadcast_pad<1>;zero_product_reshape_pad<1>;mul_zero_pad<1>;div_zero_pad<1>;binop_const_reshape_pad<1>;binop_const_pad_add<1>;binop_const_pad_subtract<1>;binop_const_pad_mul<1>;binop_const_pad_div<1>;binop_binop_pad_pad_add<1>;binop_binop_pad_pad_mul<1>;binop_pad_pad_add<1>;binop_pad_pad_subtract<1>;binop_pad_pad_mul<1>;binop_pad_pad_div<1>;binop_pad_pad_min<1>;binop_pad_pad_max<1>;unary_pad_push_convert<1>;unary_pad_push_tanh<1>;unary_pad_push_exp<1>;transpose_dot_reorder<1>;dot_transpose<1>;transpose_convolution<1>;convolution_transpose<1>;convert_convert_float<1>;concat_to_pad<1>;reshape_iota<1>;broadcast_reduce<1>;slice_dot_general<1>;if_inline<1>;if_to_select<1>;dynamic_gather_op_is_not_dynamic<16>;divide_sqrt_to_multiply_rsqrt<16>;associative_binary_op_reordering<1>;transpose_broadcast_in_dim_to_broadcast_in_dim<16>;replace_neg_add_with_subtract;binop_const_simplify;not_select_simplify;common_compare_expression_rewrite;compare_select_simplify;while_simplify<1>(1);if_remove_unused;transpose_reshape_to_broadcast;reshape_transpose_to_broadcast;dus_dus;dus_dus_concat;abs_positive_simplify;transpose_unary_transpose_abs;transpose_unary_transpose_neg;transpose_unary_transpose_sqrt;transpose_unary_transpose_rsqrt;transpose_unary_transpose_ceil;transpose_unary_transpose_convert;transpose_unary_transpose_cosine;transpose_unary_transpose_exp;transpose_unary_transpose_expm1;transpose_unary_transpose_log;transpose_unary_transpose_log1p;transpose_unary_transpose_sign;transpose_unary_transpose_sine;transpose_unary_transpose_tanh;select_comp_iota_const_simplify<1>;sign_abs_simplify<1>;broadcastindim_is_reshape;slice_reduce_window<1>;while_deadresult;while_dus;dus_licm(0);while_op_induction_replacement;dus_pad;dus_concat;slice_dus_to_concat;while_induction_reduction;slice_licm(0);pad_licm(0);elementwise_licm(0);concatenate_licm(0);slice_broadcast;while_pad_induction_reduction;while_licm<1>(1);associative_common_mul_op_reordering;slice_select_to_select_slice;pad_concat_to_concat_pad;slice_if;dus_to_i32;rotate_pad;slice_extend;concat_wrap;cse_extend<16>;cse_wrap<16>;cse_rotate<16>;cse_rotate<16>;concat_concat_axis_swap;concat_multipad;concat_concat_to_dus;speculate_if_pad_to_select;broadcast_iota_simplify;select_comp_iota_to_dus;compare_cleanup;broadcast_compare;not_compare;broadcast_iota;cse_iota;compare_iota_const_simplify;reshuffle_ands_compares;square_abs_simplify;divide_divide_simplify;concat_reshape_slice;full_reduce_reshape_or_transpose;concat_reshape_reduce;concat_elementwise;reduce_reduce;conj_real;select_broadcast_in_dim;if_op_lift_common_ops;involution_neg_simplify;involution_conj_simplify;involution_not_simplify;real_conj_simplify;conj_complex_simplify;split_convolution_into_reverse_convolution;scatter_multiply_simplify;unary_elementwise_scatter_simplify;gather_elementwise;chlo_inf_const_prop<16>;gamma_const_prop<16>;abs_const_prop<16>;log_const_prop<1>;log_plus_one_const_prop<1>;is_finite_const_prop;not_const_prop;neg_const_prop;sqrt_const_prop;rsqrt_const_prop;cos_const_prop;sin_const_prop;exp_const_prop;expm1_const_prop;tanh_const_prop;logistic_const_prop;conj_const_prop;ceil_const_prop;cbrt_const_prop;real_const_prop;imag_const_prop;round_const_prop;round_nearest_even_const_prop;sign_const_prop;floor_const_prop;tan_const_prop;add_const_prop;and_const_prop;atan2_const_prop;complex_const_prop;div_const_prop;max_const_prop;min_const_prop;mul_const_prop;or_const_prop;pow_const_prop;rem_const_prop;sub_const_prop;xor_const_prop;const_prop_through_barrier<16>;concat_const_prop<1>(1024);dynamic_update_slice_const_prop(1024);scatter_update_computation_const_prop;gather_const_prop;dus_slice_simplify;reshape_concat;reshape_dus;dot_reshape_pad<1>;pad_dot_general<1>(0);pad_dot_general<1>(1);reshape_pad;reshape_wrap;reshape_rotate;reshape_extend;reshape_slice(1);reshape_elementwise(1);transpose_select;transpose_while;transpose_slice;transpose_concat;transpose_iota;transpose_reduce;transpose_reduce_window;transpose_dus;transpose_pad<1>;transpose_einsum<1>;transpose_wrap;transpose_extend;transpose_rotate;transpose_dynamic_slice;transpose_reverse;transpose_batch_norm_training;transpose_batch_norm_inference;transpose_batch_norm_grad;transpose_if;transpose_elementwise(1);no_nan_add_sub_simplify(0);lower_extend;lower_wrap;lower_rotate},transform-interpreter,enzyme-hlo-remove-transform,lower-kernel{backend=cpu},canonicalize,canonicalize,llvm-to-memref-access,canonicalize,convert-llvm-to-cf,canonicalize,enzyme-lift-cf-to-scf,canonicalize,func.func(canonicalize-loops),canonicalize-scf-for,canonicalize,libdevice-funcs-raise,canonicalize,affine-cfg,canonicalize,func.func(canonicalize-loops),canonicalize,llvm-to-affine-access,canonicalize,delinearize-indexing,canonicalize,simplify-affine-exprs,affine-cfg,canonicalize,func.func(affine-loop-invariant-code-motion),canonicalize,sort-memory,raise-affine-to-stablehlo{prefer_while_raising=false dump_failed_lockstep=false},canonicalize,arith-raise{stablehlo=true},inline{default-pipeline=canonicalize max-iterations=4},canonicalize,cse,canonicalize,enzyme-batch,inline{default-pipeline=canonicalize max-iterations=4},enzyme-hlo-generate-td{patterns=compare_op_canon<16>;transpose_transpose<16>;broadcast_in_dim_op_canon<16>;convert_op_canon<16>;dynamic_broadcast_in_dim_op_not_actually_dynamic<16>;chained_dynamic_broadcast_in_dim_canonicalization<16>;dynamic_broadcast_in_dim_all_dims_non_expanding<16>;noop_reduce_op_canon<16>;empty_reduce_op_canon<16>;dynamic_reshape_op_canon<16>;get_tuple_element_op_canon<16>;real_op_canon<16>;imag_op_canon<16>;conj_complex_negate<16>;get_dimension_size_op_canon<16>;gather_op_canon<16>;reshape_op_canon<16>;merge_consecutive_reshapes<16>;transpose_is_reshape<16>;zero_extent_tensor_canon<16>;cse_broadcast_in_dim<16>;cse_slice<16>;cse_transpose<16>;cse_convert<16>;cse_pad<16>;cse_dot_general<16>;cse_reshape<16>;cse_mul<16>;cse_div<16>;cse_add<16>;cse_subtract<16>;cse_min<16>;cse_max<16>;cse_neg<16>;cse_abs<16>;cse_concatenate<16>;concatenate_op_canon<16>(1024);select_op_canon<16>(1024);add_simplify<16>;sub_simplify<16>;and_simplify<16>;max_simplify<16>;min_simplify<16>;or_simplify<16>;xor_simplify<16>;mul_simplify<16>;div_simplify<16>;rem_simplify<16>;pow_simplify<16>;simplify_extend<16>;simplify_wrap<16>;simplify_rotate<16>;noop_slice<16>;noop_reverse<16>;slice_slice<16>;shift_right_logical_simplify<16>;pad_simplify<16>(1024);select_pad_to_dus<1>;and_pad_pad<1>;negative_pad_to_slice<16>;slice_simplify<16>;convert_simplify<16>;dynamic_slice_to_static<16>;dynamic_update_slice_elim<16>;concat_to_broadcast<16>;reduce_to_reshape<16>;broadcast_to_reshape<16>;slice_internal;iota_simplify<16>(1024);broadcast_in_dim_simplify<16>(1024);convert_concat<1>;dynamic_update_to_concat<1>;slice_of_dynamic_update<1>;slice_elementwise<1>;slice_pad<1>;dot_reshape_dot<1>;concat_fuse<1>;pad_reshape_pad<1>;pad_pad<1>;concat_push_binop_add<1>;concat_push_binop_mul<1>;scatter_to_dynamic_update_slice<1>;reduce_concat<1>;slice_concat<1>;concat_slice<1>;select_op_used_within_if<1>;bin_broadcast_splat_add<1>;bin_broadcast_splat_subtract<1>;bin_broadcast_splat_div<1>;bin_broadcast_splat_mul<1>;dot_general_simplify<16>;transpose_simplify<16>;reshape_empty_broadcast<1>;add_pad_pad_to_concat<1>;broadcast_reshape<1>;concat_pad<1>;reduce_pad<1>;broadcast_pad<1>;zero_product_reshape_pad<1>;mul_zero_pad<1>;div_zero_pad<1>;binop_const_reshape_pad<1>;binop_const_pad_add<1>;binop_const_pad_subtract<1>;binop_const_pad_mul<1>;binop_const_pad_div<1>;binop_binop_pad_pad_add<1>;binop_binop_pad_pad_mul<1>;binop_pad_pad_add<1>;binop_pad_pad_subtract<1>;binop_pad_pad_mul<1>;binop_pad_pad_div<1>;binop_pad_pad_min<1>;binop_pad_pad_max<1>;unary_pad_push_convert<1>;unary_pad_push_tanh<1>;unary_pad_push_exp<1>;transpose_dot_reorder<1>;dot_transpose<1>;transpose_convolution<1>;convolution_transpose<1>;convert_convert_float<1>;concat_to_pad<1>;reshape_iota<1>;broadcast_reduce<1>;slice_dot_general<1>;if_inline<1>;if_to_select<1>;dynamic_gather_op_is_not_dynamic<16>;divide_sqrt_to_multiply_rsqrt<16>;associative_binary_op_reordering<1>;transpose_broadcast_in_dim_to_broadcast_in_dim<16>;replace_neg_add_with_subtract;binop_const_simplify;not_select_simplify;common_compare_expression_rewrite;compare_select_simplify;while_simplify<1>(1);if_remove_unused;transpose_reshape_to_broadcast;reshape_transpose_to_broadcast;dus_dus;dus_dus_concat;abs_positive_simplify;transpose_unary_transpose_abs;transpose_unary_transpose_neg;transpose_unary_transpose_sqrt;transpose_unary_transpose_rsqrt;transpose_unary_transpose_ceil;transpose_unary_transpose_convert;transpose_unary_transpose_cosine;transpose_unary_transpose_exp;transpose_unary_transpose_expm1;transpose_unary_transpose_log;transpose_unary_transpose_log1p;transpose_unary_transpose_sign;transpose_unary_transpose_sine;transpose_unary_transpose_tanh;select_comp_iota_const_simplify<1>;sign_abs_simplify<1>;broadcastindim_is_reshape;slice_reduce_window<1>;while_deadresult;while_dus;dus_licm(0);while_op_induction_replacement;dus_pad;dus_concat;slice_dus_to_concat;while_induction_reduction;slice_licm(0);pad_licm(0);elementwise_licm(0);concatenate_licm(0);slice_broadcast;while_pad_induction_reduction;while_licm<1>(1);associative_common_mul_op_reordering;slice_select_to_select_slice;pad_concat_to_concat_pad;slice_if;dus_to_i32;rotate_pad;slice_extend;concat_wrap;cse_extend<16>;cse_wrap<16>;cse_rotate<16>;cse_rotate<16>;concat_concat_axis_swap;concat_multipad;concat_concat_to_dus;speculate_if_pad_to_select;broadcast_iota_simplify;select_comp_iota_to_dus;compare_cleanup;broadcast_compare;not_compare;broadcast_iota;cse_iota;compare_iota_const_simplify;reshuffle_ands_compares;square_abs_simplify;divide_divide_simplify;concat_reshape_slice;full_reduce_reshape_or_transpose;concat_reshape_reduce;concat_elementwise;reduce_reduce;conj_real;select_broadcast_in_dim;if_op_lift_common_ops;involution_neg_simplify;involution_conj_simplify;involution_not_simplify;real_conj_simplify;conj_complex_simplify;split_convolution_into_reverse_convolution;scatter_multiply_simplify;unary_elementwise_scatter_simplify;gather_elementwise;chlo_inf_const_prop<16>;gamma_const_prop<16>;abs_const_prop<16>;log_const_prop<1>;log_plus_one_const_prop<1>;is_finite_const_prop;not_const_prop;neg_const_prop;sqrt_const_prop;rsqrt_const_prop;cos_const_prop;sin_const_prop;exp_const_prop;expm1_const_prop;tanh_const_prop;logistic_const_prop;conj_const_prop;ceil_const_prop;cbrt_const_prop;real_const_prop;imag_const_prop;round_const_prop;round_nearest_even_const_prop;sign_const_prop;floor_const_prop;tan_const_prop;add_const_prop;and_const_prop;atan2_const_prop;complex_const_prop;div_const_prop;max_const_prop;min_const_prop;mul_const_prop;or_const_prop;pow_const_prop;rem_const_prop;sub_const_prop;xor_const_prop;const_prop_through_barrier<16>;concat_const_prop<1>(1024);dynamic_update_slice_const_prop(1024);scatter_update_computation_const_prop;gather_const_prop;dus_slice_simplify;reshape_concat;reshape_dus;dot_reshape_pad<1>;pad_dot_general<1>(0);pad_dot_general<1>(1);reshape_pad;reshape_wrap;reshape_rotate;reshape_extend;reshape_slice(1);reshape_elementwise(1);transpose_select;transpose_while;transpose_slice;transpose_concat;transpose_iota;transpose_reduce;transpose_reduce_window;transpose_dus;transpose_pad<1>;transpose_einsum<1>;transpose_wrap;transpose_extend;transpose_rotate;transpose_dynamic_slice;transpose_reverse;transpose_batch_norm_training;transpose_batch_norm_inference;transpose_batch_norm_grad;transpose_if;transpose_elementwise(1);no_nan_add_sub_simplify(0);recognize_extend;recognize_wrap;recognize_rotate},transform-interpreter,enzyme-hlo-remove-transform,canonicalize,cse,canonicalize"
pass_pipeline = pre_pipeline * ",enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"},symbol-dce"
tic = time()
restimate_tracer_error = @compile optimize=pass_pipeline raise_first=true raise=true sync=true estimate_tracer_error(rmodel, rwind_stress)
rdifferentiate_tracer_error = @compile optimize=pass_pipeline raise_first=true raise=true sync=true differentiate_tracer_error(rmodel, rwind_stress, dJ)
println(@code_hlo optimize=pass_pipeline raise_first=true raise=true estimate_tracer_error(rmodel, rwind_stress))
println(@code_hlo optimize=pre_pipeline raise_first=true raise=true differentiate_tracer_error(rmodel, rwind_stress, dJ))
println(@code_hlo optimize=pass_pipeline raise_first=true raise=true differentiate_tracer_error(rmodel, rwind_stress, dJ))
compile_toc = time() - tic
@show compile_toc
i = 10
j = 10
dedν, dJ = rdifferentiate_tracer_error(rmodel, rwind_stress, dJ)
@allowscalar @show dJ[i, j]
# Produce finite-difference gradients for comparison:
ϵ_list = [1e-1, 1e-2, 1e-3] #, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
@allowscalar gradient_list = Array{Float64}[]
for ϵ in ϵ_list
rmodelP = double_gyre_model()
rwind_stressP = wind_stress_init()
@allowscalar diff = 2ϵ * abs(rwind_stressP[i, j])
@allowscalar rwind_stressP[i, j] = rwind_stressP[i, j] + ϵ * abs(rwind_stressP[i, j])
sq_surface_uP = restimate_tracer_error(rmodelP, rwind_stressP)
rmodelM = double_gyre_model()
rwind_stressM = wind_stress_init()
@allowscalar rwind_stressM[i, j] = rwind_stressM[i, j] - ϵ * abs(rwind_stressM[i, j])
sq_surface_uM = restimate_tracer_error(rmodelM, rwind_stressM)
dsq_surface_u = (sq_surface_uP - sq_surface_uM) / diff
@show ϵ, dsq_surface_u
end
2025-06-17 23:01:37.406072: I external/xla/xla/pjrt/pjrt_api.cc:93] PJRT_Api is set for device type tpu
[ Info: Compiling...
module @reactant_estimat... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<31x78x78xf64> {tf.aliasing_output = 1 : i32}, %arg1: tensor<31x78x78xf64> {tf.aliasing_output = 2 : i32}, %arg2: tensor<63x63xf64>) -> (tensor<f64>, tensor<31x78x78xf64>, tensor<31x78x78xf64>) {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<31x78x78xf64>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<16x63x63xf64>
%cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f64>
%c = stablehlo.constant dense<0> : tensor<i64>
%c_2 = stablehlo.constant dense<3> : tensor<i64>
%c_3 = stablehlo.constant dense<1> : tensor<i64>
%c_4 = stablehlo.constant dense<7> : tensor<i32>
%cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<63x63x16xf64>
%0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<31x78x78xf64>) -> tensor<78x78x31xf64>
%1 = stablehlo.reshape %arg2 : (tensor<63x63xf64>) -> tensor<1x63x63xf64>
%2 = stablehlo.slice %arg0 [7:14, 7:70, 7:70] : (tensor<31x78x78xf64>) -> tensor<7x63x63xf64>
%3 = stablehlo.slice %arg0 [15:23, 7:70, 7:70] : (tensor<31x78x78xf64>) -> tensor<8x63x63xf64>
%4 = stablehlo.concatenate %2, %1, %3, dim = 0 : (tensor<7x63x63xf64>, tensor<1x63x63xf64>, tensor<8x63x63xf64>) -> tensor<16x63x63xf64>
%5:4 = stablehlo.while(%iterArg = %c, %iterArg_6 = %4, %iterArg_7 = %cst_5, %iterArg_8 = %cst_0) : tensor<i64>, tensor<16x63x63xf64>, tensor<63x63x16xf64>, tensor<16x63x63xf64> attributes {enzymexla.disable_min_cut}
cond {
%9 = stablehlo.compare LT, %iterArg, %c_2 : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %9 : tensor<i1>
} do {
%9 = stablehlo.transpose %iterArg_6, dims = [2, 1, 0] : (tensor<16x63x63xf64>) -> tensor<63x63x16xf64>
%10 = stablehlo.add %iterArg, %c_3 : tensor<i64>
%11 = stablehlo.add %9, %iterArg_7 : tensor<63x63x16xf64>
%12 = stablehlo.dynamic_update_slice %0, %11, %c_4, %c_4, %c_4 : (tensor<78x78x31xf64>, tensor<63x63x16xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<78x78x31xf64>
%13 = stablehlo.transpose %12, dims = [2, 1, 0] : (tensor<78x78x31xf64>) -> tensor<31x78x78xf64>
%14 = stablehlo.slice %12 [8:71, 6:69, 8:9] : (tensor<78x78x31xf64>) -> tensor<63x63x1xf64>
%15 = stablehlo.slice %12 [8:71, 6:69, 14:15] : (tensor<78x78x31xf64>) -> tensor<63x63x1xf64>
%16 = stablehlo.add %14, %15 : tensor<63x63x1xf64>
%17 = stablehlo.slice %12 [8:71, 6:69, 7:8] : (tensor<78x78x31xf64>) -> tensor<63x63x1xf64>
%18 = stablehlo.slice %12 [8:71, 6:69, 9:23] : (tensor<78x78x31xf64>) -> tensor<63x63x14xf64>
%19 = stablehlo.concatenate %17, %16, %18, dim = 2 : (tensor<63x63x1xf64>, tensor<63x63x1xf64>, tensor<63x63x14xf64>) -> tensor<63x63x16xf64>
%20 = stablehlo.slice %13 [8:9, 7:70, 7:70] : (tensor<31x78x78xf64>) -> tensor<1x63x63xf64>
%21 = stablehlo.reshape %20 : (tensor<1x63x63xf64>) -> tensor<63x63xf64>
%22 = stablehlo.broadcast_in_dim %21, dims = [1, 2] : (tensor<63x63xf64>) -> tensor<16x63x63xf64>
stablehlo.return %10, %22, %19, %22 : tensor<i64>, tensor<16x63x63xf64>, tensor<63x63x16xf64>, tensor<16x63x63xf64>
}
%6 = stablehlo.reduce(%5#2 init: %cst_1) applies stablehlo.add across dimensions = [0, 1, 2] : (tensor<63x63x16xf64>, tensor<f64>) -> tensor<f64>
%7 = stablehlo.dynamic_update_slice %arg0, %5#1, %c_4, %c_4, %c_4 : (tensor<31x78x78xf64>, tensor<16x63x63xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<31x78x78xf64>
%8 = stablehlo.dynamic_update_slice %cst, %5#3, %c_4, %c_4, %c_4 : (tensor<31x78x78xf64>, tensor<16x63x63xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<31x78x78xf64>
return %6, %7, %8 : tensor<f64>, tensor<31x78x78xf64>, tensor<31x78x78xf64>
}
}
module @reactant_differe... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func private @"Const{typeof(estimate_tracer_error)}(Main.estimate_tracer_error)_autodiff"(%arg0: tensor<31x78x78xf64>, %arg1: tensor<31x78x78xf64>, %arg2: tensor<63x63xf64>) -> (tensor<f64>, tensor<31x78x78xf64>, tensor<31x78x78xf64>, tensor<63x63xf64>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<31x78x78xf64>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<16x63x63xf64>
%cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f64>
%c = stablehlo.constant dense<0> : tensor<i64>
%c_2 = stablehlo.constant dense<3> : tensor<i64>
%c_3 = stablehlo.constant dense<1> : tensor<i64>
%c_4 = stablehlo.constant dense<7> : tensor<i32>
%cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<63x63x16xf64>
%0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<31x78x78xf64>) -> tensor<78x78x31xf64>
%1 = stablehlo.reshape %arg2 : (tensor<63x63xf64>) -> tensor<1x63x63xf64>
%2 = stablehlo.slice %arg0 [7:14, 7:70, 7:70] : (tensor<31x78x78xf64>) -> tensor<7x63x63xf64>
%3 = stablehlo.slice %arg0 [15:23, 7:70, 7:70] : (tensor<31x78x78xf64>) -> tensor<8x63x63xf64>
%4 = stablehlo.concatenate %2, %1, %3, dim = 0 : (tensor<7x63x63xf64>, tensor<1x63x63xf64>, tensor<8x63x63xf64>) -> tensor<16x63x63xf64>
%5:4 = stablehlo.while(%iterArg = %c, %iterArg_6 = %4, %iterArg_7 = %cst_5, %iterArg_8 = %cst_0) : tensor<i64>, tensor<16x63x63xf64>, tensor<63x63x16xf64>, tensor<16x63x63xf64> attributes {enzymexla.disable_min_cut}
cond {
%9 = stablehlo.compare LT, %iterArg, %c_2 : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %9 : tensor<i1>
} do {
%9 = stablehlo.transpose %iterArg_6, dims = [2, 1, 0] : (tensor<16x63x63xf64>) -> tensor<63x63x16xf64>
%10 = stablehlo.add %iterArg, %c_3 : tensor<i64>
%11 = stablehlo.add %9, %iterArg_7 : tensor<63x63x16xf64>
%12 = stablehlo.dynamic_update_slice %0, %11, %c_4, %c_4, %c_4 : (tensor<78x78x31xf64>, tensor<63x63x16xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<78x78x31xf64>
%13 = stablehlo.transpose %12, dims = [2, 1, 0] : (tensor<78x78x31xf64>) -> tensor<31x78x78xf64>
%14 = stablehlo.slice %12 [8:71, 6:69, 8:9] : (tensor<78x78x31xf64>) -> tensor<63x63x1xf64>
%15 = stablehlo.slice %12 [8:71, 6:69, 14:15] : (tensor<78x78x31xf64>) -> tensor<63x63x1xf64>
%16 = stablehlo.add %14, %15 : tensor<63x63x1xf64>
%17 = stablehlo.slice %12 [8:71, 6:69, 7:8] : (tensor<78x78x31xf64>) -> tensor<63x63x1xf64>
%18 = stablehlo.slice %12 [8:71, 6:69, 9:23] : (tensor<78x78x31xf64>) -> tensor<63x63x14xf64>
%19 = stablehlo.concatenate %17, %16, %18, dim = 2 : (tensor<63x63x1xf64>, tensor<63x63x1xf64>, tensor<63x63x14xf64>) -> tensor<63x63x16xf64>
%20 = stablehlo.slice %13 [8:9, 7:70, 7:70] : (tensor<31x78x78xf64>) -> tensor<1x63x63xf64>
%21 = stablehlo.reshape %20 : (tensor<1x63x63xf64>) -> tensor<63x63xf64>
%22 = stablehlo.broadcast_in_dim %21, dims = [1, 2] : (tensor<63x63xf64>) -> tensor<16x63x63xf64>
stablehlo.return %10, %22, %19, %22 : tensor<i64>, tensor<16x63x63xf64>, tensor<63x63x16xf64>, tensor<16x63x63xf64>
}
%6 = stablehlo.reduce(%5#2 init: %cst_1) applies stablehlo.add across dimensions = [0, 1, 2] : (tensor<63x63x16xf64>, tensor<f64>) -> tensor<f64>
%7 = stablehlo.dynamic_update_slice %arg0, %5#1, %c_4, %c_4, %c_4 : (tensor<31x78x78xf64>, tensor<16x63x63xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<31x78x78xf64>
%8 = stablehlo.dynamic_update_slice %cst, %5#3, %c_4, %c_4, %c_4 : (tensor<31x78x78xf64>, tensor<16x63x63xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<31x78x78xf64>
return %6, %7, %8, %arg2 : tensor<f64>, tensor<31x78x78xf64>, tensor<31x78x78xf64>, tensor<63x63xf64>
}
func.func @main(%arg0: tensor<31x78x78xf64> {tf.aliasing_output = 1 : i32}, %arg1: tensor<31x78x78xf64> {tf.aliasing_output = 2 : i32}, %arg2: tensor<63x63xf64> {tf.aliasing_output = 3 : i32}, %arg3: tensor<63x63xf64> {tf.aliasing_output = 0 : i32}) -> (tensor<63x63xf64>, tensor<31x78x78xf64>, tensor<31x78x78xf64>, tensor<63x63xf64>) {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f64>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<31x78x78xf64>
%0:4 = enzyme.autodiff @"Const{typeof(estimate_tracer_error)}(Main.estimate_tracer_error)_autodiff"(%arg0, %arg1, %arg2, %cst, %cst_0, %cst_0, %arg3) {activity = [#enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>], strong_zero = true} : (tensor<31x78x78xf64>, tensor<31x78x78xf64>, tensor<63x63xf64>, tensor<f64>, tensor<31x78x78xf64>, tensor<31x78x78xf64>, tensor<63x63xf64>) -> (tensor<31x78x78xf64>, tensor<31x78x78xf64>, tensor<63x63xf64>, tensor<63x63xf64>)
return %0#3, %0#0, %0#1, %0#2 : tensor<63x63xf64>, tensor<31x78x78xf64>, tensor<31x78x78xf64>, tensor<63x63xf64>
}
}
module @reactant_differe... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func private @"diffeConst{typeof(estimate_tracer_error)}(Main.estimate_tracer_error)_autodiff"(%arg0: tensor<31x78x78xf64>, %arg1: tensor<31x78x78xf64>, %arg2: tensor<63x63xf64>, %arg3: tensor<f64>, %arg4: tensor<31x78x78xf64>, %arg5: tensor<31x78x78xf64>, %arg6: tensor<63x63xf64>) -> (tensor<31x78x78xf64>, tensor<31x78x78xf64>, tensor<63x63xf64>, tensor<63x63xf64>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
%c = stablehlo.constant dense<2> : tensor<i64>
%c_0 = stablehlo.constant dense<7> : tensor<i32>
%c_1 = stablehlo.constant dense<1> : tensor<i64>
%c_2 = stablehlo.constant dense<3> : tensor<i64>
%c_3 = stablehlo.constant dense<0> : tensor<i64>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
%cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<31x78x78xf64>
%cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<63x63x16xf64>
%cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<16x63x63xf64>
%0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<31x78x78xf64>) -> tensor<78x78x31xf64>
%1 = stablehlo.reshape %arg2 : (tensor<63x63xf64>) -> tensor<1x63x63xf64>
%2 = stablehlo.slice %arg0 [7:14, 7:70, 7:70] : (tensor<31x78x78xf64>) -> tensor<7x63x63xf64>
%3 = stablehlo.slice %arg0 [15:23, 7:70, 7:70] : (tensor<31x78x78xf64>) -> tensor<8x63x63xf64>
%4 = stablehlo.concatenate %2, %1, %3, dim = 0 : (tensor<7x63x63xf64>, tensor<1x63x63xf64>, tensor<8x63x63xf64>) -> tensor<16x63x63xf64>
%5:4 = stablehlo.while(%iterArg = %c_3, %iterArg_7 = %4, %iterArg_8 = %cst_5, %iterArg_9 = %cst_6) : tensor<i64>, tensor<16x63x63xf64>, tensor<63x63x16xf64>, tensor<16x63x63xf64>
cond {
%16 = stablehlo.compare LT, %iterArg, %c_2 : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %16 : tensor<i1>
} do {
%16 = stablehlo.transpose %iterArg_7, dims = [2, 1, 0] : (tensor<16x63x63xf64>) -> tensor<63x63x16xf64>
%17 = stablehlo.add %iterArg, %c_1 : tensor<i64>
%18 = stablehlo.add %16, %iterArg_8 : tensor<63x63x16xf64>
%19 = stablehlo.dynamic_update_slice %0, %18, %c_0, %c_0, %c_0 : (tensor<78x78x31xf64>, tensor<63x63x16xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<78x78x31xf64>
%20 = stablehlo.transpose %19, dims = [2, 1, 0] : (tensor<78x78x31xf64>) -> tensor<31x78x78xf64>
%21 = stablehlo.slice %19 [8:71, 6:69, 8:9] : (tensor<78x78x31xf64>) -> tensor<63x63x1xf64>
%22 = stablehlo.slice %19 [8:71, 6:69, 14:15] : (tensor<78x78x31xf64>) -> tensor<63x63x1xf64>
%23 = stablehlo.add %21, %22 : tensor<63x63x1xf64>
%24 = stablehlo.slice %19 [8:71, 6:69, 7:8] : (tensor<78x78x31xf64>) -> tensor<63x63x1xf64>
%25 = stablehlo.slice %19 [8:71, 6:69, 9:23] : (tensor<78x78x31xf64>) -> tensor<63x63x14xf64>
%26 = stablehlo.concatenate %24, %23, %25, dim = 2 : (tensor<63x63x1xf64>, tensor<63x63x1xf64>, tensor<63x63x14xf64>) -> tensor<63x63x16xf64>
%27 = stablehlo.slice %20 [8:9, 7:70, 7:70] : (tensor<31x78x78xf64>) -> tensor<1x63x63xf64>
%28 = stablehlo.reshape %27 : (tensor<1x63x63xf64>) -> tensor<63x63xf64>
%29 = stablehlo.broadcast_in_dim %28, dims = [1, 2] : (tensor<63x63xf64>) -> tensor<16x63x63xf64>
stablehlo.return %17, %29, %26, %29 : tensor<i64>, tensor<16x63x63xf64>, tensor<63x63x16xf64>, tensor<16x63x63xf64>
}
%6 = stablehlo.dynamic_update_slice %arg0, %5#1, %c_0, %c_0, %c_0 : (tensor<31x78x78xf64>, tensor<16x63x63xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<31x78x78xf64>
%7 = stablehlo.dynamic_update_slice %cst_4, %5#3, %c_0, %c_0, %c_0 : (tensor<31x78x78xf64>, tensor<16x63x63xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<31x78x78xf64>
%8 = stablehlo.dynamic_slice %arg5, %c_0, %c_0, %c_0, sizes = [16, 63, 63] : (tensor<31x78x78xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<16x63x63xf64>
%9 = stablehlo.dynamic_slice %arg4, %c_0, %c_0, %c_0, sizes = [16, 63, 63] : (tensor<31x78x78xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<16x63x63xf64>
%10 = stablehlo.broadcast_in_dim %arg3, dims = [] : (tensor<f64>) -> tensor<63x63x16xf64>
%11:5 = stablehlo.while(%iterArg = %c_3, %iterArg_7 = %9, %iterArg_8 = %10, %iterArg_9 = %8, %iterArg_10 = %c) : tensor<i64>, tensor<16x63x63xf64>, tensor<63x63x16xf64>, tensor<16x63x63xf64>, tensor<i64>
cond {
%16 = stablehlo.compare LT, %iterArg, %c_2 : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %16 : tensor<i1>
} do {
%16 = stablehlo.add %iterArg, %c_1 : tensor<i64>
%17 = stablehlo.reduce(%iterArg_9 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<16x63x63xf64>, tensor<f64>) -> tensor<63x63xf64>
%18 = stablehlo.reshape %17 : (tensor<63x63xf64>) -> tensor<1x63x63xf64>
%19 = stablehlo.transpose %18, dims = [1, 2, 0] : (tensor<1x63x63xf64>) -> tensor<63x63x1xf64>
%20 = stablehlo.reshape %19 : (tensor<63x63x1xf64>) -> tensor<63x63xf64>
%21 = stablehlo.reshape %20 : (tensor<63x63xf64>) -> tensor<1x63x63xf64>
%22 = stablehlo.pad %21, %cst, low = [8, 7, 7], high = [22, 8, 8], interior = [0, 0, 0] : (tensor<1x63x63xf64>, tensor<f64>) -> tensor<31x78x78xf64>
%23 = stablehlo.slice %iterArg_8 [0:63, 0:63, 0:1] : (tensor<63x63x16xf64>) -> tensor<63x63x1xf64>
%24 = stablehlo.reshape %23 : (tensor<63x63x1xf64>) -> tensor<63x63x1xf64>
%25 = stablehlo.slice %iterArg_8 [0:63, 0:63, 1:2] : (tensor<63x63x16xf64>) -> tensor<63x63x1xf64>
%26 = stablehlo.reshape %25 : (tensor<63x63x1xf64>) -> tensor<63x63x1xf64>
%27 = stablehlo.slice %iterArg_8 [0:63, 0:63, 2:16] : (tensor<63x63x16xf64>) -> tensor<63x63x14xf64>
%28 = stablehlo.reshape %27 : (tensor<63x63x14xf64>) -> tensor<63x63x14xf64>
%29 = stablehlo.pad %28, %cst, low = [8, 6, 9], high = [7, 9, 8], interior = [0, 0, 0] : (tensor<63x63x14xf64>, tensor<f64>) -> tensor<78x78x31xf64>
%30 = stablehlo.pad %24, %cst, low = [8, 6, 7], high = [7, 9, 23], interior = [0, 0, 0] : (tensor<63x63x1xf64>, tensor<f64>) -> tensor<78x78x31xf64>
%31 = stablehlo.add %29, %30 : tensor<78x78x31xf64>
%32 = stablehlo.pad %26, %cst, low = [8, 6, 14], high = [7, 9, 16], interior = [0, 0, 0] : (tensor<63x63x1xf64>, tensor<f64>) -> tensor<78x78x31xf64>
%33 = stablehlo.add %31, %32 : tensor<78x78x31xf64>
%34 = stablehlo.pad %26, %cst, low = [8, 6, 8], high = [7, 9, 22], interior = [0, 0, 0] : (tensor<63x63x1xf64>, tensor<f64>) -> tensor<78x78x31xf64>
%35 = stablehlo.add %33, %34 : tensor<78x78x31xf64>
%36 = stablehlo.transpose %22, dims = [2, 1, 0] : (tensor<31x78x78xf64>) -> tensor<78x78x31xf64>
%37 = stablehlo.add %35, %36 : tensor<78x78x31xf64>
%38 = stablehlo.dynamic_slice %37, %c_0, %c_0, %c_0, sizes = [63, 63, 16] : (tensor<78x78x31xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<63x63x16xf64>
%39 = stablehlo.transpose %38, dims = [2, 1, 0] : (tensor<63x63x16xf64>) -> tensor<16x63x63xf64>
%40 = stablehlo.subtract %iterArg_10, %c_1 : tensor<i64>
stablehlo.return %16, %39, %38, %cst_6, %40 : tensor<i64>, tensor<16x63x63xf64>, tensor<63x63x16xf64>, tensor<16x63x63xf64>, tensor<i64>
}
%12 = stablehlo.slice %11#1 [7:8, 0:63, 0:63] : (tensor<16x63x63xf64>) -> tensor<1x63x63xf64>
%13 = stablehlo.reshape %12 : (tensor<1x63x63xf64>) -> tensor<1x63x63xf64>
%14 = stablehlo.reshape %13 : (tensor<1x63x63xf64>) -> tensor<63x63xf64>
%15 = stablehlo.add %arg6, %14 : tensor<63x63xf64>
return %6, %7, %arg2, %15 : tensor<31x78x78xf64>, tensor<31x78x78xf64>, tensor<63x63xf64>, tensor<63x63xf64>
}
func.func @main(%arg0: tensor<31x78x78xf64> {tf.aliasing_output = 1 : i32}, %arg1: tensor<31x78x78xf64> {tf.aliasing_output = 2 : i32}, %arg2: tensor<63x63xf64> {tf.aliasing_output = 3 : i32}, %arg3: tensor<63x63xf64> {tf.aliasing_output = 0 : i32}) -> (tensor<63x63xf64>, tensor<31x78x78xf64>, tensor<31x78x78xf64>, tensor<63x63xf64>) {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f64>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<31x78x78xf64>
%0:4 = call @"diffeConst{typeof(estimate_tracer_error)}(Main.estimate_tracer_error)_autodiff"(%arg0, %arg1, %arg2, %cst, %cst_0, %cst_0, %arg3) : (tensor<31x78x78xf64>, tensor<31x78x78xf64>, tensor<63x63xf64>, tensor<f64>, tensor<31x78x78xf64>, tensor<31x78x78xf64>, tensor<63x63xf64>) -> (tensor<31x78x78xf64>, tensor<31x78x78xf64>, tensor<63x63xf64>, tensor<63x63xf64>)
return %0#3, %0#0, %0#1, %0#2 : tensor<63x63xf64>, tensor<31x78x78xf64>, tensor<31x78x78xf64>, tensor<63x63xf64>
}
}
compile_toc = 5.551862001419067
dJ[i, j] = 4.0
(ϵ, dsq_surface_u) = (0.1, 20.99999999976717)
(ϵ, dsq_surface_u) = (0.01, 20.99999999627471)
(ϵ, dsq_surface_u) = (0.001, 21.000000007916242)
Metadata
Metadata
Assignees
Labels
No labels