Skip to content

Commit 22210c6

Browse files
committed
Refactor autograd::Variable, option to disable grad calculaios
- autograd::Variable::Shared now a thin layer without methods - Variable::BackwardFunc_t renamed to Variable::GradFunc_t - Variable::getData renamed to Variable::array - Variable::getGrad renamed to Variable::grad - Variable::backward renamed to Variable::calcGradInputs
1 parent 3f832a0 commit 22210c6

File tree

5 files changed

+118
-89
lines changed

5 files changed

+118
-89
lines changed

examples/FFNet.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
#include <af/nn.h>
1111

12-
using namespace af;
1312
using namespace af;
1413
using namespace af::nn;
1514

examples/autograd.cpp

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,32 +13,54 @@ using af::autograd::Variable;
1313
using af::autograd::backward;
1414
void test1()
1515
{
16-
auto x = Variable(af::randu(5));
17-
af_print(x.getData());
16+
auto x = Variable(af::randu(5), true);
17+
af_print(x.array());
1818
auto y = x * x;
19-
af_print(y.getData());
20-
auto dy = Variable(af::constant(1.0, 5));
19+
af_print(y.array());
20+
auto dy = Variable(af::constant(1.0, 5), false);
2121
backward(y, dy);
22-
af_print(x.getGrad().getData() - 2 * x.getData());
22+
auto dx = x.grad();
23+
af_print(dx.array() - 2 * x.array());
2324
}
2425

2526
void test2()
2627
{
27-
auto x = Variable(af::randu(5));
28-
af_print(x.getData());
29-
auto y = Variable(af::randu(5));
30-
af_print(y.getData());
28+
auto x = Variable(af::randu(5), true);
29+
af_print(x.array());
30+
auto y = Variable(af::randu(5), true);
31+
af_print(y.array());
3132
auto z = x * x + x * y + y * y;
32-
auto dz = Variable(af::constant(1.0, 5));
33+
auto dz = Variable(af::constant(1.0, 5), false);
3334
backward(z, dz);
34-
af_print(x.getGrad().getData() - 2 * x.getData() - y.getData());
35-
af_print(y.getGrad().getData() - 2 * y.getData() - x.getData());
35+
auto dx = x.grad();
36+
auto dy = y.grad();
37+
af_print(dx.array() - 2 * x.array() - y.array());
38+
af_print(dy.array() - 2 * y.array() - x.array());
39+
}
40+
41+
void test3()
42+
{
43+
auto x = Variable(af::randu(5), false);
44+
af_print(x.array());
45+
auto y = Variable(af::randu(5), true);
46+
af_print(y.array());
47+
auto z = x * x + x * y + y * y;
48+
auto dz = Variable(af::constant(1.0, 5), false);
49+
backward(z, dz);
50+
auto dy = y.grad();
51+
af_print(dy.array() - 2 * y.array() - x.array());
52+
try {
53+
auto dx = x.grad();
54+
} catch(af::exception &ex) {
55+
std::cout << ex.what() << std::endl;
56+
}
3657
}
3758

3859
int main()
3960
{
4061
af::info();
4162
test1();
4263
test2();
64+
test3();
4365
return 0;
4466
}

