Skip to content

Commit 9ee8c73

Browse files
authored
[AutoDiff] Do not propagate same adjoint buffer multiple times (#64963)
Adjoint buffers of projections (e.g. obtained via begin_access) are same as adjoint buffer of underlying struct value. As a result, when propagating adjoint values to pullback successor blocks we tend to produce lots of identical copies (essentially for every struct access and in every basic block) of adjoint buffers. These copy_addrs instructions are then lowered down to plain loads and stores and while the redundant copies are usually optimized away by subsequent optimization passes, presence of such copies leads to elevated memory consumption and compilation time as one needs to track liveness of these values being copied. Track the values being propagated and simply do not generate extra copies if the same value was already propagated. One step towards #61773
1 parent 6854bbe commit 9ee8c73

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "swift/SILOptimizer/PassManager/PrettyStackTrace.h"
3636
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
3737
#include "llvm/ADT/DenseMap.h"
38+
#include "llvm/ADT/SmallSet.h"
3839

3940
namespace swift {
4041

@@ -2381,9 +2382,11 @@ SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor(
23812382
assert(pullbackTrampolineBB->getNumArguments() == 1);
23822383
auto loc = origBB->getParent()->getLocation();
23832384
SmallVector<SILValue, 8> trampolineArguments;
2385+
23842386
// Propagate adjoint values/buffers of active values/buffers to
23852387
// predecessor blocks.
23862388
auto &predBBActiveValues = activeValues[origPredBB];
2389+
llvm::SmallSet<std::pair<SILValue, SILValue>, 32> propagatedAdjoints;
23872390
for (auto activeValue : predBBActiveValues) {
23882391
LLVM_DEBUG(getADDebugStream()
23892392
<< "Propagating adjoint of active value " << activeValue
@@ -2425,12 +2428,14 @@ SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor(
24252428
// Propagate adjoint buffers using `copy_addr`.
24262429
auto adjBuf = getAdjointBuffer(origBB, activeValue);
24272430
auto predAdjBuf = getAdjointBuffer(origPredBB, activeValue);
2428-
builder.createCopyAddr(loc, adjBuf, predAdjBuf, IsNotTake,
2429-
IsNotInitialization);
2431+
if (propagatedAdjoints.insert({adjBuf, predAdjBuf}).second)
2432+
builder.createCopyAddr(loc, adjBuf, predAdjBuf, IsNotTake,
2433+
IsNotInitialization);
24302434
break;
24312435
}
24322436
}
24332437
}
2438+
24342439
// Propagate pullback struct argument.
24352440
TangentBuilder pullbackTrampolineBBBuilder(
24362441
pullbackTrampolineBB, getContext());
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -sil-print-after=differentiation -o /dev/null 2>&1 %s | %FileCheck %s -check-prefix=CHECK-SIL
2+
3+
import _Differentiation
4+
5+
struct Test: Differentiable {
6+
var val1: Float
7+
var val2: Float
8+
9+
@differentiable(reverse)
10+
mutating func doSomething(input: Float) {
11+
// CHECK-SIL-LABEL: TestV11doSomething5inputySf_tFTJpSSpSr :
12+
// Ensure that only two adjoint buffers will be propagated
13+
// CHECK-SIL: copy_addr %0 to %22 : $*Test.TangentVector
14+
// CHECK-SIL-NEXT: debug_value
15+
// CHECK-SIL-NEXT: copy_addr %0 to %18 : $*Test.TangentVector
16+
// CHECK-SIL-NEXT: switch_enum %1
17+
self.val1 *= input
18+
self.val2 *= input
19+
20+
if self.val1 > input {
21+
self.val1 = input
22+
}
23+
if self.val2 > input {
24+
self.val2 = input
25+
}
26+
}
27+
}
28+
29+
@differentiable(reverse)
30+
func wrapper(input: Float, multiplier: Float) -> Float {
31+
var test = Test(val1: input, val2: input)
32+
test.doSomething(input: multiplier)
33+
return test.val1 * test.val2
34+
}
35+
36+
let grad = gradient(at: 2.0, 3.0, of: wrapper)
37+
print("Grad: \(grad)")

0 commit comments

Comments
 (0)