Skip to content

Commit b3fa0a3

Browse files
authored
Add substrait tests for comparison and boolean functions (#629)
* Add substrait tests for comparison and boolean functions * fix failing tests
1 parent e19ec94 commit b3fa0a3

31 files changed

+653
-21
lines changed

src/FlowtideDotNet.Core/ColumnStore/DataValues/BoolValue.cs

+5
Original file line numberDiff line numberDiff line change
@@ -57,5 +57,10 @@ public void CopyToContainer(DataValueContainer container)
5757
container._type = ArrowTypeId.Boolean;
5858
container._boolValue = this;
5959
}
60+
61+
public override string ToString()
62+
{
63+
return value ? "true" : "false";
64+
}
6065
}
6166
}

src/FlowtideDotNet.Core/Compute/Columnar/Functions/BuiltInComparisonFunctions.cs

+200
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,15 @@ public static void AddComparisonFunctions(IFunctionsRegister functionsRegister)
4040
functionsRegister.RegisterScalarMethod(FunctionsComparison.Uri, FunctionsComparison.Equal, typeof(BuiltInComparisonFunctions), nameof(EqualImplementation));
4141
functionsRegister.RegisterScalarMethod(FunctionsComparison.Uri, FunctionsComparison.NotEqual, typeof(BuiltInComparisonFunctions), nameof(NotEqualImplementation));
4242
functionsRegister.RegisterScalarMethod(FunctionsComparison.Uri, FunctionsComparison.GreaterThan, typeof(BuiltInComparisonFunctions), nameof(GreaterThanImplementation));
43+
functionsRegister.RegisterScalarMethod(FunctionsComparison.Uri, FunctionsComparison.GreaterThanOrEqual, typeof(BuiltInComparisonFunctions), nameof(GreaterThanOrEqualImplementation));
44+
functionsRegister.RegisterScalarMethod(FunctionsComparison.Uri, FunctionsComparison.LessThan, typeof(BuiltInComparisonFunctions), nameof(LessThanImplementation));
45+
functionsRegister.RegisterScalarMethod(FunctionsComparison.Uri, FunctionsComparison.LessThanOrEqual, typeof(BuiltInComparisonFunctions), nameof(LessThanOrEqualImplementation));
4346
functionsRegister.RegisterScalarMethod(FunctionsComparison.Uri, FunctionsComparison.IsNull, typeof(BuiltInComparisonFunctions), nameof(IsNullImplementation));
47+
functionsRegister.RegisterScalarMethod(FunctionsComparison.Uri, FunctionsComparison.IsNotNull, typeof(BuiltInComparisonFunctions), nameof(IsNotNullImplementation));
4448
functionsRegister.RegisterScalarMethod(FunctionsComparison.Uri, FunctionsComparison.Between, typeof(BuiltInComparisonFunctions), nameof(BetweenImplementation));
49+
functionsRegister.RegisterScalarMethod(FunctionsComparison.Uri, FunctionsComparison.IsFinite, typeof(BuiltInComparisonFunctions), nameof(IsFiniteImplementation));
50+
functionsRegister.RegisterScalarMethod(FunctionsComparison.Uri, FunctionsComparison.isInfinite, typeof(BuiltInComparisonFunctions), nameof(IsInfiniteImplementation));
51+
functionsRegister.RegisterScalarMethod(FunctionsComparison.Uri, FunctionsComparison.IsNan, typeof(BuiltInComparisonFunctions), nameof(IsNanImplementation));
4552

4653
functionsRegister.RegisterColumnScalarFunction(FunctionsComparison.Uri, FunctionsComparison.Coalesce,
4754
(scalarFunction, parametersInfo, visitor) =>
@@ -143,6 +150,78 @@ private static IDataValue GreaterThanImplementation<T1, T2>(in T1 x, in T2 y, in
143150
}
144151
}
145152

