Skip to content

Commit 62d4063

Browse files
committed
[SYSTEMDS-3874] Java17 Vectorized LibMM
This commit adds vectorized kernels for matrix multiplication.
1 parent 496a22f commit 62d4063

File tree

7 files changed

+206
-172
lines changed

7 files changed

+206
-172
lines changed

bin/systemds

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ if [ $WORKER == 1 ]; then
413413
print_out "# starting Federated worker on port $PORT"
414414
CMD=" \
415415
java $SYSTEMDS_STANDALONE_OPTS \
416+
--add-modules=jdk.incubator.vector \
416417
$LOG4JPROPFULL \
417418
-jar $SYSTEMDS_JAR_FILE \
418419
-w $PORT \
@@ -422,6 +423,7 @@ elif [ "$FEDMONITORING" == 1 ]; then
422423
print_out "# starting Federated backend monitoring on port $PORT"
423424
CMD=" \
424425
java $SYSTEMDS_STANDALONE_OPTS \
426+
--add-modules=jdk.incubator.vector \
425427
$LOG4JPROPFULL \
426428
-jar $SYSTEMDS_JAR_FILE \
427429
-fedMonitoring $PORT \
@@ -433,6 +435,7 @@ elif [ $SYSDS_DISTRIBUTED == 0 ]; then
433435
CMD=" \
434436
java $SYSTEMDS_STANDALONE_OPTS \
435437
$LOG4JPROPFULL \
438+
--add-modules=jdk.incubator.vector \
436439
-jar $SYSTEMDS_JAR_FILE \
437440
-f $SCRIPT_FILE \
438441
-exec $SYSDS_EXEC_MODE \
@@ -442,6 +445,7 @@ else
442445
print_out "# Running script $SCRIPT_FILE distributed with opts: $*"
443446
CMD=" \
444447
spark-submit $SYSTEMDS_DISTRIBUTED_OPTS \
448+
--add-modules=jdk.incubator.vector \
445449
$SYSTEMDS_JAR_FILE \
446450
-f $SCRIPT_FILE \
447451
-exec $SYSDS_EXEC_MODE \

pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
--add-opens=java.base/java.lang.ref=ALL-UNNAMED
9393
--add-opens=java.base/java.util.concurrent=ALL-UNNAMED
9494
--add-opens=java.base/sun.nio.ch=ALL-UNNAMED
95+
--add-modules=jdk.incubator.vector
9596
</jvm.addopens>
9697
</properties>
9798

@@ -357,6 +358,9 @@
357358
<source>${java.level}</source>
358359
<target>${java.level}</target>
359360
<release>${java.level}</release>
361+
<compilerArgs>
362+
<arg>--add-modules=jdk.incubator.vector</arg>
363+
</compilerArgs>
360364
</configuration>
361365
</plugin>
362366

@@ -892,6 +896,7 @@
892896
<notimestamp>true</notimestamp>
893897
<failOnWarnings>false</failOnWarnings>
894898
<quiet>true</quiet>
899+
<additionalJOption>--add-modules=jdk.incubator.vector</additionalJOption>
895900
<skip>${doc.skip}</skip>
896901
<show>public</show>
897902
<source>${java.level}</source>

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
import java.util.List;
2727
import java.util.concurrent.ExecutorService;
2828

