Skip to content

Commit 8129b47

Browse files
pavankyumar456
authored andcommitted
Use references while iterating when possible
1 parent 82d77dd commit 8129b47

File tree

5 files changed

+10
-10
lines changed

5 files changed

+10
-10
lines changed

examples/perceptron.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ int main()
6161

6262
// Update parameters
6363
// TODO: Should use optimizer
64-
for (auto param : perceptron.parameters()) {
64+
for (auto &param : perceptron.parameters()) {
6565
param.array() += lr * param.grad().array();
6666
param.array().eval();
6767
}

include/af/autograd/Variable.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ namespace af {
5252

5353
af::array& array() const;
5454

55-
Variable grad() const;
55+
Variable& grad() const;
5656

5757
std::ptrdiff_t id() const;
5858

@@ -74,7 +74,7 @@ namespace af {
7474
private:
7575
void evalGrad(bool retain_grad_graph = false);
7676

77-
std::vector<Variable> getInputs() const;
77+
std::vector<Variable>& getInputs() const;
7878

7979
static void buildSubGraph(Cache_t &cache, DAG_t &dag, const Variable &var);
8080

src/autograd/Variable.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ namespace af {
5555
m_shared(nullptr)
5656
{
5757
bool calc_grad = false;
58-
for (auto input : inputs) {
58+
for (const auto &input : inputs) {
5959
calc_grad |= input.isCalcGrad();
6060
}
6161
if (calc_grad) {
@@ -70,7 +70,7 @@ namespace af {
7070
return m_shared->m_data;
7171
}
7272

73-
Variable Variable::grad() const
73+
Variable& Variable::grad() const
7474
{
7575
if (!m_shared->m_calc_grad) {
7676
throw af::exception("Gradient calclation disabled.");
@@ -86,7 +86,7 @@ namespace af {
8686
return (std::ptrdiff_t)m_shared.get();
8787
}
8888

89-
std::vector<Variable> Variable::getInputs() const
89+
std::vector<Variable>& Variable::getInputs() const
9090
{
9191
return m_shared->m_inputs;
9292
}
@@ -181,7 +181,7 @@ namespace af {
181181
if (cache.find(id) != cache.end()) {
182182
return;
183183
}
184-
for (auto input : var.getInputs()) {
184+
for (const auto &input : var.getInputs()) {
185185
Variable::buildSubGraph(cache, dag, input);
186186
}
187187
cache[id] = true;

src/nn/Modules/Container.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ namespace af
3333
Variable Sequential::forward(const Variable &input)
3434
{
3535
Variable output = input;
36-
for(auto module : m_modules) {
36+
for (auto &module : m_modules) {
3737
output = module->forward(output);
3838
}
3939
return output;

src/nn/Modules/Module.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@ namespace af
3434

3535
void Module::train()
3636
{
37-
for (auto parameter : m_parameters) {
37+
for (auto &parameter : m_parameters) {
3838
parameter.setCalcGrad(true);
3939
}
4040
}
4141

4242
void Module::eval()
4343
{
44-
for (auto parameter : m_parameters) {
44+
for (auto &parameter : m_parameters) {
4545
parameter.setCalcGrad(false);
4646
}
4747
}

0 commit comments

Comments
 (0)