153+
private static IDataValue GreaterThanOrEqualImplementation<T1, T2>(in T1 x, in T2 y, in DataValueContainer result)
154+
where T1 : IDataValue
155+
where T2 : IDataValue
156+
{
157+
// If either is null, return null
158+
if (x.IsNull || y.IsNull)
159+
{
160+
result._type = ArrowTypeId.Null;
161+
return result;
162+
}
163+
else if (DataValueComparer.CompareTo(x, y) >= 0)
164+
{
165+
result._type = ArrowTypeId.Boolean;
166+
result._boolValue = new BoolValue(true);
167+
return result;
168+
}
169+
else
170+
{
171+
result._type = ArrowTypeId.Boolean;
172+
result._boolValue = new BoolValue(false);
173+
return result;
174+
}
175+
}
176+
177+
private static IDataValue LessThanImplementation<T1, T2>(in T1 x, in T2 y, in DataValueContainer result)
178+
where T1 : IDataValue
179+
where T2 : IDataValue
180+
{
181+
// If either is null, return null
182+
if (x.IsNull || y.IsNull)
183+
{
184+
result._type = ArrowTypeId.Null;
185+
return result;
186+
}
187+
else if (DataValueComparer.CompareTo(x, y) < 0)
188+
{
189+
result._type = ArrowTypeId.Boolean;
190+
result._boolValue = new BoolValue(true);
191+
return result;
192+
}
193+
else
194+
{
195+
result._type = ArrowTypeId.Boolean;
196+
result._boolValue = new BoolValue(false);
197+
return result;
198+
}
199+
}
200+
201+
private static IDataValue LessThanOrEqualImplementation<T1, T2>(in T1 x, in T2 y, in DataValueContainer result)
202+
where T1 : IDataValue
203+
where T2 : IDataValue
204+
{
205+
// If either is null, return null
206+
if (x.IsNull || y.IsNull)
207+
{
208+
result._type = ArrowTypeId.Null;
209+
return result;
210+
}
211+
else if (DataValueComparer.CompareTo(x, y) <= 0)
212+
{
213+
result._type = ArrowTypeId.Boolean;
214+
result._boolValue = new BoolValue(true);
215+
return result;
216+
}
217+
else
218+
{
219+
result._type = ArrowTypeId.Boolean;
220+
result._boolValue = new BoolValue(false);
221+
return result;
222+
}
223+
}
224+
146225
private static IDataValue IsNullImplementation<T>(in T x, in DataValueContainer result)
147226
where T : IDataValue
148227
{
@@ -151,11 +230,24 @@ private static IDataValue IsNullImplementation<T>(in T x, in DataValueContainer
151230
return result;
152231
}
153232

233+
private static IDataValue IsNotNullImplementation<T>(in T x, in DataValueContainer result)
234+
where T : IDataValue
235+
{
236+
result._type = ArrowTypeId.Boolean;
237+
result._boolValue = new BoolValue(!x.IsNull);
238+
return result;
239+
}
240+
154241
private static IDataValue BetweenImplementation<T1, T2, T3>(in T1 expr, in T2 low, in T3 high, in DataValueContainer result)
155242
where T1 : IDataValue
156243
where T2 : IDataValue
157244
where T3 : IDataValue
158245
{
246+
if (expr.IsNull || low.IsNull || high.IsNull)
247+
{
248+
result._type = ArrowTypeId.Null;
249+
return result;
250+
}
159251

160252
if (DataValueComparer.CompareTo(expr, low) >= 0 && DataValueComparer.CompareTo(expr, high) <= 0)
161253
{
@@ -170,5 +262,113 @@ private static IDataValue BetweenImplementation<T1, T2, T3>(in T1 expr, in T2 lo
170262
return result;
171263
}
172264
}
265+
266+
private static IDataValue IsFiniteImplementation<T>(in T x, in DataValueContainer result)
267+
where T : IDataValue
268+
{
269+
if (x.Type == ArrowTypeId.Double)
270+
{
271+
var val = x.AsDouble;
272+
if (val == double.PositiveInfinity || val == double.NegativeInfinity || double.IsNaN(val))
273+
{
274+
result._type = ArrowTypeId.Boolean;
275+
result._boolValue = new BoolValue(false);
276+
return result;
277+
}
278+
else
279+
{
280+
result._type = ArrowTypeId.Boolean;
281+
result._boolValue = new BoolValue(true);
282+
return result;
283+
}
284+
}
285+
else if (x.Type == ArrowTypeId.Int64)
286+
{
287+
result._type = ArrowTypeId.Boolean;
288+
result._boolValue = new BoolValue(true);
289+
return result;
290+
}
291+
else if (x.Type == ArrowTypeId.Decimal128)
292+
{
293+
result._type = ArrowTypeId.Boolean;
294+
result._boolValue = new BoolValue(true);
295+
return result;
296+
}
297+
298+
result._type = ArrowTypeId.Null;
299+
return result;
300+
}
301+
302+
private static IDataValue IsInfiniteImplementation<T>(in T x, in DataValueContainer result)
303+
where T : IDataValue
304+
{
305+
if (x.Type == ArrowTypeId.Double)
306+
{
307+
var val = x.AsDouble;
308+
if (val == double.PositiveInfinity || val == double.NegativeInfinity)
309+
{
310+
result._type = ArrowTypeId.Boolean;
311+
result._boolValue = new BoolValue(true);
312+
return result;
313+
}
314+
else
315+
{
316+
result._type = ArrowTypeId.Boolean;
317+
result._boolValue = new BoolValue(false);
318+
return result;
319+
}
320+
}
321+
else if (x.Type == ArrowTypeId.Int64)
322+
{
323+
result._type = ArrowTypeId.Boolean;
324+
result._boolValue = new BoolValue(false);
325+
return result;
326+
}
327+
else if (x.Type == ArrowTypeId.Decimal128)
328+
{
329+
result._type = ArrowTypeId.Boolean;
330+
result._boolValue = new BoolValue(false);
331+
return result;
332+
}
333+
334+
result._type = ArrowTypeId.Null;
335+
return result;
336+
}
337+
338+
private static IDataValue IsNanImplementation<T>(in T x, in DataValueContainer result)
339+
where T : IDataValue
340+
{
341+
if (x.Type == ArrowTypeId.Double)
342+
{
343+
var val = x.AsDouble;
344+
if (double.IsNaN(val))
345+
{
346+
result._type = ArrowTypeId.Boolean;
347+
result._boolValue = new BoolValue(true);
348+
return result;
349+
}
350+
else
351+
{
352+
result._type = ArrowTypeId.Boolean;
353+
result._boolValue = new BoolValue(false);
354+
return result;
355+
}
356+
}
357+
else if (x.Type == ArrowTypeId.Int64)
358+
{
359+
result._type = ArrowTypeId.Boolean;
360+
result._boolValue = new BoolValue(false);
361+
return result;
362+
}
363+
else if (x.Type == ArrowTypeId.Decimal128)
364+
{
365+
result._type = ArrowTypeId.Boolean;
366+
result._boolValue = new BoolValue(false);
367+
return result;
368+
}
369+
370+
result._type = ArrowTypeId.Null;
371+
return result;
372+
}
173373
}
174374
}

