Skip to content

Commit 5645a03

Browse files
authored
Add a md5sum check before overwriting a cached prism file (#34715)
* Check md5 of downloaded/unzipped file and cached before overwriting. * Allow prism location to be a path from supported filesystems like GCS. * Apply yapf
1 parent 7ea4f06 commit 5645a03

File tree

2 files changed

+66
-21
lines changed

2 files changed

+66
-21
lines changed

sdks/python/apache_beam/runners/portability/prism_runner.py

+43-7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
# sunset it
2323
from __future__ import annotations
2424

25+
import hashlib
2526
import logging
2627
import os
2728
import platform
@@ -81,6 +82,35 @@ def create_job_service_handle(self, job_service, options):
8182
job_service, options, retain_unknown_options=True)
8283

8384

85+
def _md5sum(filename, block_size=8192) -> str:
86+
md5 = hashlib.md5()
87+
with open(filename, 'rb') as f:
88+
while True:
89+
data = f.read(block_size)
90+
if not data:
91+
break
92+
md5.update(data)
93+
return md5.hexdigest()
94+
95+
96+
def _rename_if_different(src, dst):
97+
assert (os.path.isfile(src))
98+
99+
if os.path.isfile(dst):
100+
if _md5sum(src) != _md5sum(dst):
101+
# Remove existing binary to prevent exception on Windows during
102+
# os.rename.
103+
# See: https://docs.python.org/3/library/os.html#os.rename
104+
os.remove(dst)
105+
os.rename(src, dst)
106+
else:
107+
_LOGGER.info(
108+
'Found %s and %s with the same md5. Skipping overwrite.' % (src, dst))
109+
os.remove(src)
110+
else:
111+
os.rename(src, dst)
112+
113+
84114
class PrismJobServer(job_server.SubprocessJobServer):
85115
PRISM_CACHE = os.path.expanduser("~/.apache_beam/cache/prism")
86116
BIN_CACHE = os.path.expanduser("~/.apache_beam/cache/prism/bin")
@@ -117,7 +147,13 @@ def maybe_unzip_and_make_executable(
117147
# True (cache disabled)
118148
_LOGGER.info("Unzipping prism from %s to %s" % (url, target_url))
119149
z = zipfile.ZipFile(url)
120-
target_url = z.extract(target, path=bin_cache)
150+
151+
bin_cache_tmp = os.path.join(bin_cache, 'tmp')
152+
if not os.path.exists(bin_cache_tmp):
153+
os.makedirs(bin_cache_tmp)
154+
target_tmp_url = z.extract(target, path=bin_cache_tmp)
155+
156+
_rename_if_different(target_tmp_url, target_url)
121157
else:
122158
target_url = url
123159

@@ -154,12 +190,8 @@ def local_bin(
154190
url_read = urlopen(url)
155191
with open(cached_file + '.tmp', 'wb') as zip_write:
156192
shutil.copyfileobj(url_read, zip_write, length=1 << 20)
157-
if os.path.isfile(cached_file):
158-
# Remove existing binary to prevent exception on Windows during
159-
# os.rename.
160-
# See: https://docs.python.org/3/library/os.html#os.rename
161-
os.remove(cached_file)
162-
os.rename(cached_file + '.tmp', cached_file)
193+
194+
_rename_if_different(cached_file + '.tmp', cached_file)
163195
except URLError as e:
164196
raise RuntimeError(
165197
'Unable to fetch remote prism binary at %s: %s' % (url, e))
@@ -204,6 +236,10 @@ def path_to_binary(self) -> str:
204236
# The path is local and exists, use directly.
205237
return self._path
206238

239+
if FileSystems.exists(self._path):
240+
# The path is in one of the supported filesystems.
241+
return self._path
242+
207243
# Check if the path is a URL.
208244
url = urllib.parse.urlparse(self._path)
209245
if not url.scheme:

sdks/python/apache_beam/runners/portability/prism_runner_test.py

+23-14
Original file line numberDiff line numberDiff line change
@@ -245,32 +245,37 @@ def tearDown(self) -> None:
245245
rmtree(self.local_dir)
246246
pass
247247

248-
def _make_local_bin(self):
249-
with open(self.local_bin_path, 'wb'):
248+
def _make_local_bin(self, fn=None):
249+
fn = fn or self.local_bin_path
250+
with open(fn, 'wb'):
250251
pass
251252

252-
def _make_local_zip(self):
253-
with zipfile.ZipFile(self.local_zip_path, 'w', zipfile.ZIP_DEFLATED):
253+
def _make_local_zip(self, fn=None):
254+
fn = fn or self.local_zip_path
255+
with zipfile.ZipFile(fn, 'w', zipfile.ZIP_DEFLATED):
254256
pass
255257

256-
def _make_cache_bin(self):
257-
with open(self.cache_bin_path, 'wb'):
258+
def _make_cache_bin(self, fn=None):
259+
fn = fn or self.cache_bin_path
260+
with open(fn, 'wb'):
258261
pass
259262

260-
def _make_cache_zip(self):
261-
with zipfile.ZipFile(self.cache_zip_path, 'w', zipfile.ZIP_DEFLATED):
263+
def _make_cache_zip(self, fn=None):
264+
fn = fn or self.cache_zip_path
265+
with zipfile.ZipFile(fn, 'w', zipfile.ZIP_DEFLATED):
262266
pass
263267

264-
def _extract_side_effect(self, a, path=None):
268+
def _extract_side_effect(self, fn, path=None):
265269
if path is None:
266-
return a
270+
return fn
267271

272+
full_path = os.path.join(str(path), fn)
268273
if path.startswith(self.cache_dir):
269-
self._make_cache_bin()
274+
self._make_cache_bin(full_path)
270275
else:
271-
self._make_local_bin()
276+
self._make_local_bin(full_path)
272277

273-
return os.path.join(str(path), a)
278+
return full_path
274279

275280
@parameterized.expand([[True, True], [True, False], [False, True],
276281
[False, False]])
@@ -352,7 +357,11 @@ def test_with_remote_path(self, has_cache_bin, has_cache_zip, ignore_cache):
352357
with mock.patch(
353358
'apache_beam.runners.portability.prism_runner.urlopen') as mock_urlopen:
354359
mock_response = mock.MagicMock()
355-
mock_response.read.return_value = b''
360+
if has_cache_zip:
361+
with open(self.cache_zip_path, 'rb') as f:
362+
mock_response.read.side_effect = [f.read(), b'']
363+
else:
364+
mock_response.read.return_value = b''
356365
mock_urlopen.return_value = mock_response
357366
with mock.patch('zipfile.is_zipfile') as mock_is_zipfile:
358367
with mock.patch('zipfile.ZipFile') as mock_zipfile_init:

0 commit comments

Comments
 (0)