|
14 | 14 | # ==============================================================================
|
15 | 15 | """Utilities for graph plugin."""
|
16 | 16 |
|
| 17 | +from tensorboard.compat.proto import graph_pb2 |
17 | 18 |
|
18 |
| -class _ProtoListDuplicateKeyError(Exception): |
19 |
| - pass |
20 | 19 |
|
| 20 | +def _prefixed_op_name(prefix, op_name): |
| 21 | + return "%s/%s" % (prefix, op_name) |
21 | 22 |
|
22 |
| -class _SameKeyDiffContentError(Exception): |
23 |
| - pass |
24 | 23 |
|
| 24 | +def _prefixed_func_name(prefix, func_name): |
| 25 | + # TODO(stephanwlee): add business logic to strip "__inference_". |
| 26 | + return "%s_%s" % (prefix, func_name) |
25 | 27 |
|
26 |
| -def _safe_copy_proto_list_values(dst_proto_list, src_proto_list, get_key): |
27 |
| - """Safely merge values from `src_proto_list` into `dst_proto_list`. |
28 | 28 |
|
29 |
| - Each element in `dst_proto_list` must be mapped by `get_key` to a key |
30 |
| - value that is unique within that list; likewise for `src_proto_list`. |
31 |
| - If an element of `src_proto_list` has the same key as an existing |
32 |
| - element in `dst_proto_list`, then the elements must also be equal. |
| 29 | +def _prepend_names(prefix, orig_graph_def): |
| 30 | + mut_graph_def = graph_pb2.GraphDef() |
| 31 | + for node in orig_graph_def.node: |
| 32 | + new_node = mut_graph_def.node.add() |
| 33 | + new_node.CopyFrom(node) |
| 34 | + new_node.name = _prefixed_op_name(prefix, node.name) |
| 35 | + new_node.input[:] = [ |
| 36 | + _prefixed_op_name(prefix, input_name) for input_name in node.input |
| 37 | + ] |
33 | 38 |
|
34 |
| - Args: |
35 |
| - dst_proto_list: A `RepeatedCompositeContainer` or |
36 |
| - `RepeatedScalarContainer` into which values should be copied. |
37 |
| - src_proto_list: A container holding the same kind of values as in |
38 |
| - `dst_proto_list` from which values should be copied. |
39 |
| - get_key: A function that takes an element of `dst_proto_list` or |
40 |
| - `src_proto_list` and returns a key, such that if two elements have |
41 |
| - the same key then it is required that they be deep-equal. For |
42 |
| - instance, if `dst_proto_list` is a list of nodes, then `get_key` |
43 |
| - might be `lambda node: node.name` to indicate that if two nodes |
44 |
| - have the same name then they must be the same node. All keys must |
45 |
| - be hashable. |
| 39 | + # Remap tf.function method name in the PartitionedCall. 'f' is short for |
| 40 | + # function. |
| 41 | + if new_node.op == "PartitionedCall" and new_node.attr["f"]: |
| 42 | + |
| 43 | + new_node.attr["f"].func.name = _prefixed_func_name( |
| 44 | + prefix, new_node.attr["f"].func.name, |
| 45 | + ) |
| 46 | + |
| 47 | + for func in orig_graph_def.library.function: |
| 48 | + new_func = mut_graph_def.library.function.add() |
| 49 | + new_func.CopyFrom(func) |
| 50 | + # Not creating a structure out of factored out function. They already |
| 51 | + # create an awkward hierarchy and one for each graph. |
| 52 | + new_func.signature.name = _prefixed_func_name( |
| 53 | + prefix, new_func.signature.name |
| 54 | + ) |
| 55 | + |
| 56 | + for gradient in orig_graph_def.library.gradient: |
| 57 | + new_gradient = mut_graph_def.library.gradient.add() |
| 58 | + new_gradient.CopyFrom(gradient) |
| 59 | + new_gradient.function_name = _prefixed_func_name( |
| 60 | + prefix, new_gradient.function_name, |
| 61 | + ) |
| 62 | + new_gradient.gradient_func = _prefixed_func_name( |
| 63 | + prefix, new_gradient.gradient_func, |
| 64 | + ) |
| 65 | + |
| 66 | + return mut_graph_def |
46 | 67 |
|
47 |
| - Raises: |
48 |
| - _ProtoListDuplicateKeyError: A proto_list contains items with duplicate |
49 |
| - keys. |
50 |
| - _SameKeyDiffContentError: An item with the same key has different contents. |
51 |
| - """ |
52 | 68 |
|
53 |
| - def _assert_proto_container_unique_keys(proto_list, get_key): |
54 |
| - """Asserts proto_list to only contains unique keys. |
55 |
| -
|
56 |
| - Args: |
57 |
| - proto_list: A `RepeatedCompositeContainer` or `RepeatedScalarContainer`. |
58 |
| - get_key: A function that takes an element of `proto_list` and returns a |
59 |
| - hashable key. |
60 |
| -
|
61 |
| - Raises: |
62 |
| - _ProtoListDuplicateKeyError: A proto_list contains items with duplicate |
63 |
| - keys. |
64 |
| - """ |
65 |
| - keys = set() |
66 |
| - for item in proto_list: |
67 |
| - key = get_key(item) |
68 |
| - if key in keys: |
69 |
| - raise _ProtoListDuplicateKeyError(key) |
70 |
| - keys.add(key) |
71 |
| - |
72 |
| - _assert_proto_container_unique_keys(dst_proto_list, get_key) |
73 |
| - _assert_proto_container_unique_keys(src_proto_list, get_key) |
74 |
| - |
75 |
| - key_to_proto = {} |
76 |
| - for proto in dst_proto_list: |
77 |
| - key = get_key(proto) |
78 |
| - key_to_proto[key] = proto |
79 |
| - |
80 |
| - for proto in src_proto_list: |
81 |
| - key = get_key(proto) |
82 |
| - if key in key_to_proto: |
83 |
| - if proto != key_to_proto.get(key): |
84 |
| - raise _SameKeyDiffContentError(key) |
85 |
| - else: |
86 |
| - dst_proto_list.add().CopyFrom(proto) |
87 |
| - |
88 |
| - |
89 |
| -def combine_graph_defs(to_proto, from_proto): |
90 |
| - """Combines two GraphDefs by adding nodes from from_proto into to_proto. |
| 69 | +def merge_graph_defs(graph_defs): |
| 70 | + """Merges GraphDefs by adding unique prefix, `graph_{ind}`, to names. |
91 | 71 |
|
92 | 72 | All GraphDefs are expected to be of TensorBoard's.
|
93 |
| - It assumes node names are unique across GraphDefs if contents differ. The |
94 |
| - names can be the same if the NodeDef content are exactly the same. |
| 73 | +
|
| 74 | + When collecting graphs using the `tf.summary.trace` API, node names are not |
| 75 | + guranteed to be unique. When non-unique names are not considered, it can |
| 76 | + lead to graph visualization showing them as one which creates inaccurate |
| 77 | + depiction of the flow of the graph (e.g., if there are A -> B -> C and D -> |
| 78 | + B -> E, you may see {A, D} -> B -> E). To prevent such graph, we checked |
| 79 | + for uniquenss while merging but it resulted in |
| 80 | + https://github.com/tensorflow/tensorboard/issues/1929. |
| 81 | +
|
| 82 | + To remedy these issues, we simply "apply name scope" on each graph by |
| 83 | + prefixing it with unique name (with a chance of collision) to create |
| 84 | + unconnected group of graphs. |
| 85 | +
|
| 86 | + In case there is only one graph def passed, it returns the original |
| 87 | + graph_def. In case no graph defs are passed, it returns an empty GraphDef. |
95 | 88 |
|
96 | 89 | Args:
|
97 |
| - to_proto: A destination TensorBoard GraphDef. |
98 |
| - from_proto: A TensorBoard GraphDef to copy contents from. |
| 90 | + graph_defs: TensorBoard GraphDefs to merge. |
99 | 91 |
|
100 | 92 | Returns:
|
101 |
| - to_proto |
| 93 | + TensorBoard GraphDef that merges all graph_defs with unique prefixes. |
102 | 94 |
|
103 | 95 | Raises:
|
104 |
| - ValueError in case any assumption about GraphDef is violated: A |
105 |
| - GraphDef should have unique node, function, and gradient function |
106 |
| - names. Also, when merging GraphDefs, they should have not have nodes, |
107 |
| - functions, or gradient function mappings that share the name but details |
108 |
| - do not match. |
| 96 | + ValueError in case GraphDef versions mismatch. |
109 | 97 | """
|
110 |
| - if from_proto.version != to_proto.version: |
111 |
| - raise ValueError("Cannot combine GraphDefs of different versions.") |
| 98 | + if len(graph_defs) == 1: |
| 99 | + return graph_defs[0] |
| 100 | + elif len(graph_defs) == 0: |
| 101 | + return graph_pb2.GraphDef() |
112 | 102 |
|
113 |
| - try: |
114 |
| - _safe_copy_proto_list_values( |
115 |
| - to_proto.node, from_proto.node, lambda n: n.name |
116 |
| - ) |
117 |
| - except _ProtoListDuplicateKeyError as exc: |
118 |
| - raise ValueError("A GraphDef contains non-unique node names: %s" % exc) |
119 |
| - except _SameKeyDiffContentError as exc: |
120 |
| - raise ValueError( |
121 |
| - ( |
122 |
| - "Cannot combine GraphDefs because nodes share a name " |
123 |
| - "but contents are different: %s" |
124 |
| - ) |
125 |
| - % exc |
126 |
| - ) |
127 |
| - try: |
128 |
| - _safe_copy_proto_list_values( |
129 |
| - to_proto.library.function, |
130 |
| - from_proto.library.function, |
131 |
| - lambda n: n.signature.name, |
132 |
| - ) |
133 |
| - except _ProtoListDuplicateKeyError as exc: |
134 |
| - raise ValueError( |
135 |
| - "A GraphDef contains non-unique function names: %s" % exc |
136 |
| - ) |
137 |
| - except _SameKeyDiffContentError as exc: |
138 |
| - raise ValueError( |
139 |
| - ( |
140 |
| - "Cannot combine GraphDefs because functions share a name " |
141 |
| - "but are different: %s" |
142 |
| - ) |
143 |
| - % exc |
144 |
| - ) |
| 103 | + dst_graph_def = graph_pb2.GraphDef() |
145 | 104 |
|
146 |
| - try: |
147 |
| - _safe_copy_proto_list_values( |
148 |
| - to_proto.library.gradient, |
149 |
| - from_proto.library.gradient, |
150 |
| - lambda g: g.gradient_func, |
151 |
| - ) |
152 |
| - except _ProtoListDuplicateKeyError as exc: |
153 |
| - raise ValueError( |
154 |
| - "A GraphDef contains non-unique gradient function names: %s" % exc |
155 |
| - ) |
156 |
| - except _SameKeyDiffContentError as exc: |
157 |
| - raise ValueError( |
158 |
| - ( |
159 |
| - "Cannot combine GraphDefs because gradients share a gradient_func name " |
160 |
| - "but map to different functions: %s" |
| 105 | + if graph_defs[0].versions.producer: |
| 106 | + dst_graph_def.versions.CopyFrom(graph_defs[0].versions) |
| 107 | + |
| 108 | + for index, graph_def in enumerate(graph_defs): |
| 109 | + if dst_graph_def.versions.producer != graph_def.versions.producer: |
| 110 | + raise ValueError("Cannot combine GraphDefs of different versions.") |
| 111 | + |
| 112 | + mapped_graph_def = _prepend_names("graph_%d" % (index + 1), graph_def) |
| 113 | + dst_graph_def.node.extend(mapped_graph_def.node) |
| 114 | + if mapped_graph_def.library.function: |
| 115 | + dst_graph_def.library.function.extend( |
| 116 | + mapped_graph_def.library.function |
| 117 | + ) |
| 118 | + if mapped_graph_def.library.gradient: |
| 119 | + dst_graph_def.library.gradient.extend( |
| 120 | + mapped_graph_def.library.gradient |
161 | 121 | )
|
162 |
| - % exc |
163 |
| - ) |
164 | 122 |
|
165 |
| - return to_proto |
| 123 | + return dst_graph_def |
0 commit comments