Skip to content

Add option to pickler dumps() for best-effort determinism #34698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion sdks/python/apache_beam/internal/cloudpickle_pickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,16 @@ def _pickle_enum_descriptor(obj):
return _reconstruct_enum_descriptor, (full_name, )


def dumps(o, enable_trace=True, use_zlib=False) -> bytes:
def dumps(
o,
enable_trace=True,
use_zlib=False,
enable_best_effort_determinism=False) -> bytes:
"""For internal use only; no backwards-compatibility guarantees."""
if enable_best_effort_determinism:
# TODO: Add support once https://github.com/cloudpipe/cloudpickle/pull/563
# is merged in.
raise NotImplementedError('This option has only been implemeneted for dill')
with _pickle_lock:
with io.BytesIO() as file:
pickler = cloudpickle.CloudPickler(file)
Expand Down
4 changes: 4 additions & 0 deletions sdks/python/apache_beam/internal/cloudpickle_pickler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ def test_dataclass(self):
self.assertEqual(DataClass(datum='abc'), loads(dumps(DataClass(datum='abc'))))
''')

def test_best_effort_determinism_not_implemented(self):
with self.assertRaises(NotImplementedError):
dumps(123, enable_best_effort_determinism=True)


if __name__ == '__main__':
unittest.main()
17 changes: 16 additions & 1 deletion sdks/python/apache_beam/internal/dill_pickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@

import dill

from apache_beam.internal.set_pickler import save_frozenset
from apache_beam.internal.set_pickler import save_set

settings = {'dill_byref': None}

patch_save_code = sys.version_info >= (3, 10) and dill.__version__ == "0.3.1.1"
Expand Down Expand Up @@ -376,9 +379,18 @@ def new_log_info(msg, *args, **kwargs):
logging.getLogger('dill').setLevel(logging.WARN)


def dumps(o, enable_trace=True, use_zlib=False) -> bytes:
def dumps(
o,
enable_trace=True,
use_zlib=False,
enable_best_effort_determinism=False) -> bytes:
"""For internal use only; no backwards-compatibility guarantees."""
with _pickle_lock:
if enable_best_effort_determinism:
old_save_set = dill.dill.Pickler.dispatch[set]
old_save_frozenset = dill.dill.Pickler.dispatch[frozenset]
dill.dill.pickle(set, save_set)
dill.dill.pickle(frozenset, save_frozenset)
try:
s = dill.dumps(o, byref=settings['dill_byref'])
except Exception: # pylint: disable=broad-except
Expand All @@ -389,6 +401,9 @@ def dumps(o, enable_trace=True, use_zlib=False) -> bytes:
raise
finally:
dill.dill._trace(False) # pylint: disable=protected-access
if enable_best_effort_determinism:
dill.dill.pickle(set, old_save_set)
dill.dill.pickle(frozenset, old_save_frozenset)

# Compress as compactly as possible (compresslevel=9) to decrease peak memory
# usage (of multiple in-memory copies) and to avoid hitting protocol buffer
Expand Down
11 changes: 9 additions & 2 deletions sdks/python/apache_beam/internal/pickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,17 @@
desired_pickle_lib = dill_pickler


def dumps(o, enable_trace=True, use_zlib=False) -> bytes:
def dumps(
o,
enable_trace=True,
use_zlib=False,
enable_best_effort_determinism=False) -> bytes:

return desired_pickle_lib.dumps(
o, enable_trace=enable_trace, use_zlib=use_zlib)
o,
enable_trace=enable_trace,
use_zlib=use_zlib,
enable_best_effort_determinism=enable_best_effort_determinism)


def loads(encoded, enable_trace=True, use_zlib=False):
Expand Down
47 changes: 47 additions & 0 deletions sdks/python/apache_beam/internal/pickler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

# pytype: skip-file

import random
import sys
import threading
import types
Expand Down Expand Up @@ -115,6 +116,52 @@ def test_dataclass(self):
self.assertEqual(DataClass(datum='abc'), loads(dumps(DataClass(datum='abc'))))
''')

def maybe_get_sets_with_different_iteration_orders(self):
# Use a mix of types in an attempt to create sets with the same elements
# whose iteration order is different.
elements = [
100,
'hello',
3.14159,
True,
None,
-50,
'world',
False, (1, 2), (4, 3), ('hello', 'world')
]
set1 = set(elements)
# Try random addition orders until finding an order that works.
for _ in range(100):
set2 = set()
random.shuffle(elements)
for e in elements:
set2.add(e)
if list(set1) != list(set2):
break
return set1, set2

def test_best_effort_determinism(self):
set1, set2 = self.maybe_get_sets_with_different_iteration_orders()
self.assertEqual(
dumps(set1, enable_best_effort_determinism=True),
dumps(set2, enable_best_effort_determinism=True))
# The test relies on the sets having different iteration orders for the
# elements. Iteration order is implementation dependent and undefined,
# meaning the test won't always be able to setup these conditions.
if list(set1) == list(set2):
self.skipTest('Set iteration orders matched. Test results inconclusive.')

def test_disable_best_effort_determinism(self):
set1, set2 = self.maybe_get_sets_with_different_iteration_orders()
# The test relies on the sets having different iteration orders for the
# elements. Iteration order is implementation dependent and undefined,
# meaning the test won't always be able to setup these conditions.
if list(set1) == list(set2):
self.skipTest('Set iteration orders matched. Unable to complete test.')
self.assertNotEqual(
dumps(set1, enable_best_effort_determinism=False),
dumps(set2, enable_best_effort_determinism=False))


