Skip to content

Commit 1102ab5

Browse files
authored
Decouple complex ndarray from stl/complex.h. (#485)
* Check that pass_complex64_const() accepts writable array. * Decouple complex ndarray from stl/complex.h.
1 parent 9c8096c commit 1102ab5

File tree

6 files changed

+29
-17
lines changed

6 files changed

+29
-17
lines changed

docs/ndarray.rst

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,8 @@ The following constraints are available
115115
- A scalar type (``float``, ``uint8_t``, etc.) constrains the representation
116116
of the ndarray.
117117

118-
Complex arrays (i.e., ones based on ``std::complex<float>`` or
119-
``std::complex<double>``) are supported but additionally require including
120-
the header file ``<nanobind/stl/complex.h>``.
118+
Complex arrays (e.g., based on ``std::complex<float>`` or
119+
``std::complex<double>``) are supported.
121120

122121
- This scalar type can be further annotated with ``const``, which is necessary
123122
if you plan to call nanobind functions with arrays that do not permit write

include/nanobind/nb_traits.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,28 @@ template <typename T> using is_class_caster_test = std::enable_if_t<T::IsClass>;
173173
template <typename Caster>
174174
constexpr bool is_class_caster_v = detail::detector<void, is_class_caster_test, Caster>::value;
175175

176+
// Primary template
177+
template<typename T, typename SFINAE = void>
178+
struct is_complex : std::false_type {};
179+
180+
// Specialization if `T` is complex, i.e., `T` has a member type `value_type`,
181+
// member functions `real()` and `imag()` that return such, and the size of
182+
// `T` is twice that of `value_type`.
183+
template<typename T>
184+
struct is_complex<T, std::enable_if_t<std::is_same_v<
185+
decltype(std::declval<T>().real()),
186+
typename T::value_type>
187+
&& std::is_same_v<
188+
decltype(std::declval<T>().imag()),
189+
typename T::value_type>
190+
&& (sizeof(T) ==
191+
2 * sizeof(typename T::value_type))>>
192+
: std::true_type {};
193+
194+
/// True if the type `T` is a complete type representing a complex number.
195+
template<typename T>
196+
inline constexpr bool is_complex_v = is_complex<T>::value;
197+
176198
NAMESPACE_END(detail)
177199

178200
template <typename... Args>

include/nanobind/ndarray.h

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,6 @@ struct dltensor {
6565

6666
NAMESPACE_END(dlpack)
6767

68-
NAMESPACE_BEGIN(detail)
69-
70-
template <typename T> struct is_complex : std::false_type { };
71-
72-
NAMESPACE_END(detail)
73-
7468
template <ssize_t... Is> struct shape {
7569
static_assert(
7670
((Is >= 0 || Is == -1) && ...),
@@ -89,7 +83,7 @@ struct jax { };
8983
struct ro { };
9084

9185
template <typename T> struct ndarray_traits {
92-
static constexpr bool is_complex = detail::is_complex<T>::value;
86+
static constexpr bool is_complex = detail::is_complex_v<T>;
9387
static constexpr bool is_float = std::is_floating_point_v<T>;
9488
static constexpr bool is_bool = std::is_same_v<std::remove_cv_t<T>, bool>;
9589
static constexpr bool is_int = std::is_integral_v<T> && !is_bool;

include/nanobind/stl/complex.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@
1515
NAMESPACE_BEGIN(NB_NAMESPACE)
1616
NAMESPACE_BEGIN(detail)
1717

18-
template <typename T> struct is_complex;
19-
template<typename T> struct is_complex<const T> : is_complex<T> {};
20-
template<typename T> struct is_complex<std::complex<T>> : std::true_type {};
21-
2218
template <typename T> struct type_caster<std::complex<T>> {
2319
NB_TYPE_CASTER(std::complex<T>, const_name("complex"))
2420

tests/test_ndarray.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include <nanobind/nanobind.h>
22
#include <nanobind/ndarray.h>
3-
#include <nanobind/stl/complex.h>
43
#include <algorithm>
4+
#include <complex>
55
#include <vector>
66

77
namespace nb = nanobind;
@@ -21,7 +21,7 @@ namespace nanobind {
2121
static constexpr bool is_int = false;
2222
static constexpr bool is_signed = true;
2323
};
24-
};
24+
}
2525
#endif
2626

2727
NB_MODULE(test_ndarray_ext, m) {

tests/test_ndarray.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def test03_constrain_dtype():
8989
t.pass_uint32(a_u32)
9090
t.pass_float32(a_f32)
9191
t.pass_complex64(a_cf64)
92+
t.pass_complex64_const(a_cf64)
9293
t.pass_bool(a_bool)
9394

9495
a_f32_const = a_f32.copy()
@@ -646,7 +647,7 @@ def test_uint32_complex_do_not_convert(variant):
646647
assert np.all(data == data2)
647648

648649
@needs_numpy
649-
def test26_check_generic():
650+
def test36_check_generic():
650651
class DLPackWrapper:
651652
def __init__(self, o):
652653
self.o = o

0 commit comments

Comments
 (0)