include/af/autograd/Functions.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,22 @@ namespace af {
1515

1616
Variable operator +(const Variable lhs, const Variable rhs)
1717
{
18-
auto result = lhs.getData() + rhs.getData();
19-
auto backward = [](std::vector<Variable> inputs, Variable grad_output) {
18+
auto result = lhs.array() + rhs.array();
19+
auto grad_func = [](std::vector<Variable> inputs, Variable grad_output) {
2020
inputs[0].addGrad(grad_output);
2121
inputs[1].addGrad(grad_output);
2222
};
23-
return Variable(result, {lhs, rhs}, backward);
23+
return Variable(result, {lhs, rhs}, grad_func);
2424
}
2525

2626
Variable operator *(const Variable lhs, const Variable rhs)
2727
{
28-
auto result = lhs.getData() * rhs.getData();
29-
auto backward = [](std::vector<Variable> inputs, Variable grad_output) {
28+
auto result = lhs.array() * rhs.array();
29+
auto grad_func = [](std::vector<Variable> inputs, Variable grad_output) {
3030
inputs[0].addGrad(grad_output * inputs[1]);
3131
inputs[1].addGrad(grad_output * inputs[0]);
3232
};
33-
return Variable(result, {lhs, rhs}, backward);
33+
return Variable(result, {lhs, rhs}, grad_func);
3434
}
3535

3636
}

include/af/autograd/Grad.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace af {
1818
var.addGrad(grad);
1919
Variable::DAG_t dag = var.build();
2020
for (auto iter = dag.rbegin(); iter != dag.rend(); iter++) {
21-
iter->backward();
21+
iter->calcGradInputs();
2222
}
2323
}
2424
}

include/af/autograd/Variable.hpp

Lines changed: 77 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -28,80 +28,44 @@ namespace af {
2828
class Variable
2929
{
3030
public:
31-
typedef std::function<void(std::vector<Variable>, Variable)> BackwardFunc_t;
31+
typedef std::function<void(std::vector<Variable>, Variable)> GradFunc_t;
3232
typedef std::unordered_map<std::ptrdiff_t, bool> Cache_t;
3333
typedef std::vector<Variable> DAG_t;
3434

3535
private:
36-
class Shared {
37-
public:
36+
struct Shared {
3837
Shared() :
38+
m_calc_grad(true),
3939
m_data(),
4040
m_inputs(),
4141
m_grads(),
42-
m_backward(nullptr)
42+
m_grad_func(nullptr)
4343
{}
4444

45-
Shared(af::array data) :
45+
Shared(af::array data, bool calc_grad) :
46+
m_calc_grad(calc_grad),
4647
m_data(data),
4748
m_inputs(),
4849
m_grads(),
49-
m_backward(nullptr)
50+
m_grad_func(nullptr)
5051
{}
5152

52-
Shared(af::array data, std::vector<Variable> inputs, BackwardFunc_t backward) :
53+
Shared(af::array data,
54+
std::vector<Variable> inputs,
55+
GradFunc_t grad_func,
56+
bool calc_grad) :
57+
m_calc_grad(calc_grad),
5358
m_data(data),
5459
m_inputs(inputs.begin(), inputs.end()),
5560
m_grads(),
56-
m_backward(backward)
61+
m_grad_func(grad_func)
5762
{}
5863

59-
af::array getData() const
60-
{
61-
return m_data;
62-
}
63-
64-
Variable getGrad() const
65-
{
66-
if (m_grads.size() == 0) {
67-
throw std::runtime_error("Gradient hasn't been calculated");
68-
}
69-
return m_grads[0];
70-
}
71-
72-
void addGrad(Variable grad)
73-
{
74-
m_grads.push_back(grad);
75-
}
76-
77-
std::vector<Variable> getInputs()
78-
{
79-
return m_inputs;
80-
}
81-
82-
void evalGrad()
83-
{
84-
if (m_grads.size() == 1) return;
85-
Variable grad = m_grads[0];
86-
for (int i = 1; i < (int)m_grads.size(); i++) {
87-
grad = grad + m_grads[i];
88-
}
89-
grad.getData().eval();
90-
m_grads.clear();
91-
m_grads.push_back(grad);
92-
}
93-
94-
void backward()
95-
{
96-
this->evalGrad();
97-
if (m_backward) m_backward(m_inputs, m_grads[0]);
98-
}
99-
100-
private:
64+
bool m_calc_grad;
10165
af::array m_data;
10266
std::vector<Variable> m_inputs;
10367
std::vector<Variable> m_grads;
104-
BackwardFunc_t m_backward;
68+
GradFunc_t m_grad_func;
10569
};
10670

10771
public:
@@ -111,62 +75,106 @@ namespace af {
11175
{
11276
}
11377

114-
Variable(af::array data) :
115-
m_shared(new Shared(data))
78+
Variable(af::array data, bool calc_grad) :
79+
m_shared(new Shared(data, calc_grad))
11680
{}
11781

11882
Variable(af::array data,
11983
std::vector<Variable> inputs,
120-
BackwardFunc_t backward) :
121-
m_shared(new Shared(data, inputs, backward))
122-
{}
84+
GradFunc_t grad_func) :
85+
m_shared(nullptr)
86+
{
87+
bool calc_grad = false;
88+
for (auto input : inputs) {
89+
calc_grad |= input.isCalcGrad();
90+
}
91+
if (calc_grad) {
92+
m_shared = std::shared_ptr<Shared>(new Shared(data, inputs, grad_func, true));
93+
} else {
94+
m_shared = std::shared_ptr<Shared>(new Shared(data, false));
95+
}
96+
}
97+
98+
af::array array() const
99+
{
100+
return m_shared->m_data;
101+
}
123102

124-
af::array getData() const
103+
Variable grad() const
125104
{
126-
return m_shared->getData();
105+
if (!m_shared->m_calc_grad) {
106+
throw af::exception("Gradient calclation disabled.");
107+
}
108+
if (m_shared->m_grads.size() == 0) {
109+
throw af::exception("Gradient hasn't been calculated yet.");
110+
}
111+
return m_shared->m_grads[0];
127112
}
128113

129-
Variable getGrad() const
114+
bool isCalcGrad()
130115
{
131-
return m_shared->getGrad();
116+
return m_shared->m_calc_grad;
117+
}
118+
119+
void setCalcGrad(bool calc_grad)
120+
{
121+
m_shared->m_calc_grad = calc_grad;
122+
if (!calc_grad) {
123+
m_shared->m_grad_func = nullptr;
124+
m_shared->m_inputs.clear();
125+
m_shared->m_grads.clear();
126+
}
132127
}
133128

134129
void addGrad(Variable child_grad)
135130
{
136-
m_shared->addGrad(child_grad);
131+
if (m_shared->m_calc_grad) {
132+
m_shared->m_grads.push_back(child_grad);
133+
}
137134
}
138135

139136
std::vector<Variable> getInputs() const
140137
{
141-
return m_shared->getInputs();
138+
return m_shared->m_inputs;
142139
}
143140

144141
void evalGrad()
145142
{
146-
m_shared->evalGrad();
143+
// Flag asking not to calculate gradients
144+
if (!m_shared->m_calc_grad) return;
145+
Variable grad = m_shared->m_grads[0];
146+
for (unsigned i = 1; i < m_shared->m_grads.size(); i++) {
147+
grad = grad + m_shared->m_grads[i];
148+
}
149+
grad.array().eval();
150+
m_shared->m_grads.clear();
151+
m_shared->m_grads.push_back(grad);
147152
}
148153

149-
void backward()
154+
void calcGradInputs()
150155
{
151-
m_shared->backward();
156+
evalGrad();
157+
if (m_shared->m_grad_func) {
158+
m_shared->m_grad_func(m_shared->m_inputs, m_shared->m_grads[0]);
159+
}
152160
}
153161

154162
DAG_t build()
155163
{
156164
Cache_t cache;
157165
DAG_t dag;
158-
this->buildGraph(cache, dag);
166+
this->buildSubGraph(cache, dag);
159167
return dag;
160168
}
161169

162-
void buildGraph(Cache_t &cache, DAG_t &dag)
170+
void buildSubGraph(Cache_t &cache, DAG_t &dag)
163171
{
164172
std::ptrdiff_t id = (std::ptrdiff_t)m_shared.get();
165173
if (cache.find(id) != cache.end()) {
166174
return;
167175
}
168-
for (auto input : m_shared->getInputs()) {
169-
input.buildGraph(cache, dag);
176+
for (auto input : m_shared->m_inputs) {
177+
input.buildSubGraph(cache, dag);
170178
}
171179
cache[id] = true;
172180
dag.push_back(*this);

0 commit comments

Comments
 (0)