Skip to content

uploader: validate tensors before uploading them #3624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorboard/uploader/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ py_test(
"//tensorboard:expect_grpc_installed",
"//tensorboard:expect_grpc_testing_installed",
"//tensorboard:expect_tensorflow_installed",
"//tensorboard/compat:no_tensorflow",
"//tensorboard/compat/proto:protos_all_py_pb2",
"//tensorboard/plugins/graph:metadata",
"//tensorboard/plugins/histogram:summary_v2",
Expand Down
17 changes: 17 additions & 0 deletions tensorboard/uploader/uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,12 +741,29 @@ def _create_point(self, tag_proto, event, value):
tag_proto.points.pop()
return

self._validate_tensor_value(
value.tensor, value.tag, event.step, event.wall_time
)

try:
self._byte_budget_manager.add_point(point)
except _OutOfSpaceError:
tag_proto.points.pop()
raise

def _validate_tensor_value(self, tensor_proto, tag, step, wall_time):
"""Validate a TensorProto by attempting to parse it."""
try:
tensor_util.make_ndarray(tensor_proto)
except ValueError as error:
raise ValueError(
"The uploader failed to upload a tensor. This seems to be "
"due to a malformation in the tensor, which may be caused by "
"a bug in the process that wrote the tensor.\n\n"
"The tensor has tag '%s' and is at step %d and wall_time %.6f.\n\n"
"Original error:\n%s" % (tag, step, wall_time, error)
)


