Skip to content

Commit 2c234f7

Browse files
committed
chore: Minor fix of aten::full
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 74c90fe commit 2c234f7

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

core/conversion/converters/impl/constant.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ auto constant_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
3030
}})
3131
.pattern({"aten::full(int[] size, Scalar fill_value, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)",
3232
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
33-
auto size = args[0].unwrapToIntList();
33+
auto size = util::toVec(util::toDims(args[0].unwrapToIntList()));
3434
auto scalar = args[1].unwrapToScalar().to<float>();
35-
auto scalar_tensor = torch::full({5}, scalar);
35+
auto scalar_tensor = torch::full(size, scalar);
3636
auto full_tensor = tensor_to_const(ctx, scalar_tensor);
3737
auto output = ctx->AssociateValueAndTensor(n->outputs()[0], full_tensor);
3838

0 commit comments

Comments
 (0)