Skip to content

Commit a316af0

Browse files
committed
Moving autograd from header only lib to a compiled lib
1 parent 7bb0b6c commit a316af0

File tree

6 files changed

+253
-176
lines changed

6 files changed

+253
-176
lines changed

CMakeLists.txt

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,36 @@
1-
cmake_minimum_required(VERSION 3.5.2)
1+
cmake_minimum_required(VERSION 3.5.1)
22

33
project(ArrayFireML
44
VERSION 0.1.0
55
LANGUAGES C CXX)
66

77
find_package(ArrayFire REQUIRED)
8-
set(ArrayFireML_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/include)
8+
9+
add_library(afml SHARED "")
10+
11+
target_sources(afml
12+
PRIVATE
13+
src/autograd/Variable.cpp
14+
src/autograd/Functions.cpp
15+
)
16+
17+
target_include_directories(afml
18+
PUBLIC
19+
${ArrayFire_INCLUDE_DIRS}
20+
${CMAKE_CURRENT_SOURCE_DIR}/include
21+
)
22+
23+
target_link_libraries(afml
24+
PUBLIC
25+
af
26+
)
27+
28+
set_target_properties(afml
29+
PROPERTIES
30+
VERSION "${ArrayFireML_VERSION}"
31+
SOVERSION "${ArrayFireML_VERSION_MAJOR}"
32+
CXX_STANDARD 11
33+
)
34+
35+
936
add_subdirectory(examples)

examples/CMakeLists.txt

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,9 @@ function(build_example SRC)
22
get_filename_component(src_name ${SRC} NAME_WE)
33
set(target "${src_name}")
44
add_executable(${target} ${SRC})
5-
target_include_directories(${target}
6-
PRIVATE
7-
${ArrayFire_INCLUDE_DIRS}
8-
${ArrayFireML_INCLUDE_DIRS}
9-
)
105
target_link_libraries(${target}
116
PRIVATE
12-
af
7+
afml
138
)
149
target_compile_features(${target}
1510
PRIVATE cxx_range_for)

include/af/autograd/Functions.hpp

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,12 @@
88
********************************************************/
99
#pragma once
1010

11-
#include <af/autograd/Variable.hpp>
12-
1311
namespace af {
1412
namespace autograd {
1513

16-
Variable operator +(const Variable lhs, const Variable rhs)
17-
{
18-
auto result = lhs.array() + rhs.array();
19-
auto grad_func = [](std::vector<Variable> inputs, Variable grad_output) {
20-
inputs[0].addGrad(grad_output);
21-
inputs[1].addGrad(grad_output);
22-
};
23-
return Variable(result, {lhs, rhs}, grad_func);
24-
}
25-
26-
Variable operator *(const Variable lhs, const Variable rhs)
27-
{
28-
auto result = lhs.array() * rhs.array();
29-
auto grad_func = [](std::vector<Variable> inputs, Variable grad_output) {
30-
inputs[0].addGrad(grad_output * inputs[1]);
31-
inputs[1].addGrad(grad_output * inputs[0]);
32-
};
33-
return Variable(result, {lhs, rhs}, grad_func);
34-
}
14+
class Variable;
3515

16+
Variable operator +(const Variable lhs, const Variable rhs);
17+
Variable operator *(const Variable lhs, const Variable rhs);
3618
}
37-
namespace ag = autograd;
3819
}

include/af/autograd/Variable.hpp

Lines changed: 26 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,11 @@
1414
#include <memory>
1515
#include <vector>
1616
#include <unordered_map>
17-
#include <stdexcept>
1817

1918
#include <arrayfire.h>
2019

