@@ -28,80 +28,44 @@ namespace af {
28
28
class Variable
29
29
{
30
30
public:
31
- typedef std::function<void (std::vector<Variable>, Variable)> BackwardFunc_t ;
31
+ typedef std::function<void (std::vector<Variable>, Variable)> GradFunc_t ;
32
32
typedef std::unordered_map<std::ptrdiff_t , bool > Cache_t;
33
33
typedef std::vector<Variable> DAG_t;
34
34
35
35
private:
36
- class Shared {
37
- public:
36
+ struct Shared {
38
37
Shared () :
38
+ m_calc_grad (true ),
39
39
m_data (),
40
40
m_inputs (),
41
41
m_grads (),
42
- m_backward (nullptr )
42
+ m_grad_func (nullptr )
43
43
{}
44
44
45
- Shared (af::array data) :
45
+ Shared (af::array data, bool calc_grad) :
46
+ m_calc_grad (calc_grad),
46
47
m_data (data),
47
48
m_inputs (),
48
49
m_grads (),
49
- m_backward (nullptr )
50
+ m_grad_func (nullptr )
50
51
{}
51
52
52
- Shared (af::array data, std::vector<Variable> inputs, BackwardFunc_t backward) :
53
+ Shared (af::array data,
54
+ std::vector<Variable> inputs,
55
+ GradFunc_t grad_func,
56
+ bool calc_grad) :
57
+ m_calc_grad (calc_grad),
53
58
m_data (data),
54
59
m_inputs (inputs.begin(), inputs.end()),
55
60
m_grads (),
56
- m_backward (backward )
61
+ m_grad_func (grad_func )
57
62
{}
58
63
59
- af::array getData () const
60
- {
61
- return m_data;
62
- }
63
-
64
- Variable getGrad () const
65
- {
66
- if (m_grads.size () == 0 ) {
67
- throw std::runtime_error (" Gradient hasn't been calculated" );
68
- }
69
- return m_grads[0 ];
70
- }
71
-
72
- void addGrad (Variable grad)
73
- {
74
- m_grads.push_back (grad);
75
- }
76
-
77
- std::vector<Variable> getInputs ()
78
- {
79
- return m_inputs;
80
- }
81
-
82
- void evalGrad ()
83
- {
84
- if (m_grads.size () == 1 ) return ;
85
- Variable grad = m_grads[0 ];
86
- for (int i = 1 ; i < (int )m_grads.size (); i++) {
87
- grad = grad + m_grads[i];
88
- }
89
- grad.getData ().eval ();
90
- m_grads.clear ();
91
- m_grads.push_back (grad);
92
- }
93
-
94
- void backward ()
95
- {
96
- this ->evalGrad ();
97
- if (m_backward) m_backward (m_inputs, m_grads[0 ]);
98
- }
99
-
100
- private:
64
+ bool m_calc_grad;
101
65
af::array m_data;
102
66
std::vector<Variable> m_inputs;
103
67
std::vector<Variable> m_grads;
104
- BackwardFunc_t m_backward ;
68
+ GradFunc_t m_grad_func ;
105
69
};
106
70
107
71
public:
@@ -111,62 +75,106 @@ namespace af {
111
75
{
112
76
}
113
77
114
- Variable (af::array data) :
115
- m_shared (new Shared(data))
78
+ Variable (af::array data, bool calc_grad ) :
79
+ m_shared (new Shared(data, calc_grad ))
116
80
{}
117
81
118
82
Variable (af::array data,
119
83
std::vector<Variable> inputs,
120
- BackwardFunc_t backward) :
121
- m_shared (new Shared(data, inputs, backward))
122
- {}
84
+ GradFunc_t grad_func) :
85
+ m_shared (nullptr )
86
+ {
87
+ bool calc_grad = false ;
88
+ for (auto input : inputs) {
89
+ calc_grad |= input.isCalcGrad ();
90
+ }
91
+ if (calc_grad) {
92
+ m_shared = std::shared_ptr<Shared>(new Shared (data, inputs, grad_func, true ));
93
+ } else {
94
+ m_shared = std::shared_ptr<Shared>(new Shared (data, false ));
95
+ }
96
+ }
97
+
98
+ af::array array () const
99
+ {
100
+ return m_shared->m_data ;
101
+ }
123
102
124
- af::array getData () const
103
+ Variable grad () const
125
104
{
126
- return m_shared->getData ();
105
+ if (!m_shared->m_calc_grad ) {
106
+ throw af::exception (" Gradient calclation disabled." );
107
+ }
108
+ if (m_shared->m_grads .size () == 0 ) {
109
+ throw af::exception (" Gradient hasn't been calculated yet." );
110
+ }
111
+ return m_shared->m_grads [0 ];
127
112
}
128
113
129
- Variable getGrad () const
114
+ bool isCalcGrad ()
130
115
{
131
- return m_shared->getGrad ();
116
+ return m_shared->m_calc_grad ;
117
+ }
118
+
119
+ void setCalcGrad (bool calc_grad)
120
+ {
121
+ m_shared->m_calc_grad = calc_grad;
122
+ if (!calc_grad) {
123
+ m_shared->m_grad_func = nullptr ;
124
+ m_shared->m_inputs .clear ();
125
+ m_shared->m_grads .clear ();
126
+ }
132
127
}
133
128
134
129
void addGrad (Variable child_grad)
135
130
{
136
- m_shared->addGrad (child_grad);
131
+ if (m_shared->m_calc_grad ) {
132
+ m_shared->m_grads .push_back (child_grad);
133
+ }
137
134
}
138
135
139
136
std::vector<Variable> getInputs () const
140
137
{
141
- return m_shared->getInputs () ;
138
+ return m_shared->m_inputs ;
142
139
}
143
140
144
141
void evalGrad ()
145
142
{
146
- m_shared->evalGrad ();
143
+ // Flag asking not to calculate gradients
144
+ if (!m_shared->m_calc_grad ) return ;
145
+ Variable grad = m_shared->m_grads [0 ];
146
+ for (unsigned i = 1 ; i < m_shared->m_grads .size (); i++) {
147
+ grad = grad + m_shared->m_grads [i];
148
+ }
149
+ grad.array ().eval ();
150
+ m_shared->m_grads .clear ();
151
+ m_shared->m_grads .push_back (grad);
147
152
}
148
153
149
- void backward ()
154
+ void calcGradInputs ()
150
155
{
151
- m_shared->backward ();
156
+ evalGrad ();
157
+ if (m_shared->m_grad_func ) {
158
+ m_shared->m_grad_func (m_shared->m_inputs , m_shared->m_grads [0 ]);
159
+ }
152
160
}
153
161
154
162
DAG_t build ()
155
163
{
156
164
Cache_t cache;
157
165
DAG_t dag;
158
- this ->buildGraph (cache, dag);
166
+ this ->buildSubGraph (cache, dag);
159
167
return dag;
160
168
}
161
169
162
- void buildGraph (Cache_t &cache, DAG_t &dag)
170
+ void buildSubGraph (Cache_t &cache, DAG_t &dag)
163
171
{
164
172
std::ptrdiff_t id = (std::ptrdiff_t )m_shared.get ();
165
173
if (cache.find (id) != cache.end ()) {
166
174
return ;
167
175
}
168
- for (auto input : m_shared->getInputs () ) {
169
- input.buildGraph (cache, dag);
176
+ for (auto input : m_shared->m_inputs ) {
177
+ input.buildSubGraph (cache, dag);
170
178
}
171
179
cache[id] = true ;
172
180
dag.push_back (*this );
0 commit comments