Skip to content

feat: combine graph by prefixing with unique name #4334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tensorboard/plugins/graph/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ py_library(
srcs = ["graph_util.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:private"],
deps = [
"//tensorboard/compat/proto:protos_all_py_pb2",
],
)

py_test(
Expand All @@ -136,7 +139,6 @@ py_test(
"//tensorboard:expect_tensorflow_installed",
"//tensorboard/compat/proto:protos_all_py_pb2",
"@com_google_protobuf//:protobuf_python",
"@org_pythonhosted_six",
],
)

Expand Down
213 changes: 86 additions & 127 deletions tensorboard/plugins/graph/graph_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,152 +14,111 @@
# ==============================================================================
"""Utilities for graph plugin."""

from tensorboard.compat.proto import graph_pb2

class _ProtoListDuplicateKeyError(Exception):
pass

def _prefixed_op_name(prefix, op_name):
return "%s/%s" % (prefix, op_name)

class _SameKeyDiffContentError(Exception):
pass

def _prefixed_func_name(prefix, func_name):
"""Returns function name prefixed with `prefix`.

def _safe_copy_proto_list_values(dst_proto_list, src_proto_list, get_key):
"""Safely merge values from `src_proto_list` into `dst_proto_list`.
For function libraries, which are often created out of autographed Python
function, are factored out in the graph vis. They are grouped under a
function name which often has a shape of
`__inference_[py_func_name]_[numeric_suffix]`.

Each element in `dst_proto_list` must be mapped by `get_key` to a key
value that is unique within that list; likewise for `src_proto_list`.
If an element of `src_proto_list` has the same key as an existing
element in `dst_proto_list`, then the elements must also be equal.
While it does not have some unique information about which graph it is from,
creating another wrapping structure with graph prefix and "/" is less than
ideal so we join the prefix and func_name using underscore.

Args:
dst_proto_list: A `RepeatedCompositeContainer` or
`RepeatedScalarContainer` into which values should be copied.
src_proto_list: A container holding the same kind of values as in
`dst_proto_list` from which values should be copied.
get_key: A function that takes an element of `dst_proto_list` or
`src_proto_list` and returns a key, such that if two elements have
the same key then it is required that they be deep-equal. For
instance, if `dst_proto_list` is a list of nodes, then `get_key`
might be `lambda node: node.name` to indicate that if two nodes
have the same name then they must be the same node. All keys must
be hashable.

Raises:
_ProtoListDuplicateKeyError: A proto_list contains items with duplicate
keys.
_SameKeyDiffContentError: An item with the same key has different contents.
TODO(stephanwlee): add business logic to strip "__inference_" for more user
friendlier name
"""
return "%s_%s" % (prefix, func_name)


def _assert_proto_container_unique_keys(proto_list, get_key):
"""Asserts proto_list to only contains unique keys.

Args:
proto_list: A `RepeatedCompositeContainer` or `RepeatedScalarContainer`.
get_key: A function that takes an element of `proto_list` and returns a
hashable key.

Raises:
_ProtoListDuplicateKeyError: A proto_list contains items with duplicate
keys.
"""
keys = set()
for item in proto_list:
key = get_key(item)
if key in keys:
raise _ProtoListDuplicateKeyError(key)
keys.add(key)

_assert_proto_container_unique_keys(dst_proto_list, get_key)
_assert_proto_container_unique_keys(src_proto_list, get_key)

key_to_proto = {}
for proto in dst_proto_list:
key = get_key(proto)
key_to_proto[key] = proto

for proto in src_proto_list:
key = get_key(proto)
if key in key_to_proto:
if proto != key_to_proto.get(key):
raise _SameKeyDiffContentError(key)
else:
dst_proto_list.add().CopyFrom(proto)


def combine_graph_defs(to_proto, from_proto):
"""Combines two GraphDefs by adding nodes from from_proto into to_proto.
def _add_with_prepended_names(prefix, graph_to_add, destination_graph):
for node in graph_to_add.node:
new_node = destination_graph.node.add()
new_node.CopyFrom(node)
new_node.name = _prefixed_op_name(prefix, node.name)
new_node.input[:] = [
_prefixed_op_name(prefix, input_name) for input_name in node.input
]

# Remap tf.function method name in the PartitionedCall. 'f' is short for
# function.
if new_node.op == "PartitionedCall" and new_node.attr["f"]:

new_node.attr["f"].func.name = _prefixed_func_name(
prefix, new_node.attr["f"].func.name,
)

for func in graph_to_add.library.function:
new_func = destination_graph.library.function.add()
new_func.CopyFrom(func)
new_func.signature.name = _prefixed_func_name(
prefix, new_func.signature.name
)

for gradient in graph_to_add.library.gradient:
new_gradient = destination_graph.library.gradient.add()
new_gradient.CopyFrom(gradient)
new_gradient.function_name = _prefixed_func_name(
prefix, new_gradient.function_name,
)
new_gradient.gradient_func = _prefixed_func_name(
prefix, new_gradient.gradient_func,
)


def merge_graph_defs(graph_defs):
"""Merges GraphDefs by adding unique prefix, `graph_{ind}`, to names.

All GraphDefs are expected to be of TensorBoard's.
It assumes node names are unique across GraphDefs if contents differ. The
names can be the same if the NodeDef content are exactly the same.

When collecting graphs using the `tf.summary.trace` API, node names are not
guranteed to be unique. When non-unique names are not considered, it can
lead to graph visualization showing them as one which creates inaccurate
depiction of the flow of the graph (e.g., if there are A -> B -> C and D ->
B -> E, you may see {A, D} -> B -> E). To prevent such graph, we checked
for uniquenss while merging but it resulted in
https://github.com/tensorflow/tensorboard/issues/1929.

To remedy these issues, we simply "apply name scope" on each graph by
prefixing it with unique name (with a chance of collision) to create
unconnected group of graphs.

In case there is only one graph def passed, it returns the original
graph_def. In case no graph defs are passed, it returns an empty GraphDef.

Args:
to_proto: A destination TensorBoard GraphDef.
from_proto: A TensorBoard GraphDef to copy contents from.
graph_defs: TensorBoard GraphDefs to merge.

Returns:
to_proto
TensorBoard GraphDef that merges all graph_defs with unique prefixes.

Raises:
ValueError in case any assumption about GraphDef is violated: A
GraphDef should have unique node, function, and gradient function
names. Also, when merging GraphDefs, they should have not have nodes,
functions, or gradient function mappings that share the name but details
do not match.
ValueError in case GraphDef versions mismatch.
"""
if from_proto.version != to_proto.version:
raise ValueError("Cannot combine GraphDefs of different versions.")
if len(graph_defs) == 1:
return graph_defs[0]
elif len(graph_defs) == 0:
return graph_pb2.GraphDef()

try:
_safe_copy_proto_list_values(
to_proto.node, from_proto.node, lambda n: n.name
)
except _ProtoListDuplicateKeyError as exc:
raise ValueError("A GraphDef contains non-unique node names: %s" % exc)
except _SameKeyDiffContentError as exc:
raise ValueError(
(
"Cannot combine GraphDefs because nodes share a name "
"but contents are different: %s"
)
% exc
)
try:
_safe_copy_proto_list_values(
to_proto.library.function,
from_proto.library.function,
lambda n: n.signature.name,
)
except _ProtoListDuplicateKeyError as exc:
raise ValueError(
"A GraphDef contains non-unique function names: %s" % exc
)
except _SameKeyDiffContentError as exc:
raise ValueError(
(
"Cannot combine GraphDefs because functions share a name "
"but are different: %s"
)
% exc
)
dst_graph_def = graph_pb2.GraphDef()

try:
_safe_copy_proto_list_values(
to_proto.library.gradient,
from_proto.library.gradient,
lambda g: g.gradient_func,
)
except _ProtoListDuplicateKeyError as exc:
raise ValueError(
"A GraphDef contains non-unique gradient function names: %s" % exc
)
except _SameKeyDiffContentError as exc:
raise ValueError(
(
"Cannot combine GraphDefs because gradients share a gradient_func name "
"but map to different functions: %s"
)
% exc
if graph_defs[0].versions.producer:
dst_graph_def.versions.CopyFrom(graph_defs[0].versions)

for index, graph_def in enumerate(graph_defs):
if dst_graph_def.versions.producer != graph_def.versions.producer:
raise ValueError("Cannot combine GraphDefs of different versions.")

_add_with_prepended_names(
"graph_%d" % (index + 1), graph_def, dst_graph_def,
)

return to_proto
return dst_graph_def
Loading