Skip to content

Commit 9f71389

Browse files
bo3zvloncar
authored andcommitted
Quartus Extensions
1 parent 5376181 commit 9f71389

File tree

2 files changed

+49
-49
lines changed

2 files changed

+49
-49
lines changed

hls4ml/backends/quartus/quartus_backend.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -43,29 +43,32 @@ def _register_flows(self):
4343
]
4444
quantization_flow = register_flow('quantization', quantization_passes, requires=[init_flow], backend=self.name)
4545

46+
optimization_passes = []
47+
optimization_flow = register_flow('optimize', optimization_passes, requires=[init_flow], backend=self.name)
48+
4649
templates = self._get_layer_templates()
47-
template_flow = register_flow('apply_templates', templates, requires=[init_flow], backend=self.name)
50+
template_flow = register_flow('apply_templates', self._get_layer_templates, requires=[init_flow], backend=self.name)
4851

4952
writer_passes = [
5053
'make_stamp',
5154
'quartus:write_hls'
5255
]
53-
writer_flow_requirements = ['optimize', quartus_types_flow, template_flow]
54-
self._writer_flow = register_flow('write', writer_passes, requires=writer_flow_requirements, backend=self.name)
56+
57+
self._writer_flow = register_flow('write', writer_passes, requires=['quartus:ip'], backend=self.name)
5558

5659
all_passes = get_backend_passes(self.name)
5760

5861
extras = [
5962
# Ideally this should be empty
60-
opt_pass for opt_pass in all_passes if opt_pass not in initializers + quartus_types + templates + writer_passes
63+
opt_pass for opt_pass in all_passes if opt_pass not in initializers + streaming_passes + quartus_types + quantization_passes + templates + optimization_passes + writer_passes
6164
]
6265

6366
if len(extras) > 0:
6467
extras_flow = register_flow('extras', extras, requires=[init_flow], backend=self.name)
6568
else:
6669
extras_flow = None
6770

68-
ip_flow_requirements = ['optimize', init_flow, streaming_flow, quantization_flow, quartus_types_flow, extras_flow, template_flow]
71+
ip_flow_requirements = ['optimize', init_flow, streaming_flow, quantization_flow, optimization_flow, quartus_types_flow, extras_flow, template_flow]
6972
ip_flow_requirements = list(filter(None, ip_flow_requirements))
7073

7174
self._default_flow = register_flow('ip', None, requires=ip_flow_requirements, backend=self.name)

test/pytest/test_extensions.py

+41-44
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1+
import pytest
12
import hls4ml
2-
import tensorflow as tf
33
import numpy as np
4-
import pytest
4+
import tensorflow as tf
55
from pathlib import Path
66

77
test_root_path = Path(__file__).parent
88

99
# Keras implementation of a custom layer
10-
1110
class KReverse(tf.keras.layers.Layer):
1211
''' Keras implementation of a hypothetical custom layer '''
1312
def __init__(self):
@@ -16,8 +15,7 @@ def __init__(self):
1615
def call(self, inputs):
1716
return tf.reverse(inputs, axis=[-1])
1817

19-
# hls4ml implementations
20-
18+
# hls4ml layer implementation
2119
class HReverse(hls4ml.model.layers.Layer):
2220
''' hls4ml implementation of a hypothetical custom layer '''
2321

@@ -27,8 +25,35 @@ def initialize(self):
2725
dims = inp.dim_names
2826
self.add_output_variable(shape, dims)
2927

28+
# hls4ml optimizer to remove duplicate optimizer
29+
class RemoveDuplicateReverse(hls4ml.model.optimizer.OptimizerPass):
30+
'''OptimizerPass to remove consecutive HReverse layers.'''
31+
32+
def match(self, node):
33+
return isinstance(node, HReverse) and \
34+
isinstance(node.get_input_node(), HReverse)
35+
36+
def transform(self, model, node):
37+
first = node.get_input_node()
38+
second = node
3039

31-
# Templates
40+
model.remove_node(first, rewire=True)
41+
model.remove_node(second, rewire=True)
42+
return True
43+
44+
# Parser for converter
45+
def parse_reverse_layer(keras_layer, input_names, input_shapes, data_reader, config):
46+
layer = {}
47+
layer['class_name'] = 'HReverse'
48+
layer['name'] = keras_layer['config']['name']
49+
layer['n_in'] = input_shapes[0][1]
50+
51+
if input_names is not None:
52+
layer['inputs'] = input_names
53+
54+
return layer, [shape for shape in input_shapes[0]]
55+
56+
# HLS Templates - No specific pragmas used; generic enough for both Intel and Vivado
3257

