Skip to content

Commit 8a61178

Browse files
committed
address code review feedback: count the arrays themselves
1 parent a5a38fb commit 8a61178

File tree

3 files changed

+41
-34
lines changed

3 files changed

+41
-34
lines changed

src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ private protected ArrayRecord(ArrayInfo arrayInfo)
5050

5151
internal long ValuesToRead { get; private protected set; }
5252

53-
private protected ArrayInfo ArrayInfo { get; }
53+
internal ArrayInfo ArrayInfo { get; }
5454

5555
internal bool IsJagged
5656
=> ArrayInfo.ArrayType == BinaryArrayType.Jagged

src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -191,51 +191,42 @@ private static long GetJaggedArrayFlattenedLength(BinaryArrayRecord jaggedArrayR
191191

192192
Debug.Assert(jaggedArrayRecord.IsJagged);
193193

194+
// In theory somebody could create a payload that would represent
195+
// a very nested array with total elements count > long.MaxValue.
196+
// That is why this method is using checked arithmetic.
197+
result = checked(result + jaggedArrayRecord.Length); // count the arrays themselves
198+
194199
foreach (object value in jaggedArrayRecord.Values)
195200
{
196-
object item = value is MemberReferenceRecord referenceRecord
197-
? referenceRecord.GetReferencedRecord()
198-
: value;
199-
200-
if (item is not SerializationRecord record)
201+
if (value is not SerializationRecord record)
201202
{
202-
result++;
203203
continue;
204204
}
205205

206+
if (record.RecordType == SerializationRecordType.MemberReference)
207+
{
208+
record = ((MemberReferenceRecord)record).GetReferencedRecord();
209+
}
210+
206211
switch (record.RecordType)
207212
{
208-
case SerializationRecordType.BinaryArray:
209213
case SerializationRecordType.ArraySinglePrimitive:
210214
case SerializationRecordType.ArraySingleObject:
211215
case SerializationRecordType.ArraySingleString:
216+
case SerializationRecordType.BinaryArray:
212217
ArrayRecord nestedArrayRecord = (ArrayRecord)record;
213218
if (nestedArrayRecord.IsJagged)
214219
{
215220
(jaggedArrayRecords ??= new()).Enqueue((BinaryArrayRecord)nestedArrayRecord);
216221
}
217222
else
218223
{
219-
Debug.Assert(nestedArrayRecord is not BinaryArrayRecord, "Ensure lack of recursive call");
220-
checked
221-
{
222-
// In theory somebody could create a payload that would represent
223-
// a very nested array with total elements count > long.MaxValue.
224-
result += nestedArrayRecord.FlattenedLength;
225-
}
226-
}
227-
break;
228-
case SerializationRecordType.ObjectNull:
229-
case SerializationRecordType.ObjectNullMultiple256:
230-
case SerializationRecordType.ObjectNullMultiple:
231-
// All nulls need to be included, as it's another form of possible attack.
232-
checked
233-
{
234-
result += ((NullsRecord)item).NullCount;
224+
// Don't call nestedArrayRecord.FlattenedLength to avoid any potential recursion,
225+
// just call nestedArrayRecord.ArrayInfo.FlattenedLength that returns pre-computed value.
226+
result = checked(result + nestedArrayRecord.ArrayInfo.FlattenedLength);
235227
}
236228
break;
237229
default:
238-
result++;
239230
break;
240231
}
241232
}

src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,25 @@ namespace System.Formats.Nrbf.Tests;
77

88
public class JaggedArraysTests : ReadTests
99
{
10-
[Fact]
11-
public void CanReadJaggedArraysOfPrimitiveTypes_2D()
10+
[Theory]
11+
[InlineData(true)]
12+
[InlineData(false)]
13+
public void CanReadJaggedArraysOfPrimitiveTypes_2D(bool useReferences)
1214
{
1315
int[][] input = new int[7][];
16+
int[] same = [1, 2, 3];
1417
for (int i = 0; i < input.Length; i++)
1518
{
16-
input[i] = [i, i, i];
19+
input[i] = useReferences
20+
? same // reuse the same object (represented as a single record that is referenced multiple times)
21+
: [i, i, i]; // create new array
1722
}
1823

1924
var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input));
2025

2126
Verify(input, arrayRecord);
2227
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
23-
Assert.Equal(input.Length * 3, arrayRecord.FlattenedLength);
28+
Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength);
2429
}
2530

