Skip to content

Commit 68f0317

Browse files
committed
feat(//core/lowering): Fuse aten::addmm branches into a single
aten::addm op that can be expanded by a later pass Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent db20098 commit 68f0317

File tree

5 files changed

+105
-2
lines changed

5 files changed

+105
-2
lines changed

core/lowering/lowering.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
2929
passes::RemoveDropout(g);
3030
passes::FuseFlattenLinear(g);
3131
passes::Conv2DToConvolution(g);
32+
passes::FuseAddMMBranches(g);
3233
passes::UnpackAddMM(g);
3334
//passes::UnpackBatchNorm(g);
3435
passes::UnpackLogSoftmax(g);

core/lowering/passes/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ cc_library(
1515
srcs = [
1616
"conv2d_to_convolution.cpp",
1717
"exception_elimination.cpp",
18+
"fuse_addmm_branches.cpp",
1819
"fuse_flatten_linear.cpp",
1920
"remove_contiguous.cpp",
2021
"remove_dropout.cpp",

core/lowering/passes/exception_elimination.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ struct ExceptionOrPassPatternElimination {
6464
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
6565
auto n = *it;
6666
if (n->kind() == prim::If && isExceptionOrPassNode(n)) {
67-
LOG_GRAPH("Found that node " << *n << " is an exception or pass node (EliminateChecks)");
67+
LOG_GRAPH("Found that node " << *n << " is an exception or pass node (EliminateChecks)" << std::endl);
6868
it.destroyCurrent();
6969
}
7070
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#include "torch/csrc/jit/passes/guard_elimination.h"
2+
#include "torch/csrc/jit/ir/alias_analysis.h"
3+
#include "torch/csrc/jit/jit_log.h"
4+
#include "torch/csrc/jit/passes/constant_propagation.h"
5+
#include "torch/csrc/jit/passes/peephole.h"
6+
#include "torch/csrc/jit/runtime/graph_executor.h"
7+
#include "torch/csrc/jit/passes/dead_code_elimination.h"
8+
9+
#include "core/util/prelude.h"
10+
11+
#include <vector>
12+
13+
namespace trtorch {
14+
namespace core {
15+
namespace lowering {
16+
namespace passes {
17+
namespace {
18+
using namespace torch::jit;
19+
struct AddMMBranchFusion {
20+
AddMMBranchFusion(std::shared_ptr<Graph> graph)
21+
: graph_(std::move(graph)) {}
22+
23+
void run() {
24+
findAddMMVariantsNodes(graph_->block());
25+
torch::jit::EliminateDeadCode(graph_);
26+
LOG_GRAPH("Post aten::addmm branch fusion: " << *graph_);
27+
}
28+
29+
private:
30+
bool isAddMMVariantsNode(Node* n) {
31+
/// Check if this Node hosts a pattern like so:
32+
/// %ret : Tensor = prim::If(%622)
33+
/// block0():
34+
/// %ret.1 : Tensor = aten::addmm(%self.fc.bias, %x9.1, %3677, %3, %3)
35+
/// -> (%ret.1)
36+
/// block1():
37+
/// %output.1 : Tensor = aten::matmul(%x9.1, %3677)
38+
/// %output0.1 : Tensor = aten::add_(%output.1, %self.fc.bias, %3)
39+
/// -> (%output0.1)
40+
41+
if (n->blocks().size() != 2) {
42+
return false;
43+
}
44+
auto arm1 = n->blocks()[0];
45+
auto arm2 = n->blocks()[1];
46+
47+
auto arm1_start = arm1->nodes().begin();
48+
if ((*arm1_start)->kind().toQualString() != std::string("aten::addmm")
49+
&& (*(++arm1_start))->kind() != prim::Return) {
50+
// Make sure that block0 is solely just the aten::addmm op and the return
51+
return false;
52+
}
53+
54+
auto arm2_start = arm2->nodes().begin();
55+
if ((*arm2_start)->kind().toQualString() != std::string("aten::matmul")
56+
&& (*(++arm2_start))->kind().toQualString() != std::string("aten::add_")
57+
&& (*(++arm2_start))->kind() != prim::Return) {
58+
// Make sure that block1 is solely the return
59+
return false;
60+
}
61+
62+
return true;
63+
}
64+
65+
void findAddMMVariantsNodes(Block* b) {
66+
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
67+
auto n = *it;
68+
if (n->kind() == prim::If && isAddMMVariantsNode(n)) {
69+
LOG_GRAPH("Found that node " << *n << " is an AddMM variants node (FuseAddMMBranches)" << std::endl);
70+
auto arm1 = n->blocks()[0];
71+
auto arm1_start = arm1->nodes().begin();
72+
73+
auto input_values = (*arm1_start)->inputs();
74+
75+
auto new_addmm_node = b->owningGraph()->create(c10::Symbol::fromQualString("aten::addmm"), input_values, 1);
76+
n->replaceAllUsesWith(new_addmm_node);
77+
78+
auto old_insert_point = b->owningGraph()->insertPoint();
79+
b->owningGraph()->setInsertPoint(n);
80+
b->owningGraph()->insertNode(new_addmm_node);
81+
b->owningGraph()->setInsertPoint(old_insert_point);
82+
83+
it.destroyCurrent();
84+
}
85+
}
86+
}
87+
88+
std::shared_ptr<Graph> graph_;
89+
};
90+
} // namespace
91+
92+
void FuseAddMMBranches(std::shared_ptr<Graph> graph) {
93+
AddMMBranchFusion ammbf(std::move(graph));
94+
ammbf.run();
95+
}
96+
97+
} // namespace passes
98+
} // namespace lowering
99+
} // namespace core
100+
} // namespace trtorch

core/lowering/passes/passes.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@ namespace lowering {
88
namespace passes {
99

1010
void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
11+
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
1112
void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
13+
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
1214
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
1315
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
1416
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
1517
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
1618
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
17-
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
1819

1920
} // namespace irfusers
2021
} // namespace lowering

0 commit comments

Comments
 (0)