tests/FlowtideDotNet.AcceptanceTests/ComparisonFunctionTests.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ public async Task IsInfiniteStringFalse()
7272
GenerateData();
7373
await StartStream("INSERT INTO output SELECT is_infinite(firstName) FROM users");
7474
await WaitForUpdate();
75-
AssertCurrentDataEqual(Users.Select(x => new { val = false }));
75+
AssertCurrentDataEqual(Users.Select(x => new { val = default(string) }));
7676
}
7777

7878
[Fact]
@@ -140,7 +140,7 @@ INSERT INTO output
140140
is_nan(0/0), is_nan(1/2), is_nan(userkey), is_nan(nullablestring)
141141
FROM users u");
142142
await WaitForUpdate();
143-
AssertCurrentDataEqual(Users.Select(x => new { nan = true, not_nan = false, userkey = false, nullString = x.NullableString == null ? default(bool?) : true}));
143+
AssertCurrentDataEqual(Users.Select(x => new { nan = true, not_nan = false, userkey = false, nullString = default(string)}));
144144
}
145145

146146
[Fact]

tests/FlowtideDotNet.AcceptanceTests/TypeTests.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ INSERT INTO output
4949
SELECT
5050
Money
5151
FROM orders
52-
WHERE money < 500
52+
WHERE money < cast(500 as decimal)
5353
");
5454
await WaitForUpdate();
5555

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
using Antlr4.Runtime;
2+
using Antlr4.Runtime.Misc;
3+
using Microsoft.CodeAnalysis;
4+
using Microsoft.CodeAnalysis.Text;
5+
using System;
6+
using System.Collections.Generic;
7+
using System.IO;
8+
using System.Text;
9+
10+
namespace FlowtideDotNet.ComputeTests.SourceGenerator.Internal
11+
{
12+
class ErrorStrategy : DefaultErrorStrategy
13+
{
14+
public override void ReportError(Parser recognizer, RecognitionException e)
15+
{
16+
NotifyErrorListeners(recognizer, e.Message, e);
17+
//base.ReportError(recognizer, e);
18+
}
19+
}
20+
class ErrorReporter : IAntlrErrorListener<IToken>
21+
{
22+
private readonly string path;
23+
private readonly SourceProductionContext sourceContext;
24+
25+
public bool ErrorReported { get; private set; }
26+
27+
public ErrorReporter(string path, SourceProductionContext sourceContext)
28+
{
29+
this.path = path;
30+
this.sourceContext = sourceContext;
31+
}
32+
33+
public void SyntaxError(TextWriter output, IRecognizer recognizer, IToken offendingSymbol, int line, int charPositionInLine, string msg, RecognitionException e)
34+
{
35+
ErrorReported = true;
36+
sourceContext.ReportDiagnostic(Diagnostic.Create(
37+
"TESTGEN001",
38+
"Test",
39+
new TestLocalizableString(msg),
40+
DiagnosticSeverity.Error,
41+
DiagnosticSeverity.Error,
42+
true,
43+
0,
44+
location: Location.Create(path, new Microsoft.CodeAnalysis.Text.TextSpan(), new Microsoft.CodeAnalysis.Text.LinePositionSpan(new LinePosition(line - 1, charPositionInLine), new LinePosition(line - 1, charPositionInLine)))
45+
));
46+
}
47+
}
48+
}

