Skip to content

Commit 81da6b9

Browse files
anandoleecopybara-github
authored andcommitted
Breaking Change: Python setdefault behavior change for map field.
-setdefault will be similar with dict for ScalarMap. But both key and value must be set. -setdefault will be rejected for MessageMap. PiperOrigin-RevId: 695768629
1 parent 54dc8c7 commit 81da6b9

File tree

4 files changed

+129
-2
lines changed

4 files changed

+129
-2
lines changed

python/google/protobuf/internal/containers.py

+13
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,13 @@ def __iter__(self) -> Iterator[_K]:
412412
def __repr__(self) -> str:
413413
return repr(self._values)
414414

415+
def setdefault(self, key: _K, value: Optional[_V] = None) -> _V:
416+
if value == None:
417+
raise ValueError('The value for scalar map setdefault must be set.')
418+
if key not in self._values:
419+
self.__setitem__(key, value)
420+
return self[key]
421+
415422
def MergeFrom(self, other: 'ScalarMap[_K, _V]') -> None:
416423
self._values.update(other._values)
417424
self._message_listener.Modified()
@@ -526,6 +533,12 @@ def __iter__(self) -> Iterator[_K]:
526533
def __repr__(self) -> str:
527534
return repr(self._values)
528535

536+
def setdefault(self, key: _K, value: Optional[_V] = None) -> _V:
537+
raise NotImplementedError(
538+
'Set message map value directly is not supported, call'
539+
' my_map[key].foo = 5'
540+
)
541+
529542
def MergeFrom(self, other: 'MessageMap[_K, _V]') -> None:
530543
# pylint: disable=protected-access
531544
for key in other._values:

python/google/protobuf/internal/message_test.py

+28
Original file line numberDiff line numberDiff line change
@@ -1900,13 +1900,41 @@ def testScalarMapComparison(self):
19001900

19011901
self.assertEqual(msg1.map_int32_int32, msg2.map_int32_int32)
19021902

1903+
def testScalarMapSetdefault(self):
1904+
msg = map_unittest_pb2.TestMap()
1905+
value = msg.map_int32_int32.setdefault(123, 888)
1906+
self.assertEqual(value, 888)
1907+
self.assertEqual(msg.map_int32_int32[123], 888)
1908+
value = msg.map_int32_int32.setdefault(123, 777)
1909+
self.assertEqual(value, 888)
1910+
1911+
with self.assertRaises(ValueError):
1912+
value = msg.map_int32_int32.setdefault(1001)
1913+
self.assertNotIn(1001, msg.map_int32_int32)
1914+
with self.assertRaises(TypeError):
1915+
value = msg.map_int32_int32.setdefault()
1916+
with self.assertRaises(TypeError):
1917+
value = msg.map_int32_int32.setdefault(1, 2, 3)
1918+
with self.assertRaises(TypeError):
1919+
value = msg.map_int32_int32.setdefault("1", 2)
1920+
with self.assertRaises(TypeError):
1921+
value = msg.map_int32_int32.setdefault(1, "2")
1922+
19031923
def testMessageMapComparison(self):
19041924
msg1 = map_unittest_pb2.TestMap()
19051925
msg2 = map_unittest_pb2.TestMap()
19061926

19071927
self.assertEqual(msg1.map_int32_foreign_message,
19081928
msg2.map_int32_foreign_message)
19091929

1930+
def testMessageMapSetdefault(self):
1931+
msg = map_unittest_pb2.TestMap()
1932+
msg.map_int32_foreign_message[123].c = 888
1933+
with self.assertRaises(NotImplementedError):
1934+
msg.map_int32_foreign_message.setdefault(
1935+
1, msg.map_int32_foreign_message[123]
1936+
)
1937+
19101938
def testMapGet(self):
19111939
# Need to test that get() properly returns the default, even though the dict
19121940
# has defaultdict-like semantics.

python/google/protobuf/pyext/map_container.cc

+39
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,35 @@ int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key,
438438
}
439439
}
440440