29-
// import jdk.incubator.vector.DoubleVector;
30-
// import jdk.incubator.vector.VectorSpecies;
29+
import jdk.incubator.vector.DoubleVector;
30+
import jdk.incubator.vector.VectorSpecies;
3131
import org.apache.commons.lang3.NotImplementedException;
3232
import org.apache.sysds.runtime.DMLRuntimeException;
3333
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
@@ -75,7 +75,7 @@ public class ColGroupDDC extends APreAgg implements IMapToDataGroup {
7575

7676
protected final AMapToData _data;
7777

78-
// static final VectorSpecies<Double> SPECIES = DoubleVector.SPECIES_PREFERRED;
78+
static final VectorSpecies<Double> SPECIES = DoubleVector.SPECIES_PREFERRED;
7979

8080
private ColGroupDDC(IColIndex colIndexes, IDictionary dict, AMapToData data, int[] cachedCounts) {
8181
super(colIndexes, dict, cachedCounts);
@@ -625,16 +625,16 @@ private void identityRightDecompressingMult(MatrixBlock right, MatrixBlock ret,
625625
final double[] b = right.getDenseBlockValues();
626626
final double[] c = ret.getDenseBlockValues();
627627
final int jd = right.getNumColumns();
628-
final int vLen = 8;
628+
final DoubleVector vVec = DoubleVector.zero(SPECIES);
629+
final int vLen = SPECIES.length();
629630
final int lenJ = cru - crl;
630631
final int end = cru - (lenJ % vLen);
631632
for(int i = rl; i < ru; i++) {
632633
int k = _data.getIndex(i);
633634
final int offOut = i * jd + crl;
634635
final double aa = 1;
635636
final int k_right = _colIndexes.get(k);
636-
vectMM(aa, b, c, end, jd, crl, cru, offOut, k_right, vLen);
637-
637+
vectMM(aa, b, c, end, jd, crl, cru, offOut, k_right, vLen, vVec);
638638
}
639639
}
640640

@@ -644,8 +644,8 @@ private void defaultRightDecompressingMult(MatrixBlock right, MatrixBlock ret, i
644644
final double[] c = ret.getDenseBlockValues();
645645
final int kd = _colIndexes.size();
646646
final int jd = right.getNumColumns();
647-
// final DoubleVector vVec = DoubleVector.zero(SPECIES);
648-
final int vLen = 8;
647+
final DoubleVector vVec = DoubleVector.zero(SPECIES);
648+
final int vLen = SPECIES.length();
649649

650650
final int blkzI = 32;
651651
final int blkzK = 24;
@@ -661,32 +661,22 @@ private void defaultRightDecompressingMult(MatrixBlock right, MatrixBlock ret, i
661661
for(int k = bk; k < bke; k++) {
662662
final double aa = a[offi + k];
663663
final int k_right = _colIndexes.get(k);
664-
vectMM(aa, b, c, end, jd, crl, cru, offOut, k_right, vLen);
664+
vectMM(aa, b, c, end, jd, crl, cru, offOut, k_right, vLen, vVec);
665665
}
666666
}
667667
}
668668
}
669669
}
670670