class _ByteBudgetManager(object):
"""Helper class for managing the request byte budget for certain RPCs.
Expand Down
169 changes: 130 additions & 39 deletions tensorboard/uploader/uploader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import collections
import itertools
import os
import re

import grpc
import grpc_testing
Expand All @@ -36,6 +37,7 @@
from google.protobuf import message
from tensorboard import data_compat
from tensorboard import dataclass_compat
from tensorboard.compat.proto import tensor_shape_pb2
from tensorboard.uploader.proto import experiment_pb2
from tensorboard.uploader.proto import scalar_pb2
from tensorboard.uploader.proto import write_service_pb2
Expand Down Expand Up @@ -1234,10 +1236,6 @@ def test_histogram_event(self):
wall_time=123.456,
summary=histogram_v2.histogram_pb("foo", [1.0]),
)
# Simplify the tensor value a bit. We care that it is copied to the
# request but we don't need it to be an extensive test.
event.summary.value[0].tensor.ClearField("tensor_shape")
event.summary.value[0].tensor.ClearField("tensor_content")

run_proto = self._add_events_and_flush(_apply_compat([event]))
expected_run_proto = write_service_pb2.WriteTensorRequest.Run()
Expand All @@ -1250,13 +1248,65 @@ def test_histogram_event(self):
wall_time=test_util.timestamp_pb(123456000000),
value=tensor_pb2.TensorProto(dtype=types_pb2.DT_DOUBLE),
)
# Simplify the tensor value a bit before making assertions on it.
# We care that it is copied to the request but we don't need it to be
# an extensive test.
run_proto.tags[0].points[0].value.ClearField("tensor_shape")
run_proto.tags[0].points[0].value.ClearField("tensor_content")
self.assertProtoEquals(run_proto, expected_run_proto)

def test_histogram_event_with_empty_tensor_content_errors_out(self):
event = event_pb2.Event(step=42)
event.summary.value.add(
tag="one",
tensor=tensor_pb2.TensorProto(
dtype=types_pb2.DT_DOUBLE,
# Use empty tensor content to elicit an error.
tensor_content=b"",
),
)

mock_client = _create_mock_client()
sender = _create_tensor_request_sender("123", mock_client)
with self.assertRaisesRegexp(
ValueError,
re.compile(
r"failed to upload a tensor.*malformation.*tag.*\'one\'.*step.*42",
re.DOTALL,
),
):
self._add_events(sender, "run", _apply_compat([event]))

def test_histogram_event_with_incorrect_tensor_shape_errors_out(self):
event = event_pb2.Event(step=1337)
tensor_proto = tensor_util.make_tensor_proto([1.0, 2.0])
# Add an extraneous dimension to the tensor shape in order to
# elicit an error.
tensor_proto.tensor_shape.dim.append(
tensor_shape_pb2.TensorShapeProto.Dim(size=2)
)
event.summary.value.add(tag="two", tensor=tensor_proto)

mock_client = _create_mock_client()
sender = _create_tensor_request_sender("123", mock_client)
with self.assertRaisesRegexp(
ValueError,
re.compile(
r"failed to upload a tensor.*malformation.*tag.*\'two\'.*step.*1337."
r"*shape",
re.DOTALL,
),
):
self._add_events(sender, "run", _apply_compat([event]))

def test_aggregation_by_tag(self):
def make_event(step, wall_time, tag):
event = event_pb2.Event(step=step, wall_time=wall_time)
event.summary.value.add(
tag=tag, tensor=tensor_pb2.TensorProto(double_val=[1.0])
tag=tag,
tensor=tensor_pb2.TensorProto(
dtype=types_pb2.DT_DOUBLE, double_val=[1.0]
),
)
return event

Expand Down Expand Up @@ -1285,7 +1335,10 @@ def make_event(step, wall_time, tag):
def test_propagates_experiment_deletion(self):
event = event_pb2.Event(step=1)
event.summary.value.add(
tag="one", tensor=tensor_pb2.TensorProto(double_val=[1.0])
tag="one",
tensor=tensor_pb2.TensorProto(
dtype=types_pb2.DT_DOUBLE, double_val=[1.0]
),
)

mock_client = _create_mock_client()
Expand All @@ -1312,7 +1365,10 @@ def test_no_room_for_single_point(self):
mock_client = _create_mock_client()
event = event_pb2.Event(step=1)
event.summary.value.add(
tag="one", tensor=tensor_pb2.TensorProto(double_val=[1.0])
tag="one",
tensor=tensor_pb2.TensorProto(
dtype=types_pb2.DT_DOUBLE, double_val=[1.0]
),
)
long_run_name = "A" * uploader_lib._MAX_REQUEST_LENGTH_BYTES
with self.assertRaises(RuntimeError) as cm:
Expand All @@ -1328,11 +1384,17 @@ def test_break_at_run_boundary(self):
long_run_2 = "B" * 768
event_1 = event_pb2.Event(step=1)
event_1.summary.value.add(
tag="one", tensor=tensor_pb2.TensorProto(double_val=[1.0])
tag="one",
tensor=tensor_pb2.TensorProto(
dtype=types_pb2.DT_DOUBLE, double_val=[1.0]
),
)
event_2 = event_pb2.Event(step=2)
event_2.summary.value.add(
tag="two", tensor=tensor_pb2.TensorProto(double_val=[2.0])
tag="two",
tensor=tensor_pb2.TensorProto(
dtype=types_pb2.DT_DOUBLE, double_val=[2.0]
),
)

sender = _create_tensor_request_sender("123", mock_client)
Expand All @@ -1356,10 +1418,16 @@ def test_break_at_tag_boundary(self):
long_tag_2 = "b" * 600
event = event_pb2.Event(step=1, wall_time=1)
event.summary.value.add(
tag=long_tag_1, tensor=tensor_pb2.TensorProto(double_val=[1.0])
tag=long_tag_1,
tensor=tensor_pb2.TensorProto(
dtype=types_pb2.DT_DOUBLE, double_val=[1.0]
),
)
event.summary.value.add(
tag=long_tag_2, tensor=tensor_pb2.TensorProto(double_val=[2.0])
tag=long_tag_2,
tensor=tensor_pb2.TensorProto(
dtype=types_pb2.DT_DOUBLE, double_val=[2.0]
),
)

sender = _create_tensor_request_sender("123", mock_client)
Expand Down Expand Up @@ -1387,12 +1455,13 @@ def test_break_at_tensor_point_boundary(self):
events = []
for step in range(point_count):
event = event_pb2.Event(step=step)
event.summary.value.add(
tag="histo",
tensor=tensor_pb2.TensorProto(
double_val=[1.0 * step, -1.0 * step]
),
tensor_proto = tensor_pb2.TensorProto(
dtype=types_pb2.DT_DOUBLE, double_val=[1.0 * step, -1.0 * step]
)
tensor_proto.tensor_shape.dim.append(
tensor_shape_pb2.TensorShapeProto.Dim(size=2)
)
event.summary.value.add(tag="histo", tensor=tensor_proto)
events.append(event)

sender = _create_tensor_request_sender("123", mock_client)
Expand All @@ -1402,7 +1471,7 @@ def test_break_at_tensor_point_boundary(self):

self.assertGreater(len(requests), 1)
self.assertLess(len(requests), point_count)
self.assertEqual(56, len(requests))
self.assertEqual(72, len(requests))

total_points_in_result = 0
for request in requests:
Expand All @@ -1428,32 +1497,33 @@ def test_strip_large_tensors(self):
# Generate test data with varying tensor point sizes. Use raw bytes.
event_1 = event_pb2.Event(step=1)
event_1.summary.value.add(
tag="one", tensor=tensor_pb2.TensorProto(tensor_content=b"\x01\x02")
tag="one",
# This TensorProto has a byte size of 18.
tensor=tensor_util.make_tensor_proto([1.0, 2.0]),
)
event_1.summary.value.add(
tag="two",
tensor=tensor_pb2.TensorProto(
# 6 bytes will be filtered in the second test.
tensor_content=b"\x01\x02\x03\x04\x05\x06"
),
# This TensorProto has a byte size of 22.
tensor=tensor_util.make_tensor_proto([1.0, 2.0, 3.0]),
)
# This TensorProto has a 12-byte tensor_content.
event_2 = event_pb2.Event(step=2)
event_2.summary.value.add(
tag="one", tensor=tensor_pb2.TensorProto(tensor_content=b"\x01\x02")
tag="one",
# This TensorProto has a byte size of 18.
tensor=tensor_util.make_tensor_proto([2.0, 4.0]),
)
event_2.summary.value.add(
tag="two",
tensor=tensor_pb2.TensorProto(
# 7 bytes will be filtered out in both tests.
tensor_content=b"\x01\x02\x03\x04\x05\x06\x07"
),
# This TensorProto has a byte size of 26.
tensor=tensor_util.make_tensor_proto([1.0, 2.0, 3.0, 4.0]),
)

run_proto = self._add_events_and_flush(
_apply_compat([event_1, event_2]),
# Set threshold that will filter out tensor points with 7 bytes
# Set threshold that will filter out the tensor point with 26 bytes
# of data and above. The additional byte is for proto overhead.
max_tensor_point_size=7 + 1,
max_tensor_point_size=24,
)
tag_data = {
tag.name: [(p.step, p.value.tensor_content) for p in tag.points]
Expand All @@ -1463,35 +1533,50 @@ def test_strip_large_tensors(self):
self.assertEqual(
tag_data,
{
"one": [(1, b"\x01\x02"), (2, b"\x01\x02")],
"two": [(1, b"\x01\x02\x03\x04\x05\x06")],
"one": [
(1, b"\x00\x00\x80?\x00\x00\x00@"),
(2, b"\x00\x00\x00@\x00\x00\x80@"),
],
"two": [(1, b"\x00\x00\x80?\x00\x00\x00@\x00\x00@@")],
},
)

run_proto_2 = self._add_events_and_flush(
_apply_compat([event_1, event_2]),
# Set threshold that will filter out tensor points with 6 bytes
# of data and above. The additional byte is for proto overhead.
max_tensor_point_size=6 + 1,
# Set threshold that will filter out the tensor points with 22 and 26
# bytes of data and above. The additional byte is for proto overhead.
max_tensor_point_size=20,
)
tag_data_2 = {
tag.name: [(p.step, p.value.tensor_content) for p in tag.points]
for tag in run_proto_2.tags
}
# All tensor points from the same tag are filtered out, and the tag is pruned.
self.assertEqual(
tag_data_2, {"one": [(1, b"\x01\x02"), (2, b"\x01\x02")],},
tag_data_2,
{
"one": [
(1, b"\x00\x00\x80?\x00\x00\x00@"),
(2, b"\x00\x00\x00@\x00\x00\x80@"),
],
},
)

def test_prunes_tags_and_runs(self):
mock_client = _create_mock_client()
event_1 = event_pb2.Event(step=1)
event_1.summary.value.add(
tag="one", tensor=tensor_pb2.TensorProto(double_val=[1.0])
tag="one",
tensor=tensor_pb2.TensorProto(
dtype=types_pb2.DT_DOUBLE, double_val=[1.0]
),
)
event_2 = event_pb2.Event(step=2)
event_2.summary.value.add(
tag="two", tensor=tensor_pb2.TensorProto(double_val=[2.0])
tag="two",
tensor=tensor_pb2.TensorProto(
dtype=types_pb2.DT_DOUBLE, double_val=[2.0]
),
)

add_point_call_count_box = [0]
Expand Down Expand Up @@ -1530,13 +1615,19 @@ def test_wall_time_precision(self):
# digits to incur error if converted to nanoseconds the naive way (* 1e9).
event_1 = event_pb2.Event(step=1, wall_time=1567808404.765432119)
event_1.summary.value.add(
tag="tag", tensor=tensor_pb2.TensorProto(double_val=[1.0])
tag="tag",
tensor=tensor_pb2.TensorProto(
dtype=types_pb2.DT_DOUBLE, double_val=[1.0]
),
)
# Test a wall time where as a float64, the fractional part on its own will
# introduce error if truncated to 9 decimal places instead of rounded.
event_2 = event_pb2.Event(step=2, wall_time=1.000000002)
event_2.summary.value.add(
tag="tag", tensor=tensor_pb2.TensorProto(double_val=[2.0])
tag="tag",
tensor=tensor_pb2.TensorProto(
dtype=types_pb2.DT_DOUBLE, double_val=[2.0]
),
)
run_proto = self._add_events_and_flush(
_apply_compat([event_1, event_2])
Expand Down