Skip to content

Commit 76ce920

Browse files
authored
update pnnx ci with torch 2.7 (#6018)
1 parent 19abe7b commit 76ce920

12 files changed

+52
-14
lines changed

.ci/pnnx.yml

+8-4
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ concurrency:
1919

2020
variables:
2121
protobuf_version: 21.12
22-
libtorch_version: 2.6.0
23-
libtorchvision_version: 0.21.0
24-
onnxruntime_version: 1.21.0
25-
cache_date: 20250402
22+
libtorch_version: 2.7.0
23+
libtorchvision_version: 0.22.0
24+
onnxruntime_version: 1.21.1
25+
cache_date: 20250423
2626

2727
jobs:
2828
ubuntu:
@@ -81,6 +81,10 @@ jobs:
8181
torchvision-version: 0.21.0
8282
torchaudio-version: '2.6.0+cpu'
8383

84+
- torch-version: 2.7.0
85+
torchvision-version: 0.22.0
86+
torchaudio-version: '2.7.0+cpu'
87+
8488
runs-on:
8589
pool-name: docker
8690
container:

tools/pnnx/src/pass_level2/torch_stft.cpp

+34
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,41 @@ pnnx.Output output 1 0 out
5050
}
5151
};
5252

53+
class torch_stft_0 : public torch_stft
54+
{
55+
public:
56+
const char* match_pattern_graph() const
57+
{
58+
return R"PNNXIR(7767517
59+
11 10
60+
pnnx.Input input_0 0 1 input
61+
pnnx.Input input_1 0 1 window
62+
prim::Constant op_0 0 1 n_fft value=%n_fft
63+
prim::Constant op_1 0 1 hop_length value=%hop_length
64+
prim::Constant op_2 0 1 win_length value=%win_length
65+
prim::Constant op_3 0 1 normalized value=%normalized
66+
prim::Constant op_4 0 1 onesided value=%onesided
67+
prim::Constant op_5 0 1 return_complex value=%return_complex
68+
prim::Constant op_6 0 1 align_to_window value=%align_to_window
69+
aten::stft op_7 9 1 input n_fft hop_length win_length window normalized onesided return_complex align_to_window out
70+
pnnx.Output output 1 0 out
71+
)PNNXIR";
72+
}
73+
74+
void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
75+
{
76+
torch_stft::write(op, captured_params);
77+
78+
// keep align_to_window param only when enabled
79+
if (captured_params.at("align_to_window").type != 1 || captured_params.at("align_to_window").b == false)
80+
{
81+
op->params.erase("align_to_window");
82+
}
83+
}
84+
};
85+
5386
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_stft, 80)
87+
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_stft_0, 80)
5488

5589
class torch_stft_1 : public GraphRewriterPass
5690
{

tools/pnnx/tests/onnx/test_F_relu.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test():
5959
if not torch.allclose(a0, b0, 1e-4, 1e-4):
6060
return False
6161

62-
if version.parse(torch.__version__) < version.parse('2.7'):
62+
if version.parse(torch.__version__) < version.parse('2.8'):
6363
return True
6464

6565
# export dynamo onnx

tools/pnnx/tests/onnx/test_convnext_tiny.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test():
4343
if not torch.allclose(a, b, 1e-4, 1e-4):
4444
return False
4545

46-
if version.parse(torch.__version__) < version.parse('2.7'):
46+
if version.parse(torch.__version__) < version.parse('2.8'):
4747
return True
4848

4949
# export dynamo onnx

tools/pnnx/tests/onnx/test_mobilenet_v2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test():
3939
if not torch.allclose(a, b, 1e-4, 1e-4):
4040
return False
4141

42-
if version.parse(torch.__version__) < version.parse('2.7'):
42+
if version.parse(torch.__version__) < version.parse('2.8'):
4343
return True
4444

4545
# export dynamo onnx

tools/pnnx/tests/onnx/test_mobilenet_v3_small.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test():
4242
if not torch.allclose(a, b, 1e-4, 1e-4):
4343
return False
4444

45-
if version.parse(torch.__version__) < version.parse('2.7'):
45+
if version.parse(torch.__version__) < version.parse('2.8'):
4646
return True
4747

4848
# export dynamo onnx

tools/pnnx/tests/onnx/test_nn_ReLU.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test():
6161
if not torch.allclose(a0, b0, 1e-4, 1e-4):
6262
return False
6363

64-
if version.parse(torch.__version__) < version.parse('2.7'):
64+
if version.parse(torch.__version__) < version.parse('2.8'):
6565
return True
6666

6767
# export dynamo onnx

tools/pnnx/tests/onnx/test_resnet18.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test():
3939
if not torch.allclose(a, b, 1e-4, 1e-4):
4040
return False
4141

42-
if version.parse(torch.__version__) < version.parse('2.7'):
42+
if version.parse(torch.__version__) < version.parse('2.8'):
4343
return True
4444

4545
# export dynamo onnx

tools/pnnx/tests/onnx/test_shufflenet_v2_x1_0.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test():
3939
if not torch.allclose(a, b, 1e-4, 1e-4):
4040
return False
4141

42-
if version.parse(torch.__version__) < version.parse('2.7'):
42+
if version.parse(torch.__version__) < version.parse('2.8'):
4343
return True
4444

4545
# export dynamo onnx

tools/pnnx/tests/onnx/test_squeezenet1_1.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test():
3939
if not torch.allclose(a, b, 1e-4, 1e-4):
4040
return False
4141

42-
if version.parse(torch.__version__) < version.parse('2.7'):
42+
if version.parse(torch.__version__) < version.parse('2.8'):
4343
return True
4444

4545
# export dynamo onnx

tools/pnnx/tests/onnx/test_swin_t.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test():
4343
if not torch.allclose(a, b, 1e-4, 1e-4):
4444
return False
4545

46-
if version.parse(torch.__version__) < version.parse('2.7'):
46+
if version.parse(torch.__version__) < version.parse('2.8'):
4747
return True
4848

4949
# export dynamo onnx

tools/pnnx/tests/onnx/test_vit_b_32.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test():
4646
if not torch.allclose(a, b, 1e-4, 1e-4):
4747
return False
4848

49-
if version.parse(torch.__version__) < version.parse('2.7'):
49+
if version.parse(torch.__version__) < version.parse('2.8'):
5050
return True
5151

5252
# export dynamo onnx

0 commit comments

Comments
 (0)