Skip to content

Commit 0d944ae

Browse files
github-actions[bot]jkoritzinskyAaronRobinsonMSFTelinor-fung
authored
[release/7.0] Ensure we cleanup the marshalling for elements of collections (stateful and stateless) (#76693)
* Ensure we cleanup the marshalling for elements of collections (stateful and stateless) * Add tests * Fix bad stackalloc size after we moved to strongly-typed buffers * PR feedback * Update NonBlittable.cs * Propagate details for types. * Update src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs Co-authored-by: Elinor Fung <[email protected]> Co-authored-by: Jeremy Koritzinsky <[email protected]> Co-authored-by: Aaron Robinson <[email protected]> Co-authored-by: Elinor Fung <[email protected]>
1 parent ef70886 commit 0d944ae

File tree

8 files changed

+172
-37
lines changed

8 files changed

+172
-37
lines changed

src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,33 @@ protected StatementSyntax GenerateByValueOutUnmarshalStatement(TypePositionInfo
315315
StubCodeContext.Stage.Unmarshal));
316316
}
317317

318+
protected StatementSyntax GenerateElementCleanupStatement(TypePositionInfo info, StubCodeContext context)
319+
{
320+
string nativeSpanIdentifier = MarshallerHelpers.GetNativeSpanIdentifier(info, context);
321+
StatementSyntax contentsCleanupStatements = GenerateContentsMarshallingStatement(info, context,
322+
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
323+
IdentifierName(MarshallerHelpers.GetNativeSpanIdentifier(info, context)),
324+
IdentifierName("Length")),
325+
StubCodeContext.Stage.Cleanup);
326+
327+
if (contentsCleanupStatements.IsKind(SyntaxKind.EmptyStatement))
328+
{
329+
return EmptyStatement();
330+
}
331+
332+
return Block(
333+
LocalDeclarationStatement(VariableDeclaration(
334+
GenericName(
335+
Identifier(TypeNames.System_Span),
336+
TypeArgumentList(SingletonSeparatedList(_unmanagedElementType))),
337+
SingletonSeparatedList(
338+
VariableDeclarator(
339+
Identifier(nativeSpanIdentifier))
340+
.WithInitializer(EqualsValueClause(
341+
GetUnmanagedValuesDestination(info, context)))))),
342+
contentsCleanupStatements);
343+
}
344+
318345
protected StatementSyntax GenerateContentsMarshallingStatement(
319346
TypePositionInfo info,
320347
StubCodeContext context,

src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,5 +292,37 @@ public static IEnumerable<TypePositionInfo> GetDependentElementsOfMarshallingInf
292292
}
293293
}
294294
}
295+
296+
public static StatementSyntax SkipInitOrDefaultInit(TypePositionInfo info, StubCodeContext context)
297+
{
298+
(TargetFramework fmk, _) = context.GetTargetFramework();
299+
if (info.ManagedType is not PointerTypeInfo
300+
&& info.ManagedType is not ValueTypeInfo { IsByRefLike: true }
301+
&& fmk is TargetFramework.Net)
302+
{
303+
// Use the Unsafe.SkipInit<T> API when available and
304+
// managed type is usable as a generic parameter.
305+
return ExpressionStatement(
306+
InvocationExpression(
307+
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
308+
ParseName(TypeNames.System_Runtime_CompilerServices_Unsafe),
309+
IdentifierName("SkipInit")))
310+
.WithArgumentList(
311+
ArgumentList(SingletonSeparatedList(
312+
Argument(IdentifierName(info.InstanceIdentifier))
313+
.WithRefOrOutKeyword(Token(SyntaxKind.OutKeyword))))));
314+
}
315+
else
316+
{
317+
// Assign out params to default
318+
return ExpressionStatement(
319+
AssignmentExpression(
320+
SyntaxKind.SimpleAssignmentExpression,
321+
IdentifierName(info.InstanceIdentifier),
322+
LiteralExpression(
323+
SyntaxKind.DefaultLiteralExpression,
324+
Token(SyntaxKind.DefaultKeyword))));
325+
}
326+
}
295327
}
296328
}

src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,27 @@ public StatefulLinearCollectionNonBlittableElementsMarshalling(
459459
}
460460

