Skip to content

Commit 379dd36

Browse files
committed
Ensure that Arm64 correctly handles multiplication of simd by a 64-bit scalar
1 parent f402418 commit 379dd36

File tree

3 files changed

+48
-22
lines changed

3 files changed

+48
-22
lines changed

src/coreclr/jit/gentree.cpp

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20830,21 +20830,14 @@ GenTree* Compiler::gtNewSimdBinOpNode(
2083020830
{
2083120831
GenTree** broadcastOp = nullptr;
2083220832

20833-
#if defined(TARGET_ARM64)
20834-
if (varTypeIsLong(simdBaseType))
20835-
{
20836-
break;
20837-
}
20838-
#endif // TARGET_ARM64
20839-
2084020833
if (varTypeIsArithmetic(op1))
2084120834
{
2084220835
broadcastOp = &op1;
2084320836

2084420837
#if defined(TARGET_ARM64)
2084520838
if (!varTypeIsByte(simdBaseType))
2084620839
{
20847-
// MultiplyByScalar requires the scalar op to be op2fGetHWIntrinsicIdForBinOp
20840+
// MultiplyByScalar requires the scalar op to be op2 for GetHWIntrinsicIdForBinOp
2084820841
needsReverseOps = true;
2084920842
}
2085020843
#endif // TARGET_ARM64
@@ -20857,7 +20850,12 @@ GenTree* Compiler::gtNewSimdBinOpNode(
2085720850
if (broadcastOp != nullptr)
2085820851
{
2085920852
#if defined(TARGET_ARM64)
20860-
if (!varTypeIsByte(simdBaseType))
20853+
if (varTypeIsLong(simdBaseType))
20854+
{
20855+
// This is handled via emulation and the scalar is consumed directly
20856+
break;
20857+
}
20858+
else if (!varTypeIsByte(simdBaseType))
2086120859
{
2086220860
op2ForLookup = *broadcastOp;
2086320861
*broadcastOp = gtNewSimdCreateScalarUnsafeNode(TYP_SIMD8, *broadcastOp, simdBaseJitType, 8);
@@ -21261,24 +21259,26 @@ GenTree* Compiler::gtNewSimdBinOpNode(
2126121259
#elif defined(TARGET_ARM64)
2126221260
if (varTypeIsLong(simdBaseType))
2126321261
{
21264-
GenTree** op1ToDup = &op1;
21265-
GenTree** op2ToDup = &op2;
21262+
GenTree** op2ToDup = nullptr;
2126621263

21267-
if (!varTypeIsArithmetic(op1))
21268-
{
21269-
op1 = gtNewSimdToScalarNode(TYP_LONG, op1, simdBaseJitType, simdSize);
21270-
op1ToDup = &op1->AsHWIntrinsic()->Op(1);
21271-
}
21264+
assert(varTypeIsSIMD(op1));
21265+
op1 = gtNewSimdToScalarNode(TYP_LONG, op1, simdBaseJitType, simdSize);
21266+
GenTree** op1ToDup = &op1->AsHWIntrinsic()->Op(1);
2127221267

21273-
if (!varTypeIsArithmetic(op2))
21268+
if (varTypeIsSIMD(op2))
2127421269
{
2127521270
op2 = gtNewSimdToScalarNode(TYP_LONG, op2, simdBaseJitType, simdSize);
2127621271
op2ToDup = &op2->AsHWIntrinsic()->Op(1);
2127721272
}
2127821273

2127921274
// lower = op1.GetElement(0) * op2.GetElement(0)
2128021275
GenTree* lower = gtNewOperNode(GT_MUL, TYP_LONG, op1, op2);
21281-
lower = gtNewSimdCreateScalarUnsafeNode(type, lower, simdBaseJitType, simdSize);
21276+
21277+
if (op2ToDup == nullptr)
21278+
{
21279+
op2ToDup = &lower->AsOp()->gtOp2;
21280+
}
21281+
lower = gtNewSimdCreateScalarUnsafeNode(type, lower, simdBaseJitType, simdSize);
2128221282

2128321283
if (simdSize == 8)
2128421284
{
@@ -21290,10 +21290,8 @@ GenTree* Compiler::gtNewSimdBinOpNode(
2129021290
GenTree* op1Dup = fgMakeMultiUse(op1ToDup);
2129121291
GenTree* op2Dup = fgMakeMultiUse(op2ToDup);
2129221292

21293-
if (!varTypeIsArithmetic(op1Dup))
21294-
{
21295-
op1Dup = gtNewSimdGetElementNode(TYP_LONG, op1Dup, gtNewIconNode(1), simdBaseJitType, simdSize);
21296-
}
21293+
assert(!varTypeIsArithmetic(op1Dup));
21294+
op1Dup = gtNewSimdGetElementNode(TYP_LONG, op1Dup, gtNewIconNode(1), simdBaseJitType, simdSize);
2129721295

2129821296
if (!varTypeIsArithmetic(op2Dup))
2129921297
{
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Runtime.CompilerServices;
6+
using System.Runtime.Intrinsics;
7+
using Xunit;
8+
9+
public class Runtime_106838
10+
{
11+
[MethodImpl(MethodImplOptions.NoInlining)]
12+
private static Vector128<ulong> Problem(Vector128<ulong> vector) => vector * 5UL;
13+
14+
[Fact]
15+
public static void TestEntryPoint()
16+
{
17+
Vector128<ulong> result = Problem(Vector128.Create<ulong>(5));
18+
Assert.Equal(Vector128.Create<ulong>(25), result);
19+
}
20+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
<PropertyGroup>
3+
<Optimize>True</Optimize>
4+
</PropertyGroup>
5+
<ItemGroup>
6+
<Compile Include="$(MSBuildProjectName).cs" />
7+
</ItemGroup>
8+
</Project>

0 commit comments

Comments
 (0)