tests/FlowtideDotNet.ComputeTests.SourceGenerator/Internal/TestCaseParser.cs

+7-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace FlowtideDotNet.ComputeTests
1717
{
1818
internal class TestCaseParser
1919
{
20-
public TestDocument Parse(string text)
20+
public TestDocument Parse(string text, IAntlrErrorListener<IToken> errorListener = default)
2121
{
2222
ICharStream stream = CharStreams.fromString(text);
2323
var lexer = new FuncTestCaseLexer(stream);
@@ -26,6 +26,12 @@ public TestDocument Parse(string text)
2626
{
2727
BuildParseTree = true
2828
};
29+
parser.RemoveErrorListeners();
30+
if (errorListener != null)
31+
{
32+
parser.AddErrorListener(errorListener);
33+
}
34+
2935

3036
var context = parser.doc();
3137

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Licensed under the Apache License, Version 2.0 (the "License")
2+
// you may not use this file except in compliance with the License.
3+
// You may obtain a copy of the License at
4+
//
5+
// http://www.apache.org/licenses/LICENSE-2.0
6+
//
7+
// Unless required by applicable law or agreed to in writing, software
8+
// distributed under the License is distributed on an "AS IS" BASIS,
9+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
// See the License for the specific language governing permissions and
11+
// limitations under the License.
12+
13+
using Microsoft.CodeAnalysis;
14+
using System;
15+
using System.Collections.Generic;
16+
using System.Text;
17+
18+
namespace FlowtideDotNet.ComputeTests.SourceGenerator.Internal
19+
{
20+
internal class TestLocalizableString : LocalizableString
21+
{
22+
private readonly string msg;
23+
24+
public TestLocalizableString(string msg)
25+
{
26+
this.msg = msg;
27+
}
28+
protected override bool AreEqual(object other)
29+
{
30+
return other is TestLocalizableString testLocalizableString && testLocalizableString.msg == msg;
31+
}
32+
33+
protected override int GetHash()
34+
{
35+
return msg.GetHashCode();
36+
}
37+
38+
protected override string GetText(IFormatProvider formatProvider)
39+
{
40+
return msg;
41+
}
42+
}
43+
}

0 commit comments

Comments
 (0)