Skip to content

Commit 5a3dccd

Browse files
committed
Fix BigInteger.Rotate*
1 parent 7245cf3 commit 5a3dccd

File tree

2 files changed

+49
-31
lines changed

2 files changed

+49
-31
lines changed

src/libraries/System.Runtime.Numerics/src/System/Numerics/BigInteger.cs

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1701,7 +1701,7 @@ private static BigInteger Add(ReadOnlySpan<uint> leftBits, int leftSign, ReadOnl
17011701
}
17021702

17031703
if (bitsFromPool != null)
1704-
ArrayPool<uint>.Shared.Return(bitsFromPool);
1704+
ArrayPool<uint>.Shared.Return(bitsFromPool);
17051705

17061706
return result;
17071707
}
@@ -2636,7 +2636,7 @@ public static implicit operator BigInteger(nuint value)
26362636

26372637
if (zdFromPool != null)
26382638
ArrayPool<uint>.Shared.Return(zdFromPool);
2639-
exit:
2639+
exit:
26402640
if (xdFromPool != null)
26412641
ArrayPool<uint>.Shared.Return(xdFromPool);
26422642

@@ -3239,7 +3239,16 @@ public static BigInteger PopCount(BigInteger value)
32393239
public static BigInteger RotateLeft(BigInteger value, int rotateAmount)
32403240
{
32413241
value.AssertValid();
3242-
int byteCount = (value._bits is null) ? sizeof(int) : (value._bits.Length * 4);
3242+
3243+
bool negx = value._sign < 0;
3244+
ReadOnlySpan<uint> bits = value._bits ?? stackalloc uint[1] { NumericsHelpers.Abs(value._sign) };
3245+
int xl = bits.Length;
3246+
3247+
if (negx && bits[^1] >= kuMaskHighBit
3248+
&& !(bits.IndexOfAnyExcept(0u) == bits.Length - 1 && bits[^1] == kuMaskHighBit))
3249+
++xl;
3250+
3251+
int byteCount = xl * 4;
32433252

32443253
// Normalize the rotate amount to drop full rotations
32453254
rotateAmount = (int)(rotateAmount % (byteCount * 8L));
@@ -3256,14 +3265,13 @@ public static BigInteger RotateLeft(BigInteger value, int rotateAmount)
32563265
(int digitShift, int smallShift) = Math.DivRem(rotateAmount, kcbitUint);
32573266

32583267
uint[]? xdFromPool = null;
3259-
int xl = value._bits?.Length ?? 1;
3260-
32613268
Span<uint> xd = (xl <= BigIntegerCalculator.StackAllocThreshold)
32623269
? stackalloc uint[BigIntegerCalculator.StackAllocThreshold]
32633270
: xdFromPool = ArrayPool<uint>.Shared.Rent(xl);
32643271
xd = xd.Slice(0, xl);
3272+
xd[^1] = 0;
32653273

3266-
bool negx = value.GetPartsForBitManipulation(xd);
3274+
bits.CopyTo(xd);
32673275

32683276
int zl = xl;
32693277
uint[]? zdFromPool = null;
@@ -3374,7 +3382,17 @@ public static BigInteger RotateLeft(BigInteger value, int rotateAmount)
33743382
public static BigInteger RotateRight(BigInteger value, int rotateAmount)
33753383
{
33763384
value.AssertValid();
3377-
int byteCount = (value._bits is null) ? sizeof(int) : (value._bits.Length * 4);
3385+
3386+
3387+
bool negx = value._sign < 0;
3388+
ReadOnlySpan<uint> bits = value._bits ?? stackalloc uint[1] { NumericsHelpers.Abs(value._sign) };
3389+
int xl = bits.Length;
3390+
3391+
if (negx && bits[^1] >= kuMaskHighBit
3392+
&& !(bits.IndexOfAnyExcept(0u) == bits.Length - 1 && bits[^1] == kuMaskHighBit))
3393+
++xl;
3394+
3395+
int byteCount = xl * 4;
33783396

33793397
// Normalize the rotate amount to drop full rotations
33803398
rotateAmount = (int)(rotateAmount % (byteCount * 8L));
@@ -3391,14 +3409,13 @@ public static BigInteger RotateRight(BigInteger value, int rotateAmount)
33913409
(int digitShift, int smallShift) = Math.DivRem(rotateAmount, kcbitUint);
33923410

33933411
uint[]? xdFromPool = null;
3394-
int xl = value._bits?.Length ?? 1;
3395-
33963412
Span<uint> xd = (xl <= BigIntegerCalculator.StackAllocThreshold)
33973413
? stackalloc uint[BigIntegerCalculator.StackAllocThreshold]
33983414
: xdFromPool = ArrayPool<uint>.Shared.Rent(xl);
33993415
xd = xd.Slice(0, xl);
3416+
xd[^1] = 0;
34003417

3401-
bool negx = value.GetPartsForBitManipulation(xd);
3418+
bits.CopyTo(xd);
34023419

34033420
int zl = xl;
34043421
uint[]? zdFromPool = null;
@@ -3445,19 +3462,12 @@ public static BigInteger RotateRight(BigInteger value, int rotateAmount)
34453462
{
34463463
int carryShift = kcbitUint - smallShift;
34473464

3448-
int dstIndex = 0;
3449-
int srcIndex = digitShift;
3450-
3451-
uint carry = 0;
3465+
int dstIndex = xd.Length - 1;
3466+
int srcIndex = digitShift == 0
3467+
? xd.Length - 1
3468+
: digitShift - 1;
34523469

3453-
if (digitShift == 0)
3454-
{
3455-
carry = xd[^1] << carryShift;
3456-
}
3457-
else
3458-
{
3459-
carry = xd[srcIndex - 1] << carryShift;
3460-
}
3470+
uint carry = xd[digitShift] << carryShift;
34613471

34623472
do
34633473
{
@@ -3466,22 +3476,22 @@ public static BigInteger RotateRight(BigInteger value, int rotateAmount)
34663476
zd[dstIndex] = (part >> smallShift) | carry;
34673477
carry = part << carryShift;
34683478

3469-
dstIndex++;
3470-
srcIndex++;
3479+
dstIndex--;
3480+
srcIndex--;
34713481
}
3472-
while (srcIndex < xd.Length);
3482+
while ((uint)srcIndex < (uint)xd.Length);
34733483

3474-
srcIndex = 0;
3484+
srcIndex = xd.Length - 1;
34753485

3476-
while (dstIndex < zd.Length)
3486+
while ((uint)dstIndex < (uint)zd.Length)
34773487
{
34783488
uint part = xd[srcIndex];
34793489

34803490
zd[dstIndex] = (part >> smallShift) | carry;
34813491
carry = part << carryShift;
34823492

3483-
dstIndex++;
3484-
srcIndex++;
3493+
dstIndex--;
3494+
srcIndex--;
34853495
}
34863496
}
34873497

src/libraries/System.Runtime.Numerics/tests/BigInteger/Rotate.cs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System.Globalization;
5+
using Microsoft.VisualStudio.CodeCoverage;
56
using Xunit;
67

78
namespace System.Numerics.Tests
@@ -173,9 +174,10 @@ public void RunRotateTests()
173174
[Fact]
174175
public void RunSmallTests()
175176
{
177+
VerifyRotateString(Print(GetRandomSmallByteArray(-32)) + Print(GetRandomSmallByteArray(-2147483649)) + opstring);
176178
foreach (int i in new int[] {
177-
0,
178-
1,
179+
0,
180+
1,
179181
16,
180182
31,
181183
32,
@@ -225,7 +227,13 @@ private static void VerifyRotateString(string opstring)
225227
StackCalc sc = new StackCalc(opstring);
226228
while (sc.DoNextOperation())
227229
{
230+
Eq();
228231
Assert.Equal(sc.snCalc.Peek().ToString(), sc.myCalc.Peek().ToString());
232+
void Eq()
233+
{
234+
if (sc.snCalc.Peek() != sc.myCalc.Peek())
235+
throw new Exception(opstring);
236+
}
229237
}
230238
}
231239

0 commit comments

Comments
 (0)