@@ -28,80 +28,44 @@ namespace af {
28
28
class Variable
29
29
{
30
30
public:
31
- typedef std::function<void (std::vector<Variable>, Variable)> BackwardFunc_t ;
31
+ typedef std::function<void (std::vector<Variable>, Variable)> GradFunc_t ;
32
32
typedef std::unordered_map<std::ptrdiff_t , bool > Cache_t;
33
33
typedef std::vector<Variable> DAG_t;
34
34
35
35
private:
36
- class Shared {
37
- public:
36
+ struct Shared {
38
37
Shared () :
38
+ m_calc_grad (true ),
39
39
m_data (),
40
40
m_inputs (),
41
41
m_grads (),
42
- m_backward (nullptr )
42
+ m_grad_func (nullptr )
43
43
{}
44
44
45
- Shared (af::array data) :
45
+ Shared (af::array data, bool calc_grad) :
46
+ m_calc_grad (calc_grad),
46
47
m_data (data),
47
48
m_inputs (),
48
49
m_grads (),
49
- m_backward (nullptr )
50
+ m_grad_func (nullptr )
50
51
{}
51
52
52
- Shared (af::array data, std::vector<Variable> inputs, BackwardFunc_t backward) :
53
+ Shared (af::array data,
54
+ std::vector<Variable> inputs,
55
+ GradFunc_t grad_func,
56
+ bool calc_grad) :
57
+ m_calc_grad (calc_grad),
53
58
m_data (data),
54
59
m_inputs (inputs.begin(), inputs.end()),
55
60
m_grads (),
56
- m_backward (backward )
61
+ m_grad_func (grad_func )
57
62
{}
58
63
59
- af::array getData () const
60
- {
61
- return m_data;
62
- }
63
-
64
- Variable getGrad () const
65
- {
66
- if (m_grads.size () == 0 ) {
67
- throw std::runtime_error (" Gradient hasn't been calculated" );
68
- }
69
- return m_grads[0 ];
70
- }
71
-
72
- void addGrad (Variable grad)
73
- {
74
- m_grads.push_back (grad);
75
- }
76
-
77
- std::vector<Variable> getInputs ()
78
- {
79
- return m_inputs;
80
- }
81
-
82
- void evalGrad ()
83
- {
84
- if (m_grads.size () == 1 ) return ;
85
- Variable grad = m_grads[0 ];
86
- for (int i = 1 ; i < (int )m_grads.size (); i++) {
87
- grad = grad + m_grads[i];
88
- }
89
- grad.getData ().eval ();
90
- m_grads.clear ();
91
- m_grads.push_back (grad);
92
- }
93
-
94
- void backward ()
95
- {
96
- this ->evalGrad ();
97
- if (m_backward) m_backward (m_inputs, m_grads[0 ]);
98
- }
99
-
100
- private:
64
+ bool m_calc_grad;
101
65
af::array m_data;
102
66
std::vector<Variable> m_inputs;
103
67
std::vector<Variable> m_grads;
104
- BackwardFunc_t m_backward ;
68
+ GradFunc_t m_grad_func ;
105
69
};
106
70
107
71
public:
@@ -111,62 +75,100 @@ namespace af {
111
75
{
112
76
}
113
77
114
- Variable (af::array data) :
115
- m_shared (new Shared(data))
78
+ Variable (af::array data, bool calc_grad ) :
79
+ m_shared (new Shared(data, calc_grad ))
116
80
{}
117
81
118
82
Variable (af::array data,
119
83
std::vector<Variable> inputs,
120
- BackwardFunc_t backward) :
121
- m_shared (new Shared(data, inputs, backward))
122
- {}
84
+ GradFunc_t grad_func) :
85
+ m_shared (nullptr )
86
+ {
87
+ bool calc_grad = false ;
88
+ for (auto input : inputs) {
89
+ calc_grad |= input.isCalcGrad ();
90
+ }
91
+ if (calc_grad) {
92
+ m_shared = std::shared_ptr<Shared>(new Shared (data, inputs, grad_func, true ));
93
+ } else {
94
+ m_shared = std::shared_ptr<Shared>(new Shared (data, false ));
95
+ }
96
+ }
97
+
98
+ af::array array () const
99
+ {
100
+ return m_shared->m_data ;
101
+ }
123
102
124
- af::array getData () const
103
+ Variable grad () const
125
104
{
126
- return m_shared->getData ();
105
+ if (!m_shared->m_calc_grad || m_shared->m_grads .size () == 0 ) {
106
+ throw std::runtime_error (" Gradient hasn't been calculated" );
107
+ }
108
+ return m_shared->m_grads [0 ];
127
109
}
128
110
129
- Variable getGrad () const
111
+ bool isCalcGrad ()
130
112
{
131
- return m_shared->getGrad ();
113
+ return m_shared->m_calc_grad ;
114
+ }
115
+
116
+ void setCalcGrad (bool calc_grad)
117
+ {
118
+ m_shared->m_calc_grad = calc_grad;
119
+ if (!calc_grad) {
120
+ m_shared->m_grad_func = nullptr ;
121
+ m_shared->m_inputs .clear ();
122
+ m_shared->m_grads .clear ();
123
+ }
132
124
}
133
125
134
126
void addGrad (Variable child_grad)
135
127
{
136
- m_shared->addGrad (child_grad);
128
+ m_shared->m_grads . push_back (child_grad);
137
129
}
138
130
139
131
std::vector<Variable> getInputs () const
140
132
{
141
- return m_shared->getInputs () ;
133
+ return m_shared->m_inputs ;
142
134
}
143
135
144
136
void evalGrad ()
145
137
{
146
- m_shared->evalGrad ();
138
+ if (m_shared->m_grads .size () == 1 ) return ;
139
+ Variable grad = m_shared->m_grads [0 ];
140
+ for (unsigned i = 1 ; i < m_shared->m_grads .size (); i++) {
141
+ grad = grad + m_shared->m_grads [i];
142
+ }
143
+ grad.array ().eval ();
144
+ m_shared->m_grads .clear ();
145
+ m_shared->m_grads .push_back (grad);
147
146
}
148
147
149
- void backward ()
148
+ void calcGradInputs ()
150
149
{
151
- m_shared->backward ();
150
+ evalGrad ();
151
+ if (m_shared->m_grad_func ) {
152
+ m_shared->m_grad_func (m_shared->m_inputs , m_shared->m_grads [0 ]);
153
+ }
152
154
}
153
155
154
156
DAG_t build ()
155
157
{
156
158
Cache_t cache;
157
159
DAG_t dag;
158
- this ->buildGraph (cache, dag);
160
+ this ->buildSubGraph (cache, dag);
159
161
return dag;
160
162
}
161
163
162
- void buildGraph (Cache_t &cache, DAG_t &dag)
164
+ void buildSubGraph (Cache_t &cache, DAG_t &dag)
163
165
{
164
166
std::ptrdiff_t id = (std::ptrdiff_t )m_shared.get ();
165
167
if (cache.find (id) != cache.end ()) {
166
168
return ;
167
169
}
168
- for (auto input : m_shared->getInputs () ) {
169
- input.buildGraph (cache, dag);
170
+ for (auto input : m_shared->m_inputs ) {
171
+ input.buildSubGraph (cache, dag);
170
172
}
171
173
cache[id] = true ;
172
174
dag.push_back (*this );
0 commit comments