Skip to content

Commit 8af4f12

Browse files
committed
[AutoDiff] Do not propagate same adjoint buffer multiple times
1 parent 53718a3 commit 8af4f12

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "swift/SILOptimizer/PassManager/PrettyStackTrace.h"
3636
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
3737
#include "llvm/ADT/DenseMap.h"
38+
#include "llvm/ADT/SmallSet.h"
3839

3940
namespace swift {
4041

@@ -2381,9 +2382,11 @@ SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor(
23812382
assert(pullbackTrampolineBB->getNumArguments() == 1);
23822383
auto loc = origBB->getParent()->getLocation();
23832384
SmallVector<SILValue, 8> trampolineArguments;
2385+
23842386
// Propagate adjoint values/buffers of active values/buffers to
23852387
// predecessor blocks.
23862388
auto &predBBActiveValues = activeValues[origPredBB];
2389+
llvm::SmallSet<std::pair<SILValue, SILValue>, 32> propagatedAdjoints;
23872390
for (auto activeValue : predBBActiveValues) {
23882391
LLVM_DEBUG(getADDebugStream()
23892392
<< "Propagating adjoint of active value " << activeValue
@@ -2425,12 +2428,14 @@ SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor(
24252428
// Propagate adjoint buffers using `copy_addr`.
24262429
auto adjBuf = getAdjointBuffer(origBB, activeValue);
24272430
auto predAdjBuf = getAdjointBuffer(origPredBB, activeValue);
2428-
builder.createCopyAddr(loc, adjBuf, predAdjBuf, IsNotTake,
2429-
IsNotInitialization);
2431+
if (propagatedAdjoints.insert({adjBuf, predAdjBuf}).second)
2432+
builder.createCopyAddr(loc, adjBuf, predAdjBuf, IsNotTake,
2433+
IsNotInitialization);
24302434
break;
24312435
}
24322436
}
24332437
}
2438+
24342439
// Propagate pullback struct argument.
24352440
TangentBuilder pullbackTrampolineBBBuilder(
24362441
pullbackTrampolineBB, getContext());
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -sil-print-after=differentiation -o /dev/null 2>&1 %s | %FileCheck %s -check-prefix=CHECK-SIL
2+
3+
import _Differentiation
4+
5+
struct Test: Differentiable {
6+
var val1: Float
7+
var val2: Float
8+
9+
@differentiable(reverse)
10+
mutating func doSomething(input: Float) {
11+
// CHECK-SIL-LABEL: sil private [ossa] @$s4null4TestV11doSomething5inputySf_tFTJpSSpSr : $@convention(thin)
12+
// Ensure that only two adjoint buffers will be propagated
13+
// CHECK-SIL: copy_addr %0 to %22 : $*Test.TangentVector
14+
// CHECK-SIL-NEXT: debug_value
15+
// CHECK-SIL-NEXT: copy_addr %0 to %18 : $*Test.TangentVector
16+
// CHECK-SIL-NEXT: switch_enum %1
17+
self.val1 *= input
18+
self.val2 *= input
19+
20+
if self.val1 > input {
21+
self.val1 = input
22+
}
23+
if self.val2 > input {
24+
self.val2 = input
25+
}
26+
}
27+
}
28+
29+
@differentiable(reverse)
30+
func wrapper(input: Float, multiplier: Float) -> Float {
31+
var test = Test(val1: input, val2: input)
32+
test.doSomething(input: multiplier)
33+
return test.val1 * test.val2
34+
}
35+
36+
let grad = gradient(at: 2.0, 3.0, of: wrapper)
37+
print("Grad: \(grad)")

0 commit comments

Comments
 (0)