if __name__ == '__main__':
unittest.main()
164 changes: 164 additions & 0 deletions sdks/python/apache_beam/internal/set_pickler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Custom pickling logic for sets to make the serialization semi-deterministic.

To make set serialization semi-deterministic, we must pick an order for the set
elements. Sets may contain elements of types not defining a comparison "<"
operator. To provide an order, we define our own custom comparison function
which supports elements of near-arbitrary types and use that to sort the
contents of each set during serialization. Attempts at determinism are made on a
best-effort basis to improve hit rates for cached workflows and the ordering
does not define a total order for all values.
"""

import enum
import functools


def compare(lhs, rhs):
"""Returns -1, 0, or 1 depending on whether lhs <, =, or > rhs."""
if lhs < rhs:
return -1
elif lhs > rhs:
return 1
else:
return 0


def generic_object_comparison(lhs, rhs, lhs_path, rhs_path, max_depth):
"""Identifies which object goes first in an (almost) total order of objects.

Args:
lhs: An arbitrary Python object or built-in type.
rhs: An arbitrary Python object or built-in type.
lhs_path: Traversal path from the root lhs object up to, but not including,
lhs. The original contents of lhs_path are restored before the function
returns.
rhs_path: Same as lhs_path except for the rhs.
max_depth: Maximum recursion depth.

Returns:
-1, 0, or 1 depending on whether lhs or rhs goes first in the total order.
0 if max_depth is exhausted.
0 if lhs is in lhs_path or rhs is in rhs_path (there is a cycle).
"""
if id(lhs) == id(rhs):
# Fast path
return 0
if type(lhs) != type(rhs):
return compare(str(type(lhs)), str(type(rhs)))
if type(lhs) in [int, float, bool, str, bool, bytes, bytearray]:
return compare(lhs, rhs)
if isinstance(lhs, enum.Enum):
# Enums can have values with arbitrary types. The names are strings.
return compare(lhs.name, rhs.name)

# To avoid exceeding the recursion depth limit, set a limit on recursion.
max_depth -= 1
if max_depth < 0:
return 0

# Check for cycles in the traversal path to avoid getting stuck in a loop.
if id(lhs) in lhs_path or id(rhs) in rhs_path:
return 0
lhs_path.append(id(lhs))
rhs_path.append(id(rhs))
# The comparison logic is split across two functions to simplifying updating
# and restoring the traversal paths.
result = _generic_object_comparison_recursive_path(
lhs, rhs, lhs_path, rhs_path, max_depth)
lhs_path.pop()
rhs_path.pop()
return result


def _generic_object_comparison_recursive_path(
lhs, rhs, lhs_path, rhs_path, max_depth):
if type(lhs) == tuple or type(lhs) == list:
result = compare(len(lhs), len(rhs))
if result != 0:
return result
for i in range(len(lhs)):
result = generic_object_comparison(
lhs[i], rhs[i], lhs_path, rhs_path, max_depth)
if result != 0:
return result
return 0
if type(lhs) == frozenset or type(lhs) == set:
return generic_object_comparison(
tuple(sort_if_possible(lhs, lhs_path, rhs_path, max_depth)),
tuple(sort_if_possible(rhs, lhs_path, rhs_path, max_depth)),
lhs_path,
rhs_path,
max_depth)
if type(lhs) == dict:
lhs_keys = list(lhs.keys())
rhs_keys = list(rhs.keys())
result = compare(len(lhs_keys), len(rhs_keys))
if result != 0:
return result
lhs_keys = sort_if_possible(lhs_keys, lhs_path, rhs_path, max_depth)
rhs_keys = sort_if_possible(rhs_keys, lhs_path, rhs_path, max_depth)
for lhs_key, rhs_key in zip(lhs_keys, rhs_keys):
result = generic_object_comparison(
lhs_key, rhs_key, lhs_path, rhs_path, max_depth)
if result != 0:
return result
result = generic_object_comparison(
lhs[lhs_key], rhs[rhs_key], lhs_path, rhs_path, max_depth)
if result != 0:
return result

lhs_fields = dir(lhs)
rhs_fields = dir(rhs)
result = compare(len(lhs_fields), len(rhs_fields))
if result != 0:
return result
for i in range(len(lhs_fields)):
result = compare(lhs_fields[i], rhs_fields[i])
if result != 0:
return result
result = generic_object_comparison(
getattr(lhs, lhs_fields[i], None),
getattr(rhs, rhs_fields[i], None),
lhs_path,
rhs_path,
max_depth)
if result != 0:
return result
return 0


def sort_if_possible(obj, lhs_path=None, rhs_path=None, max_depth=4):
def cmp(lhs, rhs):
if lhs_path is None:
# Start the traversal at the root call to cmp.
return generic_object_comparison(lhs, rhs, [], [], max_depth)
else:
# Continue the existing traversal path for recursive calls to cmp.
return generic_object_comparison(lhs, rhs, lhs_path, rhs_path, max_depth)

return sorted(obj, key=functools.cmp_to_key(cmp))


def save_set(pickler, obj):
pickler.save_set(sort_if_possible(obj))


def save_frozenset(pickler, obj):
pickler.save_frozenset(sort_if_possible(obj))
Loading
Loading