Skip to content

Commit 326cbb1

Browse files
authored
Add support for lead window function (#757)
This fixes #752
1 parent f6ff57a commit 326cbb1

File tree

6 files changed

+388
-0
lines changed

6 files changed

+388
-0
lines changed

docs/docs/expressions/windowfunctions/arithmetic.md

+21
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,24 @@ This function **requires an `ORDER BY` clause** to determine row position and **
3838
```sql
3939
SELECT ROW_NUMBER(column1) OVER (PARTITION BY column2 ORDER BY column3) FROM ...
4040
```
41+
42+
## Lead
43+
44+
[Substrait definition](https://substrait.io/extensions/functions_arithmetic/#lead)
45+
46+
The `LEAD` window function provides access to a subsequent row’s value within the same result set partition. It returns the value of a specified column at a given offset after the current row.
47+
48+
If no row exists at that offset, a default value (if provided) is returned; otherwise, the result is `NULL`.
49+
50+
This function **requires an `ORDER BY` clause** to establish row sequence and **does not support frame boundaries**.
51+
52+
### SQL Usage
53+
54+
```sql
55+
-- Lead with default offset 1 and null default
56+
SELECT LEAD(column1) OVER (PARTITION BY column2 ORDER BY column3) FROM ...
57+
-- Lead with offset 2 and null default
58+
SELECT LEAD(column1, 2) OVER (PARTITION BY column2 ORDER BY column3) FROM ...
59+
-- Lead with offset 2 and default value set to 'hello'
60+
SELECT LEAD(column1, 2, 'hello') OVER (PARTITION BY column2 ORDER BY column3) FROM ...
61+
```

src/FlowtideDotNet.Core/Compute/Columnar/Functions/WindowFunctions/BuiltInWindowFunctions.cs

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ public static void AddBuiltInWindowFunctions(FunctionsRegister functionsRegister
2525
{
2626
functionsRegister.RegisterWindowFunction(FunctionsArithmetic.Uri, FunctionsArithmetic.Sum, new SumWindowFunctionDefinition());
2727
functionsRegister.RegisterWindowFunction(FunctionsArithmetic.Uri, FunctionsArithmetic.RowNumber, new RowNumberWindowFunctionDefinition());
28+
functionsRegister.RegisterWindowFunction(FunctionsArithmetic.Uri, FunctionsArithmetic.Lead, new LeadWindowFunctionDefinition());
2829
}
2930
}
3031
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
// Licensed under the Apache License, Version 2.0 (the "License")
2+
// you may not use this file except in compliance with the License.
3+
// You may obtain a copy of the License at
4+
//
5+
// http://www.apache.org/licenses/LICENSE-2.0
6+
//
7+
// Unless required by applicable law or agreed to in writing, software
8+
// distributed under the License is distributed on an "AS IS" BASIS,
9+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
// See the License for the specific language governing permissions and
11+
// limitations under the License.
12+
13+
using FlowtideDotNet.Core.ColumnStore;
14+
using FlowtideDotNet.Core.ColumnStore.DataValues;
15+
using FlowtideDotNet.Core.ColumnStore.TreeStorage;
16+
using FlowtideDotNet.Core.Operators.Window;
17+
using FlowtideDotNet.Storage.Memory;
18+
using FlowtideDotNet.Storage.StateManager;
19+
using FlowtideDotNet.Storage.Tree;
20+
using FlowtideDotNet.Substrait.Expressions;
21+
using System;
22+
using System.Collections.Generic;
23+
using System.Diagnostics;
24+
using System.Linq;
25+
using System.Text;
26+
using System.Threading.Tasks;
27+
28+
namespace FlowtideDotNet.Core.Compute.Columnar.Functions.WindowFunctions
29+
{
30+
internal class LeadWindowFunctionDefinition : WindowFunctionDefinition
31+
{
32+
public override IWindowFunction Create(WindowFunction aggregateFunction, IFunctionsRegister functionsRegister)
33+
{
34+
if (aggregateFunction.Arguments.Count < 1)
35+
{
36+
throw new ArgumentException("Lead function requires at least one argument");
37+
}
38+
39+
var leadValueFunc = ColumnProjectCompiler.CompileToValue(aggregateFunction.Arguments[0], functionsRegister);
40+
41+
Func<EventBatchData, int, IDataValue>? leadOffsetFunc = default;
42+
if (aggregateFunction.Arguments.Count > 1)
43+
{
44+
leadOffsetFunc = ColumnProjectCompiler.CompileToValue(aggregateFunction.Arguments[1], functionsRegister);
45+
}
46+
47+
Func<EventBatchData, int, IDataValue> ? defaultFunc = default;
48+
if (aggregateFunction.Arguments.Count > 2)
49+
{
50+
defaultFunc = ColumnProjectCompiler.CompileToValue(aggregateFunction.Arguments[2], functionsRegister);
51+
}
52+
53+
return new LeadWindowFunction(leadValueFunc, leadOffsetFunc, defaultFunc);
54+
}
55+
}
56+
57+
internal class LeadWindowFunction : IWindowFunction
58+
{
59+
private IWindowAddOutputRow? _addOutputRow;
60+
private IBPlusTreeIterator<ColumnRowReference, WindowValue, ColumnKeyStorageContainer, WindowValueContainer>? _updateIterator;
61+
private IBPlusTreeIterator<ColumnRowReference, WindowValue, ColumnKeyStorageContainer, WindowValueContainer>? _windowIterator;
62+
private PartitionIterator? _updatePartitionIterator;
63+
private PartitionIterator? _windowPartitionIterator;
64+
65+
private readonly Func<EventBatchData, int, IDataValue> _leadValueFunc;
66+
private readonly Func<EventBatchData, int, IDataValue>? _leadOffsetFunc;
67+
private readonly Func<EventBatchData, int, IDataValue>? _defaultValueFunc;
68+
69+
public LeadWindowFunction(
70+
Func<EventBatchData, int, IDataValue> leadValueFunc,
71+
Func<EventBatchData, int, IDataValue>? leadOffsetFunc,
72+
Func<EventBatchData, int, IDataValue>? defaultValueFunc)
73+
{
74+
_leadValueFunc = leadValueFunc;
75+
_leadOffsetFunc = leadOffsetFunc;
76+
_defaultValueFunc = defaultValueFunc;
77+
}
78+
79+
public async IAsyncEnumerable<EventBatchWeighted> ComputePartition(ColumnRowReference partitionValues)
80+
{
81+
Debug.Assert(_addOutputRow != null);
82+
Debug.Assert(_windowPartitionIterator != null);
83+
Debug.Assert(_updatePartitionIterator != null);
84+
85+
await _windowPartitionIterator.Reset(partitionValues);
86+
_updatePartitionIterator.ResetCopyFrom(_windowPartitionIterator);
87+
88+
var windowEnumerator = _windowPartitionIterator.GetAsyncEnumerator();
89+
var updateEnumerator = _updatePartitionIterator.GetAsyncEnumerator();
90+
91+
long updateRowIndex = 0;
92+
long windowRowIndex = 0;
93+
94+
var currentValue = new DataValueContainer();
95+
currentValue._type = ArrowTypeId.Null;
96+
97+
while (await updateEnumerator.MoveNextAsync())
98+
{
99+
int rowOffset = 1;
100+
if (_leadOffsetFunc != null)
101+
{
102+
var offsetValue = _leadOffsetFunc(updateEnumerator.Current.Key.referenceBatch, updateEnumerator.Current.Key.RowIndex);
103+
if (offsetValue is Int64Value int64Value)
104+
{
105+
rowOffset = (int)int64Value.AsLong;
106+
}
107+
}
108+
109+
110+
if (windowRowIndex > (updateRowIndex + rowOffset))
111+
{
112+
// Must reset the window enumerator to the beginning, this can be done faster, but at this
113+
// time it is an edge case since it requires dynamic row offset
114+
_windowPartitionIterator.ResetCopyFrom(_updatePartitionIterator);
115+
windowEnumerator = _windowPartitionIterator.GetAsyncEnumerator();
116+
windowRowIndex = 0;
117+
}
118+
119+
bool movedNext = false;
120+
while (windowRowIndex <= (updateRowIndex + rowOffset))
121+
{
122+
movedNext = await windowEnumerator.MoveNextAsync();
123+
if (!movedNext)
124+
{
125+
break;
126+
}
127+
windowRowIndex++;
128+
}
129+
130+
IDataValue? val;
131+
if (!movedNext)
132+
{
133+
if (_defaultValueFunc != null)
134+
{
135+
val = _defaultValueFunc(updateEnumerator.Current.Key.referenceBatch, updateEnumerator.Current.Key.RowIndex);
136+
}
137+
else
138+
{
139+
val = NullValue.Instance;
140+
}
141+
}
142+
else
143+
{
144+
val = _leadValueFunc(windowEnumerator.Current.Key.referenceBatch, windowEnumerator.Current.Key.RowIndex);
145+
}
146+
147+
updateRowIndex++;
148+
149+
updateEnumerator.Current.Value.UpdateStateValue(val);
150+
151+
if (_addOutputRow.Count >= 100)
152+
{
153+
yield return _addOutputRow.GetCurrentBatch();
154+
}
155+
}
156+
157+
if (_addOutputRow.Count > 0)
158+
{
159+
yield return _addOutputRow.GetCurrentBatch();
160+
}
161+
}
162+
163+
public Task Initialize(IBPlusTree<ColumnRowReference, WindowValue, ColumnKeyStorageContainer, WindowValueContainer> persistentTree, List<int> partitionColumns, IMemoryAllocator memoryAllocator, IStateManagerClient stateManagerClient, IWindowAddOutputRow addOutputRow)
164+
{
165+
_addOutputRow = addOutputRow;
166+
_windowIterator = persistentTree.CreateIterator();
167+
_updateIterator = persistentTree.CreateIterator();
168+
169+
_updatePartitionIterator = new PartitionIterator(_updateIterator, partitionColumns, addOutputRow);
170+
_windowPartitionIterator = new PartitionIterator(_windowIterator, partitionColumns);
171+
172+
return Task.CompletedTask;
173+
}
174+
}
175+
}

src/FlowtideDotNet.Substrait/FunctionExtensions/FunctionsArithmetic.cs

+1
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,6 @@ public static class FunctionsArithmetic
5050

5151
// Window
5252
public const string RowNumber = "row_number";
53+
public const string Lead = "lead";
5354
}
5455
}

