Skip to content

Commit 5eaba8c

Browse files
committed
Completed TripleSource tests for reporting partitioning
1 parent fc944c9 commit 5eaba8c

File tree

1 file changed

+37
-78
lines changed

1 file changed

+37
-78
lines changed

src/test/scala/uk/co/gresearch/spark/dgraph/connector/sources/TestTriplesSource.scala

Lines changed: 37 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -853,48 +853,12 @@ class TestTriplesSource extends AnyFunSpec
853853
)
854854
}
855855

856-
it("should report single partitioning") {
857-
val target = dgraph.target
858-
val df =
859-
reader
860-
.option(PartitionerOption, SingletonPartitionerOption)
861-
.dgraph.triples(target)
862-
.repartition(1)
863-
df.queryExecution.optimizedPlan
864-
print()
865-
}
866-
867856
def containsShuffleExchangeExec(plan: SparkPlan): Boolean = plan match {
868857
case _: ShuffleExchangeExec => true
869858
case p => p.children.exists(containsShuffleExchangeExec)
870859
}
871860

872-
val predicatePartitioningTests = Seq(
873-
("distinct", (df: DataFrame) => df.select($"predicate").distinct(), Seq(
874-
Row("dgraph.type"), Row("director"), Row("name"), Row("release_date"),
875-
Row("revenue"), Row("running_time"), Row("starring"), Row("title")
876-
)),
877-
("groupBy", (df: DataFrame) => df.groupBy($"predicate").count(), Seq(
878-
Row("dgraph.type", 10), Row("director", 3), Row("name", 6), Row("release_date", 4),
879-
Row("revenue", 4), Row("running_time", 4), Row("starring", 9), Row("title", 3)
880-
)),
881-
("Window.partitionBy", (df: DataFrame) => df.select($"predicate", count(lit(1)) over Window.partitionBy($"predicate")), Seq(
882-
Row("dgraph.type", 10), Row("director", 3), Row("name", 6), Row("release_date", 4),
883-
Row("revenue", 4), Row("running_time", 4), Row("starring", 9), Row("title", 3)
884-
).flatMap(row => row * row.getInt(1))), // all rows occur with cardinality of their count
885-
("Window.partitionBy.orderBy", (df: DataFrame) => df.select($"predicate", row_number() over Window.partitionBy($"predicate").orderBy($"subject")), Seq(
886-
Row("dgraph.type", 10), Row("director", 3), Row("name", 6), Row("release_date", 4),
887-
Row("revenue", 4), Row("running_time", 4), Row("starring", 9), Row("title", 3)
888-
).flatMap(row => row ++ row.getInt(1))), // each row occurs with row_number up to their cardinality
889-
)
890-
891861
def testPartitioning(df: () => DataFrame,
892-
tests: Seq[(String, DataFrame => DataFrame, Seq[Row])],
893-
shuffleExpected: Boolean): Unit = {
894-
testPartitioning2(df, tests.map(test => (test._1, test._2, () => test._3)), shuffleExpected = shuffleExpected)
895-
}
896-
897-
def testPartitioning2(df: () => DataFrame,
898862
tests: Seq[(String, DataFrame => DataFrame, () => Seq[Row])],
899863
shuffleExpected: Boolean): Unit = {
900864
val label = if (shuffleExpected) "shuffle" else "reuse partitioning"
@@ -909,6 +873,19 @@ class TestTriplesSource extends AnyFunSpec
909873
}
910874
}
911875

876+
lazy val expectedPredicateCounts = expectedTypedTriples.toSeq.groupBy(_.predicate)
877+
.mapValues(_.length).toSeq.sortBy(_._1).map(e => Row(e._1, e._2))
878+
val predicatePartitioningTests = Seq(
879+
("distinct", (df: DataFrame) => df.select($"predicate").distinct(), () => expectedPredicateCounts.map(row => Row(row.getString(0)))),
880+
("groupBy", (df: DataFrame) => df.groupBy($"predicate").count(), () => expectedPredicateCounts),
881+
("Window.partitionBy", (df: DataFrame) => df.select($"predicate", count(lit(1)) over Window.partitionBy($"predicate")),
882+
() => expectedPredicateCounts.flatMap(row => row * row.getInt(1)) // all rows occur with cardinality of their count
883+
),
884+
("Window.partitionBy.orderBy", (df: DataFrame) => df.select($"predicate", row_number() over Window.partitionBy($"predicate").orderBy($"subject")),
885+
() => expectedPredicateCounts.flatMap(row => row ++ row.getInt(1)) // each row occurs with row_number up to their cardinality
886+
)
887+
)
888+
912889
describe("without predicate partitioning") {
913890
val withoutPartitioning = () =>
914891
reader
@@ -918,7 +895,6 @@ class TestTriplesSource extends AnyFunSpec
918895
MaxLeaseIdEstimatorIdOption -> dgraph.highestUid.toString
919896
))
920897
.dgraph.triples(dgraph.target)
921-
.where(!$"predicate".contains("@"))
922898

