Skip to content

Commit 99cea1b

Browse files
committed
fix: Resolve issues in exception elmination pass
Signed-off-by: Michael Feliz <[email protected]>
1 parent 10b55d4 commit 99cea1b

File tree

3 files changed

+182
-2
lines changed

3 files changed

+182
-2
lines changed

core/lowering/passes/exception_elimination.cpp

+10-2
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,22 @@ struct ExceptionOrPassPatternElimination {
4141
auto arm1_start = arm1->nodes().begin();
4242
auto arm2_start = arm2->nodes().begin();
4343

44+
bool arm1_starts_with_exception = (*arm1_start)->kind() == prim::RaiseException;
45+
bool arm2_starts_with_exception = (*arm2_start)->kind() == prim::RaiseException;
46+
47+
if (!arm1_starts_with_exception && !arm2_starts_with_exception) {
48+
// Neither arm matches the pattern
49+
return false;
50+
}
51+
4452
/// Check if this Node hosts a pattern like so:
4553
/// = prim::If(%5958)
4654
/// block0():
4755
/// = prim::RaiseException(%45)
4856
/// -> ()
4957
/// block1():
5058
/// -> ()
51-
if ((*arm1_start)->kind() == prim::RaiseException) {
59+
if (arm1_starts_with_exception) {
5260
if ((*(++arm1_start))->kind() != prim::Return) {
5361
// Make sure that block0 is solely just the exception and the return
5462
return false;
@@ -67,7 +75,7 @@ struct ExceptionOrPassPatternElimination {
6775
/// block1():
6876
/// = prim::RaiseException(%45)
6977
/// -> ()
70-
if ((*arm2_start)->kind() == prim::RaiseException) {
78+
if (arm2_starts_with_exception) {
7179
if ((*(++arm2_start))->kind() != prim::Return) {
7280
// Make sure that block1 is solely just the exception and the return
7381
return false;

tests/core/lowering/BUILD

+5
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ lowering_test(
3030
name = "test_conv1d_pass",
3131
)
3232

33+
lowering_test(
34+
name = "test_exception_elimination_pass",
35+
)
36+
3337
lowering_test(
3438
name = "test_remove_contiguous_pass",
3539
)
@@ -82,6 +86,7 @@ test_suite(
8286
name = "lowering_tests",
8387
tests = [
8488
":test_conv1d_pass",
89+
":test_exception_elimination_pass",
8590
":test_linear_to_addmm",
8691
":test_module_fallback_passes",
8792
":test_operator_aliasing_pass",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
#include "core/lowering/passes/passes.h"
2+
#include "gtest/gtest.h"
3+
#include "torch/csrc/jit/ir/irparser.h"
4+
5+
TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block0) {
6+
// parseIR does not support " = prim::If(%51)" with no return value
7+
/*std::string source_ir = R"IR(graph(%x.1 : Tensor, %y.1 : Tensor):
8+
%3 : NoneType = prim::Constant()
9+
%4 : int = prim::Constant[value=0]()
10+
%mod_list.1 : Tensor[] = prim::ListConstruct(%x.1)
11+
%47 : Tensor = aten::sum(%x.1, %3)
12+
%49 : Tensor = aten::sum(%y.1, %3)
13+
%50 : Tensor = aten::gt(%47, %49)
14+
%51 : bool = aten::Bool(%50)
15+
= prim::If(%51)
16+
block0():
17+
= prim::RaiseException(%45)
18+
-> ()
19+
block1():
20+
-> ()
21+
%z.1 : Tensor = aten::cat(%mod_list.1, %4)
22+
return (%z.1))IR";*/
23+
24+
auto g = std::make_shared<torch::jit::Graph>();
25+
auto x = g->insertInput(0, "x");
26+
auto y = g->insertInput(1, "y");
27+
torch::jit::IValue zero(0);
28+
auto zero_const_val = g->insertConstant(zero);
29+
auto none_const_val = g->insertConstant(torch::jit::IValue());
30+
torch::jit::IValue except("EXCEPTION");
31+
auto except_val = g->insertConstant(except);
32+
auto list_node = g->createList(x->type(), torch::jit::ArrayRef<torch::jit::Value*>(x));
33+
g->insertNode(list_node);
34+
auto sum_x_node = g->create(torch::jit::aten::sum, {x, none_const_val});
35+
g->insertNode(sum_x_node);
36+
auto sum_y_node = g->create(torch::jit::aten::sum, {y, none_const_val});
37+
g->insertNode(sum_y_node);
38+
auto gt_node = g->create(torch::jit::aten::gt, {sum_x_node->output(), sum_y_node->output()});
39+
g->insertNode(gt_node);
40+
auto bool_node = g->create(torch::jit::aten::Bool, {gt_node->output()});
41+
bool_node->output()->setType(torch::jit::BoolType::get());
42+
g->insertNode(bool_node);
43+
auto if_node = g->create(torch::jit::prim::If, {bool_node->output()}, 0);
44+
auto if_block0 = if_node->addBlock();
45+
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val}, 0);
46+
if_block0->appendNode(exception_node);
47+
auto if_block1 = if_node->addBlock();
48+
g->insertNode(if_node);
49+
auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val});
50+
g->insertNode(cat_node);
51+
g->registerOutput(cat_node->output());
52+
53+
torch_tensorrt::core::lowering::passes::EliminateExceptionOrPassPattern(g);
54+
for (auto node : g->nodes()) {
55+
EXPECT_NE(node, if_node);
56+
}
57+
}
58+
59+
TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block1) {
60+
// parseIR does not support " = prim::If(%51)" with no return value
61+
/*std::string source_ir = R"IR(graph(%x.1 : Tensor, %y.1 : Tensor):
62+
%3 : NoneType = prim::Constant()
63+
%4 : int = prim::Constant[value=0]()
64+
%mod_list.1 : Tensor[] = prim::ListConstruct(%x.1)
65+
%47 : Tensor = aten::sum(%x.1, %3)
66+
%49 : Tensor = aten::sum(%y.1, %3)
67+
%50 : Tensor = aten::gt(%47, %49)
68+
%51 : bool = aten::Bool(%50)
69+
= prim::If(%51)
70+
block0():
71+
-> ()
72+
block1():
73+
= prim::RaiseException(%45)
74+
-> ()
75+
%z.1 : Tensor = aten::cat(%mod_list.1, %4)
76+
return (%z.1))IR";*/
77+
78+
auto g = std::make_shared<torch::jit::Graph>();
79+
auto x = g->insertInput(0, "x");
80+
auto y = g->insertInput(1, "y");
81+
torch::jit::IValue zero(0);
82+
auto zero_const_val = g->insertConstant(zero);
83+
auto none_const_val = g->insertConstant(torch::jit::IValue());
84+
torch::jit::IValue except("EXCEPTION");
85+
auto except_val = g->insertConstant(except);
86+
auto list_node = g->createList(x->type(), torch::jit::ArrayRef<torch::jit::Value*>(x));
87+
g->insertNode(list_node);
88+
auto sum_x_node = g->create(torch::jit::aten::sum, {x, none_const_val});
89+
g->insertNode(sum_x_node);
90+
auto sum_y_node = g->create(torch::jit::aten::sum, {y, none_const_val});
91+
g->insertNode(sum_y_node);
92+
auto gt_node = g->create(torch::jit::aten::gt, {sum_x_node->output(), sum_y_node->output()});
93+
g->insertNode(gt_node);
94+
auto bool_node = g->create(torch::jit::aten::Bool, {gt_node->output()});
95+
bool_node->output()->setType(torch::jit::BoolType::get());
96+
g->insertNode(bool_node);
97+
auto if_node = g->create(torch::jit::prim::If, {bool_node->output()}, 0);
98+
auto if_block0 = if_node->addBlock();
99+
auto if_block1 = if_node->addBlock();
100+
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val}, 0);
101+
if_block1->appendNode(exception_node);
102+
g->insertNode(if_node);
103+
auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val});
104+
g->insertNode(cat_node);
105+
g->registerOutput(cat_node->output());
106+
107+
torch_tensorrt::core::lowering::passes::EliminateExceptionOrPassPattern(g);
108+
for (auto node : g->nodes()) {
109+
EXPECT_NE(node, if_node);
110+
}
111+
}
112+
113+
TEST(LoweringPasses, EliminateExceptionOrPassPattern_Negative) {
114+
// parseIR does not support " = prim::If(%51)" with no return value
115+
/*std::string source_ir = R"IR(graph(%x.1 : Tensor, %y.1 : Tensor):
116+
%3 : NoneType = prim::Constant()
117+
%4 : int = prim::Constant[value=0]()
118+
%mod_list.1 : Tensor[] = prim::ListConstruct(%x.1)
119+
%47 : Tensor = aten::sum(%x.1, %3)
120+
%49 : Tensor = aten::sum(%y.1, %3)
121+
%50 : Tensor = aten::gt(%47, %49)
122+
%51 : bool = aten::Bool(%50)
123+
= prim::If(%51)
124+
block0():
125+
%10 : Tensor[] = aten::append(%mod_list.1, %y.1)
126+
-> ()
127+
block1():
128+
-> ()
129+
%z.1 : Tensor = aten::cat(%mod_list.1, %4)
130+
return (%z.1))IR";*/
131+
132+
auto g = std::make_shared<torch::jit::Graph>();
133+
auto x = g->insertInput(0, "x");
134+
auto y = g->insertInput(1, "y");
135+
torch::jit::IValue zero(0);
136+
auto zero_const_val = g->insertConstant(zero);
137+
auto none_const_val = g->insertConstant(torch::jit::IValue());
138+
auto list_node = g->createList(x->type(), torch::jit::ArrayRef<torch::jit::Value*>(x));
139+
g->insertNode(list_node);
140+
auto sum_x_node = g->create(torch::jit::aten::sum, {x, none_const_val});
141+
g->insertNode(sum_x_node);
142+
auto sum_y_node = g->create(torch::jit::aten::sum, {y, none_const_val});
143+
g->insertNode(sum_y_node);
144+
auto gt_node = g->create(torch::jit::aten::gt, {sum_x_node->output(), sum_y_node->output()});
145+
g->insertNode(gt_node);
146+
auto bool_node = g->create(torch::jit::aten::Bool, {gt_node->output()});
147+
bool_node->output()->setType(torch::jit::BoolType::get());
148+
g->insertNode(bool_node);
149+
auto if_node = g->create(torch::jit::prim::If, {bool_node->output()}, 0);
150+
auto if_block0 = if_node->addBlock();
151+
auto append_node = g->create(torch::jit::aten::append, {list_node->output(), y});
152+
if_block0->appendNode(append_node);
153+
auto if_block1 = if_node->addBlock();
154+
g->insertNode(if_node);
155+
auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val});
156+
g->insertNode(cat_node);
157+
g->registerOutput(cat_node->output());
158+
159+
torch_tensorrt::core::lowering::passes::EliminateExceptionOrPassPattern(g);
160+
int if_count = 0;
161+
for (auto node : g->nodes()) {
162+
if (node == if_node) {
163+
if_count++;
164+
}
165+
}
166+
EXPECT_EQ(1, if_count);
167+
}

0 commit comments

Comments
 (0)