forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmodel.h
58 lines (50 loc) · 1.18 KB
/
model.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
/*!
* Copyright (c) 2016 by Contributors
* \file model.h
* \brief MXNET.cpp model module
* \author Zhang Chen
*/
#ifndef MXNET_CPP_MODEL_H_
#define MXNET_CPP_MODEL_H_
#include <string>
#include <vector>
#include "mxnet-cpp/base.h"
#include "mxnet-cpp/symbol.h"
#include "mxnet-cpp/ndarray.h"
namespace mxnet {
namespace cpp {
struct FeedForwardConfig {
Symbol symbol;
std::vector<Context> ctx = {Context::cpu()};
int num_epoch = 0;
int epoch_size = 0;
std::string optimizer = "sgd";
// TODO(zhangchen-qinyinghua) More implement
// initializer=Uniform(0.01),
// numpy_batch_size=128,
// arg_params=None, aux_params=None,
// allow_extra_params=False,
// begin_epoch=0,
// **kwargs):
FeedForwardConfig(const FeedForwardConfig &other) {}
FeedForwardConfig() {}
};
class FeedForward {
public:
explicit FeedForward(const FeedForwardConfig &conf) : conf_(conf) {}
void Predict();
void Score();
void Fit();
void Save();
void Load();
static FeedForward Create();
private:
void InitParams();
void InitPredictor();
void InitIter();
void InitEvalIter();
FeedForwardConfig conf_;
};
} // namespace cpp
} // namespace mxnet
#endif // MXNET_CPP_MODEL_H_