forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathop_suppl.h
161 lines (152 loc) · 4.71 KB
/
op_suppl.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
/*!
* Copyright (c) 2016 by Contributors
* \file op_suppl.h
* \brief A supplement and amendment of the operators from op.h
* \author Zhang Chen, zhubuntu, Xin Li
*/
#ifndef MXNET_CPP_OP_SUPPL_H_
#define MXNET_CPP_OP_SUPPL_H_
#include <cassert>
#include <string>
#include <vector>
#include "mxnet-cpp/base.h"
#include "mxnet-cpp/shape.h"
#include "mxnet-cpp/operator.h"
#include "mxnet-cpp/MxNetCpp.h"
namespace mxnet {
namespace cpp {
inline Symbol _Plus(Symbol lhs, Symbol rhs) {
return Operator("_Plus")(lhs, rhs)
.CreateSymbol();
}
inline Symbol _Mul(Symbol lhs, Symbol rhs) {
return Operator("_Mul")(lhs, rhs)
.CreateSymbol();
}
inline Symbol _Minus(Symbol lhs, Symbol rhs) {
return Operator("_Minus")(lhs, rhs)
.CreateSymbol();
}
inline Symbol _Div(Symbol lhs, Symbol rhs) {
return Operator("_Div")(lhs, rhs)
.CreateSymbol();
}
inline Symbol _Mod(Symbol lhs, Symbol rhs) {
return Operator("_Mod")(lhs, rhs)
.CreateSymbol();
}
inline Symbol _Power(Symbol lhs, Symbol rhs) {
return Operator("_Power")(lhs, rhs)
.CreateSymbol();
}
inline Symbol _Maximum(Symbol lhs, Symbol rhs) {
return Operator("_Maximum")(lhs, rhs)
.CreateSymbol();
}
inline Symbol _Minimum(Symbol lhs, Symbol rhs) {
return Operator("_Minimum")(lhs, rhs)
.CreateSymbol();
}
inline Symbol _PlusScalar(Symbol lhs, mx_float scalar) {
return Operator("_PlusScalar")(lhs)
.SetParam("scalar", scalar)
.CreateSymbol();
}
inline Symbol _MinusScalar(Symbol lhs, mx_float scalar) {
return Operator("_MinusScalar")(lhs)
.SetParam("scalar", scalar)
.CreateSymbol();
}
inline Symbol _RMinusScalar(mx_float scalar, Symbol rhs) {
return Operator("_RMinusScalar")(rhs)
.SetParam("scalar", scalar)
.CreateSymbol();
}
inline Symbol _MulScalar(Symbol lhs, mx_float scalar) {
return Operator("_MulScalar")(lhs)
.SetParam("scalar", scalar)
.CreateSymbol();
}
inline Symbol _DivScalar(Symbol lhs, mx_float scalar) {
return Operator("_DivScalar")(lhs)
.SetParam("scalar", scalar)
.CreateSymbol();
}
inline Symbol _RDivScalar(mx_float scalar, Symbol rhs) {
return Operator("_RDivScalar")(rhs)
.SetParam("scalar", scalar)
.CreateSymbol();
}
inline Symbol _ModScalar(Symbol lhs, mx_float scalar) {
return Operator("_ModScalar")(lhs)
.SetParam("scalar", scalar)
.CreateSymbol();
}
inline Symbol _RModScalar(mx_float scalar, Symbol rhs) {
return Operator("_RModScalar")(rhs)
.SetParam("scalar", scalar)
.CreateSymbol();
}
inline Symbol _PowerScalar(Symbol lhs, mx_float scalar) {
return Operator("_PowerScalar")(lhs)
.SetParam("scalar", scalar)
.CreateSymbol();
}
inline Symbol _RPowerScalar(mx_float scalar, Symbol rhs) {
return Operator("_RPowerScalar")(rhs)
.SetParam("scalar", scalar)
.CreateSymbol();
}
inline Symbol _MaximumScalar(Symbol lhs, mx_float scalar) {
return Operator("_MaximumScalar")(lhs)
.SetParam("scalar", scalar)
.CreateSymbol();
}
inline Symbol _MinimumScalar(Symbol lhs, mx_float scalar) {
return Operator("_MinimumScalar")(lhs)
.SetParam("scalar", scalar)
.CreateSymbol();
}
// TODO(zhangcheng-qinyinghua)
// make crop function run in op.h
// This function is due to [zhubuntu](https://github.com/zhubuntu)
inline Symbol Crop(const std::string& symbol_name,
int num_args,
Symbol data,
Symbol crop_like,
Shape offset = Shape(0, 0),
Shape h_w = Shape(0, 0),
bool center_crop = false) {
return Operator("Crop")
.SetParam("num_args", num_args)
.SetParam("offset", offset)
.SetParam("h_w", h_w)
.SetParam("center_crop", center_crop)
.SetInput("arg0", data)
.SetInput("arg1", crop_like)
.CreateSymbol(symbol_name);
}
/*!
* \brief Apply activation function to input.
* Softmax Activation is only available with CUDNN on GPUand will be
* computed at each location across channel if input is 4D.
* \param symbol_name name of the resulting symbol.
* \param data Input data to activation function.
* \param act_type Activation function to be applied.
* \return new symbol
*/
inline Symbol Activation(const std::string& symbol_name,
Symbol data,
const std::string& act_type) {
assert(act_type == "relu" ||
act_type == "sigmoid" ||
act_type == "softrelu" ||
act_type == "tanh");
return Operator("Activation")
.SetParam("act_type", act_type.c_str())
.SetInput("data", data)
.CreateSymbol(symbol_name);
}
} // namespace cpp
} // namespace mxnet
#endif // MXNET_CPP_OP_SUPPL_H_