Skip to content

Commit 39592f1

Browse files
feat: support Substrait 0.3.0 (#14)
Also changed: anchor 0 is now considered legal. Co-authored-by: Jacques Nadeau <[email protected]>
1 parent aa8755f commit 39592f1

File tree

26 files changed

+620
-183
lines changed

26 files changed

+620
-183
lines changed

rs/src/input/traits.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,30 @@ impl<T: ProtoPrimitive> InputNode for T {
127127
}
128128
}
129129

130+
impl InputNode for () {
131+
fn type_to_node() -> tree::Node {
132+
tree::NodeType::ProtoMessage("google.protobuf.Empty").into()
133+
}
134+
135+
fn data_to_node(&self) -> tree::Node {
136+
tree::NodeType::ProtoMessage("google.protobuf.Empty").into()
137+
}
138+
139+
fn oneof_variant(&self) -> Option<&'static str> {
140+
None
141+
}
142+
143+
fn parse_unknown(&self, _context: &mut context::Context<'_>) -> bool {
144+
false
145+
}
146+
}
147+
148+
impl ProtoMessage for () {
149+
fn proto_message_type() -> &'static str {
150+
"google.protobuf.Empty"
151+
}
152+
}
153+
130154
#[cfg(test)]
131155
mod tests {
132156
use super::*;

rs/src/output/diagnostic.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ pub enum Classification {
171171
#[strum(props(Description = "illegal glob"))]
172172
IllegalGlob = 5,
173173

174+
#[strum(props(Description = "deprecation"))]
175+
Deprecation = 6,
176+
174177
#[strum(props(HiddenDescription = "experimental"))]
175178
Experimental = 999,
176179

@@ -231,6 +234,9 @@ pub enum Classification {
231234
#[strum(props(Description = "failed to resolve type variation name"))]
232235
LinkMissingTypeVariationName = 3004,
233236

237+
#[strum(props(HiddenDescription = "use of anchor zero"))]
238+
LinkAnchorZero = 3005,
239+
234240
// Type-related diagnostics (group 4).
235241
#[strum(props(HiddenDescription = "type-related diagnostics"))]
236242
Type = 4000,

rs/src/parse/expressions/conditionals.rs

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ pub fn parse_if_then(
5454

5555
// Save to the "arguments" of the function we'll use to describe this
5656
// expression.
57-
args.push(condition);
58-
args.push(value);
57+
args.push(condition.into());
58+
args.push(value.into());
5959

6060
Ok(())
6161
});
@@ -76,14 +76,14 @@ pub fn parse_if_then(
7676

7777
// Save to the "arguments" of the function we'll use to describe this
7878
// expression.
79-
args.push(value);
79+
args.push(value.into());
8080
} else {
8181
// Allow missing else, making the type nullable.
8282
comment!(y, "Otherwise, yield null.");
8383
return_type = return_type.make_nullable();
8484

8585
// Yield null for the else clause.
86-
args.push(expressions::Expression::new_null(return_type.clone()));
86+
args.push(expressions::Expression::new_null(return_type.clone()).into());
8787
}
8888

8989
// Describe node.
@@ -110,7 +110,7 @@ pub fn parse_switch(
110110
// Parse value to match.
111111
let (n, e) = proto_boxed_required_field!(x, y, r#match, expressions::parse_expression);
112112
let mut match_type = n.data_type();
113-
args.push(e.unwrap_or_default());
113+
args.push(e.unwrap_or_default().into());
114114

115115
// Handle branches.
116116
proto_required_repeated_field!(x, y, ifs, |x, y| {
@@ -143,8 +143,8 @@ pub fn parse_switch(
143143

144144
// Save to the "arguments" of the function we'll use to describe this
145145
// expression.
146-
args.push(match_value.into());
147-
args.push(value);
146+
args.push(expressions::Expression::from(match_value).into());
147+
args.push(value.into());
148148

149149
Ok(())
150150
});
@@ -165,14 +165,14 @@ pub fn parse_switch(
165165

166166
// Save to the "arguments" of the function we'll use to describe this
167167
// expression.
168-
args.push(value);
168+
args.push(value.into());
169169
} else {
170170
// Allow missing else, making the type nullable.
171171
comment!(y, "Otherwise, yield null.");
172172
return_type = return_type.make_nullable();
173173

174174
// Yield null for the else clause.
175-
args.push(expressions::Expression::new_null(return_type.clone()));
175+
args.push(expressions::Expression::new_null(return_type.clone()).into());
176176
}
177177

178178
// Describe node.
@@ -200,13 +200,13 @@ pub fn parse_singular_or_list(
200200
// Parse value to match.
201201
let (n, e) = proto_boxed_required_field!(x, y, value, expressions::parse_expression);
202202
let match_type = n.data_type();
203-
args.push(e.unwrap_or_default());
203+
args.push(e.unwrap_or_default().into());
204204

205205
// Handle allowed values.
206206
proto_required_repeated_field!(x, y, options, |x, y| {
207207
let expression = expressions::parse_expression(x, y)?;
208208
let value_type = y.data_type();
209-
args.push(expression);
209+
args.push(expression.into());
210210

211211
// Check that the type is the same as the value.
212212
types::assert_equal(
@@ -249,17 +249,19 @@ pub fn parse_multi_or_list(
249249
// Parse value to match.
250250
let (ns, es) = proto_required_repeated_field!(x, y, value, expressions::parse_expression);
251251
let match_types = ns.iter().map(|x| x.data_type()).collect::<Vec<_>>();
252-
args.push(expressions::Expression::Tuple(
253-
es.into_iter().map(|x| x.unwrap_or_default()).collect(),
254-
));
252+
args.push(
253+
expressions::Expression::Tuple(es.into_iter().map(|x| x.unwrap_or_default()).collect())
254+
.into(),
255+
);
255256

256257
// Handle allowed values.
257258
proto_required_repeated_field!(x, y, options, |x, y| {
258259
let (ns, es) = proto_required_repeated_field!(x, y, fields, expressions::parse_expression);
259260
let value_types = ns.iter().map(|x| x.data_type()).collect::<Vec<_>>();
260-
args.push(expressions::Expression::Tuple(
261-
es.into_iter().map(|x| x.unwrap_or_default()).collect(),
262-
));
261+
args.push(
262+
expressions::Expression::Tuple(es.into_iter().map(|x| x.unwrap_or_default()).collect())
263+
.into(),
264+
);
263265

264266
// Check that the type is the same as the value.
265267
if match_types.len() != value_types.len() {

rs/src/parse/expressions/functions.rs

Lines changed: 159 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,119 @@ use crate::parse::expressions;
1212
use crate::parse::extensions;
1313
use crate::parse::sorts;
1414
use crate::parse::types;
15+
use crate::util;
16+
use crate::util::string::Describe;
1517
use std::sync::Arc;
1618

19+
/// A function argument; either a value, a type, or an enum option.
20+
#[derive(Clone, Debug, PartialEq)]
21+
pub enum FunctionArgument {
22+
/// Used for value arguments or normal expressions.
23+
Value(expressions::Expression),
24+
25+
/// Used for type arguments.
26+
Type(Arc<data_type::DataType>),
27+
28+
/// Used for enum option arguments.
29+
Enum(Option<String>),
30+
}
31+
32+
impl Default for FunctionArgument {
33+
fn default() -> Self {
34+
FunctionArgument::Value(expressions::Expression::default())
35+
}
36+
}
37+
38+
impl From<expressions::Expression> for FunctionArgument {
39+
fn from(expr: expressions::Expression) -> Self {
40+
FunctionArgument::Value(expr)
41+
}
42+
}
43+
44+
impl Describe for FunctionArgument {
45+
fn describe(
46+
&self,
47+
f: &mut std::fmt::Formatter<'_>,
48+
limit: util::string::Limit,
49+
) -> std::fmt::Result {
50+
match self {
51+
FunctionArgument::Value(e) => e.describe(f, limit),
52+
FunctionArgument::Type(e) => e.describe(f, limit),
53+
FunctionArgument::Enum(Some(x)) => util::string::describe_identifier(f, x, limit),
54+
FunctionArgument::Enum(None) => write!(f, "-"),
55+
}
56+
}
57+
}
58+
59+
impl std::fmt::Display for FunctionArgument {
60+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61+
self.display().fmt(f)
62+
}
63+
}
64+
65+
/// Parse an enum option argument type.
66+
fn parse_enum_type(
67+
x: &substrait::function_argument::r#enum::EnumKind,
68+
_y: &mut context::Context,
69+
) -> diagnostic::Result<Option<String>> {
70+
match x {
71+
substrait::function_argument::r#enum::EnumKind::Specified(x) => Ok(Some(x.clone())),
72+
substrait::function_argument::r#enum::EnumKind::Unspecified(_) => Ok(None),
73+
}
74+
}
75+
76+
/// Parse an enum option argument.
77+
fn parse_enum(
78+
x: &substrait::function_argument::Enum,
79+
y: &mut context::Context,
80+
) -> diagnostic::Result<Option<String>> {
81+
Ok(proto_required_field!(x, y, enum_kind, parse_enum_type)
82+
.1
83+
.flatten())
84+
}
85+
86+
/// Parse a 0.3.0+ function argument type.
87+
fn parse_function_argument_type(
88+
x: &substrait::function_argument::ArgType,
89+
y: &mut context::Context,
90+
) -> diagnostic::Result<FunctionArgument> {
91+
match x {
92+
substrait::function_argument::ArgType::Enum(x) => {
93+
Ok(FunctionArgument::Enum(parse_enum(x, y)?))
94+
}
95+
substrait::function_argument::ArgType::Type(x) => {
96+
types::parse_type(x, y)?;
97+
Ok(FunctionArgument::Type(y.data_type()))
98+
}
99+
substrait::function_argument::ArgType::Value(x) => Ok(FunctionArgument::Value(
100+
expressions::parse_expression(x, y)?,
101+
)),
102+
}
103+
}
104+
105+
/// Parse a 0.3.0+ function argument.
106+
fn parse_function_argument(
107+
x: &substrait::FunctionArgument,
108+
y: &mut context::Context,
109+
) -> diagnostic::Result<FunctionArgument> {
110+
Ok(
111+
proto_required_field!(x, y, arg_type, parse_function_argument_type)
112+
.1
113+
.unwrap_or_default(),
114+
)
115+
}
116+
117+
/// Parse a pre-0.3.0 function argument expression.
118+
fn parse_legacy_function_argument(
119+
x: &substrait::Expression,
120+
y: &mut context::Context,
121+
) -> diagnostic::Result<FunctionArgument> {
122+
expressions::parse_legacy_function_argument(x, y).map(|x| match x {
123+
expressions::ExpressionOrEnum::Value(x) => FunctionArgument::Value(x),
124+
expressions::ExpressionOrEnum::Enum(x) => FunctionArgument::Enum(x),
125+
})
126+
}
127+
17128
/// Matches a function call with its YAML definition, yielding its return type.
18129
/// Yields an unresolved type if resolution fails.
19130
pub fn check_function(
@@ -41,7 +152,8 @@ pub fn check_function(
41152
fn parse_function(
42153
y: &mut context::Context,
43154
function: Option<Arc<extension::Reference<extension::Function>>>,
44-
arguments: (Vec<Arc<tree::Node>>, Vec<Option<expressions::Expression>>),
155+
arguments: (Vec<Arc<tree::Node>>, Vec<Option<FunctionArgument>>),
156+
legacy_arguments: (Vec<Arc<tree::Node>>, Vec<Option<FunctionArgument>>),
45157
return_type: Arc<data_type::DataType>,
46158
) -> (Arc<data_type::DataType>, expressions::Expression) {
47159
// Determine the name of the function.
@@ -50,6 +162,36 @@ fn parse_function(
50162
.map(|x| x.name.to_string())
51163
.unwrap_or_else(|| String::from("?"));
52164

165+
// Reconcile v3.0.0+ vs older function argument syntax.
166+
let arguments = if legacy_arguments.1.is_empty() {
167+
arguments
168+
} else if arguments.1.is_empty() {
169+
diagnostic!(
170+
y,
171+
Warning,
172+
Deprecation,
173+
"the args field for specifying function arguments was deprecated Substrait 0.3.0 (#161)"
174+
);
175+
legacy_arguments
176+
} else {
177+
if arguments != legacy_arguments {
178+
diagnostic!(
179+
y,
180+
Error,
181+
IllegalValue,
182+
"mismatch between v0.3+ and legacy function argument specification"
183+
);
184+
comment!(
185+
y,
186+
"If both the v0.3+ and legacy syntax is used to specify function \
187+
arguments, please make sure both map to the same arguments. If \
188+
the argument pack is not representable using the legacy syntax, \
189+
do not use it."
190+
);
191+
}
192+
arguments
193+
};
194+
53195
// Unpack the arguments into the function's enum options and regular
54196
// arguments.
55197
let mut opt_values = vec![];
@@ -61,7 +203,7 @@ fn parse_function(
61203
.into_iter()
62204
.zip(arguments.1.into_iter().map(|x| x.unwrap_or_default()))
63205
{
64-
if let expressions::Expression::EnumVariant(x) = &expr {
206+
if let FunctionArgument::Enum(x) = &expr {
65207
if opt_exprs.is_empty() && !arg_exprs.is_empty() {
66208
diagnostic!(
67209
y,
@@ -122,13 +264,16 @@ pub fn parse_scalar_function(
122264
extensions::simple::parse_function_reference
123265
)
124266
.1;
125-
let arguments = proto_repeated_field!(x, y, args, expressions::parse_function_argument);
267+
#[allow(deprecated)]
268+
let legacy_arguments = proto_repeated_field!(x, y, args, parse_legacy_function_argument);
269+
let arguments = proto_repeated_field!(x, y, arguments, parse_function_argument);
126270
let return_type = proto_required_field!(x, y, output_type, types::parse_type)
127271
.0
128272
.data_type();
129273

130274
// Check function information.
131-
let (return_type, expression) = parse_function(y, function, arguments, return_type);
275+
let (return_type, expression) =
276+
parse_function(y, function, arguments, legacy_arguments, return_type);
132277

133278
// Describe node.
134279
y.set_data_type(return_type);
@@ -168,13 +313,16 @@ pub fn parse_window_function(
168313
extensions::simple::parse_function_reference
169314
)
170315
.1;
171-
let arguments = proto_repeated_field!(x, y, args, expressions::parse_function_argument);
316+
#[allow(deprecated)]
317+
let legacy_arguments = proto_repeated_field!(x, y, args, parse_legacy_function_argument);
318+
let arguments = proto_repeated_field!(x, y, arguments, parse_function_argument);
172319
let return_type = proto_required_field!(x, y, output_type, types::parse_type)
173320
.0
174321
.data_type();
175322

176323
// Check function information.
177-
let (return_type, expression) = parse_function(y, function, arguments, return_type);
324+
let (return_type, expression) =
325+
parse_function(y, function, arguments, legacy_arguments, return_type);
178326

179327
// Parse modifiers.
180328
proto_repeated_field!(x, y, partitions, expressions::parse_expression);
@@ -216,13 +364,16 @@ pub fn parse_aggregate_function(
216364
extensions::simple::parse_function_reference
217365
)
218366
.1;
219-
let arguments = proto_repeated_field!(x, y, args, expressions::parse_function_argument);
367+
#[allow(deprecated)]
368+
let legacy_arguments = proto_repeated_field!(x, y, args, parse_legacy_function_argument);
369+
let arguments = proto_repeated_field!(x, y, arguments, parse_function_argument);
220370
let return_type = proto_required_field!(x, y, output_type, types::parse_type)
221371
.0
222372
.data_type();
223373

224374
// Check function information.
225-
let (return_type, expression) = parse_function(y, function, arguments, return_type);
375+
let (return_type, expression) =
376+
parse_function(y, function, arguments, legacy_arguments, return_type);
226377

227378
// Parse modifiers.
228379
proto_repeated_field!(x, y, sorts, sorts::parse_sort_field);

0 commit comments

Comments
 (0)