Skip to content

Commit 4f152f3

Browse files
Add align_extent(), to align extent but not min (#5829)
* Allow align_bounds() to align extent but not min This can be handy when you have an intermediate Func that is being tiled inside an outer Func and you want to ensure that it fits an exact multiple of tiles. * Add separate align_extent() method
1 parent bc42da9 commit 4f152f3

File tree

8 files changed

+115
-22
lines changed

8 files changed

+115
-22
lines changed

python_bindings/src/PyFunc.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ void define_func(py::module &m) {
325325
.def("set_estimates", &Func::set_estimates, py::arg("estimates"))
326326

327327
.def("align_bounds", &Func::align_bounds, py::arg("var"), py::arg("modulus"), py::arg("remainder") = 0)
328+
.def("align_extent", &Func::align_extent, py::arg("var"), py::arg("modulus"))
328329

329330
.def("bound_extent", &Func::bound_extent, py::arg("var"), py::arg("extent"))
330331

src/AllocationBoundsInference.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -83,16 +83,20 @@ class AllocationInference : public IRMutator {
8383
extent = simplify((max - min) + 1);
8484
}
8585
if (bound.modulus.defined()) {
86-
internal_assert(bound.remainder.defined());
87-
min -= bound.remainder;
88-
min = (min / bound.modulus) * bound.modulus;
89-
min += bound.remainder;
90-
Expr max_plus_one = max + 1;
91-
max_plus_one -= bound.remainder;
92-
max_plus_one = ((max_plus_one + bound.modulus - 1) / bound.modulus) * bound.modulus;
93-
max_plus_one += bound.remainder;
94-
extent = simplify(max_plus_one - min);
95-
max = max_plus_one - 1;
86+
if (bound.remainder.defined()) {
87+
min -= bound.remainder;
88+
min = (min / bound.modulus) * bound.modulus;
89+
min += bound.remainder;
90+
Expr max_plus_one = max + 1;
91+
max_plus_one -= bound.remainder;
92+
max_plus_one = ((max_plus_one + bound.modulus - 1) / bound.modulus) * bound.modulus;
93+
max_plus_one += bound.remainder;
94+
extent = simplify(max_plus_one - min);
95+
max = max_plus_one - 1;
96+
} else {
97+
extent = simplify(((extent + bound.modulus - 1) / bound.modulus) * bound.modulus);
98+
max = simplify(min + extent - 1);
99+
}
96100
}
97101

98102
Expr min_var = Variable::make(Int(32), min_name);

src/BoundsInference.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -578,14 +578,20 @@ class BoundsInference : public IRMutator {
578578
}
579579

580580
if (bound.modulus.defined()) {
581-
min_required -= bound.remainder;
582-
min_required = (min_required / bound.modulus) * bound.modulus;
583-
min_required += bound.remainder;
584-
Expr max_plus_one = max_required + 1;
585-
max_plus_one -= bound.remainder;
586-
max_plus_one = ((max_plus_one + bound.modulus - 1) / bound.modulus) * bound.modulus;
587-
max_plus_one += bound.remainder;
588-
max_required = max_plus_one - 1;
581+
if (bound.remainder.defined()) {
582+
min_required -= bound.remainder;
583+
min_required = (min_required / bound.modulus) * bound.modulus;
584+
min_required += bound.remainder;
585+
Expr max_plus_one = max_required + 1;
586+
max_plus_one -= bound.remainder;
587+
max_plus_one = ((max_plus_one + bound.modulus - 1) / bound.modulus) * bound.modulus;
588+
max_plus_one += bound.remainder;
589+
max_required = max_plus_one - 1;
590+
} else {
591+
Expr extent = (max_required - min_required) + 1;
592+
extent = simplify(((extent + bound.modulus - 1) / bound.modulus) * bound.modulus);
593+
max_required = simplify(min_required + extent - 1);
594+
}
589595
s = LetStmt::make(min_var, min_required, s);
590596
s = LetStmt::make(max_var, max_required, s);
591597
}

src/Func.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2264,7 +2264,6 @@ Func &Func::align_bounds(const Var &var, Expr modulus, Expr remainder) {
22642264

22652265
// Reduce the remainder
22662266
remainder = remainder % modulus;
2267-
22682267
invalidate_cache();
22692268

22702269
bool found = func.is_pure_arg(var.name());
@@ -2279,6 +2278,26 @@ Func &Func::align_bounds(const Var &var, Expr modulus, Expr remainder) {
22792278
return *this;
22802279
}
22812280

2281+
Func &Func::align_extent(const Var &var, Expr modulus) {
2282+
user_assert(modulus.defined()) << "modulus is undefined\n";
2283+
user_assert(Int(32).can_represent(modulus.type())) << "Can't represent modulus as int32\n";
2284+
2285+
modulus = cast<int32_t>(modulus);
2286+
2287+
invalidate_cache();
2288+
2289+
bool found = func.is_pure_arg(var.name());
2290+
user_assert(found)
2291+
<< "Can't align extent of variable " << var.name()
2292+
<< " of function " << name()
2293+
<< " because " << var.name()
2294+
<< " is not one of the pure variables of " << name() << ".\n";
2295+
2296+
Bound b = {var.name(), Expr(), Expr(), modulus, Expr()};
2297+
func.schedule().bounds().push_back(b);
2298+
return *this;
2299+
}
2300+
22822301
Func &Func::tile(const VarOrRVar &x, const VarOrRVar &y,
22832302
const VarOrRVar &xo, const VarOrRVar &yo,
22842303
const VarOrRVar &xi, const VarOrRVar &yi,

src/Func.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1551,9 +1551,19 @@ class Func {
15511551
* f.align_bounds(x, 2, 1) forces the min to be odd and the extent
15521552
* to be even. The region computed always contains the region that
15531553
* would have been computed without this directive, so no
1554-
* assertions are injected. */
1554+
* assertions are injected.
1555+
*/
15551556
Func &align_bounds(const Var &var, Expr modulus, Expr remainder = 0);
15561557

1558+
/** Expand the region computed so that the extent is a
1559+
* multiple of 'modulus'. For example, f.align_extent(x, 2) forces
1560+
* the extent realized to be even. The region computed always contains the
1561+
* region that would have been computed without this directive, so no
1562+
* assertions are injected. (This is essentially equivalent to align_bounds(),
1563+
* but always leaving the min untouched.)
1564+
*/
1565+
Func &align_extent(const Var &var, Expr modulus);
1566+
15571567
/** Bound the extent of a Func's realization, but not its
15581568
* min. This means the dimension can be unrolled or vectorized
15591569
* even when its min is not fixed (for example because it is

src/Generator.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2188,6 +2188,7 @@ class GeneratorOutputBase : public GIOBase {
21882188
// @{
21892189
HALIDE_FORWARD_METHOD(Func, add_trace_tag)
21902190
HALIDE_FORWARD_METHOD(Func, align_bounds)
2191+
HALIDE_FORWARD_METHOD(Func, align_extent)
21912192
HALIDE_FORWARD_METHOD(Func, align_storage)
21922193
HALIDE_FORWARD_METHOD_CONST(Func, args)
21932194
HALIDE_FORWARD_METHOD(Func, bound)

src/Schedule.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,8 @@ struct Bound {
432432

433433
/** If defined, the number of iterations will be a multiple of
434434
* "modulus", and the first iteration will be at a value congruent
435-
* to "remainder" modulo "modulus". Set by Func::align_bounds. */
435+
* to "remainder" modulo "modulus". Set by Func::align_bounds and
436+
* Func::align_extent. */
436437
Expr modulus, remainder;
437438
};
438439

@@ -557,7 +558,7 @@ class FuncSchedule {
557558

558559
/** You may explicitly bound some of the dimensions of a function,
559560
* or constrain them to lie on multiples of a given factor. See
560-
* \ref Func::bound and \ref Func::align_bounds */
561+
* \ref Func::bound and \ref Func::align_bounds and \ref Func::align_extent. */
561562
// @{
562563
const std::vector<Bound> &bounds() const;
563564
std::vector<Bound> &bounds();

test/correctness/align_bounds.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,57 @@ int main(int argc, char **argv) {
146146
}
147147
}
148148

149+
// Now try a case where we align the extent but not the min.
150+
{
151+
Func f, g, h;
152+
Var x;
153+
154+
f(x) = 3;
155+
156+
g(x) = select(x % 2 == 0, f(x + 1), f(x - 1) + 8);
157+
158+
Param<int> p;
159+
h(x) = g(x - p) + g(x + p);
160+
161+
f.compute_root();
162+
g.compute_root().align_extent(x, 32).trace_realizations();
163+
164+
p.set(3);
165+
h.set_custom_trace(my_trace);
166+
Buffer<int> result = h.realize({10});
167+
168+
for (int i = 0; i < 10; i++) {
169+
int correct = (i & 1) == 1 ? 6 : 22;
170+
if (result(i) != correct) {
171+
printf("result(%d) = %d instead of %d\n",
172+
i, result(i), correct);
173+
return -1;
174+
}
175+
}
176+
177+
// Now the min/max should stick to odd numbers
178+
if (trace_min != -3 || trace_extent != 32) {
179+
printf("%d: Wrong bounds: [%d, %d]\n", __LINE__, trace_min, trace_extent);
180+
return -1;
181+
}
182+
183+
// Increasing p by one should have no effect
184+
p.set(4);
185+
h.realize(result);
186+
if (trace_min != -4 || trace_extent != 32) {
187+
printf("%d: Wrong bounds: [%d, %d]\n", __LINE__, trace_min, trace_extent);
188+
return -1;
189+
}
190+
191+
// But increasing it again should cause a jump of two in the bounds computed.
192+
p.set(5);
193+
h.realize(result);
194+
if (trace_min != -5 || trace_extent != 32) {
195+
printf("%d: Wrong bounds: [%d, %d]\n", __LINE__, trace_min, trace_extent);
196+
return -1;
197+
}
198+
}
199+
149200
printf("Success!\n");
150201
return 0;
151202
}

0 commit comments

Comments
 (0)