Skip to content

Commit eacf107

Browse files
committed
feat: combine graph by prefixing with unique name
Previously, graph plugin combined multiple graphs traced by creating one monolith of a GraphDef. In doing so, we checked whether, for example, node names are unique to detect a case when our graph vis can result in faulty UI. To alleviate the poor UX, we decided, instead, to duplicate all nodes in one gaint GraphDef container prefixing all names. While this creates some bloat, (1) users should see the confusing error less and (2) combined graphs make it very clear that we have traced multiple graphs.
1 parent da59d4c commit eacf107

File tree

4 files changed

+245
-587
lines changed

4 files changed

+245
-587
lines changed

tensorboard/plugins/graph/BUILD

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ py_library(
123123
srcs = ["graph_util.py"],
124124
srcs_version = "PY2AND3",
125125
visibility = ["//visibility:private"],
126+
deps = [
127+
"//tensorboard/compat/proto:protos_all_py_pb2",
128+
],
126129
)
127130

128131
py_test(
@@ -136,7 +139,6 @@ py_test(
136139
"//tensorboard:expect_tensorflow_installed",
137140
"//tensorboard/compat/proto:protos_all_py_pb2",
138141
"@com_google_protobuf//:protobuf_python",
139-
"@org_pythonhosted_six",
140142
],
141143
)
142144

tensorboard/plugins/graph/graph_util.py

Lines changed: 85 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -14,152 +14,110 @@
1414
# ==============================================================================
1515
"""Utilities for graph plugin."""
1616

17+
from tensorboard.compat.proto import graph_pb2
1718

18-
class _ProtoListDuplicateKeyError(Exception):
19-
pass
2019

20+
def _prefixed_op_name(prefix, op_name):
21+
return "%s/%s" % (prefix, op_name)
2122

22-
class _SameKeyDiffContentError(Exception):
23-
pass
2423

24+
def _prefixed_func_name(prefix, func_name):
25+
# TODO(stephanwlee): add business logic to strip "__inference_".
26+
return "%s_%s" % (prefix, func_name)
2527

26-
def _safe_copy_proto_list_values(dst_proto_list, src_proto_list, get_key):
27-
"""Safely merge values from `src_proto_list` into `dst_proto_list`.
2828

29-
Each element in `dst_proto_list` must be mapped by `get_key` to a key
30-
value that is unique within that list; likewise for `src_proto_list`.
31-
If an element of `src_proto_list` has the same key as an existing
32-
element in `dst_proto_list`, then the elements must also be equal.
29+
def _prepend_names(prefix, orig_graph_def):
30+
mut_graph_def = graph_pb2.GraphDef()
31+
for node in orig_graph_def.node:
32+
new_node = mut_graph_def.node.add()
33+
new_node.CopyFrom(node)
34+
new_node.name = _prefixed_op_name(prefix, node.name)
35+
new_node.input[:] = [
36+
_prefixed_op_name(prefix, input_name) for input_name in node.input
37+
]
3338

34-
Args:
35-
dst_proto_list: A `RepeatedCompositeContainer` or
36-
`RepeatedScalarContainer` into which values should be copied.
37-
src_proto_list: A container holding the same kind of values as in
38-
`dst_proto_list` from which values should be copied.
39-
get_key: A function that takes an element of `dst_proto_list` or
40-
`src_proto_list` and returns a key, such that if two elements have
41-
the same key then it is required that they be deep-equal. For
42-
instance, if `dst_proto_list` is a list of nodes, then `get_key`
43-
might be `lambda node: node.name` to indicate that if two nodes
44-
have the same name then they must be the same node. All keys must
45-
be hashable.
39+
# Remap tf.function method name in the PartitionedCall. 'f' is short for
40+
# function.
41+
if new_node.op == "PartitionedCall" and new_node.attr["f"]:
42+
43+
new_node.attr["f"].func.name = _prefixed_func_name(
44+
prefix, new_node.attr["f"].func.name,
45+
)
46+
47+
for func in orig_graph_def.library.function:
48+
new_func = mut_graph_def.library.function.add()
49+
new_func.CopyFrom(func)
50+
# Not creating a structure out of factored out function. They already
51+
# create an awkward hierarchy and one for each graph.
52+
new_func.signature.name = _prefixed_func_name(
53+
prefix, new_func.signature.name
54+
)
55+
56+
for gradient in orig_graph_def.library.gradient:
57+
new_gradient = mut_graph_def.library.gradient.add()
58+
new_gradient.CopyFrom(gradient)
59+
new_gradient.function_name = _prefixed_func_name(
60+
prefix, new_gradient.function_name,
61+
)
62+
new_gradient.gradient_func = _prefixed_func_name(
63+
prefix, new_gradient.gradient_func,
64+
)
65+
66+
return mut_graph_def
4667

47-
Raises:
48-
_ProtoListDuplicateKeyError: A proto_list contains items with duplicate
49-
keys.
50-
_SameKeyDiffContentError: An item with the same key has different contents.
51-
"""
5268

