@@ -540,5 +540,39 @@ TEST(Evaluators, EqStrResultIsFalseEvaluatesCorrectly) {
540
540
auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
541
541
auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
542
542
543
+ ASSERT_TRUE (jit_results[0 ] == trt_results[0 ]);
544
+ }
545
+
546
+ TEST (Evaluators, AndBoolResultIsTrueEvaluatesCorrectly) {
547
+ const auto graph = R"IR(
548
+ graph():
549
+ %1 : bool = prim::Constant[value=1]()
550
+ %2 : bool = prim::Constant[value=1]()
551
+ %3 : bool = aten::__and__(%1, %2)
552
+ return (%3))IR" ;
553
+
554
+ auto g = std::make_shared<torch::jit::Graph>();
555
+ torch::jit::parseIR (graph, g.get ());
556
+
557
+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
558
+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
559
+
560
+ ASSERT_TRUE (jit_results[0 ] == trt_results[0 ]);
561
+ }
562
+
563
+ TEST (Evaluators, AndBoolResultIsFalseEvaluatesCorrectly) {
564
+ const auto graph = R"IR(
565
+ graph():
566
+ %1 : bool = prim::Constant[value=1]()
567
+ %2 : bool = prim::Constant[value=0]()
568
+ %3 : bool = aten::__and__(%1, %2)
569
+ return (%3))IR" ;
570
+
571
+ auto g = std::make_shared<torch::jit::Graph>();
572
+ torch::jit::parseIR (graph, g.get ());
573
+
574
+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
575
+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
576
+
543
577
ASSERT_TRUE (jit_results[0 ] == trt_results[0 ]);
544
578
}
0 commit comments