Skip to content

[automated] Merge branch 'release/8.0' => 'main' #31601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions src/Microsoft.Data.Sqlite.Core/SqliteConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -821,22 +821,24 @@ private void CreateAggregateCore<TAccumulate, TResult>(
delegate_function_aggregate_step? func_step = null;
if (func != null)
{
func_step = (ctx, user_data, args) =>
func_step = static (ctx, user_data, args) =>
{
var context = (AggregateContext<TAccumulate>)user_data;
var definition = (AggregateDefinition<TAccumulate, TResult>)user_data;
ctx.state ??= new AggregateContext<TAccumulate>(definition.Seed);

var context = (AggregateContext<TAccumulate>)ctx.state;
if (context.Exception != null)
{
return;
}

// TODO: Avoid allocation when niladic
var reader = new SqliteParameterReader(name, args);
var reader = new SqliteParameterReader(definition.Name, args);

try
{
// TODO: Avoid closure by passing func via user_data
// NB: No need to set ctx.state since we just mutate the instance
context.Accumulate = func(context.Accumulate, reader);
context.Accumulate = definition.Func!(context.Accumulate, reader);
}
catch (Exception ex)
{
Expand All @@ -848,16 +850,18 @@ private void CreateAggregateCore<TAccumulate, TResult>(
delegate_function_aggregate_final? func_final = null;
if (resultSelector != null)
{
func_final = (ctx, user_data) =>
func_final = static (ctx, user_data) =>
{
var context = (AggregateContext<TAccumulate>)user_data;
var definition = (AggregateDefinition<TAccumulate, TResult>)user_data;
ctx.state ??= new AggregateContext<TAccumulate>(definition.Seed);

var context = (AggregateContext<TAccumulate>)ctx.state;

if (context.Exception == null)
{
try
{
// TODO: Avoid closure by passing resultSelector via user_data
var result = resultSelector(context.Accumulate);
var result = definition.ResultSelector!(context.Accumulate);

new SqliteResultBinder(ctx, result).Bind();
}
Expand All @@ -881,7 +885,7 @@ private void CreateAggregateCore<TAccumulate, TResult>(
}

var flags = isDeterministic ? SQLITE_DETERMINISTIC : 0;
var state = new AggregateContext<TAccumulate>(seed);
var state = new AggregateDefinition<TAccumulate, TResult>(name, seed, func, resultSelector);

if (State == ConnectionState.Open)
{
Expand Down Expand Up @@ -915,6 +919,22 @@ private void CreateAggregateCore<TAccumulate, TResult>(
return values;
}

private sealed class AggregateDefinition<TAccumulate, TResult>
{
public AggregateDefinition(string name, TAccumulate seed, Func<TAccumulate, SqliteValueReader, TAccumulate>? func, Func<TAccumulate, TResult>? resultSelector)
{
Name = name;
Seed = seed;
Func = func;
ResultSelector = resultSelector;
}

public string Name { get; }
public TAccumulate Seed { get; }
public Func<TAccumulate, SqliteValueReader, TAccumulate>? Func { get; }
public Func<TAccumulate, TResult>? ResultSelector { get; }
}

private sealed class AggregateContext<T>
{
public AggregateContext(T seed)
Expand Down
59 changes: 59 additions & 0 deletions test/Microsoft.Data.Sqlite.Tests/SqliteConnectionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,65 @@ public void CreateAggregate_works()
Assert.Equal("AX1Z", result);
}

[Fact]
public void CreateAggregate_works_called_twice()
{
using var connection = new SqliteConnection("Data Source=:memory:");
connection.Open();
connection.ExecuteNonQuery("CREATE TABLE dual2 (dummy1, dummy2); INSERT INTO dual2 (dummy1, dummy2) VALUES ('X', 'Y');");
connection.CreateAggregate(
"test",
"A",
(string a, string x, string y) => a + x + y,
a => a + "Z");

var result = connection.ExecuteScalar<string>("SELECT test(dummy1, dummy2) FROM dual2;");
Assert.Equal("AXYZ", result);

result = connection.ExecuteScalar<string>("SELECT test(dummy1, dummy2) FROM dual2;");
Assert.Equal("AXYZ", result);
}

[Fact]
public void CreateAggregate_works_called_twice_in_same_query()
{
using var connection = new SqliteConnection("Data Source=:memory:");
connection.Open();
connection.ExecuteNonQuery("CREATE TABLE dual2 (dummy1, dummy2); INSERT INTO dual2 (dummy1, dummy2) VALUES ('X', 'Y');");
connection.CreateAggregate(
"test",
"A",
(string a, string x, string y) => a + x + y,
a => a + "Z");

using (var reader = connection.ExecuteReader("SELECT test(dummy1, dummy2), test(dummy2, dummy1) FROM dual2;"))
{
Assert.True(reader.Read());

Assert.Equal("AXYZ", reader.GetString(0));
Assert.Equal("AYXZ", reader.GetString(1));

Assert.False(reader.Read());
}
}

[Fact]
public void CreateAggregate_works_when_no_rows()
{
using var connection = new SqliteConnection("Data Source=:memory:");
connection.Open();
connection.ExecuteNonQuery("CREATE TABLE dual2 (dummy1, dummy2);");
connection.CreateAggregate(
"test",
"A",
(string a, string x, string y) => a + x + y,
a => a + "Z");

var result = connection.ExecuteScalar<string>("SELECT test(dummy1, dummy2) FROM dual2;");

Assert.Equal("AZ", result);
}

[Fact]
public void CreateAggregate_works_when_params()
{
Expand Down