923899
testPartitioning(withoutPartitioning, predicatePartitioningTests, shuffleExpected = true)
924900
}
@@ -929,28 +905,21 @@ class TestTriplesSource extends AnyFunSpec
929905
.option(PartitionerOption, PredicatePartitionerOption)
930906
.option(PredicatePartitionerPredicatesOption, "2")
931907
.dgraph.triples(dgraph.target)
932-
.where(!$"predicate".contains("@"))
933908

934909
testPartitioning(withPartitioning, predicatePartitioningTests, shuffleExpected = false)
935910
}
936911

912+
lazy val expectedSubjectCounts = expectedTypedTriples.toSeq.groupBy(_.subject)
913+
.mapValues(_.length).toSeq.sortBy(_._1).map(e => Row(e._1, e._2))
937914
val subjectPartitioningTests = Seq(
938-
("distinct", (df: DataFrame) => df.select($"subject").distinct(), () => dgraph.allUids.sorted.map(Row(_))),
939-
("groupBy", (df: DataFrame) => df.groupBy($"subject").count(), () => Seq(
940-
Row(dgraph.han, 2), Row(dgraph.irvin, 2), Row(dgraph.leia, 2), Row(dgraph.luke, 2),
941-
Row(dgraph.lucas, 2), Row(dgraph.richard, 2),
942-
Row(dgraph.st1, 4), Row(dgraph.sw1, 9), Row(dgraph.sw2, 9), Row(dgraph.sw3, 9)
943-
).sortBy(_.getLong(0))),
944-
("Window.partitionBy", (df: DataFrame) => df.select($"subject", count(lit(1)) over Window.partitionBy($"subject")), () => Seq(
945-
Row(dgraph.han, 2), Row(dgraph.irvin, 2), Row(dgraph.leia, 2), Row(dgraph.luke, 2),
946-
Row(dgraph.lucas, 2), Row(dgraph.richard, 2),
947-
Row(dgraph.st1, 4), Row(dgraph.sw1, 9), Row(dgraph.sw2, 9), Row(dgraph.sw3, 9)
948-
).sortBy(_.getLong(0)).flatMap(row => row * row.getInt(1))), // all rows occur with cardinality of their count
949-
("Window.partitionBy.orderBy", (df: DataFrame) => df.select($"subject", row_number() over Window.partitionBy($"subject").orderBy($"predicate")), () => Seq(
950-
Row(dgraph.han, 2), Row(dgraph.irvin, 2), Row(dgraph.leia, 2), Row(dgraph.luke, 2),
951-
Row(dgraph.lucas, 2), Row(dgraph.richard, 2),
952-
Row(dgraph.st1, 4), Row(dgraph.sw1, 9), Row(dgraph.sw2, 9), Row(dgraph.sw3, 9)
953-
).sortBy(_.getLong(0)).flatMap(row => row ++ row.getInt(1))), // each row occurs with row_number up to their cardinality
915+
("distinct", (df: DataFrame) => df.select($"subject").distinct(), () => expectedSubjectCounts.map(row => Row(row.getLong(0)))),
916+
("groupBy", (df: DataFrame) => df.groupBy($"subject").count(), () => expectedSubjectCounts),
917+
("Window.partitionBy", (df: DataFrame) => df.select($"subject", count(lit(1)) over Window.partitionBy($"subject")),
918+
() => expectedSubjectCounts.flatMap(row => row * row.getInt(1)) // all rows occur with cardinality of their count
919+
),
920+
("Window.partitionBy.orderBy", (df: DataFrame) => df.select($"subject", row_number() over Window.partitionBy($"subject").orderBy($"predicate")),
921+
() => expectedSubjectCounts.flatMap(row => row ++ row.getInt(1)) // each row occurs with row_number up to their cardinality
922+
)
954923
)
955924

956925
describe("without subject partitioning") {
@@ -959,9 +928,8 @@ class TestTriplesSource extends AnyFunSpec
959928
.option(PartitionerOption, PredicatePartitionerOption)
960929
.option(PredicatePartitionerPredicatesOption, "2")
961930
.dgraph.triples(dgraph.target)
962-
.where(!$"predicate".contains("@"))
963931

964-
testPartitioning2(withoutPartitioning, subjectPartitioningTests, shuffleExpected = true)
932+
testPartitioning(withoutPartitioning, subjectPartitioningTests, shuffleExpected = true)
965933
}
966934

967935
describe("with subject partitioning") {
@@ -973,30 +941,21 @@ class TestTriplesSource extends AnyFunSpec
973941
MaxLeaseIdEstimatorIdOption -> dgraph.highestUid.toString
974942
))
975943
.dgraph.triples(dgraph.target)
976-
.where(!$"predicate".contains("@"))
977944

978-
testPartitioning2(withPartitioning, subjectPartitioningTests, shuffleExpected = false)
945+
testPartitioning(withPartitioning, subjectPartitioningTests, shuffleExpected = false)
979946
}
980947

