Skip to content

Commit b30a3c8

Browse files
fohx13pavanky
authored andcommitted
added min function
1 parent 0f170a9 commit b30a3c8

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

include/af/autograd/Functions.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,28 @@ namespace af {
1717
Variable operator *(const Variable &lhs, const Variable &rhs);
1818
Variable operator -(const Variable &lhs, const Variable &rhs);
1919
Variable operator /(const Variable &lhs, const Variable &rhs);
20+
Variable operator >(const Variable &lhs, const Variable &rhs);
21+
Variable operator <(const Variable &lhs, const Variable &rhs);
22+
Variable operator >=(const Variable &lhs, const Variable &rhs);
23+
Variable operator <=(const Variable &lhs, const Variable &rhs);
2024

2125
Variable operator +(const double &lhs, const Variable &rhs);
2226
Variable operator *(const double &lhs, const Variable &rhs);
2327
Variable operator -(const double &lhs, const Variable &rhs);
2428
Variable operator /(const double &lhs, const Variable &rhs);
29+
Variable operator >(const double &lhs, const Variable &rhs);
30+
Variable operator <(const double &lhs, const Variable &rhs);
31+
Variable operator >=(const double &lhs, const Variable &rhs);
32+
Variable operator <=(const double &lhs, const Variable &rhs);
2533

2634
Variable operator +(const Variable &lhs, const double &rhs);
2735
Variable operator *(const Variable &lhs, const double &rhs);
2836
Variable operator -(const Variable &lhs, const double &rhs);
2937
Variable operator /(const Variable &lhs, const double &rhs);
38+
Variable operator >(const Variable &lhs, const double &rhs);
39+
Variable operator <(const Variable &lhs, const double &rhs);
40+
Variable operator >=(const Variable &lhs, const double &rhs);
41+
Variable operator <=(const Variable &lhs, const double &rhs);
3042

3143
Variable negate(const Variable &input);
3244
Variable reciprocal(const Variable &input);
@@ -41,6 +53,10 @@ namespace af {
4153
Variable max(const Variable &lhs, const double &rhs);
4254
Variable max(const double &lhs, const Variable &rhs);
4355

56+
Variable min(const Variable &lhs, const Variable &rhs);
57+
Variable min(const Variable &lhs, const double &rhs);
58+
Variable min(const double &lhs, const Variable &rhs);
59+
4460
Variable transpose(const Variable &input);
4561
Variable expandAs(const Variable &input, const Variable &reference);
4662
Variable reduceAs(const Variable &input, const Variable &reference);

src/autograd/Functions.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,26 @@ namespace af {
6161
return Variable(result, false);
6262
}
6363

64+
Variable operator <(const Variable &lhs, const Variable &rhs)
65+
{
66+
auto result = lhs.array() < rhs.array();
67+
return Variable(result, false);
68+
}
69+
70+
Variable operator >=(const Variable &lhs, const Variable &rhs)
71+
{
72+
auto result = lhs.array() >= rhs.array();
73+
return Variable(result, false);
74+
}
75+
6476
Variable operator <=(const Variable &lhs, const Variable &rhs)
6577
{
6678
auto result = lhs.array() <= rhs.array();
6779
return Variable(result, false);
6880
}
6981

82+
83+
7084
#define INSTANTIATE_OPERATOR(OP) \
7185
Variable operator OP(const double &lhs_val, const Variable &rhs) \
7286
{ \
@@ -91,6 +105,8 @@ namespace af {
91105
INSTANTIATE_OPERATOR(*)
92106
INSTANTIATE_OPERATOR(/)
93107
INSTANTIATE_OPERATOR(>)
108+
INSTANTIATE_OPERATOR(<)
109+
INSTANTIATE_OPERATOR(>=)
94110
INSTANTIATE_OPERATOR(<=)
95111

96112
#undef INSTANTIATE_OPERATOR
@@ -113,6 +129,18 @@ namespace af {
113129
return Variable(result, {lhs, rhs, mask}, grad_func);
114130
}
115131

132+
Variable min(const Variable &lhs, const Variable &rhs)
133+
{
134+
auto mask = lhs < rhs;
135+
auto result = min(lhs.array(), rhs.array());
136+
137+
auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {
138+
inputs[0].addGrad( inputs[2] * grad_output);
139+
inputs[1].addGrad(!inputs[2] * grad_output);
140+
};
141+
return Variable(result, {lhs, rhs, mask}, grad_func);
142+
}
143+
116144
#define INSTANTIATE_FUNCTION(FN) \
117145
Variable FN(const double &lhs_val, const Variable &rhs) \
118146
{ \
@@ -134,6 +162,7 @@ namespace af {
134162

135163

136164
INSTANTIATE_FUNCTION(max);
165+
INSTANTIATE_FUNCTION(min);
137166

138167
#undef INSTANTIATE_FUNCTION
139168

0 commit comments

Comments
 (0)