Skip to content

Commit 879aaa9

Browse files
yihonglyupull[bot]
authored andcommitted
Parallelize Max (microsoft#16745)
It gives up to 7.5% improvement in LLaMA 7B case.
1 parent 8a93c98 commit 879aaa9

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

onnxruntime/core/providers/cpu/math/element_wise_ops.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -855,7 +855,12 @@ struct Max_8::ComputeImpl {
855855
}};
856856

857857
int input_count = inst.Node().InputArgCount().front();
858-
UntypedBroadcastVariadic(input_count, *context, typed_allocator, funcs);
858+
// TODO: Parallelize across spans in UntypedBroadcastVariadic to avoid specific logic here
859+
if (input_count == 2) {
860+
UntypedBroadcastTwo(*context, funcs, 1.0);
861+
} else {
862+
UntypedBroadcastVariadic(input_count, *context, typed_allocator, funcs);
863+
}
859864

860865
return Status::OK();
861866
}

0 commit comments

Comments
 (0)