Skip to content

Commit a83397a

Browse files
authored
Bring the C++ ONNX importer on par with onnx_importer.py (#3960)
This PR heavily refactors the C++ importer to make it a viable (and faster) alternative to the Python one for those who need to import ONNX models and want to avoid depending on the python ecosystem. - The C++ importer now outputs the same exact mlir as the Python one (tested on `alt_e2eshark` test suite). Achieving perfect output matches required to introduce an associative map iterable according to insertion order (to mimic `Dict` in Python). - The code tries to mirror 1-to-1 the Python counterpart whenever possible/convenient. - Adds support for embedding ONNX external data in the mlir. This functionality is not part of torch-mlir's `onnx_importer.py` but of IREE's `import_onnx`. - Efforts have been made to remove the direct dependency on LLVM support lib. There is however a transitive dependency on such lib through `MLIRCAPIIR` and `TorchMLIRCAPI` (MLIR libraries uniformly depend on `LLVMSupport`).
1 parent 07ad3c1 commit a83397a

File tree

9 files changed

+2291
-738
lines changed

9 files changed

+2291
-738
lines changed

projects/onnx_c_importer/CMakeLists.txt

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,47 @@ include(FetchContent)
44

55
find_package(Protobuf REQUIRED CONFIG)
66

7-
option(ONNX_DISABLE_EXCEPTIONS "For compatibility with LLVM build" ON)
8-
97
FetchContent_Declare(
108
onnx
119
EXCLUDE_FROM_ALL
1210
GIT_REPOSITORY https://github.com/onnx/onnx.git
13-
GIT_TAG v1.15.0
11+
GIT_TAG v1.16.1
1412
GIT_SHALLOW ON
1513
GIT_PROGRESS ON
1614
)
1715
FetchContent_MakeAvailable(onnx)
1816

17+
set(LLVM_REQUIRES_EH ON)
18+
set(LLVM_REQUIRES_RTTI ON)
19+
20+
1921
add_llvm_executable(
2022
torch-mlir-import-onnx
2123
PARTIAL_SOURCES_INTENDED
2224

2325
import-onnx-main.cpp
2426
OnnxImporter.h
2527
OnnxImporter.cpp
28+
SimpleArgParser.hpp
29+
Dict.hpp
30+
Status.hpp
31+
onnx_extras.hpp
2632
)
2733

34+
set_target_properties(torch-mlir-import-onnx PROPERTIES CXX_STANDARD 20)
35+
36+
# Supress compiler warnings from onnx headers
37+
check_cxx_compiler_flag(-Wno-c++98-compat-extra-semi
38+
CXX_SUPPORTS_NO_CXX98_COMPAT_EXTRA_SEMI_FLAG)
39+
if (CXX_SUPPORTS_CXX98_COMPAT_EXTRA_SEMI_FLAG)
40+
target_compile_options(torch-mlir-import-onnx PRIVATE
41+
"-Wno-c++98-compat-extra-semi")
42+
target_compile_options(onnx PRIVATE
43+
"-Wno-c++98-compat-extra-semi")
44+
endif()
45+
2846
target_link_libraries(
2947
torch-mlir-import-onnx
30-
LLVMSupport
3148
MLIRCAPIIR
3249
TorchMLIRCAPI
3350
onnx

projects/onnx_c_importer/Dict.hpp

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
//===------------------------------------------------------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// Also available under a BSD-style license. See LICENSE.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
/// (almost) STL-compatible container that implements an associative map
11+
/// iteratable according to insertion order. Mimicks Python Dict.
12+
/// Rationale: to ease testing of the C++ importer against onnx_importer.py we
13+
/// need to compare text outputs. MLIR values corresponding to tensors might be
14+
/// written in different (compatible) orders due to differences in iteration
15+
/// order between C++ STL unordered_map and Python Dict. Therefore we adopt the
16+
/// insertion order here as well.
17+
18+
#pragma once
19+
20+
#include <unordered_map>
21+
#include <vector>
22+
23+
namespace torch_mlir_onnx {
24+
25+
template <typename _Key, typename _Tp> struct DictIterator {
26+
private:
27+
using key_type = _Key;
28+
using mapped_type = _Tp;
29+
using self = DictIterator<key_type, mapped_type>;
30+
using vector = std::vector<key_type>;
31+
using key_value_map = std::unordered_map<key_type, mapped_type>;
32+
using vector_iterator = typename vector::iterator;
33+
34+
vector_iterator v_it_;
35+
key_value_map *m_ = nullptr;
36+
37+
public:
38+
using iterator_category = std::forward_iterator_tag;
39+
using value_type = std::pair<const key_type, mapped_type>;
40+
using difference_type = std::ptrdiff_t;
41+
using pointer = value_type *;
42+
using reference = value_type &;
43+
44+
DictIterator() = default;
45+
46+
explicit DictIterator(const vector_iterator &it, key_value_map *m) noexcept
47+
: v_it_(it), m_(m) {}
48+
49+
reference operator*() const noexcept { return *m_->find(*v_it_); }
50+
51+
pointer operator->() const noexcept { return m_->find(*v_it_).operator->(); }
52+
53+
self &operator++() noexcept {
54+
++v_it_;
55+
return *this;
56+
}
57+
58+
self operator++(int) noexcept {
59+
self _tmp(*this);
60+
++*this;
61+
return _tmp;
62+
}
63+
64+
friend bool operator==(const self &x, const self &y) noexcept {
65+
return x.v_it_ == y.v_it_ && x.m_ == y.m_;
66+
}
67+
};
68+
69+
template <typename _Key, typename _Tp> class DictConstIterator {
70+
private:
71+
using key_type = _Key;
72+
using mapped_type = _Tp;
73+
using self = DictConstIterator<key_type, mapped_type>;
74+
using vector = std::vector<key_type>;
75+
using key_value_map = std::unordered_map<key_type, mapped_type>;
76+
using vector_const_iterator = typename vector::const_iterator;
77+
78+
vector_const_iterator v_it_;
79+
const key_value_map *m_ = nullptr;
80+
81+
public:
82+
using iterator_category = std::forward_iterator_tag;
83+
using value_type = std::pair<const key_type, mapped_type>;
84+
using difference_type = std::ptrdiff_t;
85+
using pointer = const value_type *;
86+
using reference = const value_type &;
87+
88+
DictConstIterator() = default;
89+
90+
explicit DictConstIterator(const vector_const_iterator &it,
91+
const key_value_map *m) noexcept
92+
: v_it_(it), m_(m) {}
93+
94+
reference operator*() const noexcept { return *m_->find(*v_it_); }
95+
96+
pointer operator->() const noexcept { return m_->find(*v_it_).operator->(); }
97+
98+
self &operator++() noexcept {
99+
++v_it_;
100+
return *this;
101+
}
102+
103+
self operator++(int) noexcept {
104+
self _tmp(*this);
105+
++*this;
106+
return _tmp;
107+
}
108+
109+
friend bool operator==(const self &x, const self &y) noexcept {
110+
return x.v_it_ == y.v_it_ && x.m_ == y.m_;
111+
}
112+
};
113+
114+
template <typename _Key, typename _Tp> class Dict {
115+
116+
private:
117+
using key_value_map = std::unordered_map<_Key, _Tp>;
118+
using key_vector = std::vector<_Key>;
119+
using key_index_map =
120+
std::unordered_map<_Key, typename key_vector::iterator::difference_type>;
121+
122+
key_value_map m_;
123+
key_vector k_;
124+
key_index_map i_;
125+
126+
public:
127+
/// Public typedefs.
128+
using key_type = _Key;
129+
using mapped_type = _Tp;
130+
using value_type = std::pair<const _Key, _Tp>;
131+
using size_type = std::size_t;
132+
using allocator_type = std::allocator<value_type>;
133+
134+
/// Iterator-related typedefs.
135+
using reference = mapped_type &;
136+
using const_reference = const mapped_type &;
137+
using pointer = typename std::allocator_traits<allocator_type>::pointer;
138+
using const_pointer =
139+
typename std::allocator_traits<allocator_type>::const_pointer;
140+
using iterator = DictIterator<key_type, mapped_type>;
141+
using const_iterator = DictConstIterator<key_type, mapped_type>;
142+
143+
/* Constructors, assignment and destructor */
144+
Dict() = default;
145+
Dict(const Dict &) = default;
146+
Dict(Dict &&) = default;
147+
148+
Dict &operator=(const Dict &) = default;
149+
Dict &operator=(Dict &&) = default;
150+
151+
~Dict() = default;
152+
153+
/* Selectors */
154+
const_iterator find(const key_type &key) const {
155+
auto ii = i_.find(key);
156+
if (ii == i_.end())
157+
return end();
158+
return const_iterator{k_.cbegin() + (*ii).second, &m_};
159+
}
160+
size_type size() const { return m_.size(); }
161+
bool empty() const { return m_.empty(); }
162+
reference at(const key_type &key) { return m_.at(key); }
163+
const_reference at(const key_type &key) const { return m_.at(key); }
164+
165+
/* Mutators */
166+
iterator find(const key_type &key) {
167+
auto ii = i_.find(key);
168+
if (ii == i_.end())
169+
return end();
170+
return iterator{k_.begin() + (*ii).second, &m_};
171+
}
172+
std::pair<iterator, bool> insert(const value_type &pair) {
173+
auto found_it = find(pair.first);
174+
if (found_it == end()) {
175+
auto ki = k_.insert(k_.end(), pair.first);
176+
i_.emplace(pair.first, ki - k_.begin());
177+
m_.insert(pair);
178+
return {iterator{ki, &m_}, true};
179+
}
180+
return {found_it, false};
181+
}
182+
std::pair<iterator, bool> insert(value_type &&pair) {
183+
auto found_it = find(pair.first);
184+
if (found_it == end()) {
185+
auto ki = k_.insert(k_.end(), pair.first);
186+
i_.emplace(pair.first, ki - k_.begin());
187+
m_.insert(std::move(pair));
188+
return {iterator{ki, &m_}, true};
189+
}
190+
return {found_it, false};
191+
}
192+
193+
template <typename... _Args>
194+
std::pair<iterator, bool> emplace(_Args &&...args) {
195+
return insert(value_type(std::forward<_Args>(args)...));
196+
}
197+
reference operator[](const key_type &key) {
198+
auto ins = emplace(key, mapped_type());
199+
return (*ins.first).second;
200+
}
201+
202+
/* Iterators */
203+
iterator begin() { return iterator(k_.begin(), &m_); }
204+
const_iterator begin() const { return const_iterator(k_.cbegin(), &m_); }
205+
const_iterator cbegin() const { return const_iterator(k_.cbegin(), &m_); }
206+
iterator end() { return iterator(k_.end(), &m_); }
207+
const_iterator end() const { return const_iterator(k_.cend(), &m_); }
208+
const_iterator cend() const { return const_iterator(k_.cend(), &m_); }
209+
};
210+
211+
} // namespace torch_mlir_onnx

0 commit comments

Comments
 (0)