671-
final void vectMM(double aa, double[] b, double[] c, int endT, int jd, int crl, int cru, int offOut, int k,
672-
int vLen) {
673-
// vVec = vVec.broadcast(aa);
671+
final void vectMM(double aa, double[] b, double[] c, int endT, int jd, int crl, int cru, int offOut, int k, int vLen, DoubleVector vVec) {
672+
vVec = vVec.broadcast(aa);
674673
final int offj = k * jd;
675674
final int end = endT + offj;
676675
for(int j = offj + crl; j < end; j += vLen, offOut += vLen) {
677-
// DoubleVector res = DoubleVector.fromArray(SPECIES, c, offOut);
678-
// DoubleVector bVec = DoubleVector.fromArray(SPECIES, b, j);
679-
// res = vVec.fma(bVec, res);
680-
// res.intoArray(c, offOut);
681-
682-
c[offOut] += aa * b[j];
683-
c[offOut + 1] += aa * b[j + 1];
684-
c[offOut + 2] += aa * b[j + 2];
685-
c[offOut + 3] += aa * b[j + 3];
686-
c[offOut + 4] += aa * b[j + 4];
687-
c[offOut + 5] += aa * b[j + 5];
688-
c[offOut + 6] += aa * b[j + 6];
689-
c[offOut + 7] += aa * b[j + 7];
676+
DoubleVector res = DoubleVector.fromArray(SPECIES, c, offOut);
677+
DoubleVector bVec = DoubleVector.fromArray(SPECIES, b, j);
678+
res = vVec.fma(bVec, res);
679+
res.intoArray(c, offOut);
690680
}
691681
for(int j = end; j < cru + offj; j++, offOut++) {
692682
double bb = b[j];

src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
import java.util.Arrays;
2828
import java.util.Set;
2929

30+
import jdk.incubator.vector.DoubleVector;
31+
import jdk.incubator.vector.VectorSpecies;
3032
import org.apache.commons.lang3.NotImplementedException;
3133
import org.apache.sysds.runtime.compress.DMLCompressionException;
3234
import org.apache.sysds.runtime.compress.colgroup.indexes.ArrayIndex;
@@ -65,6 +67,8 @@ public class MatrixBlockDictionary extends ADictionary {
6567

6668
final private MatrixBlock _data;
6769

70+
static final VectorSpecies<Double> SPECIES = DoubleVector.SPECIES_PREFERRED;
71+
6872
/**
6973
* Unsafe private constructor that does not check the data validity. USE WITH CAUTION.
7074
*
@@ -2088,7 +2092,71 @@ private void preaggValuesFromDenseDictDenseAggArray(final int numVals, final ICo
20882092

20892093
private void preaggValuesFromDenseDictDenseAggRange(final int numVals, final IColIndex colIndexes, final int s,
20902094
final int e, final double[] b, final int cut, final double[] ret) {
2091-
preaggValuesFromDenseDictDenseAggRangeGeneric(numVals, colIndexes, s, e, b, cut, ret);
2095+
if(colIndexes instanceof RangeIndex) {
2096+
RangeIndex ri = (RangeIndex) colIndexes;
2097+
preaggValuesFromDenseDictDenseAggRangeRange(numVals, ri.get(0), ri.get(0) + ri.size(), s, e, b, cut, ret);
2098+
}
2099+
else
2100+
preaggValuesFromDenseDictDenseAggRangeGeneric(numVals, colIndexes, s, e, b, cut, ret);
2101+
}
2102+
2103+
private void preaggValuesFromDenseDictDenseAggRangeRange(final int numVals, final int ls, final int le, final int rs,
2104+
final int re, final double[] b, final int cut, final double[] ret) {
2105+
final int cz = le - ls;
2106+
final int az = re - rs;
2107+
// final int nCells = numVals * cz;
2108+
final double[] values = _data.getDenseBlockValues();
2109+
// Correctly named ikj matrix multiplication .
2110+
2111+
final int blkzI = 32;
2112+
final int blkzK = 24;
2113+
final int blkzJ = 1024;
2114+
for(int bi = 0; bi < numVals; bi += blkzI) {
2115+
final int bie = Math.min(numVals, bi + blkzI);
2116+
for(int bk = 0; bk < cz; bk += blkzK) {
2117+
final int bke = Math.min(cz, bk + blkzK);
2118+
for(int bj = 0; bj < az; bj += blkzJ) {
2119+
final int bje = Math.min(az, bj + blkzJ);
2120+
final int sOffT = rs + bj;
2121+
final int eOffT = rs + bje;
2122+
preaggValuesFromDenseDictBlockedIKJ(values, b, ret, bi, bk, bj, bie, bke, cz, az, ls, cut, sOffT, eOffT);
2123+
}
2124+
}
2125+
}
2126+
}
2127+
2128+
private static void preaggValuesFromDenseDictBlockedIKJ(double[] a, double[] b, double[] ret, int bi, int bk, int bj,
2129+
int bie, int bke, int cz, int az, int ls, int cut, int sOffT, int eOffT) {
2130+
final int vLen = SPECIES.length();
2131+
final DoubleVector vVec = DoubleVector.zero(SPECIES);
2132+
final int leftover = sOffT - eOffT % vLen; // leftover not vectorized
2133+
for(int i = bi; i < bie; i++) {
2134+
final int offI = i * cz;
2135+
final int offOutT = i * az + bj;
2136+
for(int k = bk; k < bke; k++) {
2137+
final int idb = (k + ls) * cut;
2138+
final int sOff = sOffT + idb;
2139+
final int eOff = eOffT + idb;
2140+
final double v = a[offI + k];
2141+
vecInnerLoop(v, b, ret, offOutT, eOff, sOff, leftover, vLen, vVec);
2142+
}
2143+
}
2144+
}
2145+
2146+
private static void vecInnerLoop(final double v, final double[] b, final double[] ret, final int offOutT,
2147+
final int eOff, final int sOff, final int leftover, final int vLen, DoubleVector vVec) {
2148+
int offOut = offOutT;
2149+
vVec = vVec.broadcast(v);
2150+
final int end = eOff - leftover;
2151+
for(int j = sOff; j < end; j += vLen, offOut += vLen) {
2152+
DoubleVector res = DoubleVector.fromArray(SPECIES, ret, offOut);
2153+
DoubleVector bVec = DoubleVector.fromArray(SPECIES, b, j);
2154+
vVec.fma(bVec, res).intoArray(ret, offOut);
2155+
}
2156+
for(int j = end; j < eOff; j++, offOut++) {
2157+
ret[offOut] += v * b[j];
2158+
}
2159+
20922160
}
20932161

20942162
private void preaggValuesFromDenseDictDenseAggRangeGeneric(final int numVals, final IColIndex colIndexes,

0 commit comments

Comments
 (0)