Skip to content

Commit a052cf0

Browse files
authored
Merge pull request #1847 from mfeliz-cruise/michael.feliz/dynamic_zeros_like
feat: Add support for dynamic zeros_like and ones_like
2 parents 92b37e7 + b236435 commit a052cf0

File tree

5 files changed

+327
-0
lines changed

5 files changed

+327
-0
lines changed

core/conversion/evaluators/aten.cpp

+57
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,63 @@ auto aten_registrations TORCHTRT_UNUSED =
157157
auto out_tensor = torch::ones(args.at(n->input(0)).unwrapToIntList().vec(), options);
158158
return out_tensor;
159159
}})
160+
.evaluator(
161+
{c10::Symbol::fromQualString("aten::new_zeros"),
162+
// aten::new_zeros(Tensor self, int[] size, *, int? dtype=None, int? layout=None,
163+
// Device? device=None, bool? pin_memory=None) -> (Tensor)
164+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
165+
auto tensor_info = newTensorImplementation(n, args);
166+
return torch::zeros(tensor_info.first, tensor_info.second);
167+
}})
168+
.evaluator(
169+
{c10::Symbol::fromQualString("aten::new_ones"),
170+
// aten::new_ones(Tensor self, int[] size, *, int? dtype=None, int? layout=None,
171+
// Device? device=None, bool? pin_memory=None) -> (Tensor)
172+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
173+
auto tensor_info = newTensorImplementation(n, args);
174+
return torch::ones(tensor_info.first, tensor_info.second);
175+
}})
176+
.evaluator(
177+
{c10::Symbol::fromQualString("aten::zeros_like"),
178+
// aten::zeros_like(Tensor self, *, int? dtype=None, int? layout=None,
179+
// Device? device=None, bool? pin_memory=None, int? memory_format=None) -> (Tensor)
180+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
181+
return newTensorLikeImplementation(
182+
ctx, n, args, [](const std::vector<int64_t>& dims, const torch::TensorOptions& options) {
183+
return torch::zeros(dims, options);
184+
});
185+
}})
186+
.evaluator(
187+
{c10::Symbol::fromQualString("aten::ones_like"),
188+
// aten::ones_like(Tensor self, *, int? dtype=None, int? layout=None,
189+
// Device? device=None, bool? pin_memory=None, int? memory_format=None) -> (Tensor)
190+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
191+
return newTensorLikeImplementation(
192+
ctx, n, args, [](const std::vector<int64_t>& dims, const torch::TensorOptions& options) {
193+
return torch::ones(dims, options);
194+
});
195+
}})
196+
.evaluator(
197+
{c10::Symbol::fromQualString("aten::fill_"),
198+
// aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> (Tensor(a!))
199+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
200+
auto tensor_var = args.at(n->input(0));
201+
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
202+
std::vector<int64_t> dims;
203+
if (tensor_var.isITensor()) {
204+
auto tensor = tensor_var.ITensor();
205+
auto dtype = util::TRTDataTypeToScalarType(tensor->getType());
206+
options = options.dtype(dtype);
207+
dims = util::toVec(tensor->getDimensions());
208+
} else {
209+
auto tensor = tensor_var.unwrapToTensor();
210+
options = options.dtype(tensor.dtype());
211+
dims = tensor.sizes().vec();
212+
}
213+
auto scalar_value = args.at(n->input(1)).unwrapToScalar();
214+
auto out_tensor = torch::full(dims, scalar_value, options);
215+
return out_tensor;
216+
}})
160217
.evaluator(
161218
{c10::Symbol::fromQualString("aten::full"),
162219
// aten::full(int[] size, Scalar fill_value, *, int? dtype=None, int? layout=None,

core/conversion/evaluators/eval_util.cpp

+71
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,77 @@ at::Tensor createTensorFromList(
367367
return tensor;
368368
}
369369

370+
std::pair<std::vector<int64_t>, torch::TensorOptions> newTensorImplementation(const torch::jit::Node* n, kwargs& args) {
371+
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
372+
373+
// Input 2 is the dtype
374+
if (!args.at(n->input(2)).isNone() && !args.at(n->input(2)).IValue()->isNone()) {
375+
options = options.dtype(c10::ScalarType(args.at(n->input(2)).unwrapToInt()));
376+
} else {
377+
auto tensor_var = args.at(n->input(0));
378+
if (tensor_var.isITensor()) {
379+
auto tensor = tensor_var.ITensor();
380+
options = options.dtype(scalarTypeToTypeMeta(util::TRTDataTypeToScalarType(tensor->getType())));
381+
} else {
382+
auto tensor = tensor_var.unwrapToTensor();
383+
options = options.dtype(tensor.dtype());
384+
}
385+
}
386+
return std::make_pair(args.at(n->input(1)).unwrapToIntList().vec(), options);
387+
}
388+
389+
c10::optional<torch::jit::IValue> newTensorLikeImplementation(
390+
ConversionCtx* ctx,
391+
const torch::jit::Node* n,
392+
kwargs& args,
393+
const std::function<torch::Tensor(const std::vector<int64_t>&, const torch::TensorOptions&)>& tensor_builder) {
394+
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
395+
auto tensor_var = args.at(n->input(0));
396+
397+
if (tensor_var.isITensor()) {
398+
auto tensor = tensor_var.ITensor();
399+
auto dtype = util::TRTDataTypeToScalarType(tensor->getType());
400+
options = options.dtype(dtype);
401+
} else {
402+
auto tensor = tensor_var.unwrapToTensor();
403+
options = options.dtype(tensor.dtype());
404+
}
405+
406+
// Input 1 is the dtype
407+
if (!args.at(n->input(1)).isNone() && !args.at(n->input(1)).IValue()->isNone()) {
408+
options = options.dtype(c10::ScalarType(args.at(n->input(1)).unwrapToInt()));
409+
}
410+
std::vector<int64_t> tensor_dims;
411+
if (tensor_var.isITensor()) {
412+
auto tensor = tensor_var.ITensor();
413+
tensor_dims = util::toVec(tensor->getDimensions());
414+
} else {
415+
auto tensor = tensor_var.unwrapToTensor();
416+
tensor_dims = tensor.sizes().vec();
417+
}
418+
if (ctx->settings.allow_shape_tensors && ctx->input_is_dynamic) {
419+
auto self = args.at(n->input(0)).ITensorOrFreeze(ctx);
420+
std::vector<int64_t> dims_vec(self->getDimensions().nbDims, 1);
421+
auto constant = tensor_builder(dims_vec, options);
422+
auto constant_itensor = converters::tensor_to_const(ctx, constant);
423+
// broadcast constant to output shape
424+
std::vector<int64_t> start_vec(self->getDimensions().nbDims, 0);
425+
auto start_offset = util::toDims(c10::IntArrayRef(start_vec));
426+
auto shape_layer = ctx->net->addShape(*self);
427+
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
428+
shape_layer->setName((util::node_info(n) + "_shape").c_str());
429+
// slice implements expand
430+
auto slice_layer = ctx->net->addSlice(*constant_itensor, start_offset, self->getDimensions(), start_offset);
431+
TORCHTRT_CHECK(slice_layer, "Unable to create slice layer from node: " << *n);
432+
slice_layer->setInput(2, *shape_layer->getOutput(0));
433+
slice_layer->setName((util::node_info(n) + "_slice").c_str());
434+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], slice_layer->getOutput(0));
435+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
436+
return {};
437+
}
438+
return tensor_builder(tensor_dims, options);
439+
}
440+
370441
} // namespace evaluators
371442
} // namespace conversion
372443
} // namespace core