441+
static PyObject* ScalarMapSetdefault(PyObject* self, PyObject* args) {
442+
PyObject* key = nullptr;
443+
PyObject* default_value = Py_None;
444+
445+
if (!PyArg_UnpackTuple(args, "setdefault", 1, 2, &key, &default_value)) {
446+
return nullptr;
447+
}
448+
449+
if (default_value == Py_None) {
450+
PyErr_Format(PyExc_ValueError,
451+
"The value for scalar map setdefault must be set.");
452+
return nullptr;
453+
}
454+
455+
ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
456+
if (is_present == nullptr) {
457+
return nullptr;
458+
}
459+
if (PyObject_IsTrue(is_present.get())) {
460+
return MapReflectionFriend::ScalarMapGetItem(self, key);
461+
}
462+
463+
if (MapReflectionFriend::ScalarMapSetItem(self, key, default_value) < 0) {
464+
return nullptr;
465+
}
466+
Py_INCREF(default_value);
467+
return default_value;
468+
}
469+
441470
static PyObject* ScalarMapGet(PyObject* self, PyObject* args,
442471
PyObject* kwargs) {
443472
static const char* kwlist[] = {"key", "default", nullptr};
@@ -512,6 +541,8 @@ static PyMethodDef ScalarMapMethods[] = {
512541
"Tests whether a key is a member of the map."},
513542
{"clear", (PyCFunction)Clear, METH_NOARGS,
514543
"Removes all elements from the map."},
544+
{"setdefault", (PyCFunction)ScalarMapSetdefault, METH_VARARGS,
545+
"If the key does not exist, insert the key, with the specified value"},
515546
{"get", (PyCFunction)ScalarMapGet, METH_VARARGS | METH_KEYWORDS,
516547
"Gets the value for the given key if present, or otherwise a default"},
517548
{"GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS,
@@ -685,6 +716,12 @@ PyObject* MapReflectionFriend::MessageMapToStr(PyObject* _self) {
685716
return PyObject_Repr(dict.get());
686717
}
687718

719+
static PyObject* MessageMapSetdefault(PyObject* self, PyObject* args) {
720+
PyErr_Format(PyExc_NotImplementedError,
721+
"Set message map value directly is not supported.");
722+
return nullptr;
723+
}
724+
688725
PyObject* MessageMapGet(PyObject* self, PyObject* args, PyObject* kwargs) {
689726
static const char* kwlist[] = {"key", "default", nullptr};
690727
PyObject* key;
@@ -729,6 +766,8 @@ static PyMethodDef MessageMapMethods[] = {
729766
"Tests whether the map contains this element."},
730767
{"clear", (PyCFunction)Clear, METH_NOARGS,
731768
"Removes all elements from the map."},
769+
{"setdefault", (PyCFunction)MessageMapSetdefault, METH_VARARGS,
770+
"setdefault is disallowed in MessageMap."},
732771
{"get", (PyCFunction)MessageMapGet, METH_VARARGS | METH_KEYWORDS,
733772
"Gets the value for the given key if present, or otherwise a default"},
734773
{"get_or_create", MapReflectionFriend::MessageMapGetItem, METH_O,

python/map.c

+49-2
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,48 @@ static PyObject* PyUpb_MapContainer_Clear(PyObject* _self, PyObject* key) {
217217
Py_RETURN_NONE;
218218
}
219219

220+
static PyObject* PyUpb_ScalarMapContainer_Setdefault(PyObject* _self,
221+
PyObject* args) {
222+
PyObject* key;
223+
PyObject* default_value = Py_None;
224+
225+
if (!PyArg_UnpackTuple(args, "setdefault", 1, 2, &key, &default_value)) {
226+
return NULL;
227+
}
228+
229+
if (default_value == Py_None) {
230+
PyErr_Format(PyExc_ValueError,
231+
"The value for scalar map setdefault must be set.");
232+
return NULL;
233+
}
234+
235+
PyUpb_MapContainer* self = (PyUpb_MapContainer*)_self;
236+
upb_Map* map = PyUpb_MapContainer_EnsureReified(_self);
237+
const upb_FieldDef* f = PyUpb_MapContainer_GetField(self);
238+
const upb_MessageDef* entry_m = upb_FieldDef_MessageSubDef(f);
239+
const upb_FieldDef* key_f = upb_MessageDef_Field(entry_m, 0);
240+
const upb_FieldDef* val_f = upb_MessageDef_Field(entry_m, 1);
241+
upb_MessageValue u_key, u_val;
242+
if (!PyUpb_PyToUpb(key, key_f, &u_key, NULL)) return NULL;
243+
if (upb_Map_Get(map, u_key, &u_val)) {
244+
return PyUpb_UpbToPy(u_val, val_f, self->arena);
245+
}
246+
247+
upb_Arena* arena = PyUpb_Arena_Get(self->arena);
248+
if (!PyUpb_PyToUpb(default_value, val_f, &u_val, arena)) return NULL;
249+
if (!PyUpb_MapContainer_Set(self, map, u_key, u_val, arena)) return NULL;
250+
251+
Py_INCREF(default_value);
252+
return default_value;
253+
}
254+
255+
static PyObject* PyUpb_MessageMapContainer_Setdefault(PyObject* self,
256+
PyObject* args) {
257+
PyErr_Format(PyExc_NotImplementedError,
258+
"Set message map value directly is not supported.");
259+
return NULL;
260+
}
261+
220262
static PyObject* PyUpb_MapContainer_Get(PyObject* _self, PyObject* args,
221263
PyObject* kwargs) {
222264
PyUpb_MapContainer* self = (PyUpb_MapContainer*)_self;
@@ -331,6 +373,9 @@ PyObject* PyUpb_MapContainer_GetOrCreateWrapper(upb_Map* map,
331373
static PyMethodDef PyUpb_ScalarMapContainer_Methods[] = {
332374
{"clear", PyUpb_MapContainer_Clear, METH_NOARGS,
333375
"Removes all elements from the map."},
376+
{"setdefault", (PyCFunction)PyUpb_ScalarMapContainer_Setdefault,
377+
METH_VARARGS,
378+
"If the key does not exist, insert the key, with the specified value"},
334379
{"get", (PyCFunction)PyUpb_MapContainer_Get, METH_VARARGS | METH_KEYWORDS,
335380
"Gets the value for the given key if present, or otherwise a default"},
336381
{"GetEntryClass", PyUpb_MapContainer_GetEntryClass, METH_NOARGS,
@@ -373,6 +418,8 @@ static PyType_Spec PyUpb_ScalarMapContainer_Spec = {
373418
static PyMethodDef PyUpb_MessageMapContainer_Methods[] = {
374419
{"clear", PyUpb_MapContainer_Clear, METH_NOARGS,
375420
"Removes all elements from the map."},
421+
{"setdefault", (PyCFunction)PyUpb_MessageMapContainer_Setdefault,
422+
METH_VARARGS, "setdefault is disallowed in MessageMap."},
376423
{"get", (PyCFunction)PyUpb_MapContainer_Get, METH_VARARGS | METH_KEYWORDS,
377424
"Gets the value for the given key if present, or otherwise a default"},
378425
{"get_or_create", PyUpb_MapContainer_Subscript, METH_O,
@@ -480,8 +527,8 @@ bool PyUpb_Map_Init(PyObject* m) {
480527
PyObject* base = GetMutableMappingBase();
481528
if (!base) return false;
482529

483-
const char* methods[] = {"keys", "items", "values", "__eq__", "__ne__",
484-
"pop", "popitem", "update", "setdefault", NULL};
530+
const char* methods[] = {"keys", "items", "values", "__eq__", "__ne__",
531+
"pop", "popitem", "update", NULL};
485532

486533
state->message_map_container_type = PyUpb_AddClassWithRegister(
487534
m, &PyUpb_MessageMapContainer_Spec, base, methods);

0 commit comments

Comments
 (0)