981-
// Array([3,dgraph.type], [3,release_date], [3,revenue], [3,running_time], [4,dgraph.type], [4,name], [5,dgraph.type], [5,name], [6,dgraph.type], [6,director], [6,release_date], [6,revenue], [6,running_time], [6,starring], [6,title], [7,dgraph.type], [7,name], [8,dgraph.type], [8,director], [8,release_date], [8,revenue], [8,running_time], [8,starring], [8,title], [9,dgraph.type], [9,director], [9,release_date], [9,revenue], [9,running_time], [9,starring], [9,title], [10,dgraph.type], [10,name], [11,dgraph.type], [11,name], [12,dgraph.type], [12,name])
982-
948+
lazy val expectedSubjectAndPredicateCounts = expectedTypedTriples.toSeq.groupBy(t => (t.subject, t.predicate))
949+
.mapValues(_.length).toSeq.sortBy(_._1).map(e => Row(e._1._1, e._1._2, e._2))
983950
val subjectAndPredicatePartitioningTests = Seq(
984-
("distinct", (df: DataFrame) => df.select($"subject", $"predicate").distinct(),
985-
() => TriplesSourceExpecteds(dgraph).getExpectedTypedTriples.map(t => Row(t.subject, t.predicate))
986-
.toSeq.sortBy(row => (row.getLong(0), row.getString(1)))
951+
("distinct", (df: DataFrame) => df.select($"subject", $"predicate").distinct(), () => expectedSubjectAndPredicateCounts.map(row => Row(row.getLong(0), row.getString(1)))),
952+
("groupBy", (df: DataFrame) => df.groupBy($"subject", $"predicate").count(), () => expectedSubjectAndPredicateCounts),
953+
("Window.partitionBy", (df: DataFrame) => df.select($"subject", $"predicate", count(lit(1)) over Window.partitionBy($"subject", $"predicate")),
954+
() => expectedSubjectAndPredicateCounts.flatMap(row => row * row.getInt(2)) // all rows occur with cardinality of their count
987955
),
988-
("groupBy", (df: DataFrame) => df.groupBy($"subject", $"predicate").count(), () => Seq(
989-
Row("dgraph.type", 11), Row("director", 3), Row("name", 6), Row("release_date", 4),
990-
Row("revenue", 4), Row("running_time", 4), Row("starring", 9), Row("title", 3)
991-
)),
992-
("Window.partitionBy", (df: DataFrame) => df.select($"subject", $"predicate", count(lit(1)) over Window.partitionBy($"subject", $"predicate")), () => Seq(
993-
Row("dgraph.type", 11), Row("director", 3), Row("name", 6), Row("release_date", 4),
994-
Row("revenue", 4), Row("running_time", 4), Row("starring", 9), Row("title", 3)
995-
).flatMap(row => row * row.getInt(1))), // all rows occur with cardinality of their count
996-
("Window.partitionBy.orderBy", (df: DataFrame) => df.select($"subject", $"predicate", row_number() over Window.partitionBy($"subject", $"predicate").orderBy($"objectType")), () => Seq(
997-
Row("dgraph.type", 11), Row("director", 3), Row("name", 6), Row("release_date", 4),
998-
Row("revenue", 4), Row("running_time", 4), Row("starring", 9), Row("title", 3)
999-
).flatMap(row => row ++ row.getInt(1))), // each row occurs with row_number up to their cardinality
956+
("Window.partitionBy.orderBy", (df: DataFrame) => df.select($"subject", $"predicate", row_number() over Window.partitionBy($"subject", $"predicate").orderBy($"objectType")),
957+
() => expectedSubjectAndPredicateCounts.flatMap(row => row ++ row.getInt(2)) // each row occurs with row_number up to their cardinality
958+
)
1000959
)
1001960

1002961
describe("without subject and predicate partitioning") {
@@ -1006,7 +965,7 @@ class TestTriplesSource extends AnyFunSpec
1006965
.option(PredicatePartitionerPredicatesOption, "2")
1007966
.dgraph.triples(dgraph.target)
1008967

1009-
testPartitioning2(withoutPartitioning, subjectAndPredicatePartitioningTests, shuffleExpected = true)
968+
testPartitioning(withoutPartitioning, subjectAndPredicatePartitioningTests, shuffleExpected = true)
1010969
}
1011970

1012971
describe("with subject and predicate partitioning") {
@@ -1020,7 +979,7 @@ class TestTriplesSource extends AnyFunSpec
1020979
))
1021980
.dgraph.triples(dgraph.target)
1022981

1023-
testPartitioning2(withPartitioning, subjectAndPredicatePartitioningTests, shuffleExpected = false)
982+
testPartitioning(withPartitioning, subjectAndPredicatePartitioningTests, shuffleExpected = false)
1024983
}
1025984

1026985
}
@@ -1180,7 +1139,7 @@ object TestTriplesSource {
11801139

11811140
implicit class ExtendedRow(row: Row) {
11821141
def *(n: Int): Seq[Row] = Seq.fill(n)(row)
1183-
def ++(n: Int): Seq[Row] = Seq.fill(n)(row).zipWithIndex.map { case (row, idx) => Row(row.get(0), idx+1) }
1142+
def ++(n: Int): Seq[Row] = Seq.fill(n)(row).zipWithIndex.map { case (row, idx) => Row(row.toSeq.init :+ (idx+1): _*) }
11841143
}
11851144

11861145
}

0 commit comments

Comments
 (0)