core/conversion/evaluators/eval_util.h

+8
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "core/conversion/evaluators/evaluators.h"
44
#include "torch/csrc/jit/ir/ir.h"
5+
#include "torch/torch.h"
56

67
namespace torch_tensorrt {
78
namespace core {
@@ -26,6 +27,13 @@ int64_t normalizeIndex(int64_t idx, int64_t list_size);
2627

2728
at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device = at::kCPU);
2829

30+
std::pair<std::vector<int64_t>, torch::TensorOptions> newTensorImplementation(const torch::jit::Node* n, kwargs& args);
31+
c10::optional<torch::jit::IValue> newTensorLikeImplementation(
32+
ConversionCtx* ctx,
33+
const torch::jit::Node* n,
34+
kwargs& args,
35+
const std::function<torch::Tensor(const std::vector<int64_t>&, const torch::TensorOptions&)>& tensor_builder);
36+
2937
} // namespace evaluators
3038
} // namespace conversion
3139
} // namespace core

tests/core/conversion/evaluators/test_aten_evaluators.cpp

+190
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,196 @@ TEST(Evaluators, ZerosDataTypeEvaluatesCorrectly) {
207207
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
208208
}
209209

210+
TEST(Evaluators, NewZerosEvaluatesCorrectly) {
211+
const auto graph = R"IR(
212+
graph(%x.1 : Tensor):
213+
%2 : None = prim::Constant() # :0:0
214+
%3 : int[] = aten::size(%x.1) # <string>:7:9
215+
%z.1 : Tensor = aten::new_zeros(%x.1, %3, %2, %2, %2, %2)
216+
return (%z.1))IR";
217+
218+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
219+
220+
auto g = std::make_shared<torch::jit::Graph>();
221+
torch::jit::parseIR(graph, g.get());
222+
223+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
224+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in});
225+
226+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
227+
}
228+
229+
TEST(Evaluators, NewZerosDataTypeEvaluatesCorrectly) {
230+
const auto graph = R"IR(
231+
graph(%x.1 : Tensor):
232+
%2 : int = prim::Constant[value=5]() # :0:0 (Float16)
233+
%3 : None = prim::Constant() # :0:0
234+
%4 : int[] = aten::size(%x.1) # <string>:7:9
235+
%z.1 : Tensor = aten::new_zeros(%x.1, %4, %2, %3, %3, %3)
236+
return (%z.1))IR";
237+
238+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
239+
240+
auto g = std::make_shared<torch::jit::Graph>();
241+
torch::jit::parseIR(graph, g.get());
242+
243+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
244+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in});
245+
246+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
247+
}
248+
249+
TEST(Evaluators, NewOnesEvaluatesCorrectly) {
250+
const auto graph = R"IR(
251+
graph(%x.1 : Tensor):
252+
%2 : None = prim::Constant() # :0:0
253+
%3 : int[] = aten::size(%x.1) # <string>:7:9
254+
%z.1 : Tensor = aten::new_ones(%x.1, %3, %2, %2, %2, %2)
255+
return (%z.1))IR";
256+
257+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
258+
259+
auto g = std::make_shared<torch::jit::Graph>();
260+
torch::jit::parseIR(graph, g.get());
261+
262+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
263+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in});
264+
265+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
266+
}
267+
268+
TEST(Evaluators, NewOnesDataTypeEvaluatesCorrectly) {
269+
const auto graph = R"IR(
270+
graph(%x.1 : Tensor):
271+
%2 : int = prim::Constant[value=5]() # :0:0 (Float16)
272+
%3 : None = prim::Constant() # :0:0
273+
%4 : int[] = aten::size(%x.1) # <string>:7:9
274+
%z.1 : Tensor = aten::new_ones(%x.1, %4, %2, %3, %3, %3)
275+
return (%z.1))IR";
276+
277+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
278+
279+
auto g = std::make_shared<torch::jit::Graph>();
280+
torch::jit::parseIR(graph, g.get());
281+
282+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
283+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in});
284+
285+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
286+
}
287+
288+
TEST(Evaluators, ZerosLikeEvaluatesCorrectly) {
289+
const auto graph = R"IR(
290+
graph(%x.1 : Tensor):
291+
%2 : None = prim::Constant() # :0:0
292+
%z.1 : Tensor = aten::zeros_like(%x.1, %2, %2, %2, %2, %2)
293+
return (%z.1))IR";
294+
295+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
296+
297+
auto g = std::make_shared<torch::jit::Graph>();
298+
torch::jit::parseIR(graph, g.get());
299+
300+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
301+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in});
302+
303+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
304+
}
305+
306+
TEST(Evaluators, ZerosLikeDataTypeEvaluatesCorrectly) {
307+
const auto graph = R"IR(
308+
graph(%x.1 : Tensor):
309+
%2 : int = prim::Constant[value=5]() # :0:0 (Float16)
310+
%3 : None = prim::Constant()
311+
%z.1 : Tensor = aten::zeros_like(%x.1, %2, %3, %3, %3, %3)
312+
return (%z.1))IR";
313+
314+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
315+
316+
auto g = std::make_shared<torch::jit::Graph>();
317+
torch::jit::parseIR(graph, g.get());
318+
319+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
320+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in});
321+
322+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
323+
}
324+
325+
TEST(Evaluators, ZerosLikeDynamic) {
326+
const auto graph = R"IR(
327+
graph(%x.1 : Tensor):
328+
%2 : int = prim::Constant[value=5]() # :0:0 (Float16)
329+
%3 : None = prim::Constant()
330+
%z.1 : Tensor = aten::zeros_like(%x.1, %2, %3, %3, %3, %3)
331+
return (%z.1))IR";
332+
auto in = at::randint(1, 10, {23, 17, 5, 29}, {at::kCUDA});
333+
334+
auto g = std::make_shared<torch::jit::Graph>();
335+
torch::jit::parseIR(graph, g.get());
336+
337+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
338+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
339+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true, true);
340+
341+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0]));
342+
}
343+
344+
TEST(Evaluators, OnesLikeEvaluatesCorrectly) {
345+
const auto graph = R"IR(
346+
graph(%x.1 : Tensor):
347+
%2 : None = prim::Constant() # :0:0
348+
%z.1 : Tensor = aten::ones_like(%x.1, %2, %2, %2, %2, %2)
349+
return (%z.1))IR";
350+
351+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
352+
353+
auto g = std::make_shared<torch::jit::Graph>();
354+
torch::jit::parseIR(graph, g.get());
355+
356+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
357+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in});
358+
359+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
360+
}
361+
362+
TEST(Evaluators, OnesLikeDataTypeEvaluatesCorrectly) {
363+
const auto graph = R"IR(
364+
graph(%x.1 : Tensor):
365+
%2 : int = prim::Constant[value=5]() # :0:0 (Float16)
366+
%3 : None = prim::Constant()
367+
%z.1 : Tensor = aten::ones_like(%x.1, %2, %3, %3, %3, %3)
368+
return (%z.1))IR";
369+
370+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
371+
372+
auto g = std::make_shared<torch::jit::Graph>();
373+
torch::jit::parseIR(graph, g.get());
374+
375+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
376+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in});
377+
378+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
379+
}
380+
381+
TEST(Evaluators, OnesLikeDynamic) {
382+
const auto graph = R"IR(
383+
graph(%x.1 : Tensor):
384+
%2 : int = prim::Constant[value=5]() # :0:0 (Float16)
385+
%3 : None = prim::Constant()
386+
%z.1 : Tensor = aten::ones_like(%x.1, %2, %3, %3, %3, %3)
387+
return (%z.1))IR";
388+
auto in = at::randint(1, 10, {3, 6}, {at::kCUDA});
389+
390+
auto g = std::make_shared<torch::jit::Graph>();
391+
torch::jit::parseIR(graph, g.get());
392+
393+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
394+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
395+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true, true);
396+
397+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0]));
398+
}
399+
210400
TEST(Evaluators, ATenArangeIntEvaluatesCorrectly) {
211401
const auto graph = R"IR(
212402
graph():

tests/core/partitioning/test_loop_fallback.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ TEST(Partitioning, CheckLoopFallbackNoEvalCompilesCorrectly) {
5353

5454
std::vector<torch_tensorrt::core::ir::Input> input_ranges{torch_tensorrt::core::ir::Input({1, 10})};
5555
torch_tensorrt::core::CompileSpec cfg(input_ranges);
56+
cfg.partitioning_info.forced_fallback_operators.push_back("aten::ones_like");
5657
cfg.partitioning_info.enabled = true;
5758

5859
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();

0 commit comments

Comments
 (0)