461461
public TypeSyntax AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info);
462-
public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateCleanupStatements(info, context);
462+
public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
463+
{
464+
StatementSyntax elementCleanup = GenerateElementCleanupStatement(info, context);
465+
466+
if (!elementCleanup.IsKind(SyntaxKind.EmptyStatement))
467+
{
468+
yield return elementCleanup;
469+
}
470+
471+
if (!_shape.HasFlag(MarshallerShape.Free))
472+
yield break;
473+
474+
string marshaller = StatefulValueMarshalling.GetMarshallerIdentifier(info, context);
475+
// <marshaller>.Free();
476+
yield return ExpressionStatement(
477+
InvocationExpression(
478+
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
479+
IdentifierName(marshaller),
480+
IdentifierName(ShapeMemberNames.Free)),
481+
ArgumentList()));
482+
}
463483
public IEnumerable<StatementSyntax> GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context);
464484

465485
public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context)

src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,10 @@ public StatelessFreeMarshalling(ICustomTypeMarshallingStrategy innerMarshaller,
251251

252252
public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
253253
{
254+
foreach (StatementSyntax statement in _innerMarshaller.GenerateCleanupStatements(info, context))
255+
{
256+
yield return statement;
257+
}
254258
// <marshallerType>.Free(<nativeIdentifier>);
255259
yield return ExpressionStatement(
256260
InvocationExpression(
@@ -372,11 +376,19 @@ public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo i
372376
public IEnumerable<StatementSyntax> GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty<StatementSyntax>();
373377
public IEnumerable<StatementSyntax> GenerateSetupStatements(TypePositionInfo info, StubCodeContext context)
374378
{
379+
string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context);
375380
yield return LocalDeclarationStatement(
376381
VariableDeclaration(
377382
PredefinedType(Token(SyntaxKind.IntKeyword)),
378383
SingletonSeparatedList(
379-
VariableDeclarator(MarshallerHelpers.GetNumElementsIdentifier(info, context)))));
384+
VariableDeclarator(numElementsIdentifier))));
385+
// Use the numElements local to ensure the compiler doesn't give errors for using an uninitialized variable.
386+
// The value will never be used unless it has been initialized, so this is safe.
387+
yield return MarshallerHelpers.SkipInitOrDefaultInit(
388+
new TypePositionInfo(SpecialTypeInfo.Int32, NoMarshallingInfo.Instance)
389+
{
390+
InstanceIdentifier = numElementsIdentifier
391+
}, context);
380392
}
381393

382394
public IEnumerable<StatementSyntax> GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty<StatementSyntax>();
@@ -512,7 +524,15 @@ public StatelessLinearCollectionNonBlittableElementsMarshalling(
512524

513525
public TypeSyntax AsNativeType(TypePositionInfo info) => _nativeTypeSyntax;
514526

515-
public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty<StatementSyntax>();
527+
public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
528+
{
529+
StatementSyntax elementCleanup = GenerateElementCleanupStatement(info, context);
530+
531+
if (!elementCleanup.IsKind(SyntaxKind.EmptyStatement))
532+
{
533+
yield return elementCleanup;
534+
}
535+
}
516536

517537
public IEnumerable<StatementSyntax> GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context)
518538
{
@@ -588,11 +608,19 @@ public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo i
588608

589609
public IEnumerable<StatementSyntax> GenerateSetupStatements(TypePositionInfo info, StubCodeContext context)
590610
{
611+
string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context);
591612
yield return LocalDeclarationStatement(
592613
VariableDeclaration(
593614
PredefinedType(Token(SyntaxKind.IntKeyword)),
594615
SingletonSeparatedList(
595-
VariableDeclarator(MarshallerHelpers.GetNumElementsIdentifier(info, context)))));
616+
VariableDeclarator(numElementsIdentifier))));
617+
// Use the numElements local to ensure the compiler doesn't give errors for using an uninitialized variable.
618+
// The value will never be used unless it has been initialized, so this is safe.
619+
yield return MarshallerHelpers.SkipInitOrDefaultInit(
620+
new TypePositionInfo(SpecialTypeInfo.Int32, NoMarshallingInfo.Instance)
621+
{
622+
InstanceIdentifier = numElementsIdentifier
623+
}, context);
596624
}
597625