53-
def _assert_proto_container_unique_keys(proto_list, get_key):
54-
"""Asserts proto_list to only contains unique keys.
55-
56-
Args:
57-
proto_list: A `RepeatedCompositeContainer` or `RepeatedScalarContainer`.
58-
get_key: A function that takes an element of `proto_list` and returns a
59-
hashable key.
60-
61-
Raises:
62-
_ProtoListDuplicateKeyError: A proto_list contains items with duplicate
63-
keys.
64-
"""
65-
keys = set()
66-
for item in proto_list:
67-
key = get_key(item)
68-
if key in keys:
69-
raise _ProtoListDuplicateKeyError(key)
70-
keys.add(key)
71-
72-
_assert_proto_container_unique_keys(dst_proto_list, get_key)
73-
_assert_proto_container_unique_keys(src_proto_list, get_key)
74-
75-
key_to_proto = {}
76-
for proto in dst_proto_list:
77-
key = get_key(proto)
78-
key_to_proto[key] = proto
79-
80-
for proto in src_proto_list:
81-
key = get_key(proto)
82-
if key in key_to_proto:
83-
if proto != key_to_proto.get(key):
84-
raise _SameKeyDiffContentError(key)
85-
else:
86-
dst_proto_list.add().CopyFrom(proto)
87-
88-
89-
def combine_graph_defs(to_proto, from_proto):
90-
"""Combines two GraphDefs by adding nodes from from_proto into to_proto.
69+
def merge_graph_defs(graph_defs):
70+
"""Merges GraphDefs by adding unique prefix, `graph_{ind}`, to names.
9171
9272
All GraphDefs are expected to be of TensorBoard's.
93-
It assumes node names are unique across GraphDefs if contents differ. The
94-
names can be the same if the NodeDef content are exactly the same.
73+
74+
When collecting graphs using the `tf.summary.trace` API, node names are not
75+
guranteed to be unique. When non-unique names are not considered, it can
76+
lead to graph visualization showing them as one which creates inaccurate
77+
depiction of the flow of the graph (e.g., if there are A -> B -> C and D ->
78+
B -> E, you may see {A, D} -> B -> E). To prevent such graph, we checked
79+
for uniquenss while merging but it resulted in
80+
https://github.com/tensorflow/tensorboard/issues/1929.
81+
82+
To remedy these issues, we simply "apply name scope" on each graph by
83+
prefixing it with unique name (with a chance of collision) to create
84+
unconnected group of graphs.
85+
86+
In case there is only one graph def passed, it returns the original
87+
graph_def. In case no graph defs are passed, it returns an empty GraphDef.
9588
9689
Args:
97-
to_proto: A destination TensorBoard GraphDef.
98-
from_proto: A TensorBoard GraphDef to copy contents from.
90+
graph_defs: TensorBoard GraphDefs to merge.
9991
10092
Returns:
101-
to_proto
93+
TensorBoard GraphDef that merges all graph_defs with unique prefixes.
10294
10395
Raises:
104-
ValueError in case any assumption about GraphDef is violated: A
105-
GraphDef should have unique node, function, and gradient function
106-
names. Also, when merging GraphDefs, they should have not have nodes,
107-
functions, or gradient function mappings that share the name but details
108-
do not match.
96+
ValueError in case GraphDef versions mismatch.
10997
"""
110-
if from_proto.version != to_proto.version:
111-
raise ValueError("Cannot combine GraphDefs of different versions.")
98+
if len(graph_defs) == 1:
99+
return graph_defs[0]
100+
elif len(graph_defs) == 0:
101+
return graph_pb2.GraphDef()
112102

113-
try:
114-
_safe_copy_proto_list_values(
115-
to_proto.node, from_proto.node, lambda n: n.name
116-
)
117-
except _ProtoListDuplicateKeyError as exc:
118-
raise ValueError("A GraphDef contains non-unique node names: %s" % exc)
119-
except _SameKeyDiffContentError as exc:
120-
raise ValueError(
121-
(
122-
"Cannot combine GraphDefs because nodes share a name "
123-
"but contents are different: %s"
124-
)
125-
% exc
126-
)
127-
try:
128-
_safe_copy_proto_list_values(
129-
to_proto.library.function,
130-
from_proto.library.function,
131-
lambda n: n.signature.name,
132-
)
133-
except _ProtoListDuplicateKeyError as exc:
134-
raise ValueError(
135-
"A GraphDef contains non-unique function names: %s" % exc
136-
)
137-
except _SameKeyDiffContentError as exc:
138-
raise ValueError(
139-
(
140-
"Cannot combine GraphDefs because functions share a name "
141-
"but are different: %s"
142-
)
143-
% exc
144-
)
103+
dst_graph_def = graph_pb2.GraphDef()
145104

146-
try:
147-
_safe_copy_proto_list_values(
148-
to_proto.library.gradient,
149-
from_proto.library.gradient,
150-
lambda g: g.gradient_func,
151-
)
152-
except _ProtoListDuplicateKeyError as exc:
153-
raise ValueError(
154-
"A GraphDef contains non-unique gradient function names: %s" % exc
155-
)
156-
except _SameKeyDiffContentError as exc:
157-
raise ValueError(
158-
(
159-
"Cannot combine GraphDefs because gradients share a gradient_func name "
160-
"but map to different functions: %s"
105+
if graph_defs[0].versions.producer:
106+
dst_graph_def.versions.CopyFrom(graph_defs[0].versions)
107+
108+
for index, graph_def in enumerate(graph_defs):
109+
if dst_graph_def.versions.producer != graph_def.versions.producer:
110+
raise ValueError("Cannot combine GraphDefs of different versions.")
111+
112+
mapped_graph_def = _prepend_names("graph_%d" % (index + 1), graph_def)
113+
dst_graph_def.node.extend(mapped_graph_def.node)
114+
if mapped_graph_def.library.function:
115+
dst_graph_def.library.function.extend(
116+
mapped_graph_def.library.function
117+
)
118+
if mapped_graph_def.library.gradient:
119+
dst_graph_def.library.gradient.extend(
120+
mapped_graph_def.library.gradient
161121
)
162-
% exc
163-
)
164122

165-
return to_proto
123+
return dst_graph_def

0 commit comments

Comments
 (0)