3358
rev_config_template = """struct config{index} : nnet::reverse_config {{
3459
static const unsigned n_in = {n_in};
@@ -55,8 +80,6 @@ def format(self, node):
5580
params = self._default_function_params(node)
5681
return self.template.format(**params)
5782

58-
59-
# HLS implementation
6083
rev_hls = \
6184
"""#ifndef NNET_REVERSE_H_
6285
#define NNET_REVERSE_H_
@@ -74,8 +97,6 @@ def format(self, node):
7497
data_T input[CONFIG_T::n_in],
7598
data_T reversed[CONFIG_T::n_in]
7699
) {
77-
#pragma HLS PIPELINE
78-
79100
for (int i = 0; i < CONFIG_T::n_in; i++) {
80101
reversed[CONFIG_T::n_in - 1 - i] = input[i];
81102
}
@@ -86,43 +107,19 @@ def format(self, node):
86107
#endif
87108
"""
88109

89-
class RemoveDuplicateReverse(hls4ml.model.optimizer.OptimizerPass):
90-
'''OptimizerPass to remove consecutive HReverse layers.'''
91-
92-
def match(self, node):
93-
return isinstance(node, HReverse) and \
94-
isinstance(node.get_input_node(), HReverse)
95-
96-
def transform(self, model, node):
97-
first = node.get_input_node()
98-
second = node
99-
100-
model.remove_node(first, rewire=True)
101-
model.remove_node(second, rewire=True)
102-
return True
103-
104-
# Parser for converter
105-
def parse_reverse_layer(keras_layer, input_names, input_shapes, data_reader, config):
106-
layer = {}
107-
layer['class_name'] = 'HReverse'
108-
layer['name'] = keras_layer['config']['name']
109-
layer['n_in'] = input_shapes[0][1]
110-
111-
if input_names is not None:
112-
layer['inputs'] = input_names
113-
114-
return layer, [shape for shape in input_shapes[0]]
115-
116-
def test_extensions(tmp_path):
110+
@pytest.fixture(scope='session', autouse=True)
111+
def regsister_custom_layer():
117112
# Register the converter for custom Keras layer
118113
hls4ml.converters.register_keras_layer_handler('KReverse', parse_reverse_layer)
119114

120115
# Register the hls4ml's IR layer
121116
hls4ml.model.layers.register_layer('HReverse', HReverse)
122117

118+
@pytest.mark.parametrize('backend_id', ['Vivado', 'Quartus'])
119+
def test_extensions(tmp_path, backend_id):
123120
# Register the optimization passes (if any)
124-
backend = hls4ml.backends.get_backend('Vivado')
125-
backend.register_pass('remove_duplicate_reverse', RemoveDuplicateReverse, flow='vivado:optimize')
121+
backend = hls4ml.backends.get_backend(backend_id)
122+
backend.register_pass('remove_duplicate_reverse', RemoveDuplicateReverse, flow=f'{backend_id.lower()}:optimize')
126123

127124
# Register template passes for the given backend
128125
backend.register_template(HReverseConfigTemplate)
@@ -148,15 +145,15 @@ def test_extensions(tmp_path):
148145

149146
hmodel = hls4ml.converters.convert_from_keras_model(
150147
kmodel,
151-
output_dir=str(test_root_path / 'hls4mlprj_extensions'),
152-
backend='Vivado',
148+
output_dir=str(test_root_path / f'hls4mlprj_extensions_{backend_id}'),
149+
backend=backend_id,
153150
io_type='io_parallel',
154-
hls_config={ 'Model': { 'Precision': 'ap_int<4>', 'ReuseFactor': 1} })
151+
hls_config={ 'Model': { 'Precision': 'ap_int<6>', 'ReuseFactor': 1} })
155152

156153
hmodel.compile()
157154
hres = hmodel.predict(x.astype('float32'))
158155

159156
# Check if the optimizer pass was applied
160-
assert 'vivado:remove_duplicate_reverse' in hmodel._applied_flows[0]['vivado:optimize']
157+
assert f'{backend_id.lower()}:remove_duplicate_reverse' in hmodel._applied_flows[0][f'{backend_id.lower()}:optimize']
161158

162159
np.testing.assert_array_equal(kres, hres)

0 commit comments

Comments
 (0)