1
1
import logging
2
- from typing import Callable , Dict , List , Optional , Sequence , Set
2
+ from typing import Dict , List , Optional , Sequence , Set
3
3
4
4
import torch
5
5
@@ -55,10 +55,6 @@ def __init__(
55
55
)
56
56
57
57
self .min_block_size = min_block_size
58
- logger .debug (
59
- "Initialized Capability-Based Partitioner with available Converters:\n "
60
- + f"{ CONVERTERS .display_all_available_converters ()} "
61
- )
62
58
63
59
def propose_partitions (self ) -> List [Partition ]:
64
60
# Propose partitions using the default, then refine the results
@@ -114,8 +110,8 @@ def __init__(self, support_dict=None, torch_executed_ops=set()):
114
110
super ().__init__ (support_dict )
115
111
116
112
# Initialize sets of supported/unsupported operators
117
- self .supported_operators = set ()
118
- self .unsupported_operators = set ()
113
+ self .supported_operators = {}
114
+ self .unsupported_operators = {}
119
115
self .torch_executed_ops = torch_executed_ops
120
116
121
117
def is_node_supported (
@@ -130,12 +126,18 @@ def is_node_supported(
130
126
if node in CONVERTERS and node_name not in self .torch_executed_ops :
131
127
# If node is a proper, supported computational node, store the operator
132
128
if not node .is_impure ():
133
- self .supported_operators .add (node_name )
129
+ if node_name not in self .supported_operators :
130
+ self .supported_operators [node_name ] = 1
131
+ else :
132
+ self .supported_operators [node_name ] += 1
134
133
135
134
return True
136
135
else :
137
136
if not node .is_impure ():
138
- self .unsupported_operators .add (node_name )
137
+ if node_name not in self .unsupported_operators :
138
+ self .unsupported_operators [node_name ] = 1
139
+ else :
140
+ self .unsupported_operators [node_name ] += 1
139
141
140
142
return False
141
143
@@ -147,15 +149,16 @@ def print_support_overview(self, num_trt_blocks: Optional[int] = None):
147
149
148
150
# Reformat support messages for debugger to print node overview as a single string
149
151
supported_nodes_str = "\n Supported Nodes:\n "
150
- for node_name in self .supported_operators :
151
- supported_nodes_str += f"- { node_name } \n "
152
+ for node_name , count in self .supported_operators . items () :
153
+ supported_nodes_str += f"- { node_name } + Operator Count: { count } \n "
152
154
153
155
logger .debug (supported_nodes_str )
154
156
155
- if len ( self .unsupported_operators ) != 0 :
157
+ if self .unsupported_operators :
156
158
unsupported_nodes_str = "\n Unsupported or Excluded Nodes:\n "
157
- for node_name in self .unsupported_operators :
158
- unsupported_nodes_str += f"- { node_name } \n "
159
+ for node_name , count in self .unsupported_operators .items ():
160
+ unsupported_nodes_str += f"- { node_name } + Operator Count: { count } \n "
161
+
159
162
logger .debug (unsupported_nodes_str )
160
163
else :
161
164
logger .debug ("\n All Nodes Supported\n " )
0 commit comments