2631
[Theory]
@@ -42,13 +47,17 @@ public void FlattenedLengthIncludesNullArrays(int nullCount)
4247
public void ItIsPossibleToHaveBinaryArrayRecordsHaveAnElementTypeOfArrayWithoutBeingMarkedAsJagged()
4348
{
4449
int[][][] input = new int[3][][];
50+
long totalElementsCount = 0;
4551
for (int i = 0; i < input.Length; i++)
4652
{
4753
input[i] = new int[4][];
54+
totalElementsCount++; // count the arrays themselves
4855

4956
for (int j = 0; j < input[i].Length; j++)
5057
{
5158
input[i][j] = [i, j, 0, 1, 2];
59+
totalElementsCount += input[i][j].Length;
60+
totalElementsCount++; // count the arrays themselves
5261
}
5362
}
5463

@@ -67,25 +76,31 @@ public void ItIsPossibleToHaveBinaryArrayRecordsHaveAnElementTypeOfArrayWithoutB
6776

6877
Verify(input, arrayRecord);
6978
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
70-
Assert.Equal(3 * 4 * 5, arrayRecord.FlattenedLength);
79+
Assert.Equal(3 + 3 * 4 + 3 * 4 * 5, totalElementsCount);
80+
Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength);
7181
}
7282

7383
[Fact]
7484
public void CanReadJaggedArraysOfPrimitiveTypes_3D()
7585
{
7686
int[][][] input = new int[7][][];
87+
long totalElementsCount = 0;
7788
for (int i = 0; i < input.Length; i++)
7889
{
90+
totalElementsCount++; // count the arrays themselves
7991
input[i] = new int[1][];
92+
totalElementsCount++; // count the arrays themselves
8093
input[i][0] = [i, i, i];
94+
totalElementsCount += input[i][0].Length;
8195
}
8296

8397
var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input));
8498

8599
Verify(input, arrayRecord);
86100
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
87101
Assert.Equal(1, arrayRecord.Rank);
88-
Assert.Equal(input.Length * 1 * 3, arrayRecord.FlattenedLength);
102+
Assert.Equal(7 + 7 * 1 + 7 * 1 * 3, totalElementsCount);
103+
Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength);
89104
}
90105

91106
[Fact]
@@ -110,7 +125,7 @@ public void CanReadJaggedArrayOfRectangularArrays()
110125
Verify(input, arrayRecord);
111126
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
112127
Assert.Equal(1, arrayRecord.Rank);
113-
Assert.Equal(input.Length * 3 * 3, arrayRecord.FlattenedLength);
128+
Assert.Equal(input.Length + input.Length * 3 * 3, arrayRecord.FlattenedLength);
114129
}
115130

116131
[Fact]
@@ -126,7 +141,7 @@ public void CanReadJaggedArraysOfStrings()
126141

127142
Verify(input, arrayRecord);
128143
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
129-
Assert.Equal(input.Length * 3, arrayRecord.FlattenedLength);
144+
Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength);
130145
}
131146

132147
[Fact]
@@ -142,7 +157,7 @@ public void CanReadJaggedArraysOfObjects()
142157

143158
Verify(input, arrayRecord);
144159
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
145-
Assert.Equal(input.Length * 3, arrayRecord.FlattenedLength);
160+
Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength);
146161
}
147162

148163
[Serializable]
@@ -160,6 +175,7 @@ public void CanReadJaggedArraysOfComplexTypes()
160175
{
161176
input[i] = Enumerable.Range(0, i + 1).Select(j => new ComplexType { SomeField = j }).ToArray();
162177
totalElementsCount += input[i].Length;
178+
totalElementsCount++; // count the arrays themselves
163179
}
164180

165181
var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input));

0 commit comments

Comments
 (0)