Skip to content

Commit aae4488

Browse files
authored
Add reducers for threading.Lock and EnumDescriptor. (#34537)
1 parent 81642eb commit aae4488

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

sdks/python/apache_beam/internal/cloudpickle_pickler.py

+66
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import base64
3131
import bz2
3232
import io
33+
import sys
3334
import threading
3435
import zlib
3536

@@ -40,10 +41,65 @@
4041
except (ImportError, ModuleNotFoundError):
4142
pass
4243

44+
45+
def _get_proto_enum_descriptor_class():
46+
try:
47+
from google.protobuf.internal import api_implementation
48+
except ImportError:
49+
return None
50+
51+
implementation_type = api_implementation.Type()
52+
53+
if implementation_type == 'upb':
54+
try:
55+
from google._upb._message import EnumDescriptor
56+
return EnumDescriptor
57+
except ImportError:
58+
pass
59+
elif implementation_type == 'cpp':
60+
try:
61+
from google.protobuf.pyext._message import EnumDescriptor
62+
return EnumDescriptor
63+
except ImportError:
64+
pass
65+
elif implementation_type == 'python':
66+
try:
67+
from google.protobuf.internal.python_message import EnumDescriptor
68+
return EnumDescriptor
69+
except ImportError:
70+
pass
71+
72+
return None
73+
74+
75+
EnumDescriptor = _get_proto_enum_descriptor_class()
76+
4377
# Pickling, especially unpickling, causes broken module imports on Python 3
4478
# if executed concurrently, see: BEAM-8651, http://bugs.python.org/issue38884.
4579
_pickle_lock = threading.RLock()
4680
RLOCK_TYPE = type(_pickle_lock)
81+
LOCK_TYPE = type(threading.Lock())
82+
83+
84+
def _reconstruct_enum_descriptor(full_name):
85+
for _, module in sys.modules.items():
86+
if not hasattr(module, 'DESCRIPTOR'):
87+
continue
88+
89+
for _, attr_value in vars(module).items():
90+
if not hasattr(attr_value, 'DESCRIPTOR'):
91+
continue
92+
93+
if hasattr(attr_value.DESCRIPTOR, 'enum_types_by_name'):
94+
for (_, enum_desc) in attr_value.DESCRIPTOR.enum_types_by_name.items():
95+
if enum_desc.full_name == full_name:
96+
return enum_desc
97+
raise ImportError(f'Could not find enum descriptor: {full_name}')
98+
99+
100+
def _pickle_enum_descriptor(obj):
101+
full_name = obj.full_name
102+
return _reconstruct_enum_descriptor, (full_name, )
47103

48104

49105
def dumps(o, enable_trace=True, use_zlib=False) -> bytes:
@@ -59,6 +115,12 @@ def dumps(o, enable_trace=True, use_zlib=False) -> bytes:
59115
pickler.dispatch_table[RLOCK_TYPE] = _pickle_rlock
60116
except NameError:
61117
pass
118+
try:
119+
pickler.dispatch_table[LOCK_TYPE] = _lock_reducer
120+
except NameError:
121+
pass
122+
if EnumDescriptor is not None:
123+
pickler.dispatch_table[EnumDescriptor] = _pickle_enum_descriptor
62124
pickler.dump(o)
63125
s = file.getvalue()
64126

@@ -106,6 +168,10 @@ def _pickle_rlock(obj):
106168
return RLOCK_TYPE, tuple([])
107169

108170

171+
def _lock_reducer(obj):
172+
return threading.Lock, tuple([])
173+
174+
109175
def dump_session(file_path):
110176
# It is possible to dump session with cloudpickle. However, since references
111177
# are saved it should not be necessary. See https://s.apache.org/beam-picklers

sdks/python/apache_beam/internal/cloudpickle_pickler_test.py

+15
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,21 @@
2626
from apache_beam.internal import module_test
2727
from apache_beam.internal.cloudpickle_pickler import dumps
2828
from apache_beam.internal.cloudpickle_pickler import loads
29+
from apache_beam.portability.api import beam_runner_api_pb2
2930

3031

3132
class PicklerTest(unittest.TestCase):
3233

3334
NO_MAPPINGPROXYTYPE = not hasattr(types, "MappingProxyType")
3435

36+
def test_pickle_enum_descriptor(self):
37+
TimeDomain = beam_runner_api_pb2.TimeDomain.Enum
38+
39+
def fn():
40+
return TimeDomain.EVENT_TIME
41+
42+
self.assertEqual(fn(), loads(dumps(fn))())
43+
3544
def test_basics(self):
3645
self.assertEqual([1, 'a', ('z', )], loads(dumps([1, 'a', ('z', )])))
3746
fun = lambda x: 'xyz-%s' % x
@@ -97,6 +106,12 @@ def test_pickle_rlock(self):
97106

98107
self.assertIsInstance(loads(dumps(rlock_instance)), rlock_type)
99108

109+
def test_pickle_lock(self):
110+
lock_instance = threading.Lock()
111+
lock_type = type(lock_instance)
112+
113+
self.assertIsInstance(loads(dumps(lock_instance)), lock_type)
114+
100115
@unittest.skipIf(NO_MAPPINGPROXYTYPE, 'test if MappingProxyType introduced')
101116
def test_dump_and_load_mapping_proxy(self):
102117
self.assertEqual(

0 commit comments

Comments
 (0)