@@ -207,6 +207,196 @@ TEST(Evaluators, ZerosDataTypeEvaluatesCorrectly) {
207
207
ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ].toTensor ()));
208
208
}
209
209
210
+ TEST (Evaluators, NewZerosEvaluatesCorrectly) {
211
+ const auto graph = R"IR(
212
+ graph(%x.1 : Tensor):
213
+ %2 : None = prim::Constant() # :0:0
214
+ %3 : int[] = aten::size(%x.1) # <string>:7:9
215
+ %z.1 : Tensor = aten::new_zeros(%x.1, %3, %2, %2, %2, %2)
216
+ return (%z.1))IR" ;
217
+
218
+ auto in = at::randint (1 , 10 , {1 , 5 , 5 , 5 }, {at::kCUDA });
219
+
220
+ auto g = std::make_shared<torch::jit::Graph>();
221
+ torch::jit::parseIR (graph, g.get ());
222
+
223
+ auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g, {in});
224
+ auto trt_results = torch_tensorrt::tests::util::EvaluateGraph (g->block (), {in});
225
+
226
+ ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ].toTensor ()));
227
+ }
228
+
229
+ TEST (Evaluators, NewZerosDataTypeEvaluatesCorrectly) {
230
+ const auto graph = R"IR(
231
+ graph(%x.1 : Tensor):
232
+ %2 : int = prim::Constant[value=5]() # :0:0 (Float16)
233
+ %3 : None = prim::Constant() # :0:0
234
+ %4 : int[] = aten::size(%x.1) # <string>:7:9
235
+ %z.1 : Tensor = aten::new_zeros(%x.1, %4, %2, %3, %3, %3)
236
+ return (%z.1))IR" ;
237
+
238
+ auto in = at::randint (1 , 10 , {1 , 5 , 5 , 5 }, {at::kCUDA });
239
+
240
+ auto g = std::make_shared<torch::jit::Graph>();
241
+ torch::jit::parseIR (graph, g.get ());
242
+
243
+ auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g, {in});
244
+ auto trt_results = torch_tensorrt::tests::util::EvaluateGraph (g->block (), {in});
245
+
246
+ ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ].toTensor ()));
247
+ }
248
+
249
+ TEST (Evaluators, NewOnesEvaluatesCorrectly) {
250
+ const auto graph = R"IR(
251
+ graph(%x.1 : Tensor):
252
+ %2 : None = prim::Constant() # :0:0
253
+ %3 : int[] = aten::size(%x.1) # <string>:7:9
254
+ %z.1 : Tensor = aten::new_ones(%x.1, %3, %2, %2, %2, %2)
255
+ return (%z.1))IR" ;
256
+
257
+ auto in = at::randint (1 , 10 , {1 , 5 , 5 , 5 }, {at::kCUDA });
258
+
259
+ auto g = std::make_shared<torch::jit::Graph>();
260
+ torch::jit::parseIR (graph, g.get ());
261
+
262
+ auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g, {in});
263
+ auto trt_results = torch_tensorrt::tests::util::EvaluateGraph (g->block (), {in});
264
+
265
+ ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ].toTensor ()));
266
+ }
267
+
268
+ TEST (Evaluators, NewOnesDataTypeEvaluatesCorrectly) {
269
+ const auto graph = R"IR(
270
+ graph(%x.1 : Tensor):
271
+ %2 : int = prim::Constant[value=5]() # :0:0 (Float16)
272
+ %3 : None = prim::Constant() # :0:0
273
+ %4 : int[] = aten::size(%x.1) # <string>:7:9
274
+ %z.1 : Tensor = aten::new_ones(%x.1, %4, %2, %3, %3, %3)
275
+ return (%z.1))IR" ;
276
+
277
+ auto in = at::randint (1 , 10 , {1 , 5 , 5 , 5 }, {at::kCUDA });
278
+
279
+ auto g = std::make_shared<torch::jit::Graph>();
280
+ torch::jit::parseIR (graph, g.get ());
281
+
282
+ auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g, {in});
283
+ auto trt_results = torch_tensorrt::tests::util::EvaluateGraph (g->block (), {in});
284
+
285
+ ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ].toTensor ()));
286
+ }
287
+
288
+ TEST (Evaluators, ZerosLikeEvaluatesCorrectly) {
289
+ const auto graph = R"IR(
290
+ graph(%x.1 : Tensor):
291
+ %2 : None = prim::Constant() # :0:0
292
+ %z.1 : Tensor = aten::zeros_like(%x.1, %2, %2, %2, %2, %2)
293
+ return (%z.1))IR" ;
294
+
295
+ auto in = at::randint (1 , 10 , {1 , 5 , 5 , 5 }, {at::kCUDA });
296
+
297
+ auto g = std::make_shared<torch::jit::Graph>();
298
+ torch::jit::parseIR (graph, g.get ());
299
+
300
+ auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g, {in});
301
+ auto trt_results = torch_tensorrt::tests::util::EvaluateGraph (g->block (), {in});
302
+
303
+ ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ].toTensor ()));
304
+ }
305
+
306
+ TEST (Evaluators, ZerosLikeDataTypeEvaluatesCorrectly) {
307
+ const auto graph = R"IR(
308
+ graph(%x.1 : Tensor):
309
+ %2 : int = prim::Constant[value=5]() # :0:0 (Float16)
310
+ %3 : None = prim::Constant()
311
+ %z.1 : Tensor = aten::zeros_like(%x.1, %2, %3, %3, %3, %3)
312
+ return (%z.1))IR" ;
313
+
314
+ auto in = at::randint (1 , 10 , {1 , 5 , 5 , 5 }, {at::kCUDA });
315
+
316
+ auto g = std::make_shared<torch::jit::Graph>();
317
+ torch::jit::parseIR (graph, g.get ());
318
+
319
+ auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g, {in});
320
+ auto trt_results = torch_tensorrt::tests::util::EvaluateGraph (g->block (), {in});
321
+
322
+ ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ].toTensor ()));
323
+ }
324
+
325
+ TEST (Evaluators, ZerosLikeDynamic) {
326
+ const auto graph = R"IR(
327
+ graph(%x.1 : Tensor):
328
+ %2 : int = prim::Constant[value=5]() # :0:0 (Float16)
329
+ %3 : None = prim::Constant()
330
+ %z.1 : Tensor = aten::zeros_like(%x.1, %2, %3, %3, %3, %3)
331
+ return (%z.1))IR" ;
332
+ auto in = at::randint (1 , 10 , {23 , 17 , 5 , 29 }, {at::kCUDA });
333
+
334
+ auto g = std::make_shared<torch::jit::Graph>();
335
+ torch::jit::parseIR (graph, g.get ());
336
+
337
+ auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g, {in});
338
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
339
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic (g, params, {in}, true , true );
340
+
341
+ ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ]));
342
+ }
343
+
344
+ TEST (Evaluators, OnesLikeEvaluatesCorrectly) {
345
+ const auto graph = R"IR(
346
+ graph(%x.1 : Tensor):
347
+ %2 : None = prim::Constant() # :0:0
348
+ %z.1 : Tensor = aten::ones_like(%x.1, %2, %2, %2, %2, %2)
349
+ return (%z.1))IR" ;
350
+
351
+ auto in = at::randint (1 , 10 , {1 , 5 , 5 , 5 }, {at::kCUDA });
352
+
353
+ auto g = std::make_shared<torch::jit::Graph>();
354
+ torch::jit::parseIR (graph, g.get ());
355
+
356
+ auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g, {in});
357
+ auto trt_results = torch_tensorrt::tests::util::EvaluateGraph (g->block (), {in});
358
+
359
+ ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ].toTensor ()));
360
+ }
361
+
362
+ TEST (Evaluators, OnesLikeDataTypeEvaluatesCorrectly) {
363
+ const auto graph = R"IR(
364
+ graph(%x.1 : Tensor):
365
+ %2 : int = prim::Constant[value=5]() # :0:0 (Float16)
366
+ %3 : None = prim::Constant()
367
+ %z.1 : Tensor = aten::ones_like(%x.1, %2, %3, %3, %3, %3)
368
+ return (%z.1))IR" ;
369
+
370
+ auto in = at::randint (1 , 10 , {1 , 5 , 5 , 5 }, {at::kCUDA });
371
+
372
+ auto g = std::make_shared<torch::jit::Graph>();
373
+ torch::jit::parseIR (graph, g.get ());
374
+
375
+ auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g, {in});
376
+ auto trt_results = torch_tensorrt::tests::util::EvaluateGraph (g->block (), {in});
377
+
378
+ ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ].toTensor ()));
379
+ }
380
+
381
+ TEST (Evaluators, OnesLikeDynamic) {
382
+ const auto graph = R"IR(
383
+ graph(%x.1 : Tensor):
384
+ %2 : int = prim::Constant[value=5]() # :0:0 (Float16)
385
+ %3 : None = prim::Constant()
386
+ %z.1 : Tensor = aten::ones_like(%x.1, %2, %3, %3, %3, %3)
387
+ return (%z.1))IR" ;
388
+ auto in = at::randint (1 , 10 , {3 , 6 }, {at::kCUDA });
389
+
390
+ auto g = std::make_shared<torch::jit::Graph>();
391
+ torch::jit::parseIR (graph, g.get ());
392
+
393
+ auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g, {in});
394
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
395
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic (g, params, {in}, true , true );
396
+
397
+ ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ]));
398
+ }
399
+
210
400
TEST (Evaluators, ATenArangeIntEvaluatesCorrectly) {
211
401
const auto graph = R"IR(
212
402
graph():
0 commit comments