2120
namespace af {
2221
namespace autograd {
23-
24-
// Forward declare the function
25-
class Variable;
26-
Variable operator +(const Variable lhs, const Variable rhs);
27-
2822
class Variable
2923
{
3024
public:
@@ -34,32 +28,12 @@ namespace af {
3428

3529
private:
3630
struct Shared {
37-
Shared() :
38-
m_calc_grad(true),
39-
m_data(),
40-
m_inputs(),
41-
m_grads(),
42-
m_grad_func(nullptr)
43-
{}
44-
45-
Shared(af::array data, bool calc_grad) :
46-
m_calc_grad(calc_grad),
47-
m_data(data),
48-
m_inputs(),
49-
m_grads(),
50-
m_grad_func(nullptr)
51-
{}
52-
31+
Shared();
32+
Shared(af::array data, bool calc_grad);
5333
Shared(af::array data,
5434
std::vector<Variable> inputs,
5535
GradFunc_t grad_func,
56-
bool calc_grad) :
57-
m_calc_grad(calc_grad),
58-
m_data(data),
59-
m_inputs(inputs.begin(), inputs.end()),
60-
m_grads(),
61-
m_grad_func(grad_func)
62-
{}
36+
bool calc_grad);
6337

6438
bool m_calc_grad;
6539
af::array m_data;
@@ -70,127 +44,33 @@ namespace af {
7044

7145
public:
7246

73-
Variable() :
74-
m_shared(new Shared())
75-
{
76-
}
77-
78-
Variable(af::array data, bool calc_grad) :
79-
m_shared(new Shared(data, calc_grad))
80-
{}
81-
47+
Variable();
48+
Variable(af::array data, bool calc_grad);
8249
Variable(af::array data,
8350
std::vector<Variable> inputs,
84-
GradFunc_t grad_func) :
85-
m_shared(nullptr)
86-
{
87-
bool calc_grad = false;
88-
for (auto input : inputs) {
89-
calc_grad |= input.isCalcGrad();
90-
}
91-
if (calc_grad) {
92-
m_shared = std::shared_ptr<Shared>(new Shared(data, inputs, grad_func, true));
93-
} else {
94-
m_shared = std::shared_ptr<Shared>(new Shared(data, false));
95-
}
96-
}
97-
98-
af::array array() const
99-
{
100-
return m_shared->m_data;
101-
}
102-
103-
Variable grad() const
104-
{
105-
if (!m_shared->m_calc_grad) {
106-
throw af::exception("Gradient calclation disabled.");
107-
}
108-
if (m_shared->m_grads.size() == 0) {
109-
throw af::exception("Gradient hasn't been calculated yet.");
110-
}
111-
return m_shared->m_grads[0];
112-
}
113-
114-
bool isCalcGrad()
115-
{
116-
return m_shared->m_calc_grad;
117-
}
118-
119-
void setCalcGrad(bool calc_grad)
120-
{
121-
m_shared->m_calc_grad = calc_grad;
122-
if (!calc_grad) {
123-
m_shared->m_grad_func = nullptr;
124-
m_shared->m_inputs.clear();
125-
m_shared->m_grads.clear();
126-
}
127-
}
128-
129-
void addGrad(Variable child_grad)
130-
{
131-
if (m_shared->m_calc_grad) {
132-
m_shared->m_grads.push_back(child_grad);
133-
}
134-
}
135-
136-
std::vector<Variable> getInputs() const
137-
{
138-
return m_shared->m_inputs;
139-
}
140-
141-
void evalGrad()
142-
{
143-
// Flag asking not to calculate gradients
144-
if (!m_shared->m_calc_grad) return;
145-
Variable grad = m_shared->m_grads[0];
146-
for (unsigned i = 1; i < m_shared->m_grads.size(); i++) {
147-
grad = grad + m_shared->m_grads[i];
148-
}
149-
grad.array().eval();
150-
m_shared->m_grads.clear();
151-
m_shared->m_grads.push_back(grad);
152-
}
153-
154-
void calcGradInputs()
155-
{
156-
evalGrad();
157-
if (m_shared->m_grad_func) {
158-
m_shared->m_grad_func(m_shared->m_inputs, m_shared->m_grads[0]);
159-
}
160-
}
161-
162-
void backward(Variable grad)
163-
{
164-
this->addGrad(grad);
165-
DAG_t dag = this->build();
166-
for (auto iter = dag.rbegin(); iter != dag.rend(); iter++) {
167-
iter->calcGradInputs();
168-
}
169-
}
170-
171-
DAG_t build()
172-
{
173-
Cache_t cache;
174-
DAG_t dag;
175-
this->buildSubGraph(cache, dag);
176-
return dag;
177-
}
178-
179-
void buildSubGraph(Cache_t &cache, DAG_t &dag)
180-
{
181-
std::ptrdiff_t id = (std::ptrdiff_t)m_shared.get();
182-
if (cache.find(id) != cache.end()) {
183-
return;
184-
}
185-
for (auto input : m_shared->m_inputs) {
186-
input.buildSubGraph(cache, dag);
187-
}
188-
cache[id] = true;
189-
dag.push_back(*this);
190-
}
51+
GradFunc_t grad_func);
52+
53+
af::array array() const;
54+
55+
Variable grad() const;
56+
57+
bool isCalcGrad();
58+
59+
void setCalcGrad(bool calc_grad);
60+
61+
void addGrad(Variable child_grad);
62+
63+
void evalGrad();
64+
65+
void calcGradInputs();
66+
67+
void backward(Variable grad);
68+
69+
DAG_t build();
70+
71+
void buildSubGraph(Cache_t &cache, DAG_t &dag);
19172
private:
19273
std::shared_ptr<Shared> m_shared;
19374
};
19475
}
195-
namespace ag = autograd;
19676
}

src/autograd/Functions.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*******************************************************
2+
* Copyright (c) 2017, ArrayFire
3+
* All rights reserved.
4+
*
5+
* This file is distributed under 3-clause BSD license.
6+
* The complete license agreement can be obtained at:
7+
* http://arrayfire.com/licenses/BSD-3-Clause
8+
********************************************************/
9+
10+
#include <af/autograd/Variable.hpp>
11+
#include <af/autograd/Functions.hpp>
12+
13+
namespace af {
14+
namespace autograd {
15+
16+
Variable operator +(const Variable lhs, const Variable rhs)
17+
{
18+
auto result = lhs.array() + rhs.array();
19+
auto grad_func = [](std::vector<Variable> inputs, Variable grad_output) {
20+
inputs[0].addGrad(grad_output);
21+
inputs[1].addGrad(grad_output);
22+
};
23+
return Variable(result, {lhs, rhs}, grad_func);
24+
}
25+
26+
Variable operator *(const Variable lhs, const Variable rhs)
27+
{
28+
auto result = lhs.array() * rhs.array();
29+
auto grad_func = [](std::vector<Variable> inputs, Variable grad_output) {
30+
inputs[0].addGrad(grad_output * inputs[1]);
31+
inputs[1].addGrad(grad_output * inputs[0]);
32+
};
33+
return Variable(result, {lhs, rhs}, grad_func);
34+
}
35+
36+
}
37+
}

0 commit comments

Comments
 (0)