Skip to content

Commit c83aa15

Browse files
committed
fix: truncate_long_and_double incur torchscript inference issues
Signed-off-by: Bo Wang <[email protected]>
1 parent 17490b1 commit c83aa15

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

core/partitioning/shape_analysis.cpp

+17-9
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,13 @@ void getSegmentsOutputByRunning(
8181
jit_inputs_ivalues.push_back(ivalues_maps[input].toList());
8282
} else if (input->type()->kind() == torch::jit::TypeKind::TupleType) {
8383
jit_inputs_ivalues.push_back(ivalues_maps[input].toTuple());
84+
} else if (input->type()->kind() == torch::jit::TypeKind::NumberType) {
85+
jit_inputs_ivalues.push_back(ivalues_maps[input].toScalar());
8486
} else {
85-
TORCHTRT_THROW_ERROR("Unable to find type for value: " << input->debugName() << " to get the ivalues.\n");
87+
TORCHTRT_THROW_ERROR(
88+
"Unable to find type for value: " << input->debugName()
89+
<< " to get the ivalues. The type for this value should be "
90+
<< input->type()->str() << " \n");
8691
}
8792
}
8893

@@ -110,28 +115,31 @@ void getSegmentsOutputByRunning(
110115
for (auto& i : seg_block.raw_inputs()) {
111116
if (ivalues_maps[i].isTensor()) {
112117
// set the input_shape and data_type
113-
at::ScalarType t = ivalues_maps[i].toTensor().scalar_type();
118+
// we can use a temp value here instead of replacing the values in ivalues_map since we only use ivalues_map for
119+
// shape inference
120+
auto cur_ivalue = ivalues_maps[i];
121+
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
114122
if (!partition_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble)) {
115123
TORCHTRT_THROW_ERROR(
116124
"Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled");
117125
} else if (partition_info.truncate_long_and_double && t == at::kLong) {
118-
ivalues_maps[i] = ivalues_maps[i].toTensor().to(at::kInt);
126+
cur_ivalue = cur_ivalue.toTensor().to(at::kInt);
119127
LOG_WARNING("Truncating graph input type from at::kLong to at::kInt");
120128
} else if (partition_info.truncate_long_and_double && t == at::kDouble) {
121-
ivalues_maps[i] = ivalues_maps[i].toTensor().to(at::kFloat);
129+
cur_ivalue = cur_ivalue.toTensor().to(at::kFloat);
122130
LOG_WARNING("Truncating graph input type from at::kDouble to at::kFloat");
123131
}
124-
c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType(ivalues_maps[i].toTensor().dtype());
132+
c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType(cur_ivalue.toTensor().dtype());
125133
if (dtype == c10::nullopt) {
126-
TORCHTRT_THROW_ERROR("Unsupported input data type " << ivalues_maps[i].toTensor().dtype());
134+
TORCHTRT_THROW_ERROR("Unsupported input data type " << cur_ivalue.toTensor().dtype());
127135
}
128-
if (ivalues_maps[i].toTensor().sizes().size() == 0) {
136+
if (cur_ivalue.toTensor().sizes().size() == 0) {
129137
// handle Scalar types, which has sizes of []
130138
input_shapes.push_back(util::toVec(util::toDims(c10::List<long int>({1}))));
131139
} else {
132-
input_shapes.push_back(util::toVec(util::toDims(ivalues_maps[i].toTensor().sizes())));
140+
input_shapes.push_back(util::toVec(util::toDims(cur_ivalue.toTensor().sizes())));
133141
}
134-
input_types.push_back(ivalues_maps[i].toTensor().scalar_type());
142+
input_types.push_back(cur_ivalue.toTensor().scalar_type());
135143
}
136144
}
137145

0 commit comments

Comments
 (0)