Skip to content

Commit 9967e97

Browse files
committed
Sketch writing with V2 API
1 parent d7efaad commit 9967e97

File tree

8 files changed

+137
-4
lines changed

8 files changed

+137
-4
lines changed

src/main/scala/uk/co/gresearch/spark/dgraph/connector/TableBase.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.util
2222
import java.util.UUID
2323
import scala.jdk.CollectionConverters._
2424

25-
trait TableBase extends Table with SupportsRead{
25+
trait TableBase extends Table with SupportsRead {
2626

2727
val cid: UUID
2828

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package uk.co.gresearch.spark.dgraph.connector
2+
3+
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, PhysicalWriteInfo, WriterCommitMessage}
4+
import org.apache.spark.sql.types.StructType
5+
import uk.co.gresearch.spark.dgraph.connector.model.GraphTableModel
6+
7+
case class TripleBatchWrite(schema: StructType, model: GraphTableModel) extends BatchWrite {
8+
override def createBatchWriterFactory(physicalWriteInfo: PhysicalWriteInfo): DataWriterFactory =
9+
TripleDataWriterFactory(schema, model)
10+
11+
override def commit(writerCommitMessages: Array[WriterCommitMessage]): Unit = {
12+
writerCommitMessages.foreach(msg => Console.println(s"Committed $msg"))
13+
}
14+
15+
override def abort(writerCommitMessages: Array[WriterCommitMessage]): Unit = { }
16+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package uk.co.gresearch.spark.dgraph.connector
2+
3+
import org.apache.spark.sql.catalyst.InternalRow
4+
import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
5+
import org.apache.spark.sql.types.StructType
6+
import uk.co.gresearch.spark.dgraph.connector.model.GraphTableModel
7+
8+
case class TripleDataWriter(schema: StructType, model: GraphTableModel) extends DataWriter[InternalRow] {
9+
var triples = 0L
10+
11+
override def write(row: InternalRow): Unit = {
12+
// Console.println(s"Writing row: $row")
13+
triples = triples + 1
14+
}
15+
16+
override def commit(): WriterCommitMessage = {
17+
val msg: WriterCommitMessage = new WriterCommitMessage {
18+
val name: String = s"$triples triples (${Thread.currentThread().getName})"
19+
override def toString: String = name
20+
}
21+
Console.println(s"Committing $msg")
22+
msg
23+
}
24+
25+
override def abort(): Unit = { }
26+
27+
override def close(): Unit = { }
28+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package uk.co.gresearch.spark.dgraph.connector
2+
3+
import org.apache.spark.sql.catalyst.InternalRow
4+
import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory}
5+
import org.apache.spark.sql.types.StructType
6+
import uk.co.gresearch.spark.dgraph.connector.model.GraphTableModel
7+
8+
case class TripleDataWriterFactory(schema: StructType, model: GraphTableModel) extends DataWriterFactory {
9+
override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = TripleDataWriter(schema, model)
10+
}

src/main/scala/uk/co/gresearch/spark/dgraph/connector/TripleTable.scala

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,31 @@
1616

1717
package uk.co.gresearch.spark.dgraph.connector
1818

19-
import java.util.UUID
20-
19+
import org.apache.spark.sql.connector.catalog.{SupportsWrite, TableCapability}
2120
import org.apache.spark.sql.connector.read.ScanBuilder
21+
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
2222
import org.apache.spark.sql.types.StructType
2323
import org.apache.spark.sql.util.CaseInsensitiveStringMap
2424
import uk.co.gresearch.spark.dgraph.connector.model.GraphTableModel
2525
import uk.co.gresearch.spark.dgraph.connector.partitioner.Partitioner
2626

27-
case class TripleTable(partitioner: Partitioner, model: GraphTableModel, val cid: UUID) extends TableBase {
27+
import java.util
28+
import java.util.UUID
29+
import scala.jdk.CollectionConverters._
30+
31+
case class TripleTable(partitioner: Partitioner, model: GraphTableModel, cid: UUID)
32+
extends TableBase with SupportsWrite {
2833

2934
override def schema(): StructType = model.schema()
3035

3136
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
3237
TripleScanBuilder(partitioner, model)
3338

39+
override def newWriteBuilder(logicalWriteInfo: LogicalWriteInfo): WriteBuilder =
40+
TripleWriteBuilder(logicalWriteInfo.schema(), model)
41+
42+
override def capabilities(): util.Set[TableCapability] = Set(
43+
TableCapability.BATCH_READ, TableCapability.BATCH_WRITE, TableCapability.ACCEPT_ANY_SCHEMA
44+
).asJava
45+
3446
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package uk.co.gresearch.spark.dgraph.connector
2+
3+
import org.apache.spark.sql.connector.write.{BatchWrite, WriteBuilder}
4+
import org.apache.spark.sql.types.StructType
5+
import uk.co.gresearch.spark.dgraph.connector.model.GraphTableModel
6+
7+
case class TripleWriteBuilder(schema: StructType, model: GraphTableModel)
8+
extends WriteBuilder {
9+
override def buildForBatch(): BatchWrite = TripleBatchWrite(schema, model)
10+
}

src/main/scala/uk/co/gresearch/spark/dgraph/connector/sources/NodeSource.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
package uk.co.gresearch.spark.dgraph.connector.sources
1818

19+
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
1920
import org.apache.spark.sql.connector.catalog.Table
2021
import org.apache.spark.sql.connector.expressions.Transform
22+
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider}
2123
import org.apache.spark.sql.types.StructType
2224
import org.apache.spark.sql.util.CaseInsensitiveStringMap
2325
import uk.co.gresearch.spark.dgraph.connector._
@@ -99,4 +101,11 @@ class NodeSource() extends TableProviderBase
99101
TripleTable(partitioner, model, clusterState.cid)
100102
}
101103

104+
def createRelation(context: SQLContext, mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = {
105+
new BaseRelation {
106+
override def sqlContext: SQLContext = context
107+
override def schema: StructType = data.schema
108+
}
109+
}
110+
102111
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright 2020 G-Research
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package uk.co.gresearch.spark.dgraph.connector
18+
19+
import org.apache.spark.sql.SaveMode
20+
import org.apache.spark.sql.types.{DoubleType, StringType}
21+
import org.scalatest.funspec.AnyFunSpec
22+
import uk.co.gresearch.spark.dgraph.DgraphTestCluster
23+
24+
class TestWriter extends AnyFunSpec with ConnectorSparkTestSession with DgraphTestCluster {
25+
26+
import spark.implicits._
27+
28+
// we want a fresh cluster that we can mutate, definitively not one that is always running and used by all tests
29+
override val clusterAlwaysStartUp: Boolean = true
30+
31+
describe("Connector") {
32+
it("should write") {
33+
spark.range(0, 1000000, 1, 10)
34+
.select(
35+
$"id".as("subject"),
36+
$"id".cast(StringType).as("str"),
37+
$"id".cast(DoubleType).as("dbl")
38+
)
39+
.repartition($"subject")
40+
.sortWithinPartitions($"subject")
41+
.write
42+
.mode(SaveMode.Append)
43+
.option("dgraph.nodes.mode", "wide")
44+
.format("uk.co.gresearch.spark.dgraph.nodes")
45+
.save(dgraph.target)
46+
}
47+
}
48+
}

0 commit comments

Comments
 (0)