@@ -50,7 +50,41 @@ pnnx.Output output 1 0 out
50
50
}
51
51
};
52
52
53
+ class torch_stft_0 : public torch_stft
54
+ {
55
+ public:
56
+ const char * match_pattern_graph () const
57
+ {
58
+ return R"PNNXIR( 7767517
59
+ 11 10
60
+ pnnx.Input input_0 0 1 input
61
+ pnnx.Input input_1 0 1 window
62
+ prim::Constant op_0 0 1 n_fft value=%n_fft
63
+ prim::Constant op_1 0 1 hop_length value=%hop_length
64
+ prim::Constant op_2 0 1 win_length value=%win_length
65
+ prim::Constant op_3 0 1 normalized value=%normalized
66
+ prim::Constant op_4 0 1 onesided value=%onesided
67
+ prim::Constant op_5 0 1 return_complex value=%return_complex
68
+ prim::Constant op_6 0 1 align_to_window value=%align_to_window
69
+ aten::stft op_7 9 1 input n_fft hop_length win_length window normalized onesided return_complex align_to_window out
70
+ pnnx.Output output 1 0 out
71
+ )PNNXIR" ;
72
+ }
73
+
74
+ void write (Operator* op, const std::map<std::string, Parameter>& captured_params) const
75
+ {
76
+ torch_stft::write (op, captured_params);
77
+
78
+ // keep align_to_window param only when enabled
79
+ if (captured_params.at (" align_to_window" ).type != 1 || captured_params.at (" align_to_window" ).b == false )
80
+ {
81
+ op->params .erase (" align_to_window" );
82
+ }
83
+ }
84
+ };
85
+
53
86
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS (torch_stft, 80 )
87
+ REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS (torch_stft_0, 80 )
54
88
55
89
class torch_stft_1 : public GraphRewriterPass
56
90
{
0 commit comments