1
1
/* !
2
2
* Copyright 2019 by Contributors
3
3
* \file array_interface.h
4
- * \brief Basic structure holding a reference to arrow columnar data format.
4
+ * \brief View of __array_interface__
5
5
*/
6
6
#ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_
7
7
#define XGBOOST_DATA_ARRAY_INTERFACE_H_
11
11
#include < string>
12
12
#include < utility>
13
13
14
+ #include " xgboost/base.h"
14
15
#include " xgboost/data.h"
15
16
#include " xgboost/json.h"
16
17
#include " xgboost/logging.h"
@@ -113,6 +114,7 @@ class ArrayInterfaceHandler {
113
114
get<Array const >(
114
115
obj.at (" data" ))
115
116
.at (0 ))));
117
+ CHECK (p_data);
116
118
return p_data;
117
119
}
118
120
@@ -186,7 +188,7 @@ class ArrayInterfaceHandler {
186
188
return 0 ;
187
189
}
188
190
189
- static std::pair<size_t , size_t > ExtractShape (
191
+ static std::pair<bst_row_t , bst_feature_t > ExtractShape (
190
192
std::map<std::string, Json> const & column) {
191
193
auto j_shape = get<Array const >(column.at (" shape" ));
192
194
auto typestr = get<String const >(column.at (" typestr" ));
@@ -201,12 +203,12 @@ class ArrayInterfaceHandler {
201
203
}
202
204
203
205
if (j_shape.size () == 1 ) {
204
- return {static_cast <size_t >(get<Integer const >(j_shape.at (0 ))), 1 };
206
+ return {static_cast <bst_row_t >(get<Integer const >(j_shape.at (0 ))), 1 };
205
207
} else {
206
208
CHECK_EQ (j_shape.size (), 2 )
207
209
<< " Only 1D or 2-D arrays currently supported." ;
208
- return {static_cast <size_t >(get<Integer const >(j_shape.at (0 ))),
209
- static_cast <size_t >(get<Integer const >(j_shape.at (1 )))};
210
+ return {static_cast <bst_row_t >(get<Integer const >(j_shape.at (0 ))),
211
+ static_cast <bst_feature_t >(get<Integer const >(j_shape.at (1 )))};
210
212
}
211
213
}
212
214
template <typename T>
@@ -219,7 +221,6 @@ class ArrayInterfaceHandler {
219
221
CHECK_EQ (typestr.at (2 ), static_cast <char >(sizeof (T) + 48 ))
220
222
<< " Input data type and typestr mismatch. typestr: " << typestr;
221
223
222
-
223
224
auto shape = ExtractShape (column);
224
225
225
226
T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column);
@@ -231,8 +232,8 @@ class ArrayInterfaceHandler {
231
232
class ArrayInterface {
232
233
public:
233
234
ArrayInterface () = default ;
234
- explicit ArrayInterface (std::map<std::string, Json> const &column,
235
- bool allow_mask = true ) {
235
+ void Initialize (std::map<std::string, Json> const &column,
236
+ bool allow_mask = true ) {
236
237
ArrayInterfaceHandler::Validate (column);
237
238
data = ArrayInterfaceHandler::GetPtrFromArrayData<void *>(column);
238
239
CHECK (data) << " Column is null" ;
@@ -263,6 +264,25 @@ class ArrayInterface {
263
264
this ->CheckType ();
264
265
}
265
266
267
+ explicit ArrayInterface (std::string const & str, bool allow_mask = true ) {
268
+ auto jinterface = Json::Load ({str.c_str (), str.size ()});
269
+ if (IsA<Object>(jinterface)) {
270
+ this ->Initialize (get<Object const >(jinterface), allow_mask);
271
+ return ;
272
+ }
273
+ if (IsA<Array>(jinterface)) {
274
+ CHECK_EQ (get<Array const >(jinterface).size (), 1 )
275
+ << " Column: " << ArrayInterfaceErrors::Dimension (1 );
276
+ this ->Initialize (get<Object const >(get<Array const >(jinterface)[0 ]), allow_mask);
277
+ return ;
278
+ }
279
+ }
280
+
281
+ explicit ArrayInterface (std::map<std::string, Json> const &column,
282
+ bool allow_mask = true ) {
283
+ this ->Initialize (column, allow_mask);
284
+ }
285
+
266
286
void CheckType () const {
267
287
if (type[1 ] == ' f' && type[2 ] == ' 4' ) {
268
288
return ;
@@ -291,6 +311,7 @@ class ArrayInterface {
291
311
}
292
312
293
313
XGBOOST_DEVICE float GetElement (size_t idx) const {
314
+ SPAN_CHECK (idx < num_cols * num_rows);
294
315
if (type[1 ] == ' f' && type[2 ] == ' 4' ) {
295
316
return reinterpret_cast <float *>(data)[idx];
296
317
} else if (type[1 ] == ' f' && type[2 ] == ' 8' ) {
@@ -318,8 +339,8 @@ class ArrayInterface {
318
339
}
319
340
320
341
RBitField8 valid;
321
- int32_t num_rows;
322
- int32_t num_cols;
342
+ bst_row_t num_rows;
343
+ bst_feature_t num_cols;
323
344
void * data;
324
345
char type[3 ];
325
346
};
0 commit comments