Skip to content

Commit decd0ed

Browse files
committed
feat: Enable sparsity support in TRTorch
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 4771d20 commit decd0ed

File tree

11 files changed

+31
-5
lines changed

11 files changed

+31
-5
lines changed

core/conversion/conversionctx/ConversionCtx.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
8686
cfg->clearFlag(nvinfer1::BuilderFlag::kTF32);
8787
}
8888

89+
if (settings.sparse_weights) {
90+
cfg->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS);
91+
}
92+
8993
if (settings.refit) {
9094
cfg->setFlag(nvinfer1::BuilderFlag::kREFIT);
9195
}

core/conversion/conversionctx/ConversionCtx.h

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ struct Device {
2424

2525
struct BuilderSettings {
2626
nvinfer1::DataType op_precision = nvinfer1::DataType::kFLOAT;
27+
bool sparse_weights = false;
2728
bool disable_tf32 = false;
2829
bool refit = false;
2930
bool debug = false;

cpp/api/include/trtorch/trtorch.h

+5
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,11 @@ struct TRTORCH_API CompileSpec {
248248
*/
249249
bool disable_tf32 = false;
250250

251+
/**
252+
* Enable sparsity for weights of conv and FC layers
253+
*/
254+
bool sparse_weights = false;
255+
251256
/**
252257
* Build a refitable engine
253258
*/

cpp/api/src/compile_spec.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
102102
internal.convert_info.engine_settings.op_precision = nvinfer1::DataType::kFLOAT;
103103
}
104104

105+
internal.convert_info.engine_settings.sparse_weights = external.sparse_weights;
105106
internal.convert_info.engine_settings.disable_tf32 = external.disable_tf32;
106107
internal.convert_info.engine_settings.refit = external.refit;
107108
internal.convert_info.engine_settings.debug = external.debug;

cpp/trtorchexec/main.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,13 @@ int main(int argc, const char* argv[]) {
5757

5858
auto compile_spec = trtorch::CompileSpec(dims);
5959
compile_spec.workspace_size = 1 << 24;
60+
compile_spec.sparse_weights = true;
6061

61-
std::cout << "Checking operator support" << std::endl;
62-
if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) {
63-
std::cerr << "Method is not currently supported by TRTorch" << std::endl;
64-
return -1;
65-
}
62+
// std::cout << "Checking operator support" << std::endl;
63+
// if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) {
64+
// std::cerr << "Method is not currently supported by TRTorch" << std::endl;
65+
// return -1;
66+
// }
6667

6768
std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
6869
auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", compile_spec);

py/trtorch/_compile_spec.py

+5
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
156156
if "calibrator" in compile_spec:
157157
info.ptq_calibrator = compile_spec["calibrator"]
158158

159+
if "sparse_weights" in compile_spec:
160+
assert isinstance(compile_spec["sparse_weights"], bool)
161+
info.sparse_weights = compile_spec["sparse_weights"]
162+
159163
if "disable_tf32" in compile_spec:
160164
assert isinstance(compile_spec["disable_tf32"], bool)
161165
info.disable_tf32 = compile_spec["disable_tf32"]
@@ -237,6 +241,7 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
237241
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
238242
},
239243
"op_precision": torch.half, # Operating precision set to FP16
244+
"sparse_weights": Enable sparsity for convolution and fully connected layers.
240245
"disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
241246
"refit": False, # enable refit
242247
"debug": False, # enable debuggable engine

py/trtorch/_compiler.py

+3
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def compile(module: torch.jit.ScriptModule, compile_spec: Any) -> torch.jit.Scri
4141
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
4242
},
4343
"op_precision": torch.half, # Operating precision set to FP16
44+
"disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
45+
"sparse_weights": Enable sparsity for convolution and fully connected layers.
4446
"refit": false, # enable refit
4547
"debug": false, # enable debuggable engine
4648
"strict_types": false, # kernels should strictly run in operating precision
@@ -107,6 +109,7 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st
107109
},
108110
"op_precision": torch.half, # Operating precision set to FP16
109111
"disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
112+
"sparse_weights": Enable sparsity for convolution and fully connected layers.
110113
"refit": false, # enable refit
111114
"debug": false, # enable debuggable engine
112115
"strict_types": false, # kernels should strictly run in operating precision

py/trtorch/csrc/register_tensorrt_classes.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ void RegisterTRTCompileSpec() {
4646
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);
4747

4848
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision);
49+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, sparse_weights);
4950
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, disable_tf32);
5051
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, refit);
5152
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, debug);

py/trtorch/csrc/tensorrt_classes.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
127127
auto info = core::CompileSpec(internal_input_ranges);
128128
info.convert_info.engine_settings.op_precision = toTRTDataType(op_precision);
129129
info.convert_info.engine_settings.calibrator = ptq_calibrator;
130+
info.convert_info.engine_settings.sparse_weights = sparse_weights;
130131
info.convert_info.engine_settings.disable_tf32 = disable_tf32;
131132
info.convert_info.engine_settings.refit = refit;
132133
info.convert_info.engine_settings.debug = debug;
@@ -163,6 +164,7 @@ std::string CompileSpec::stringify() {
163164
ss << " ]" << std::endl;
164165
ss << " \"Op Precision\": " << to_str(op_precision) << std::endl;
165166
ss << " \"TF32 Disabled\": " << disable_tf32 << std::endl;
167+
ss << " \"Sparsity\": " << sparse_weights << std::endl;
166168
ss << " \"Refit\": " << refit << std::endl;
167169
ss << " \"Debug\": " << debug << std::endl;
168170
ss << " \"Strict Types\": " << strict_types << std::endl;

py/trtorch/csrc/tensorrt_classes.h

+2
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ struct CompileSpec : torch::CustomClassHolder {
126126

127127
ADD_ENUM_GET_SET(op_precision, DataType, static_cast<int64_t>(DataType::kChar));
128128
ADD_FIELD_GET_SET(disable_tf32, bool);
129+
ADD_FIELD_GET_SET(sparse_weights, bool);
129130
ADD_FIELD_GET_SET(refit, bool);
130131
ADD_FIELD_GET_SET(debug, bool);
131132
ADD_FIELD_GET_SET(strict_types, bool);
@@ -142,6 +143,7 @@ struct CompileSpec : torch::CustomClassHolder {
142143
std::vector<InputRange> input_ranges;
143144
nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;
144145
DataType op_precision = DataType::kFloat;
146+
bool sparse_weights = false;
145147
bool disable_tf32 = false;
146148
bool refit = false;
147149
bool debug = false;

py/trtorch/csrc/trtorch_py.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ PYBIND11_MODULE(_C, m) {
244244
.def_readwrite("op_precision", &CompileSpec::op_precision)
245245
.def_readwrite("ptq_calibrator", &CompileSpec::ptq_calibrator)
246246
.def_readwrite("refit", &CompileSpec::refit)
247+
.def_readwrite("sparse_weights", &CompileSpec::sparse_weights)
247248
.def_readwrite("disable_tf32", &CompileSpec::disable_tf32)
248249
.def_readwrite("debug", &CompileSpec::debug)
249250
.def_readwrite("strict_types", &CompileSpec::strict_types)

0 commit comments

Comments
 (0)