File tree Expand file tree Collapse file tree 5 files changed +10
-10
lines changed Expand file tree Collapse file tree 5 files changed +10
-10
lines changed Original file line number Diff line number Diff line change @@ -61,7 +61,7 @@ int main()
61
61
62
62
// Update parameters
63
63
// TODO: Should use optimizer
64
- for (auto param : perceptron.parameters ()) {
64
+ for (auto & param : perceptron.parameters ()) {
65
65
param.array () += lr * param.grad ().array ();
66
66
param.array ().eval ();
67
67
}
Original file line number Diff line number Diff line change @@ -52,7 +52,7 @@ namespace af {
52
52
53
53
af::array& array () const ;
54
54
55
- Variable grad () const ;
55
+ Variable& grad () const ;
56
56
57
57
std::ptrdiff_t id () const ;
58
58
@@ -74,7 +74,7 @@ namespace af {
74
74
private:
75
75
void evalGrad (bool retain_grad_graph = false );
76
76
77
- std::vector<Variable> getInputs () const ;
77
+ std::vector<Variable>& getInputs () const ;
78
78
79
79
static void buildSubGraph (Cache_t &cache, DAG_t &dag, const Variable &var);
80
80
Original file line number Diff line number Diff line change @@ -55,7 +55,7 @@ namespace af {
55
55
m_shared (nullptr )
56
56
{
57
57
bool calc_grad = false ;
58
- for (auto input : inputs) {
58
+ for (const auto & input : inputs) {
59
59
calc_grad |= input.isCalcGrad ();
60
60
}
61
61
if (calc_grad) {
@@ -70,7 +70,7 @@ namespace af {
70
70
return m_shared->m_data ;
71
71
}
72
72
73
- Variable Variable::grad () const
73
+ Variable& Variable::grad () const
74
74
{
75
75
if (!m_shared->m_calc_grad ) {
76
76
throw af::exception (" Gradient calclation disabled." );
@@ -86,7 +86,7 @@ namespace af {
86
86
return (std::ptrdiff_t )m_shared.get ();
87
87
}
88
88
89
- std::vector<Variable> Variable::getInputs () const
89
+ std::vector<Variable>& Variable::getInputs () const
90
90
{
91
91
return m_shared->m_inputs ;
92
92
}
@@ -181,7 +181,7 @@ namespace af {
181
181
if (cache.find (id) != cache.end ()) {
182
182
return ;
183
183
}
184
- for (auto input : var.getInputs ()) {
184
+ for (const auto & input : var.getInputs ()) {
185
185
Variable::buildSubGraph (cache, dag, input);
186
186
}
187
187
cache[id] = true ;
Original file line number Diff line number Diff line change @@ -33,7 +33,7 @@ namespace af
33
33
Variable Sequential::forward (const Variable &input)
34
34
{
35
35
Variable output = input;
36
- for (auto module : m_modules) {
36
+ for (auto & module : m_modules) {
37
37
output = module ->forward (output);
38
38
}
39
39
return output;
Original file line number Diff line number Diff line change @@ -34,14 +34,14 @@ namespace af
34
34
35
35
void Module::train ()
36
36
{
37
- for (auto parameter : m_parameters) {
37
+ for (auto & parameter : m_parameters) {
38
38
parameter.setCalcGrad (true );
39
39
}
40
40
}
41
41
42
42
void Module::eval ()
43
43
{
44
- for (auto parameter : m_parameters) {
44
+ for (auto & parameter : m_parameters) {
45
45
parameter.setCalcGrad (false );
46
46
}
47
47
}
You can’t perform that action at this time.
0 commit comments