Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,9 @@ def aten_linalg_vector_norm(
keepdim = False
else:
dim = op.Reshape(dim, op.Constant(value_ints=[-1]))
self = op.Abs(self)

if math.isinf(ord):
self = op.Abs(self)
if ord > 0:
return op.ReduceMax(self, dim, keepdims=keepdim)
else:
Expand All @@ -345,6 +346,9 @@ def aten_linalg_vector_norm(
elif ord == 2.0:
return op.ReduceL2(self, dim, keepdims=keepdim)
else:
if ord < 0 or ord % 2 != 0:
# Not-even integer, use abs
self = op.Abs(self)
self_pow = op.Pow(self, ord)
exp = op.CastLike(1 / ord, self)
return op.Pow(op.ReduceSum(self_pow, dim, keepdims=keepdim), exp)
Loading