Skip to content

Commit ca43432

Browse files
committed
address review comments
1 parent 8ff3e54 commit ca43432

File tree

3 files changed

+78
-71
lines changed

3 files changed

+78
-71
lines changed

core/conversion/evaluators/aten.cpp

-71
Original file line numberDiff line numberDiff line change
@@ -125,77 +125,6 @@ DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(
125125
int64_t,
126126
{"aten::__round_to_zero_floordiv(int a, int b) -> (int)"});
127127

128-
std::pair<std::vector<int64_t>, torch::TensorOptions> newTensorImplementation(const torch::jit::Node* n, kwargs& args) {
129-
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
130-
131-
// Input 2 is the dtype
132-
if (!args.at(n->input(2)).isNone() && !args.at(n->input(2)).IValue()->isNone()) {
133-
options = options.dtype(c10::ScalarType(args.at(n->input(2)).unwrapToInt()));
134-
} else {
135-
auto tensor_var = args.at(n->input(0));
136-
if (tensor_var.isITensor()) {
137-
auto tensor = tensor_var.ITensor();
138-
options = options.dtype(scalarTypeToTypeMeta(util::TRTDataTypeToScalarType(tensor->getType())));
139-
} else {
140-
auto tensor = tensor_var.unwrapToTensor();
141-
options = options.dtype(tensor.dtype());
142-
}
143-
}
144-
return std::make_pair(args.at(n->input(1)).unwrapToIntList().vec(), options);
145-
}
146-
147-
c10::optional<torch::jit::IValue> newTensorLikeImplementation(
148-
ConversionCtx* ctx,
149-
const torch::jit::Node* n,
150-
kwargs& args,
151-
const std::function<torch::Tensor(const std::vector<int64_t>&, const torch::TensorOptions&)>& tensor_builder) {
152-
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
153-
auto tensor_var = args.at(n->input(0));
154-
155-
if (tensor_var.isITensor()) {
156-
auto tensor = tensor_var.ITensor();
157-
auto dtype = util::TRTDataTypeToScalarType(tensor->getType());
158-
options = options.dtype(dtype);
159-
} else {
160-
auto tensor = tensor_var.unwrapToTensor();
161-
options = options.dtype(tensor.dtype());
162-
}
163-
164-
// Input 1 is the dtype
165-
if (!args.at(n->input(1)).isNone() && !args.at(n->input(1)).IValue()->isNone()) {
166-
options = options.dtype(c10::ScalarType(args.at(n->input(1)).unwrapToInt()));
167-
}
168-
std::vector<int64_t> tensor_dims;
169-
if (tensor_var.isITensor()) {
170-
auto tensor = tensor_var.ITensor();
171-
tensor_dims = util::toVec(tensor->getDimensions());
172-
} else {
173-
auto tensor = tensor_var.unwrapToTensor();
174-
tensor_dims = tensor.sizes().vec();
175-
}
176-
if (ctx->input_is_dynamic) {
177-
auto self = args.at(n->input(0)).ITensorOrFreeze(ctx);
178-
std::vector<int64_t> dims_vec(self->getDimensions().nbDims, 1);
179-
auto constant = tensor_builder(dims_vec, options);
180-
auto constant_itensor = converters::tensor_to_const(ctx, constant);
181-
// broadcast constant to output shape
182-
std::vector<int64_t> start_vec(self->getDimensions().nbDims, 0);
183-
auto start_offset = util::toDims(c10::IntArrayRef(start_vec));
184-
auto shape_layer = ctx->net->addShape(*self);
185-
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
186-
shape_layer->setName((util::node_info(n) + "_shape").c_str());
187-
// slice implements expand
188-
auto slice_layer = ctx->net->addSlice(*constant_itensor, start_offset, self->getDimensions(), start_offset);
189-
TORCHTRT_CHECK(slice_layer, "Unable to create slice layer from node: " << *n);
190-
slice_layer->setInput(2, *shape_layer->getOutput(0));
191-
slice_layer->setName((util::node_info(n) + "_slice").c_str());
192-
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], slice_layer->getOutput(0));
193-
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
194-
return {};
195-
}
196-
return tensor_builder(tensor_dims, options);
197-
}
198-
199128
auto aten_registrations TORCHTRT_UNUSED =
200129
RegisterNodeEvaluators()
201130
.evaluator(

core/conversion/evaluators/eval_util.cpp

+71
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,77 @@ at::Tensor createTensorFromList(
349349
return tensor;
350350
}
351351

352+
std::pair<std::vector<int64_t>, torch::TensorOptions> newTensorImplementation(const torch::jit::Node* n, kwargs& args) {
353+
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
354+
355+
// Input 2 is the dtype
356+
if (!args.at(n->input(2)).isNone() && !args.at(n->input(2)).IValue()->isNone()) {
357+
options = options.dtype(c10::ScalarType(args.at(n->input(2)).unwrapToInt()));
358+
} else {
359+
auto tensor_var = args.at(n->input(0));
360+
if (tensor_var.isITensor()) {
361+
auto tensor = tensor_var.ITensor();
362+
options = options.dtype(scalarTypeToTypeMeta(util::TRTDataTypeToScalarType(tensor->getType())));
363+
} else {
364+
auto tensor = tensor_var.unwrapToTensor();
365+
options = options.dtype(tensor.dtype());
366+
}
367+
}
368+
return std::make_pair(args.at(n->input(1)).unwrapToIntList().vec(), options);
369+
}
370+
371+
c10::optional<torch::jit::IValue> newTensorLikeImplementation(
372+
ConversionCtx* ctx,
373+
const torch::jit::Node* n,
374+
kwargs& args,
375+
const std::function<torch::Tensor(const std::vector<int64_t>&, const torch::TensorOptions&)>& tensor_builder) {
376+
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
377+
auto tensor_var = args.at(n->input(0));
378+
379+
if (tensor_var.isITensor()) {
380+
auto tensor = tensor_var.ITensor();
381+
auto dtype = util::TRTDataTypeToScalarType(tensor->getType());
382+
options = options.dtype(dtype);
383+
} else {
384+
auto tensor = tensor_var.unwrapToTensor();
385+
options = options.dtype(tensor.dtype());
386+
}
387+
388+
// Input 1 is the dtype
389+
if (!args.at(n->input(1)).isNone() && !args.at(n->input(1)).IValue()->isNone()) {
390+
options = options.dtype(c10::ScalarType(args.at(n->input(1)).unwrapToInt()));
391+
}
392+
std::vector<int64_t> tensor_dims;
393+
if (tensor_var.isITensor()) {
394+
auto tensor = tensor_var.ITensor();
395+
tensor_dims = util::toVec(tensor->getDimensions());
396+
} else {
397+
auto tensor = tensor_var.unwrapToTensor();
398+
tensor_dims = tensor.sizes().vec();
399+
}
400+
if (ctx->settings.allow_shape_tensors && ctx->input_is_dynamic) {
401+
auto self = args.at(n->input(0)).ITensorOrFreeze(ctx);
402+
std::vector<int64_t> dims_vec(self->getDimensions().nbDims, 1);
403+
auto constant = tensor_builder(dims_vec, options);
404+
auto constant_itensor = converters::tensor_to_const(ctx, constant);
405+
// broadcast constant to output shape
406+
std::vector<int64_t> start_vec(self->getDimensions().nbDims, 0);
407+
auto start_offset = util::toDims(c10::IntArrayRef(start_vec));
408+
auto shape_layer = ctx->net->addShape(*self);
409+
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
410+
shape_layer->setName((util::node_info(n) + "_shape").c_str());
411+
// slice implements expand
412+
auto slice_layer = ctx->net->addSlice(*constant_itensor, start_offset, self->getDimensions(), start_offset);
413+
TORCHTRT_CHECK(slice_layer, "Unable to create slice layer from node: " << *n);
414+
slice_layer->setInput(2, *shape_layer->getOutput(0));
415+
slice_layer->setName((util::node_info(n) + "_slice").c_str());
416+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], slice_layer->getOutput(0));
417+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
418+
return {};
419+
}
420+
return tensor_builder(tensor_dims, options);
421+
}
422+
352423
} // namespace evaluators
353424
} // namespace conversion
354425
} // namespace core

core/conversion/evaluators/eval_util.h

+7
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ int64_t normalizeIndex(int64_t idx, int64_t list_size);
2626

2727
at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device = at::kCPU);
2828

29+
std::pair<std::vector<int64_t>, torch::TensorOptions> newTensorImplementation(const torch::jit::Node* n, kwargs& args);
30+
c10::optional<torch::jit::IValue> newTensorLikeImplementation(
31+
ConversionCtx* ctx,
32+
const torch::jit::Node* n,
33+
kwargs& args,
34+
const std::function<torch::Tensor(const std::vector<int64_t>&, const torch::TensorOptions&)>& tensor_builder);
35+
2936
} // namespace evaluators
3037
} // namespace conversion
3138
} // namespace core

0 commit comments

Comments
 (0)