Skip to content

Commit e318768

Browse files
authored
Jh/checktraceargs (#183)
* check permutation validity in tensortrace * add test * update existing test
1 parent ea4f1c6 commit e318768

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

src/implementation/abstractarray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ function argcheck_tensortrace(C::AbstractArray, A::AbstractArray, p::Index2Tuple
8686
throw(IndexError("invalid selection of length $(ndims(C)): $p"))
8787
2 * numin(q) == 2 * numout(q) == ndims(A) - ndims(C) ||
8888
throw(IndexError("invalid number of trace dimensions"))
89+
argcheck_index2tuple(A, ((p[1]..., q[1]...), (p[2]..., q[2]...)))
8990
return nothing
9091
end
9192

test/methods.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ backendlist = (BaseCopy(), BaseView(), StridedNative(), StridedBLAS())
5353
C4 = ncon(Any[A], Any[[-2, 1, 2, -3, 1, -1, 2]]; backend=b)
5454
@test C1 C2
5555
@test C2 == C3 == C4
56+
@test_throws IndexError tensortrace(randn(2, 2, 2, 2, 2, 2, 2), ((1,), (3, 2)),
57+
((1, 5), (2, 6)), false)
5658
end
5759

5860
@testset "tensorcontract" begin
@@ -171,8 +173,8 @@ end
171173
α, β, b)
172174
@test_throws IndexError tensortrace!(B, A, ((1, 4), ()), ((1, 1), (4,)), false, α,
173175
β, b)
174-
@test_throws DimensionMismatch tensortrace!(B, A, ((1, 4), ()), ((1,), (3,)), false,
175-
α, β, b)
176+
@test_throws IndexError tensortrace!(B, A, ((1, 4), ()), ((1,), (3,)), false,
177+
α, β, b)
176178
end
177179

178180
@testset "tensorcontract! with allocator = $allocator" for allocator in

0 commit comments

Comments
 (0)