src/FlowtideDotNet.Substrait/Sql/Internal/BuiltInSqlFunctions.cs

+79
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
using FlowtideDotNet.Substrait.FunctionExtensions;
1616
using FlowtideDotNet.Substrait.Sql.Internal.TableFunctions;
1717
using FlowtideDotNet.Substrait.Type;
18+
using SqlParser;
1819
using SqlParser.Ast;
1920
using System.Diagnostics;
2021
using static SqlParser.Ast.WindowType;
@@ -764,6 +765,84 @@ public static void AddBuiltInFunctions(SqlFunctionRegister sqlFunctionRegister)
764765
// WindowFunction
765766
RegisterSingleVariableWindowFunction(sqlFunctionRegister, "sum", FunctionsArithmetic.Uri, FunctionsArithmetic.Sum, new AnyType(), true, false);
766767
RegisterZeroVariableWindowFunction(sqlFunctionRegister, "row_number", FunctionsArithmetic.Uri, FunctionsArithmetic.RowNumber, new Int64Type(), false, true);
768+
769+
sqlFunctionRegister.RegisterWindowFunction("lead",
770+
(func, visitor, emitData) =>
771+
{
772+
var argList = GetFunctionArguments(func.Args);
773+
if (argList.Args == null || argList.Args.Count < 1)
774+
{
775+
throw new InvalidOperationException($"lead must have exactly at least one argument, and not be '*'");
776+
}
777+
if ((argList.Args[0] is FunctionArg.Unnamed unnamed0 && unnamed0.FunctionArgExpression is FunctionArgExpression.Wildcard))
778+
{
779+
throw new InvalidOperationException($"lead must have at least one argument, and not be '*'");
780+
}
781+
if (argList.Args.Count > 3)
782+
{
783+
throw new InvalidOperationException($"lead must have at most three arguments, and not be '*'");
784+
}
785+
786+
if (func.Over is WindowSpecType windowSpecType)
787+
{
788+
if (windowSpecType.Spec.OrderBy == null)
789+
{
790+
throw new SubstraitParseException($"'lead' function must have an order by clause");
791+
}
792+
if (windowSpecType.Spec.WindowFrame != null)
793+
{
794+
if (windowSpecType.Spec.WindowFrame.Units == WindowFrameUnit.Rows)
795+
{
796+
throw new SubstraitParseException($"'lead' function does not support ROWS frame");
797+
}
798+
}
799+
}
800+
801+
WindowFunction windowFunc = new WindowFunction()
802+
{
803+
Arguments = new List<Expressions.Expression>(),
804+
ExtensionName = FunctionsArithmetic.Lead,
805+
ExtensionUri = FunctionsArithmetic.Uri,
806+
};
807+
808+
SubstraitBaseType? returnType = null;
809+
for (int i = 0; i < argList.Args.Count; i++)
810+
{
811+
var arg = argList.Args[i];
812+
if (arg is FunctionArg.Unnamed unnamed)
813+
{
814+
if (unnamed.FunctionArgExpression is FunctionArgExpression.FunctionExpression funcExpr)
815+
{
816+
var expr = visitor.Visit(funcExpr.Expression, emitData);
817+
windowFunc.Arguments.Add(expr.Expr);
818+
819+
if (i == 0)
820+
{
821+
returnType = expr.Type;
822+
}
823+
else if (returnType != expr.Type && i == 2)
824+
{
825+
returnType = AnyType.Instance;
826+
}
827+
}
828+
else
829+
{
830+
throw new NotImplementedException("lead does not support the input parameter");
831+
}
832+
}
833+
else
834+
{
835+
throw new NotImplementedException("lead does not support the input parameter");
836+
}
837+
}
838+
839+
if (returnType == null)
840+
{
841+
returnType = AnyType.Instance;
842+
}
843+
844+
return new WindowResponse(windowFunc, returnType);
845+
});
767846
}
768847

769848
private static void RegisterSingleVariableFunction(

0 commit comments

Comments
 (0)