Skip to content

Commit b1573d4

Browse files
committed
use mmap for external memory.
1 parent 152e2fb commit b1573d4

File tree

1 file changed

+63
-27
lines changed

1 file changed

+63
-27
lines changed

src/data/sparse_page_source.h

Lines changed: 63 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,27 @@
55
#ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
66
#define XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
77

8+
#include <fcntl.h> // for open, O_RDONLY
9+
#include <sys/mman.h> // for mmap, munmap
10+
#include <unistd.h> // for close
11+
812
#include <algorithm> // std::min
9-
#include <string>
10-
#include <utility>
11-
#include <vector>
1213
#include <future>
13-
#include <thread>
1414
#include <map>
1515
#include <memory>
16-
17-
#include "xgboost/base.h"
18-
#include "xgboost/data.h"
19-
20-
#include "adapter.h"
21-
#include "sparse_page_writer.h"
22-
#include "proxy_dmatrix.h"
16+
#include <string>
17+
#include <thread>
18+
#include <utility>
19+
#include <vector>
2320

2421
#include "../common/common.h"
22+
#include "../common/io.h"
2523
#include "../common/timer.h"
24+
#include "adapter.h"
25+
#include "proxy_dmatrix.h"
26+
#include "sparse_page_writer.h"
27+
#include "xgboost/base.h"
28+
#include "xgboost/data.h"
2629

2730
namespace xgboost {
2831
namespace data {
@@ -40,6 +43,7 @@ struct Cache {
4043
std::string format;
4144
// offset into binary cache file.
4245
std::vector<size_t> offset;
46+
std::vector<std::uint64_t> bytes;
4347

4448
Cache(bool w, std::string n, std::string fmt)
4549
: written{w}, name{std::move(n)}, format{std::move(fmt)} {
@@ -54,6 +58,10 @@ struct Cache {
5458
std::string ShardName() {
5559
return ShardName(this->name, this->format);
5660
}
61+
void Push(std::size_t n_bytes) {
62+
bytes.push_back(n_bytes);
63+
offset.push_back(n_bytes);
64+
}
5765

5866
// The write is completed.
5967
void Commit() {
@@ -95,7 +103,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
95103
uint32_t n_batches_ {0};
96104

97105
std::shared_ptr<Cache> cache_info_;
98-
std::unique_ptr<dmlc::Stream> fo_;
106+
// std::unique_ptr<dmlc::Stream> fo_;
99107

100108
using Ring = std::vector<std::future<std::shared_ptr<S>>>;
101109
// A ring storing futures to data. Since the DMatrix iterator is forward only, so we
@@ -107,8 +115,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
107115
if (!cache_info_->written) {
108116
return false;
109117
}
110-
if (fo_) {
111-
fo_.reset(); // flush the data to disk.
118+
if (ring_->empty()) {
112119
ring_->resize(n_batches_);
113120
}
114121
// An heuristic for number of pre-fetched batches. We can make it part of BatchParam
@@ -126,20 +133,39 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
126133
}
127134
auto const *self = this; // make sure it's const
128135
CHECK_LT(fetch_it, cache_info_->offset.size());
129-
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self]() {
136+
dmlc::OMPException exec;
137+
ring_->at(fetch_it) = std::async(std::launch::async, [&exec, fetch_it, self]() {
138+
auto page = std::make_shared<S>();
139+
130140
common::Timer timer;
131141
timer.Start();
132142
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
133143
auto n = self->cache_info_->ShardName();
134-
size_t offset = self->cache_info_->offset.at(fetch_it);
135-
std::unique_ptr<dmlc::SeekStream> fi{dmlc::SeekStream::CreateForRead(n.c_str())};
136-
fi->Seek(offset);
137-
CHECK_EQ(fi->Tell(), offset);
138-
auto page = std::make_shared<S>();
139-
CHECK(fmt->Read(page.get(), fi.get()));
144+
145+
std::uint64_t offset = self->cache_info_->offset.at(fetch_it);
146+
std::uint64_t length = self->cache_info_->bytes.at(fetch_it);
147+
148+
// mmap
149+
auto fd = open(n.c_str(), O_RDONLY);
150+
CHECK_GE(fd, 0) << "Failed to open:" << n << ". " << strerror(errno);
151+
auto ptr = mmap64(nullptr, length, PROT_READ, MAP_PRIVATE, fd, offset);
152+
if (ptr == MAP_FAILED) {
153+
LOG(FATAL) << "Failed to map: " << n << ". " << strerror(errno) << ". "
154+
<< "len:" << length << " off:" << offset << " it:" << fetch_it << std::endl;
155+
}
156+
157+
// read page
158+
auto fi = common::MemoryFixSizeBuffer(ptr, length);
159+
CHECK(fmt->Read(page.get(), &fi));
140160
LOG(INFO) << "Read a page in " << timer.ElapsedSeconds() << " seconds.";
161+
162+
// cleanup
163+
CHECK_NE(close(fd), -1) << "Faled to close: " << n << ". " << strerror(errno);
164+
CHECK_NE(munmap(ptr, length), -1) << "Faled to munmap: " << n << ". " << strerror(errno);
165+
141166
return page;
142167
});
168+
exec.Rethrow();
143169
}
144170
CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }),
145171
n_prefetch_batches)
@@ -153,16 +179,26 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
153179
common::Timer timer;
154180
timer.Start();
155181
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
156-
if (!fo_) {
157-
auto n = cache_info_->ShardName();
158-
fo_.reset(dmlc::Stream::Create(n.c_str(), "w"));
159-
}
160-
auto bytes = fmt->Write(*page_, fo_.get());
182+
183+
auto name = cache_info_->ShardName();
184+
std::unique_ptr<dmlc::Stream> fo{dmlc::Stream::Create(name.c_str(), "a")};
185+
186+
auto bytes = fmt->Write(*page_, fo.get());
187+
188+
// align for mmap
189+
auto page_size = getpagesize();
190+
CHECK(page_size != 0 && page_size % 2 == 0) << "Failed to get page size on the current system.";
191+
auto n = bytes / page_size;
192+
auto padded = (n + 1) * page_size;
193+
auto padding = padded - bytes;
194+
std::vector<std::uint8_t> padding_bytes(padding, 0);
195+
fo->Write(padding_bytes.data(), padding_bytes.size());
196+
161197
timer.Stop();
162198

163199
LOG(INFO) << static_cast<double>(bytes) / 1024.0 / 1024.0 << " MB written in "
164200
<< timer.ElapsedSeconds() << " seconds.";
165-
cache_info_->offset.push_back(bytes);
201+
cache_info_->Push(padded);
166202
}
167203

168204
virtual void Fetch() = 0;

0 commit comments

Comments
 (0)