Skip to content

Commit e8ecafb

Browse files
committed
Accept string for ArrayInterface constructor.
1 parent b47b5ac commit e8ecafb

File tree

4 files changed

+90
-12
lines changed

4 files changed

+90
-12
lines changed

src/data/array_interface.h

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/*!
22
* Copyright 2019 by Contributors
33
* \file array_interface.h
4-
* \brief Basic structure holding a reference to arrow columnar data format.
4+
* \brief View of __array_interface__
55
*/
66
#ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_
77
#define XGBOOST_DATA_ARRAY_INTERFACE_H_
@@ -11,6 +11,7 @@
1111
#include <string>
1212
#include <utility>
1313

14+
#include "xgboost/base.h"
1415
#include "xgboost/data.h"
1516
#include "xgboost/json.h"
1617
#include "xgboost/logging.h"
@@ -113,6 +114,7 @@ class ArrayInterfaceHandler {
113114
get<Array const>(
114115
obj.at("data"))
115116
.at(0))));
117+
CHECK(p_data);
116118
return p_data;
117119
}
118120

@@ -186,7 +188,7 @@ class ArrayInterfaceHandler {
186188
return 0;
187189
}
188190

189-
static std::pair<size_t, size_t> ExtractShape(
191+
static std::pair<bst_row_t, bst_feature_t> ExtractShape(
190192
std::map<std::string, Json> const& column) {
191193
auto j_shape = get<Array const>(column.at("shape"));
192194
auto typestr = get<String const>(column.at("typestr"));
@@ -201,12 +203,12 @@ class ArrayInterfaceHandler {
201203
}
202204

203205
if (j_shape.size() == 1) {
204-
return {static_cast<size_t>(get<Integer const>(j_shape.at(0))), 1};
206+
return {static_cast<bst_row_t>(get<Integer const>(j_shape.at(0))), 1};
205207
} else {
206208
CHECK_EQ(j_shape.size(), 2)
207209
<< "Only 1D or 2-D arrays currently supported.";
208-
return {static_cast<size_t>(get<Integer const>(j_shape.at(0))),
209-
static_cast<size_t>(get<Integer const>(j_shape.at(1)))};
210+
return {static_cast<bst_row_t>(get<Integer const>(j_shape.at(0))),
211+
static_cast<bst_feature_t>(get<Integer const>(j_shape.at(1)))};
210212
}
211213
}
212214
template <typename T>
@@ -219,7 +221,6 @@ class ArrayInterfaceHandler {
219221
CHECK_EQ(typestr.at(2), static_cast<char>(sizeof(T) + 48))
220222
<< "Input data type and typestr mismatch. typestr: " << typestr;
221223

222-
223224
auto shape = ExtractShape(column);
224225

225226
T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column);
@@ -231,8 +232,8 @@ class ArrayInterfaceHandler {
231232
class ArrayInterface {
232233
public:
233234
ArrayInterface() = default;
234-
explicit ArrayInterface(std::map<std::string, Json> const &column,
235-
bool allow_mask = true) {
235+
void Initialize(std::map<std::string, Json> const &column,
236+
bool allow_mask = true) {
236237
ArrayInterfaceHandler::Validate(column);
237238
data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
238239
CHECK(data) << "Column is null";
@@ -263,6 +264,25 @@ class ArrayInterface {
263264
this->CheckType();
264265
}
265266

267+
explicit ArrayInterface(std::string const& str, bool allow_mask = true) {
268+
auto jinterface = Json::Load({str.c_str(), str.size()});
269+
if (IsA<Object>(jinterface)) {
270+
this->Initialize(get<Object const>(jinterface), allow_mask);
271+
return;
272+
}
273+
if (IsA<Array>(jinterface)) {
274+
CHECK_EQ(get<Array const>(jinterface).size(), 1)
275+
<< "Column: " << ArrayInterfaceErrors::Dimension(1);
276+
this->Initialize(get<Object const>(get<Array const>(jinterface)[0]), allow_mask);
277+
return;
278+
}
279+
}
280+
281+
explicit ArrayInterface(std::map<std::string, Json> const &column,
282+
bool allow_mask = true) {
283+
this->Initialize(column, allow_mask);
284+
}
285+
266286
void CheckType() const {
267287
if (type[1] == 'f' && type[2] == '4') {
268288
return;
@@ -291,6 +311,7 @@ class ArrayInterface {
291311
}
292312

293313
XGBOOST_DEVICE float GetElement(size_t idx) const {
314+
SPAN_CHECK(idx < num_cols * num_rows);
294315
if (type[1] == 'f' && type[2] == '4') {
295316
return reinterpret_cast<float*>(data)[idx];
296317
} else if (type[1] == 'f' && type[2] == '8') {
@@ -318,8 +339,8 @@ class ArrayInterface {
318339
}
319340

320341
RBitField8 valid;
321-
int32_t num_rows;
322-
int32_t num_cols;
342+
bst_row_t num_rows;
343+
bst_feature_t num_cols;
323344
void* data;
324345
char type[3];
325346
};

src/data/data.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
6363
auto const& j_arr = get<Array>(j_interface);
6464
CHECK_EQ(j_arr.size(), 1)
6565
<< "MetaInfo: " << c_key << ". " << ArrayInterfaceErrors::Dimension(1);
66-
ArrayInterface array_interface(get<Object const>(j_arr[0]));
66+
ArrayInterface array_interface(interface_str);
6767
std::string key{c_key};
6868
CHECK(!array_interface.valid.Data())
6969
<< "Meta info " << key << " should be dense, found validity mask";
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*!
2+
* Copyright 2020 by XGBoost Contributors
3+
*/
4+
#include <gtest/gtest.h>
5+
#include <xgboost/host_device_vector.h>
6+
#include "../helpers.h"
7+
#include "../../../src/data/array_interface.h"
8+
9+
namespace xgboost {
10+
TEST(ArrayInterface, Initialize) {
11+
size_t constexpr kRows = 10, kCols = 10;
12+
HostDeviceVector<float> storage;
13+
auto array = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage);
14+
auto arr_interface = ArrayInterface(array);
15+
ASSERT_EQ(arr_interface.num_rows, kRows);
16+
ASSERT_EQ(arr_interface.num_cols, kCols);
17+
ASSERT_EQ(arr_interface.data, storage.ConstHostPointer());
18+
}
19+
20+
TEST(ArrayInterface, Error) {
21+
constexpr size_t kRows = 16, kCols = 10;
22+
Json column { Object() };
23+
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
24+
column["shape"] = Array(j_shape);
25+
std::vector<Json> j_data {
26+
Json(Integer(reinterpret_cast<Integer::Int>(nullptr))),
27+
Json(Boolean(false))};
28+
29+
auto const& column_obj = get<Object>(column);
30+
// missing version
31+
EXPECT_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj), dmlc::Error);
32+
column["version"] = Integer(static_cast<Integer::Int>(1));
33+
// missing data
34+
EXPECT_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj), dmlc::Error);
35+
column["data"] = j_data;
36+
// missing typestr
37+
EXPECT_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj), dmlc::Error);
38+
column["typestr"] = String("<f4");
39+
// nullptr is not valid
40+
EXPECT_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj), dmlc::Error);
41+
42+
HostDeviceVector<float> storage;
43+
auto array = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage);
44+
j_data = {
45+
Json(Integer(reinterpret_cast<Integer::Int>(storage.ConstHostPointer()))),
46+
Json(Boolean(false))};
47+
column["data"] = j_data;
48+
EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj));
49+
}
50+
51+
} // namespace xgboost

tests/cpp/helpers.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,13 @@ Json RandomDataGenerator::ArrayInterfaceImpl(HostDeviceVector<float> *storage,
182182
this->GenerateDense(storage);
183183
Json array_interface {Object()};
184184
array_interface["data"] = std::vector<Json>(2);
185-
array_interface["data"][0] = Integer(reinterpret_cast<int64_t>(storage->DevicePointer()));
185+
if (storage->DeviceCanRead()) {
186+
array_interface["data"][0] =
187+
Integer(reinterpret_cast<int64_t>(storage->ConstDevicePointer()));
188+
} else {
189+
array_interface["data"][0] =
190+
Integer(reinterpret_cast<int64_t>(storage->ConstHostPointer()));
191+
}
186192
array_interface["data"][1] = Boolean(false);
187193

188194
array_interface["shape"] = std::vector<Json>(2);

0 commit comments

Comments
 (0)