Skip to content

Further improve ProbabilisticMap on Avx512 #107798

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 28, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,8 @@ private static Vector512<byte> ContainsMask64CharsAvx512(Vector512<byte> charMap
Vector512<ushort> source0 = Vector512.LoadUnsafe(ref searchSpace0);
Vector512<ushort> source1 = Vector512.LoadUnsafe(ref searchSpace1);

Vector512<byte> sourceLower = Avx512BW.PackUnsignedSaturate(
(source0 & Vector512.Create((ushort)255)).AsInt16(),
(source1 & Vector512.Create((ushort)255)).AsInt16());

Vector512<byte> sourceUpper = Avx512BW.PackUnsignedSaturate(
(source0 >>> 8).AsInt16(),
(source1 >>> 8).AsInt16());
Vector512<byte> sourceLower = Avx512Vbmi.PermuteVar64x8x2(source0.AsByte(), Vector512.CreateSequence<byte>(0, 2), source1.AsByte());
Vector512<byte> sourceUpper = Avx512Vbmi.PermuteVar64x8x2(source0.AsByte(), Vector512.CreateSequence<byte>(1, 2), source1.AsByte());

Vector512<byte> resultLower = IsCharBitNotSetAvx512(charMap, sourceLower);
Vector512<byte> resultUpper = IsCharBitNotSetAvx512(charMap, sourceUpper);
Expand All @@ -128,12 +123,17 @@ private static Vector512<byte> ContainsMask64CharsAvx512(Vector512<byte> charMap
[CompExactlyDependsOn(typeof(Avx512Vbmi))]
private static Vector512<byte> IsCharBitNotSetAvx512(Vector512<byte> charMap, Vector512<byte> values)
{
Vector512<byte> shifted = values >>> VectorizedIndexShift;
// X86 does not have an instruction for right shifting 8-bit values, so it's instead emulated
// by using a 32-bit value shift followed by an AND to mask off the bits that should be zeroed.
// We're using PermuteVar64x8, which only looks at the lower 6 bits, so we can skip the AND.
// Bits 4/5/6 will not affect the result as the bit positions vector is duplicated 8 times.
Vector512<byte> shifted = (values.AsInt32() >>> VectorizedIndexShift).AsByte();

Vector512<byte> bitPositions = Avx512BW.Shuffle(Vector512.Create(0x8040201008040201).AsByte(), shifted);
Vector512<byte> bitPositions = Avx512Vbmi.PermuteVar64x8(Vector512.Create(0x8040201008040201).AsByte(), shifted);

Vector512<byte> index = values & Vector512.Create((byte)VectorizedIndexMask);
Vector512<byte> bitMask = Avx512Vbmi.PermuteVar64x8(charMap, index);
// We want to select bytes from 'charMap' based on the low 5 bits of 'values' (values & VectorizedIndexMask).
// PermuteVar64x8 will look at the low 6 bits, but the 6th bit will not affect the result as the 'charMap' is duplicated.
Vector512<byte> bitMask = Avx512Vbmi.PermuteVar64x8(charMap, values);

return Vector512.Equals(bitMask & bitPositions, Vector512<byte>.Zero);
}
Expand All @@ -145,13 +145,8 @@ private static Vector256<byte> ContainsMask32CharsAvx512(Vector256<byte> charMap
Vector256<ushort> source0 = Vector256.LoadUnsafe(ref searchSpace0);
Vector256<ushort> source1 = Vector256.LoadUnsafe(ref searchSpace1);

Vector256<byte> sourceLower = Avx2.PackUnsignedSaturate(
(source0 & Vector256.Create((ushort)255)).AsInt16(),
(source1 & Vector256.Create((ushort)255)).AsInt16());

Vector256<byte> sourceUpper = Avx2.PackUnsignedSaturate(
(source0 >>> 8).AsInt16(),
(source1 >>> 8).AsInt16());
Vector256<byte> sourceLower = Avx512Vbmi.VL.PermuteVar32x8x2(source0.AsByte(), Vector256.CreateSequence<byte>(0, 2), source1.AsByte());
Vector256<byte> sourceUpper = Avx512Vbmi.VL.PermuteVar32x8x2(source0.AsByte(), Vector256.CreateSequence<byte>(1, 2), source1.AsByte());

Vector256<byte> resultLower = IsCharBitNotSetAvx512(charMap, sourceLower);
Vector256<byte> resultUpper = IsCharBitNotSetAvx512(charMap, sourceUpper);
Expand All @@ -163,12 +158,17 @@ private static Vector256<byte> ContainsMask32CharsAvx512(Vector256<byte> charMap
[CompExactlyDependsOn(typeof(Avx512Vbmi.VL))]
private static Vector256<byte> IsCharBitNotSetAvx512(Vector256<byte> charMap, Vector256<byte> values)
{
Vector256<byte> shifted = values >>> VectorizedIndexShift;
// X86 does not have an instruction for right shifting 8-bit values, so it's instead emulated
// by using a 32-bit value shift followed by an AND to mask off the bits that should be zeroed.
// We're using PermuteVar32x8, which only looks at the lower 5 bits, so we can skip the AND.
// Bits 4/5 will not affect the result as the bit positions vector is duplicated 4 times
Vector256<byte> shifted = (values.AsInt32() >>> VectorizedIndexShift).AsByte();

Vector256<byte> bitPositions = Avx2.Shuffle(Vector256.Create(0x8040201008040201).AsByte(), shifted);
Vector256<byte> bitPositions = Avx512Vbmi.VL.PermuteVar32x8(Vector256.Create(0x8040201008040201).AsByte(), shifted);

Vector256<byte> index = values & Vector256.Create((byte)VectorizedIndexMask);
Vector256<byte> bitMask = Avx512Vbmi.VL.PermuteVar32x8(charMap, index);
// We want to select bytes from 'charMap' based on the low 5 bits of 'values' (values & VectorizedIndexMask).
// PermuteVar32x8 already looks only at the low 5 bits, so we can skip the redundant AND.
Vector256<byte> bitMask = Avx512Vbmi.VL.PermuteVar32x8(charMap, values);

return Vector256.Equals(bitMask & bitPositions, Vector256<byte>.Zero);
}
Expand Down Expand Up @@ -436,7 +436,7 @@ private static int IndexOfAnyVectorizedAvx512<TUseFastContains>(ref char searchS

if (result != Vector512<byte>.Zero)
{
if (TryFindMatchAvx512<TUseFastContains>(ref cur, PackedSpanHelpers.FixUpPackedVector512Result(result).ExtractMostSignificantBits(), ref state, out int index))
if (TryFindMatchAvx512<TUseFastContains>(ref cur, result.ExtractMostSignificantBits(), ref state, out int index))
{
return MatchOffset(ref searchSpace, ref cur) + index;
}
Expand Down Expand Up @@ -466,7 +466,7 @@ private static int IndexOfAnyVectorizedAvx512<TUseFastContains>(ref char searchS

if (result != Vector512<byte>.Zero)
{
if (TryFindMatchOverlappedAvx512<TUseFastContains>(ref searchSpace, searchSpaceLength, PackedSpanHelpers.FixUpPackedVector512Result(result).ExtractMostSignificantBits(), ref state, out int index))
if (TryFindMatchOverlappedAvx512<TUseFastContains>(ref searchSpace, searchSpaceLength, result.ExtractMostSignificantBits(), ref state, out int index))
{
return index;
}
Expand All @@ -483,7 +483,7 @@ private static int IndexOfAnyVectorizedAvx512<TUseFastContains>(ref char searchS

if (result != Vector256<byte>.Zero)
{
if (TryFindMatchOverlappedAvx512<TUseFastContains>(ref searchSpace, searchSpaceLength, PackedSpanHelpers.FixUpPackedVector256Result(result).ExtractMostSignificantBits(), ref state, out int index))
if (TryFindMatchOverlappedAvx512<TUseFastContains>(ref searchSpace, searchSpaceLength, result.ExtractMostSignificantBits(), ref state, out int index))
{
return index;
}
Expand Down Expand Up @@ -614,7 +614,7 @@ private static int LastIndexOfAnyVectorizedAvx512<TUseFastContains>(ref char sea

if (result != Vector512<byte>.Zero)
{
if (TryFindLastMatchAvx512<TUseFastContains>(ref cur, PackedSpanHelpers.FixUpPackedVector512Result(result).ExtractMostSignificantBits(), ref state, out int index))
if (TryFindLastMatchAvx512<TUseFastContains>(ref cur, result.ExtractMostSignificantBits(), ref state, out int index))
{
return MatchOffset(ref searchSpace, ref cur) + index;
}
Expand Down Expand Up @@ -643,7 +643,7 @@ private static int LastIndexOfAnyVectorizedAvx512<TUseFastContains>(ref char sea

if (result != Vector512<byte>.Zero)
{
if (TryFindLastMatchOverlappedAvx512<TUseFastContains>(ref searchSpace, searchSpaceLength, PackedSpanHelpers.FixUpPackedVector512Result(result).ExtractMostSignificantBits(), ref state, out int index))
if (TryFindLastMatchOverlappedAvx512<TUseFastContains>(ref searchSpace, searchSpaceLength, result.ExtractMostSignificantBits(), ref state, out int index))
{
return index;
}
Expand All @@ -661,7 +661,7 @@ private static int LastIndexOfAnyVectorizedAvx512<TUseFastContains>(ref char sea

if (result != Vector256<byte>.Zero)
{
if (TryFindLastMatchOverlappedAvx512<TUseFastContains>(ref searchSpace, searchSpaceLength, PackedSpanHelpers.FixUpPackedVector256Result(result).ExtractMostSignificantBits(), ref state, out int index))
if (TryFindLastMatchOverlappedAvx512<TUseFastContains>(ref searchSpace, searchSpaceLength, result.ExtractMostSignificantBits(), ref state, out int index))
{
return index;
}
Expand Down
Loading