File tree Expand file tree Collapse file tree 4 files changed +12
-31
lines changed Expand file tree Collapse file tree 4 files changed +12
-31
lines changed Original file line number Diff line number Diff line change 10
10
#include < af/autograd.h>
11
11
12
12
using af::autograd::Variable;
13
- using af::autograd::backward;
14
13
void test1 ()
15
14
{
16
15
auto x = Variable (af::randu (5 ), true );
17
16
af_print (x.array ());
18
17
auto y = x * x;
19
18
af_print (y.array ());
20
19
auto dy = Variable (af::constant (1.0 , 5 ), false );
21
- backward (y, dy);
20
+ y. backward (dy);
22
21
auto dx = x.grad ();
23
22
af_print (dx.array () - 2 * x.array ());
24
23
}
@@ -31,7 +30,7 @@ void test2()
31
30
af_print (y.array ());
32
31
auto z = x * x + x * y + y * y;
33
32
auto dz = Variable (af::constant (1.0 , 5 ), false );
34
- backward (z, dz);
33
+ z. backward (dz);
35
34
auto dx = x.grad ();
36
35
auto dy = y.grad ();
37
36
af_print (dx.array () - 2 * x.array () - y.array ());
@@ -46,7 +45,7 @@ void test3()
46
45
af_print (y.array ());
47
46
auto z = x * x + x * y + y * y;
48
47
auto dz = Variable (af::constant (1.0 , 5 ), false );
49
- backward (z, dz);
48
+ z. backward (dz);
50
49
auto dy = y.grad ();
51
50
af_print (dy.array () - 2 * y.array () - x.array ());
52
51
try {
Original file line number Diff line number Diff line change 8
8
********************************************************/
9
9
#include <af/autograd/Variable.hpp>
10
10
#include <af/autograd/Functions.hpp>
11
- #include <af/autograd/Grad.hpp>
Load Diff This file was deleted.
Original file line number Diff line number Diff line change @@ -159,6 +159,15 @@ namespace af {
159
159
}
160
160
}
161
161
162
+ void backward (Variable grad)
163
+ {
164
+ this ->addGrad (grad);
165
+ DAG_t dag = this ->build ();
166
+ for (auto iter = dag.rbegin (); iter != dag.rend (); iter++) {
167
+ iter->calcGradInputs ();
168
+ }
169
+ }
170
+
162
171
DAG_t build ()
163
172
{
164
173
Cache_t cache;
You can’t perform that action at this time.
0 commit comments