5
5
#ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
6
6
#define XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
7
7
8
+ #include < fcntl.h> // for open, O_RDONLY
9
+ #include < sys/mman.h> // for mmap, munmap
10
+ #include < unistd.h> // for close
11
+
8
12
#include < algorithm> // std::min
9
- #include < string>
10
- #include < utility>
11
- #include < vector>
12
13
#include < future>
13
- #include < thread>
14
14
#include < map>
15
15
#include < memory>
16
-
17
- #include " xgboost/base.h"
18
- #include " xgboost/data.h"
19
-
20
- #include " adapter.h"
21
- #include " sparse_page_writer.h"
22
- #include " proxy_dmatrix.h"
16
+ #include < string>
17
+ #include < thread>
18
+ #include < utility>
19
+ #include < vector>
23
20
24
21
#include " ../common/common.h"
22
+ #include " ../common/io.h"
25
23
#include " ../common/timer.h"
24
+ #include " adapter.h"
25
+ #include " proxy_dmatrix.h"
26
+ #include " sparse_page_writer.h"
27
+ #include " xgboost/base.h"
28
+ #include " xgboost/data.h"
26
29
27
30
namespace xgboost {
28
31
namespace data {
@@ -40,6 +43,7 @@ struct Cache {
40
43
std::string format;
41
44
// offset into binary cache file.
42
45
std::vector<size_t > offset;
46
+ std::vector<std::uint64_t > bytes;
43
47
44
48
Cache (bool w, std::string n, std::string fmt)
45
49
: written{w}, name{std::move (n)}, format{std::move (fmt)} {
@@ -54,6 +58,10 @@ struct Cache {
54
58
std::string ShardName () {
55
59
return ShardName (this ->name , this ->format );
56
60
}
61
+ void Push (std::size_t n_bytes) {
62
+ bytes.push_back (n_bytes);
63
+ offset.push_back (n_bytes);
64
+ }
57
65
58
66
// The write is completed.
59
67
void Commit () {
@@ -95,7 +103,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
95
103
uint32_t n_batches_ {0 };
96
104
97
105
std::shared_ptr<Cache> cache_info_;
98
- std::unique_ptr<dmlc::Stream> fo_;
106
+ // std::unique_ptr<dmlc::Stream> fo_;
99
107
100
108
using Ring = std::vector<std::future<std::shared_ptr<S>>>;
101
109
// A ring storing futures to data. Since the DMatrix iterator is forward only, so we
@@ -107,8 +115,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
107
115
if (!cache_info_->written ) {
108
116
return false ;
109
117
}
110
- if (fo_) {
111
- fo_.reset (); // flush the data to disk.
118
+ if (ring_->empty ()) {
112
119
ring_->resize (n_batches_);
113
120
}
114
121
// An heuristic for number of pre-fetched batches. We can make it part of BatchParam
@@ -126,20 +133,39 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
126
133
}
127
134
auto const *self = this ; // make sure it's const
128
135
CHECK_LT (fetch_it, cache_info_->offset .size ());
129
- ring_->at (fetch_it) = std::async (std::launch::async, [fetch_it, self]() {
136
+ dmlc::OMPException exec;
137
+ ring_->at (fetch_it) = std::async (std::launch::async, [&exec, fetch_it, self]() {
138
+ auto page = std::make_shared<S>();
139
+
130
140
common::Timer timer;
131
141
timer.Start ();
132
142
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>(" raw" )};
133
143
auto n = self->cache_info_ ->ShardName ();
134
- size_t offset = self->cache_info_ ->offset .at (fetch_it);
135
- std::unique_ptr<dmlc::SeekStream> fi{dmlc::SeekStream::CreateForRead (n.c_str ())};
136
- fi->Seek (offset);
137
- CHECK_EQ (fi->Tell (), offset);
138
- auto page = std::make_shared<S>();
139
- CHECK (fmt->Read (page.get (), fi.get ()));
144
+
145
+ std::uint64_t offset = self->cache_info_ ->offset .at (fetch_it);
146
+ std::uint64_t length = self->cache_info_ ->bytes .at (fetch_it);
147
+
148
+ // mmap
149
+ auto fd = open (n.c_str (), O_RDONLY);
150
+ CHECK_GE (fd, 0 ) << " Failed to open:" << n << " . " << strerror (errno);
151
+ auto ptr = mmap64 (nullptr , length, PROT_READ, MAP_PRIVATE, fd, offset);
152
+ if (ptr == MAP_FAILED) {
153
+ LOG (FATAL) << " Failed to map: " << n << " . " << strerror (errno) << " . "
154
+ << " len:" << length << " off:" << offset << " it:" << fetch_it << std::endl;
155
+ }
156
+
157
+ // read page
158
+ auto fi = common::MemoryFixSizeBuffer (ptr, length);
159
+ CHECK (fmt->Read (page.get (), &fi));
140
160
LOG (INFO) << " Read a page in " << timer.ElapsedSeconds () << " seconds." ;
161
+
162
+ // cleanup
163
+ CHECK_NE (close (fd), -1 ) << " Faled to close: " << n << " . " << strerror (errno);
164
+ CHECK_NE (munmap (ptr, length), -1 ) << " Faled to munmap: " << n << " . " << strerror (errno);
165
+
141
166
return page;
142
167
});
168
+ exec.Rethrow ();
143
169
}
144
170
CHECK_EQ (std::count_if (ring_->cbegin (), ring_->cend (), [](auto const & f) { return f.valid (); }),
145
171
n_prefetch_batches)
@@ -153,16 +179,26 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
153
179
common::Timer timer;
154
180
timer.Start ();
155
181
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>(" raw" )};
156
- if (!fo_) {
157
- auto n = cache_info_->ShardName ();
158
- fo_.reset (dmlc::Stream::Create (n.c_str (), " w" ));
159
- }
160
- auto bytes = fmt->Write (*page_, fo_.get ());
182
+
183
+ auto name = cache_info_->ShardName ();
184
+ std::unique_ptr<dmlc::Stream> fo{dmlc::Stream::Create (name.c_str (), " a" )};
185
+
186
+ auto bytes = fmt->Write (*page_, fo.get ());
187
+
188
+ // align for mmap
189
+ auto page_size = getpagesize ();
190
+ CHECK (page_size != 0 && page_size % 2 == 0 ) << " Failed to get page size on the current system." ;
191
+ auto n = bytes / page_size;
192
+ auto padded = (n + 1 ) * page_size;
193
+ auto padding = padded - bytes;
194
+ std::vector<std::uint8_t > padding_bytes (padding, 0 );
195
+ fo->Write (padding_bytes.data (), padding_bytes.size ());
196
+
161
197
timer.Stop ();
162
198
163
199
LOG (INFO) << static_cast <double >(bytes) / 1024.0 / 1024.0 << " MB written in "
164
200
<< timer.ElapsedSeconds () << " seconds." ;
165
- cache_info_->offset . push_back (bytes );
201
+ cache_info_->Push (padded );
166
202
}
167
203
168
204
virtual void Fetch () = 0;
0 commit comments