598626
public IEnumerable<StatementSyntax> GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty<StatementSyntax>();

src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ private MarshallingInfo CreateNativeMarshallingInfo(
593593
}
594594

595595
int maxIndirectionDepthUsedLocal = maxIndirectionDepthUsed;
596-
Func<ITypeSymbol, MarshallingInfo> getMarshallingInfoForElement = (ITypeSymbol elementType) => GetMarshallingInfo(elementType, new Dictionary<int, AttributeData>(), 1, ImmutableHashSet<string>.Empty, ref maxIndirectionDepthUsedLocal);
596+
Func<ITypeSymbol, MarshallingInfo> getMarshallingInfoForElement = (ITypeSymbol elementType) => GetMarshallingInfo(elementType, useSiteAttributes, indirectionLevel + 1, inspectedElements, ref maxIndirectionDepthUsedLocal);
597597
if (ManualTypeMarshallingHelper.TryGetLinearCollectionMarshallersFromEntryType(entryPointType, type, _compilation, getMarshallingInfoForElement, out CustomTypeMarshallers? collectionMarshallers))
598598
{
599599
maxIndirectionDepthUsed = maxIndirectionDepthUsedLocal;

src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/VariableDeclarations.cs

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,34 +29,7 @@ public static VariableDeclarations GenerateDeclarationsForManagedToNative(BoundG
2929

3030
if (info.RefKind == RefKind.Out)
3131
{
32-
(TargetFramework fmk, _) = context.GetTargetFramework();
33-
if (info.ManagedType is not PointerTypeInfo
34-
&& info.ManagedType is not ValueTypeInfo { IsByRefLike: true }
35-
&& fmk is TargetFramework.Net)
36-
{
37-
// Use the Unsafe.SkipInit<T> API when available and
38-
// managed type is usable as a generic parameter.
39-
initializations.Add(ExpressionStatement(
40-
InvocationExpression(
41-
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
42-
ParseName(TypeNames.System_Runtime_CompilerServices_Unsafe),
43-
IdentifierName("SkipInit")))
44-
.WithArgumentList(
45-
ArgumentList(SingletonSeparatedList(
46-
Argument(IdentifierName(info.InstanceIdentifier))
47-
.WithRefOrOutKeyword(Token(SyntaxKind.OutKeyword)))))));
48-
}
49-
else
50-
{
51-
// Assign out params to default
52-
initializations.Add(ExpressionStatement(
53-
AssignmentExpression(
54-
SyntaxKind.SimpleAssignmentExpression,
55-
IdentifierName(info.InstanceIdentifier),
56-
LiteralExpression(
57-
SyntaxKind.DefaultLiteralExpression,
58-
Token(SyntaxKind.DefaultKeyword)))));
59-
}
32+
initializations.Add(MarshallerHelpers.SkipInitOrDefaultInit(info, context));
6033
}
6134

6235
// Declare variables for parameters

src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/CollectionTests.cs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ public partial class Stateless
2525
[LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array")]
2626
public static partial int SumWithBuffer([MarshalUsing(typeof(ListMarshallerWithBuffer<,>))] List<int> values, int numValues);
2727

28+
[LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_ptr_array")]
29+
public static unsafe partial int SumWithFreeTracking([MarshalUsing(typeof(ListMarshaller<,>)), MarshalUsing(typeof(IntWrapperMarshallerWithFreeCounts), ElementIndirectionDepth = 1)] List<IntWrapper> values, int numValues);
30+
2831
[LibraryImport(NativeExportsNE_Binary, EntryPoint = "double_values")]
2932
public static partial int DoubleValues([MarshalUsing(typeof(ListMarshallerWithPinning<,>))] List<BlittableIntWrapper> values, int length);
3033

@@ -99,6 +102,9 @@ public partial class Stateful
99102
[LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array")]
100103
public static partial int Sum([MarshalUsing(typeof(ListMarshallerStateful<,>))] List<int> values, int numValues);
101104

105+
[LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_ptr_array")]
106+
public static unsafe partial int SumWithFreeTracking([MarshalUsing(typeof(ListMarshallerStateful<,>)), MarshalUsing(typeof(IntWrapperMarshallerWithFreeCounts), ElementIndirectionDepth = 1)] List<IntWrapper> values, int numValues);
107+
102108
[LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array_ref")]
103109
public static partial int SumInArray([MarshalUsing(typeof(ListMarshallerStateful<,>))] in List<int> values, int numValues);
104110

@@ -369,6 +375,30 @@ public void NonBlittableElementCollection_GuaranteedUnmarshal()
369375
Assert.True(NativeExportsNE.Collections.Stateful.ListGuaranteedUnmarshal<BoolStruct, BoolStructMarshaller.BoolStructNative>.Marshaller.ToManagedFinallyCalled);
370376
}
371377

378+
[Fact]
379+
public void ElementsFreed()
380+
{
381+
List<IntWrapper> list = new List<IntWrapper>
382+
{
383+
new IntWrapper { i = 1 },
384+
new IntWrapper { i = 10 },
385+
new IntWrapper { i = 24 },
386+
new IntWrapper { i = 30 },
387+
};
388+
389+
int startingCount = IntWrapperMarshallerWithFreeCounts.NumCallsToFree;
390+
391+
NativeExportsNE.Collections.Stateless.SumWithFreeTracking(list, list.Count);
392+
393+
Assert.Equal(startingCount + list.Count, IntWrapperMarshallerWithFreeCounts.NumCallsToFree);
394+
395+
startingCount = IntWrapperMarshallerWithFreeCounts.NumCallsToFree;
396+
397+
NativeExportsNE.Collections.Stateful.SumWithFreeTracking(list, list.Count);
398+
399+
Assert.Equal(startingCount + list.Count, IntWrapperMarshallerWithFreeCounts.NumCallsToFree);
400+
}
401+
372402
private static List<BoolStruct> GetBoolStructsToAnd(bool result) => new List<BoolStruct>
373403
{
374404
new BoolStruct

src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/NonBlittable.cs

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,31 @@ public static void Free(int* unmanaged)
196196
}
197197
}
198198

199+
[CustomMarshaller(typeof(IntWrapper), MarshalMode.Default, typeof(IntWrapperMarshallerWithFreeCounts))]
200+
public static unsafe class IntWrapperMarshallerWithFreeCounts
201+
{
202+
[ThreadStatic]
203+
public static int NumCallsToFree = 0;
204+
205+
public static int* ConvertToUnmanaged(IntWrapper managed)
206+
{
207+
int* ret = (int*)Marshal.AllocCoTaskMem(sizeof(int));
208+
*ret = managed.i;
209+
return ret;
210+
}
211+
212+
public static IntWrapper ConvertToManaged(int* unmanaged)
213+
{
214+
return new IntWrapper { i = *unmanaged };
215+
}
216+
217+
public static void Free(int* unmanaged)
218+
{
219+
NumCallsToFree++;
220+
Marshal.FreeCoTaskMem((IntPtr)unmanaged);
221+
}
222+
}
223+
199224
[CustomMarshaller(typeof(IntWrapper), MarshalMode.Default, typeof(Marshaller))]
200225
public static unsafe class IntWrapperMarshallerStateful
201226
{
@@ -477,14 +502,14 @@ public void FromManaged(List<T> managed, Span<TUnmanagedElement> buffer)
477502

478503
_list = managed;
479504
// Always allocate at least one byte when the list is zero-length.
480-
int spaceToAllocate = Math.Max(managed.Count * sizeof(TUnmanagedElement), 1);
481-
if (spaceToAllocate <= buffer.Length)
505+
int countToAllocate = Math.Max(managed.Count, 1);
506+
if (countToAllocate <= buffer.Length)
482507
{
483-
_span = buffer[0..spaceToAllocate];
508+
_span = buffer[0..countToAllocate];
484509
}
485510
else
486511
{
487-
_allocatedMemory = Marshal.AllocCoTaskMem(spaceToAllocate);
512+
_allocatedMemory = Marshal.AllocCoTaskMem(countToAllocate * sizeof(TUnmanagedElement));
488513
_span = new Span<TUnmanagedElement>((void*)_allocatedMemory, managed.Count);
489514
}
490515
}

0 commit comments

Comments
 (0)