Skip to content

Commit 7bb0b6c

Browse files
committed
Changing autograd::backward function to Variable::backward method
1 parent c94ee3d commit 7bb0b6c

File tree

4 files changed

+12
-31
lines changed

4 files changed

+12
-31
lines changed

examples/autograd.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,14 @@
1010
#include <af/autograd.h>
1111

1212
using af::autograd::Variable;
13-
using af::autograd::backward;
1413
void test1()
1514
{
1615
auto x = Variable(af::randu(5), true);
1716
af_print(x.array());
1817
auto y = x * x;
1918
af_print(y.array());
2019
auto dy = Variable(af::constant(1.0, 5), false);
21-
backward(y, dy);
20+
y.backward(dy);
2221
auto dx = x.grad();
2322
af_print(dx.array() - 2 * x.array());
2423
}
@@ -31,7 +30,7 @@ void test2()
3130
af_print(y.array());
3231
auto z = x * x + x * y + y * y;
3332
auto dz = Variable(af::constant(1.0, 5), false);
34-
backward(z, dz);
33+
z.backward(dz);
3534
auto dx = x.grad();
3635
auto dy = y.grad();
3736
af_print(dx.array() - 2 * x.array() - y.array());
@@ -46,7 +45,7 @@ void test3()
4645
af_print(y.array());
4746
auto z = x * x + x * y + y * y;
4847
auto dz = Variable(af::constant(1.0, 5), false);
49-
backward(z, dz);
48+
z.backward(dz);
5049
auto dy = y.grad();
5150
af_print(dy.array() - 2 * y.array() - x.array());
5251
try {

include/af/autograd.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,3 @@
88
********************************************************/
99
#include <af/autograd/Variable.hpp>
1010
#include <af/autograd/Functions.hpp>
11-
#include <af/autograd/Grad.hpp>

include/af/autograd/Grad.hpp

Lines changed: 0 additions & 26 deletions
This file was deleted.

include/af/autograd/Variable.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,15 @@ namespace af {
159159
}
160160
}
161161

162+
void backward(Variable grad)
163+
{
164+
this->addGrad(grad);
165+
DAG_t dag = this->build();
166+
for (auto iter = dag.rbegin(); iter != dag.rend(); iter++) {
167+
iter->calcGradInputs();
168+
}
169+
}
170+
162171
DAG_t build()
163172
{
164173
Cache_t cache;

0 commit comments

Comments
 (0)