Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 15 additions & 34 deletions src/VectorizeLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,6 @@ class SerializeLoops : public IRMutator {
class PredicateLoadStore : public IRMutator {
string var;
Expr vector_predicate;
bool in_hexagon;
const Target ⌖
int lanes;
bool valid;
Expand All @@ -381,20 +380,12 @@ class PredicateLoadStore : public IRMutator {
using IRMutator::visit;

bool should_predicate_store_load(int bit_size) {
if (in_hexagon) {
internal_assert(target.has_feature(Target::HVX))
<< "We are inside a hexagon loop, but the target doesn't have hexagon's features\n";
return true;
} else if (target.arch == Target::X86) {
// Should only attempt to predicate store/load if the lane size is
// no less than 4
// TODO: disabling for now due to trunk LLVM breakage.
// See: https://github.com/halide/Halide/issues/3534
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#3534 is still open -- probably time to see if it is still active or time to close it. (Also, might be worth a reality check on older LLVM versions to see if this should be disabled for them.)

// return (bit_size == 32) && (lanes >= 4);
return false;
if (target.arch == Target::X86) {
// x86 is slower when using predicated stores for
// types smaller than 32 bits.
return bit_size >= 32;
}
// For other architecture, do not predicate vector load/store
return false;
return true;
}

Expr merge_predicate(Expr pred, const Expr &new_pred) {
Expand All @@ -407,6 +398,7 @@ class PredicateLoadStore : public IRMutator {
}

Expr visit(const Load *op) override {
valid = valid && op->predicate.type().lanes() == lanes;
valid = valid && should_predicate_store_load(op->type.bits());
if (!valid) {
return op;
Expand Down Expand Up @@ -435,6 +427,7 @@ class PredicateLoadStore : public IRMutator {
}

Stmt visit(const Store *op) override {
valid = valid && op->predicate.type().lanes() == lanes;
valid = valid && should_predicate_store_load(op->value.type().bits());
if (!valid) {
return op;
Expand Down Expand Up @@ -472,8 +465,8 @@ class PredicateLoadStore : public IRMutator {
}

public:
PredicateLoadStore(string v, const Expr &vpred, bool in_hexagon, const Target &t)
: var(std::move(v)), vector_predicate(vpred), in_hexagon(in_hexagon), target(t),
PredicateLoadStore(string v, const Expr &vpred, const Target &t)
: var(std::move(v)), vector_predicate(vpred), target(t),
lanes(vpred.type().lanes()), valid(true), vectorized(false) {
internal_assert(lanes > 1);
}
Expand Down Expand Up @@ -503,8 +496,6 @@ class VectorSubs : public IRMutator {

const Target &target;

bool in_hexagon; // Are we inside the hexagon loop?

// A scope containing lets and letstmts whose values became
// vectors. Contains are original, non-vectorized expressions.
Scope<Expr> scope;
Expand Down Expand Up @@ -849,12 +840,12 @@ class VectorSubs : public IRMutator {

Stmt predicated_stmt;
if (vectorize_predicate) {
PredicateLoadStore p(vectorized_vars.front().name, cond, in_hexagon, target);
PredicateLoadStore p(vectorized_vars.front().name, cond, target);
predicated_stmt = p.mutate(then_case);
vectorize_predicate = p.is_vectorized();
}
if (vectorize_predicate && else_case.defined()) {
PredicateLoadStore p(vectorized_vars.front().name, !cond, in_hexagon, target);
PredicateLoadStore p(vectorized_vars.front().name, !cond, target);
predicated_stmt = Block::make(predicated_stmt, p.mutate(else_case));
vectorize_predicate = p.is_vectorized();
}
Expand Down Expand Up @@ -1335,8 +1326,8 @@ class VectorSubs : public IRMutator {
}

public:
VectorSubs(const VectorizedVar &vv, bool in_hexagon, const Target &t)
: target(t), in_hexagon(in_hexagon) {
VectorSubs(const VectorizedVar &vv, const Target &t)
: target(t) {
vectorized_vars.push_back(vv);
update_replacements();
}
Expand Down Expand Up @@ -1516,16 +1507,10 @@ class LiftVectorizableExprsOutOfAllAtomicNodes : public IRMutator {
// Vectorize all loops marked as such in a Stmt
class VectorizeLoops : public IRMutator {
const Target &target;
bool in_hexagon;

using IRMutator::visit;

Stmt visit(const For *for_loop) override {
bool old_in_hexagon = in_hexagon;
if (for_loop->device_api == DeviceAPI::Hexagon) {
in_hexagon = true;
}

Stmt stmt;
if (for_loop->for_type == ForType::Vectorized) {
const IntImm *extent = for_loop->extent.as<IntImm>();
Expand All @@ -1537,21 +1522,17 @@ class VectorizeLoops : public IRMutator {
}

VectorizedVar vectorized_var = {for_loop->name, for_loop->min, (int)extent->value};
stmt = VectorSubs(vectorized_var, in_hexagon, target).mutate(for_loop->body);
stmt = VectorSubs(vectorized_var, target).mutate(for_loop->body);
} else {
stmt = IRMutator::visit(for_loop);
}

if (for_loop->device_api == DeviceAPI::Hexagon) {
in_hexagon = old_in_hexagon;
}

return stmt;
}

public:
VectorizeLoops(const Target &t)
: target(t), in_hexagon(false) {
: target(t) {
}
};

Expand Down
18 changes: 6 additions & 12 deletions test/correctness/predicated_store_load.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,6 @@ class CheckPredicatedStoreLoad : public IRMutator {
public:
CheckPredicatedStoreLoad(const Target &target, int store, int load)
: expected_store_count(store), expected_load_count(load) {
// TODO: disabling for now due to trunk LLVM breakage.
// See: https://github.com/halide/Halide/issues/3534
if (target.arch == Target::X86) {
expected_store_count = 0;
expected_load_count = 0;
}
}
using IRMutator::mutate;

Expand Down Expand Up @@ -95,7 +89,7 @@ int vectorized_predicated_store_scalarized_predicated_load_test(const Target &t)
f.update(0).hexagon().vectorize(r.x, 32);
} else if (t.arch == Target::X86) {
f.update(0).vectorize(r.x, 32);
f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(t, 3, 9));
f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(t, 2, 6));
}

Buffer<int> im = f.realize({170, 170});
Expand Down Expand Up @@ -160,7 +154,7 @@ int multiple_vectorized_predicate_test(const Target &t) {
f.update(0).hexagon().vectorize(r.x, 32);
} else if (t.arch == Target::X86) {
f.update(0).vectorize(r.x, 32);
f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(t, 3, 6));
f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(t, 1, 2));
}

Buffer<int> im = f.realize({size, size});
Expand Down Expand Up @@ -192,7 +186,7 @@ int scalar_load_test(const Target &t) {
f.update(0).hexagon().vectorize(r.x, 32);
} else if (t.arch == Target::X86) {
f.update(0).vectorize(r.x, 32);
f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(t, 1, 2));
f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(t, 0, 0));
}

Buffer<int> im = f.realize({160, 160});
Expand Down Expand Up @@ -226,7 +220,7 @@ int scalar_store_test(const Target &t) {
f.update(0).hexagon().vectorize(r.x, 32);
} else if (t.arch == Target::X86) {
f.update(0).vectorize(r.x, 32);
f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(t, 1, 1));
f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(t, 0, 0));
}

Buffer<int> im = f.realize({160, 160});
Expand Down Expand Up @@ -325,7 +319,7 @@ int vectorized_predicated_predicate_with_pure_call_test(const Target &t) {
f.update(0).hexagon().vectorize(r.x, 32);
} else if (t.arch == Target::X86) {
f.update(0).vectorize(r.x, 32);
f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(t, 3, 6));
f.add_custom_lowering_pass(new CheckPredicatedStoreLoad(t, 2, 4));
}

Buffer<int> im = f.realize({160, 160});
Expand Down Expand Up @@ -402,7 +396,7 @@ int vectorized_predicated_load_lut_test(const Target &t) {
// Ignore the race condition so we can have predicated vectorized
// LUT loads on both LHS and RHS of the predicated vectorized store
dst.update().allow_race_conditions().vectorize(r, vector_size);
dst.add_custom_lowering_pass(new CheckPredicatedStoreLoad(t, 1, 3));
dst.add_custom_lowering_pass(new CheckPredicatedStoreLoad(t, 1, 2));

dst.realize({dst_len});

Expand Down
92 changes: 47 additions & 45 deletions test/performance/vectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ bool test() {
const Target target = get_jit_target_from_environment();
const int vec_width = target.natural_vector_size<A>();

int W = vec_width * 1;
int W = vec_width * 2 - 1;
int H = 10000;

Buffer<A> input(W, H + 20);
Expand All @@ -38,55 +38,57 @@ bool test() {
}
}

Var x, y;
Func f, g;
for (TailStrategy tail_strategy : {TailStrategy::ShiftInwards, TailStrategy::GuardWithIf}) {
Var x, y;
Func f, g;

Expr e = input(x, y);
for (int i = 1; i < 5; i++) {
e = e + input(x, y + i);
}

for (int i = 5; i >= 0; i--) {
e = e + input(x, y + i);
}

f(x, y) = e;
g(x, y) = e;
f.bound(x, 0, vec_width).vectorize(x);

// Stop llvm from auto-vectorizing the scalar case and messing up
// the comparison. Also causes cache effects, but the entire input
// is small enough to fit in cache.
g.reorder(y, x);

Buffer<A> outputg = g.realize({W, H});
Buffer<A> outputf = f.realize({W, H});
Expr e = input(x, y);
for (int i = 1; i < 5; i++) {
e = e + input(x, y + i);
}

double t_g = benchmark([&]() {
g.realize(outputg);
});
double t_f = benchmark([&]() {
f.realize(outputf);
});
for (int i = 5; i >= 0; i--) {
e = e + input(x, y + i);
}

for (int y = 0; y < H; y++) {
for (int x = 0; x < W; x++) {
if (outputf(x, y) != outputg(x, y)) {
printf("%s x %d failed at %d %d: %d vs %d\n",
string_of_type<A>(), vec_width,
x, y,
(int)outputf(x, y),
(int)outputg(x, y));
return false;
f(x, y) = e;
g(x, y) = e;
f.vectorize(x, vec_width, tail_strategy);

// Stop llvm from auto-vectorizing the scalar case and messing up
// the comparison. Also causes cache effects, but the entire input
// is small enough to fit in cache.
g.reorder(y, x);

Buffer<A> outputg = g.realize({W, H});
Buffer<A> outputf = f.realize({W, H});

double t_g = benchmark([&]() {
g.realize(outputg);
});
double t_f = benchmark([&]() {
f.realize(outputf);
});

for (int y = 0; y < H; y++) {
for (int x = 0; x < W; x++) {
if (outputf(x, y) != outputg(x, y)) {
printf("%s x %d failed at %d %d: %d vs %d\n",
string_of_type<A>(), vec_width,
x, y,
(int)outputf(x, y),
(int)outputg(x, y));
return false;
}
}
}
}

printf("Vectorized vs scalar (%s x %d): %1.3gms %1.3gms. Speedup = %1.3f\n",
string_of_type<A>(), vec_width, t_f * 1e3, t_g * 1e3, t_g / t_f);
printf("Vectorized vs scalar (%s x %d): %1.3gms %1.3gms. Speedup = %1.3f\n",
string_of_type<A>(), vec_width, t_f * 1e3, t_g * 1e3, t_g / t_f);

if (t_f > t_g) {
return false;
if (t_f > t_g) {
return false;
}
}

return true;
Expand All @@ -102,14 +104,14 @@ int main(int argc, char **argv) {
bool ok = true;

// Only native vector widths for now
ok = ok && test<float>();
ok = ok && test<double>();
ok = ok && test<uint8_t>();
ok = ok && test<int8_t>();
ok = ok && test<uint16_t>();
ok = ok && test<int16_t>();
ok = ok && test<uint32_t>();
ok = ok && test<int32_t>();
ok = ok && test<float>();
ok = ok && test<double>();

if (!ok) {
return -1;
Expand Down