Skip to content

Commit 6d73e43

Browse files
committed
feat: support aten::__and__.bool evaluator
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent 5643972 commit 6d73e43

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

core/conversion/evaluators/aten.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,12 @@ DEFINE_GENERIC_TWO_INPUT_EVALUATOR(
9898
"aten::ge.float_int(float a, int b) -> (bool)",
9999
}));
100100

101-
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(and, "aten::__and__", a&& b, bool, {"aten::__and__(int a, int b) -> (bool)"});
101+
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(
102+
and,
103+
"aten::__and__",
104+
a&& b,
105+
bool,
106+
std::set<std::string>({"aten::__and__(int a, int b) -> (bool)", "aten::__and__.bool(bool a, bool b) -> (bool)"}));
102107
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(or, "aten::__or__", a || b, bool, {"aten::__or__(int a, int b) -> (bool)"});
103108
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(
104109
xor,

tests/core/conversion/evaluators/test_aten_evaluators.cpp

+34
Original file line numberDiff line numberDiff line change
@@ -540,5 +540,39 @@ TEST(Evaluators, EqStrResultIsFalseEvaluatesCorrectly) {
540540
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
541541
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
542542

543+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
544+
}
545+
546+
TEST(Evaluators, AndBoolResultIsTrueEvaluatesCorrectly) {
547+
const auto graph = R"IR(
548+
graph():
549+
%1 : bool = prim::Constant[value=1]()
550+
%2 : bool = prim::Constant[value=1]()
551+
%3 : bool = aten::__and__(%1, %2)
552+
return (%3))IR";
553+
554+
auto g = std::make_shared<torch::jit::Graph>();
555+
torch::jit::parseIR(graph, g.get());
556+
557+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
558+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
559+
560+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
561+
}
562+
563+
TEST(Evaluators, AndBoolResultIsFalseEvaluatesCorrectly) {
564+
const auto graph = R"IR(
565+
graph():
566+
%1 : bool = prim::Constant[value=1]()
567+
%2 : bool = prim::Constant[value=0]()
568+
%3 : bool = aten::__and__(%1, %2)
569+
return (%3))IR";
570+
571+
auto g = std::make_shared<torch::jit::Graph>();
572+
torch::jit::parseIR(graph, g.get());
573+
574+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
575+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
576+
543577
ASSERT_TRUE(jit_results[0] == trt_results[0]);
544578
}

0 commit comments

Comments
 (0)