3
3
#include " core/util/prelude.h"
4
4
#include " torch/csrc/jit/api/function_impl.h"
5
5
#include " torch/csrc/jit/ir/alias_analysis.h"
6
+ #include " torch/csrc/jit/ir/irparser.h"
6
7
#include " torch/csrc/jit/jit_log.h"
7
8
#include " torch/csrc/jit/passes/constant_propagation.h"
8
9
#include " torch/csrc/jit/passes/dead_code_elimination.h"
@@ -16,26 +17,58 @@ namespace core {
16
17
namespace lowering {
17
18
namespace passes {
18
19
19
- void replaceLinearWithBiasNonePattern (std::shared_ptr< torch::jit::Graph> graph ) {
20
+ void replaceLinear ( torch::jit::Block* block ) {
20
21
// Define the decomposition function for aten::linear for the case where bias (mat2) is None.
21
22
static torch::jit::CompilationUnit decompose_funcs (R"SCRIPT(
22
23
def linear(self: Tensor, mat1: Tensor, mat2: Tensor):
23
24
return torch.matmul(self, mat1.t())
24
25
)SCRIPT" );
25
26
26
- // Iterate through nodes and search for aten::linear nodes where bias is not a Tensor (includes bias=None case)
27
- auto block = graph->block ();
27
+ // Define graph format for aten::linear with Tensor-type bias
28
+ std::string fused_linear = R"IR(
29
+ graph(%input, %weight, %bias):
30
+ %1: int = prim::Constant[value=1]()
31
+ %weight = aten::t(%weight)
32
+ %mm: Tensor = aten::matmul(%input, %weight)
33
+ %b_f: Tensor = trt::const(%bias)
34
+ %out: Tensor = aten::add(%b_f, %mm, %1)
35
+ return (%out))IR" ;
36
+
37
+ // Iterate through nodes in block, seaching for aten::linear
28
38
for (auto it = block->nodes ().begin (); it != block->nodes ().end (); it++) {
29
39
auto n = *it;
30
- if (n->kind ().toQualString () == std::string (" aten::linear" )) {
40
+
41
+ // Recursively explore nested blocks, such as those arising from prim::If
42
+ for (auto block : n->blocks ()) {
43
+ replaceLinear (block);
44
+ }
45
+
46
+ if ((n->kind ().toQualString () == std::string (" aten::linear" )) && (n->inputs ().size () >= 3 )) {
31
47
auto input_values = n->inputs ();
32
- // input_values[2] is the bias. If none, replace it with the decomposed linear graph.
48
+
49
+ // input_values[2] is the bias
50
+ // If Tensor, replace with fused-bias decomposed graph
51
+ // Otherwise, replace it with the no-bias decomposed linear graph.
33
52
if (input_values[2 ]->type ()->isSubtypeOf (c10::TensorType::get ())) {
34
- continue ;
53
+ torch::jit::WithInsertPoint guard (*it);
54
+
55
+ // Initialize new fused subgraph from IR code above
56
+ auto fused_g = std::make_shared<torch::jit::Graph>();
57
+ torch::jit::parseIR (fused_linear, fused_g.get ());
58
+
59
+ // Insert subgraph in place of aten::linear, replacing inputs and outputs accordingly
60
+ torch::jit::Value* new_output = insertGraph (*it->owningGraph (), *fused_g, it->inputs ()).at (0 );
61
+ new_output->setType (it->output ()->type ());
62
+ it->output ()->replaceAllUsesWith (new_output);
63
+ it.destroyCurrent ();
35
64
} else {
36
65
torch::jit::WithInsertPoint guard (*it);
66
+
67
+ // Initialized decomposed graph without bias term
37
68
std::shared_ptr<torch::jit::Graph> d_graph = toGraphFunction (decompose_funcs.get_function (" linear" )).graph ();
38
69
torch::jit::Value* new_output = insertGraph (*it->owningGraph (), *d_graph, it->inputs ()).at (0 );
70
+
71
+ // Insert function in place of aten::linear, replacing inputs and outputs accordingly
39
72
new_output->setType (it->output ()->type ());
40
73
it->output ()->replaceAllUsesWith (new_output);
41
74
it.destroyCurrent ();
@@ -45,27 +78,8 @@ void replaceLinearWithBiasNonePattern(std::shared_ptr<torch::jit::Graph> graph)
45
78
}
46
79
47
80
void LinearToAddMM (std::shared_ptr<torch::jit::Graph>& graph) {
48
- // TensorRT implicitly adds a flatten layer infront of FC layers if necessary
49
- std::string flatten_linear_pattern = R"IR(
50
- graph(%input, %weight, %bias):
51
- %res = aten::linear(%input, %weight, %bias)
52
- return (%res))IR" ;
53
-
54
- std::string fused_linear = R"IR(
55
- graph(%input, %weight_t, %bias):
56
- %1: int = prim::Constant[value=1]()
57
- %weight = aten::t(%weight_t)
58
- %mm: Tensor = aten::matmul(%input, %weight)
59
- %b_f: Tensor = trt::const(%bias)
60
- %out: Tensor = aten::add(%b_f, %mm, %1)
61
- return (%out))IR" ;
62
-
63
- // First find and replace aten::linear nodes with non-tensor bias values.
64
- replaceLinearWithBiasNonePattern (graph);
65
-
66
- torch::jit::SubgraphRewriter flatten_linear_to_linear;
67
- flatten_linear_to_linear.RegisterRewritePattern (flatten_linear_pattern, fused_linear);
68
- flatten_linear_to_linear.runOnGraph (graph);
81
+ // Recursively find and replace all instances of aten::linear with the corresponding decomposed form
82
+ replaceLinear (graph->block ());
69
83
}
70
84
71
85
} // namespace passes
0 commit comments