9
9
10
10
#include " google/protobuf/pyext/map_container.h"
11
11
12
+ #include < Python.h>
13
+
14
+ #include < cstddef>
12
15
#include < cstdint>
13
16
#include < memory>
14
17
#include < string>
@@ -30,7 +33,7 @@ namespace python {
30
33
class MapReflectionFriend {
31
34
public:
32
35
// Methods that are in common between the map types.
33
- static PyObject* Contains (PyObject* _self, PyObject* key);
36
+ static int Contains (PyObject* _self, PyObject* key);
34
37
static Py_ssize_t Length (PyObject* _self);
35
38
static PyObject* GetIterator (PyObject* _self);
36
39
static PyObject* IterNext (PyObject* _self);
@@ -328,7 +331,7 @@ PyObject* MapReflectionFriend::MergeFrom(PyObject* _self, PyObject* arg) {
328
331
Py_RETURN_NONE;
329
332
}
330
333
331
- PyObject* MapReflectionFriend::Contains (PyObject* _self, PyObject* key) {
334
+ int MapReflectionFriend::Contains (PyObject* _self, PyObject* key) {
332
335
MapContainer* self = GetMap (_self);
333
336
334
337
const Message* message = self->parent ->message ;
@@ -337,14 +340,14 @@ PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
337
340
MapKey map_key;
338
341
339
342
if (!PythonToMapKey (self, key, &map_key, &map_key_string)) {
340
- return nullptr ;
343
+ return - 1 ;
341
344
}
342
345
343
346
if (reflection->ContainsMapKey (*message, self->parent_field_descriptor ,
344
347
map_key)) {
345
- Py_RETURN_TRUE ;
348
+ return 1 ;
346
349
} else {
347
- Py_RETURN_FALSE ;
350
+ return 0 ;
348
351
}
349
352
}
350
353
@@ -450,11 +453,11 @@ static PyObject* ScalarMapSetdefault(PyObject* self, PyObject* args) {
450
453
return nullptr ;
451
454
}
452
455
453
- ScopedPyObjectPtr is_present ( MapReflectionFriend::Contains (self, key) );
454
- if (is_present == nullptr ) {
456
+ int is_present = MapReflectionFriend::Contains (self, key);
457
+ if (is_present < 0 ) {
455
458
return nullptr ;
456
459
}
457
- if (PyObject_IsTrue ( is_present. get ()) ) {
460
+ if (is_present) {
458
461
return MapReflectionFriend::ScalarMapGetItem (self, key);
459
462
}
460
463
@@ -476,12 +479,12 @@ static PyObject* ScalarMapGet(PyObject* self, PyObject* args,
476
479
return nullptr ;
477
480
}
478
481
479
- ScopedPyObjectPtr is_present ( MapReflectionFriend::Contains (self, key) );
480
- if (is_present. get () == nullptr ) {
482
+ int is_present = MapReflectionFriend::Contains (self, key);
483
+ if (is_present < 0 ) {
481
484
return nullptr ;
482
485
}
483
486
484
- if (PyObject_IsTrue ( is_present. get ()) ) {
487
+ if (is_present) {
485
488
return MapReflectionFriend::ScalarMapGetItem (self, key);
486
489
} else {
487
490
if (default_value != nullptr ) {
@@ -534,8 +537,6 @@ static void ScalarMapDealloc(PyObject* _self) {
534
537
}
535
538
536
539
static PyMethodDef ScalarMapMethods[] = {
537
- {" __contains__" , MapReflectionFriend::Contains, METH_O,
538
- " Tests whether a key is a member of the map." },
539
540
{" clear" , (PyCFunction)Clear, METH_NOARGS,
540
541
" Removes all elements from the map." },
541
542
{" setdefault" , (PyCFunction)ScalarMapSetdefault, METH_VARARGS,
@@ -561,6 +562,7 @@ static PyType_Slot ScalarMapContainer_Type_slots[] = {
561
562
{Py_mp_length, (void *)MapReflectionFriend::Length},
562
563
{Py_mp_subscript, (void *)MapReflectionFriend::ScalarMapGetItem},
563
564
{Py_mp_ass_subscript, (void *)MapReflectionFriend::ScalarMapSetItem},
565
+ {Py_sq_contains, (void *)MapReflectionFriend::Contains},
564
566
{Py_tp_methods, (void *)ScalarMapMethods},
565
567
{Py_tp_iter, (void *)MapReflectionFriend::GetIterator},
566
568
{Py_tp_repr, (void *)MapReflectionFriend::ScalarMapToStr},
@@ -727,12 +729,12 @@ PyObject* MessageMapGet(PyObject* self, PyObject* args, PyObject* kwargs) {
727
729
return nullptr ;
728
730
}
729
731
730
- ScopedPyObjectPtr is_present ( MapReflectionFriend::Contains (self, key) );
731
- if (is_present. get () == nullptr ) {
732
+ int is_present = MapReflectionFriend::Contains (self, key);
733
+ if (is_present < 0 ) {
732
734
return nullptr ;
733
735
}
734
736
735
- if (PyObject_IsTrue ( is_present. get ()) ) {
737
+ if (is_present) {
736
738
return MapReflectionFriend::MessageMapGetItem (self, key);
737
739
} else {
738
740
if (default_value != nullptr ) {
@@ -757,8 +759,6 @@ static void MessageMapDealloc(PyObject* _self) {
757
759
}
758
760
759
761
static PyMethodDef MessageMapMethods[] = {
760
- {" __contains__" , (PyCFunction)MapReflectionFriend::Contains, METH_O,
761
- " Tests whether the map contains this element." },
762
762
{" clear" , (PyCFunction)Clear, METH_NOARGS,
763
763
" Removes all elements from the map." },
764
764
{" setdefault" , (PyCFunction)MessageMapSetdefault, METH_VARARGS,
@@ -786,6 +786,7 @@ static PyType_Slot MessageMapContainer_Type_slots[] = {
786
786
{Py_mp_length, (void *)MapReflectionFriend::Length},
787
787
{Py_mp_subscript, (void *)MapReflectionFriend::MessageMapGetItem},
788
788
{Py_mp_ass_subscript, (void *)MapReflectionFriend::MessageMapSetItem},
789
+ {Py_sq_contains, (void *)MapReflectionFriend::Contains},
789
790
{Py_tp_methods, (void *)MessageMapMethods},
790
791
{Py_tp_iter, (void *)MapReflectionFriend::GetIterator},
791
792
{Py_tp_repr, (void *)MapReflectionFriend::MessageMapToStr},
@@ -910,6 +911,30 @@ PyTypeObject MapIterator_Type = {
910
911
nullptr , // tp_init
911
912
};
912
913
914
+ PyTypeObject* Py_AddClassWithRegister (PyType_Spec* spec, PyObject* virtual_base,
915
+ const char ** methods) {
916
+ PyObject* type = PyType_FromSpec (spec);
917
+ PyObject* ret1 = PyObject_CallMethod (virtual_base, " register" , " O" , type);
918
+ if (!ret1) {
919
+ Py_XDECREF (type);
920
+ return nullptr ;
921
+ }
922
+ for (size_t i = 0 ; methods[i] != nullptr ; i++) {
923
+ PyObject* method = PyObject_GetAttrString (virtual_base, methods[i]);
924
+ if (!method) {
925
+ Py_XDECREF (type);
926
+ return nullptr ;
927
+ }
928
+ int ret2 = PyObject_SetAttrString (type, methods[i], method);
929
+ if (ret2 < 0 ) {
930
+ Py_XDECREF (type);
931
+ return nullptr ;
932
+ }
933
+ }
934
+
935
+ return (PyTypeObject*)type;
936
+ }
937
+
913
938
bool InitMapContainers () {
914
939
// ScalarMapContainer_Type derives from our MutableMapping type.
915
940
ScopedPyObjectPtr abc (PyImport_ImportModule (" collections.abc" ));
@@ -923,21 +948,20 @@ bool InitMapContainers() {
923
948
return false ;
924
949
}
925
950
926
- Py_INCREF (mutable_mapping.get ());
927
- ScopedPyObjectPtr bases (PyTuple_Pack (1 , mutable_mapping.get ()));
928
- if (bases == nullptr ) {
929
- return false ;
930
- }
951
+ const char * methods[] = {" keys" , " items" , " values" , " __eq__" , " __ne__" ,
952
+ " pop" , " popitem" , " update" , nullptr };
931
953
932
- ScalarMapContainer_Type = reinterpret_cast <PyTypeObject*>(
933
- PyType_FromSpecWithBases (&ScalarMapContainer_Type_spec, bases.get ()));
954
+ ScalarMapContainer_Type =
955
+ reinterpret_cast <PyTypeObject*>(Py_AddClassWithRegister (
956
+ &ScalarMapContainer_Type_spec, mutable_mapping.get (), methods));
934
957
935
958
if (PyType_Ready (&MapIterator_Type) < 0 ) {
936
959
return false ;
937
960
}
938
961
939
- MessageMapContainer_Type = reinterpret_cast <PyTypeObject*>(
940
- PyType_FromSpecWithBases (&MessageMapContainer_Type_spec, bases.get ()));
962
+ MessageMapContainer_Type =
963
+ reinterpret_cast <PyTypeObject*>(Py_AddClassWithRegister (
964
+ &MessageMapContainer_Type_spec, mutable_mapping.get (), methods));
941
965
return true ;
942
966
}
943
967
0 commit comments