Skip to content

Commit 8efa8b4

Browse files
committed
Add BufferedMessageEncoder, fix PSQLFrontendMessageEncoder
1 parent ce57b02 commit 8efa8b4

17 files changed

+147
-96
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import NIOCore
2+
3+
struct BufferedMessageEncoder<Encoder: MessageToByteEncoder> {
4+
private enum State {
5+
case flushed
6+
case writable
7+
}
8+
9+
private var buffer: ByteBuffer
10+
private var state: State = .writable
11+
private var encoder: Encoder
12+
13+
init(buffer: ByteBuffer, encoder: Encoder) {
14+
self.buffer = buffer
15+
self.encoder = encoder
16+
}
17+
18+
mutating func encode(_ message: Encoder.OutboundIn) throws {
19+
switch self.state {
20+
case .flushed:
21+
self.state = .writable
22+
self.buffer.clear()
23+
24+
case .writable:
25+
break
26+
}
27+
28+
try self.encoder.encode(data: message, out: &self.buffer)
29+
}
30+
31+
mutating func flush() -> ByteBuffer? {
32+
guard self.buffer.readableBytes > 0 else {
33+
return nil
34+
}
35+
36+
self.state = .flushed
37+
return self.buffer
38+
}
39+
}

Sources/PostgresNIO/New/PSQLConnection.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ final class PSQLConnection {
231231

232232
return channel.pipeline.addHandlers([
233233
decoder,
234-
MessageToByteHandler(PSQLFrontendMessage.Encoder(jsonEncoder: configuration.coders.jsonEncoder)),
234+
MessageToByteHandler(PSQLFrontendMessageEncoder(jsonEncoder: configuration.coders.jsonEncoder)),
235235
PSQLChannelHandler(
236236
authentification: configuration.authentication,
237237
logger: logger,

Sources/PostgresNIO/New/PSQLFrontendMessage.swift

Lines changed: 0 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -129,82 +129,6 @@ extension PSQLFrontendMessage {
129129
}
130130
}
131131

132-
extension PSQLFrontendMessage {
133-
struct Encoder: MessageToByteEncoder {
134-
typealias OutboundIn = PSQLFrontendMessage
135-
136-
let jsonEncoder: PSQLJSONEncoder
137-
138-
init(jsonEncoder: PSQLJSONEncoder) {
139-
self.jsonEncoder = jsonEncoder
140-
}
141-
142-
func encode(data message: PSQLFrontendMessage, out buffer: inout ByteBuffer) throws {
143-
struct EmptyPayload: PSQLMessagePayloadEncodable {
144-
func encode(into buffer: inout ByteBuffer) {}
145-
}
146-
147-
func encode<Payload: PSQLMessagePayloadEncodable>(_ payload: Payload, into buffer: inout ByteBuffer) {
148-
let startIndex = buffer.writerIndex
149-
buffer.writeInteger(Int32(0)) // placeholder for length
150-
payload.encode(into: &buffer)
151-
let length = Int32(buffer.writerIndex - startIndex)
152-
buffer.setInteger(length, at: startIndex)
153-
}
154-
155-
switch message {
156-
case .bind(let bind):
157-
buffer.writeInteger(message.id.rawValue)
158-
let startIndex = buffer.writerIndex
159-
buffer.writeInteger(Int32(0)) // placeholder for length
160-
try bind.encode(into: &buffer, using: self.jsonEncoder)
161-
let length = Int32(buffer.writerIndex - startIndex)
162-
buffer.setInteger(length, at: startIndex)
163-
164-
case .cancel(let cancel):
165-
// cancel requests don't have an identifier
166-
encode(cancel, into: &buffer)
167-
case .close(let close):
168-
buffer.writeFrontendMessageID(message.id)
169-
encode(close, into: &buffer)
170-
case .describe(let describe):
171-
buffer.writeFrontendMessageID(message.id)
172-
encode(describe, into: &buffer)
173-
case .execute(let execute):
174-
buffer.writeFrontendMessageID(message.id)
175-
encode(execute, into: &buffer)
176-
case .flush:
177-
buffer.writeFrontendMessageID(message.id)
178-
encode(EmptyPayload(), into: &buffer)
179-
case .parse(let parse):
180-
buffer.writeFrontendMessageID(message.id)
181-
encode(parse, into: &buffer)
182-
case .password(let password):
183-
buffer.writeFrontendMessageID(message.id)
184-
encode(password, into: &buffer)
185-
case .saslInitialResponse(let saslInitialResponse):
186-
buffer.writeFrontendMessageID(message.id)
187-
encode(saslInitialResponse, into: &buffer)
188-
case .saslResponse(let saslResponse):
189-
buffer.writeFrontendMessageID(message.id)
190-
encode(saslResponse, into: &buffer)
191-
case .sslRequest(let request):
192-
// sslRequests don't have an identifier
193-
encode(request, into: &buffer)
194-
case .startup(let startup):
195-
// startup requests don't have an identifier
196-
encode(startup, into: &buffer)
197-
case .sync:
198-
buffer.writeFrontendMessageID(message.id)
199-
encode(EmptyPayload(), into: &buffer)
200-
case .terminate:
201-
buffer.writeFrontendMessageID(message.id)
202-
encode(EmptyPayload(), into: &buffer)
203-
}
204-
}
205-
}
206-
}
207-
208132
protocol PSQLMessagePayloadEncodable {
209133
func encode(into buffer: inout ByteBuffer)
210134
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
2+
struct PSQLFrontendMessageEncoder: MessageToByteEncoder {
3+
typealias OutboundIn = PSQLFrontendMessage
4+
5+
let jsonEncoder: PSQLJSONEncoder
6+
7+
init(jsonEncoder: PSQLJSONEncoder) {
8+
self.jsonEncoder = jsonEncoder
9+
}
10+
11+
func encode(data message: PSQLFrontendMessage, out buffer: inout ByteBuffer) throws {
12+
switch message {
13+
case .bind(let bind):
14+
buffer.writeInteger(message.id.rawValue)
15+
let startIndex = buffer.writerIndex
16+
buffer.writeInteger(Int32(0)) // placeholder for length
17+
try bind.encode(into: &buffer, using: self.jsonEncoder)
18+
let length = Int32(buffer.writerIndex - startIndex)
19+
buffer.setInteger(length, at: startIndex)
20+
21+
case .cancel(let cancel):
22+
// cancel requests don't have an identifier
23+
self.encode(payload: cancel, into: &buffer)
24+
25+
case .close(let close):
26+
self.encode(messageID: message.id, payload: close, into: &buffer)
27+
28+
case .describe(let describe):
29+
self.encode(messageID: message.id, payload: describe, into: &buffer)
30+
31+
case .execute(let execute):
32+
self.encode(messageID: message.id, payload: execute, into: &buffer)
33+
34+
case .flush:
35+
self.encode(messageID: message.id, payload: EmptyPayload(), into: &buffer)
36+
37+
case .parse(let parse):
38+
self.encode(messageID: message.id, payload: parse, into: &buffer)
39+
40+
case .password(let password):
41+
self.encode(messageID: message.id, payload: password, into: &buffer)
42+
43+
case .saslInitialResponse(let saslInitialResponse):
44+
self.encode(messageID: message.id, payload: saslInitialResponse, into: &buffer)
45+
46+
case .saslResponse(let saslResponse):
47+
self.encode(messageID: message.id, payload: saslResponse, into: &buffer)
48+
49+
case .sslRequest(let request):
50+
// sslRequests don't have an identifier
51+
self.encode(payload: request, into: &buffer)
52+
53+
case .startup(let startup):
54+
// startup requests don't have an identifier
55+
self.encode(payload: startup, into: &buffer)
56+
57+
case .sync:
58+
self.encode(messageID: message.id, payload: EmptyPayload(), into: &buffer)
59+
60+
case .terminate:
61+
self.encode(messageID: message.id, payload: EmptyPayload(), into: &buffer)
62+
}
63+
}
64+
65+
private struct EmptyPayload: PSQLMessagePayloadEncodable {
66+
func encode(into buffer: inout ByteBuffer) {}
67+
}
68+
69+
private func encode<Payload: PSQLMessagePayloadEncodable>(
70+
messageID: PSQLFrontendMessage.ID,
71+
payload: Payload,
72+
into buffer: inout ByteBuffer)
73+
{
74+
buffer.writeFrontendMessageID(messageID)
75+
self.encode(payload: payload, into: &buffer)
76+
}
77+
78+
private func encode<Payload: PSQLMessagePayloadEncodable>(
79+
payload: Payload,
80+
into buffer: inout ByteBuffer)
81+
{
82+
let startIndex = buffer.writerIndex
83+
buffer.writeInteger(Int32(0)) // placeholder for length
84+
payload.encode(into: &buffer)
85+
let length = Int32(buffer.writerIndex - startIndex)
86+
buffer.setInteger(length, at: startIndex)
87+
}
88+
}

Tests/PostgresNIOTests/New/Extensions/PSQLCoding+TestUtils.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
@testable import PostgresNIO
22
import Foundation
33

4-
extension PSQLFrontendMessage.Encoder {
4+
extension PSQLFrontendMessageEncoder {
55
static var forTests: Self {
66
Self(jsonEncoder: JSONEncoder())
77
}

Tests/PostgresNIOTests/New/Messages/BindTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import NIOCore
55
class BindTests: XCTestCase {
66

77
func testEncodeBind() {
8-
let encoder = PSQLFrontendMessage.Encoder.forTests
8+
let encoder = PSQLFrontendMessageEncoder.forTests
99
var byteBuffer = ByteBuffer()
1010
let bind = PSQLFrontendMessage.Bind(portalName: "", preparedStatementName: "", parameters: ["Hello", "World"])
1111
let message = PSQLFrontendMessage.bind(bind)

Tests/PostgresNIOTests/New/Messages/CancelTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import NIOCore
55
class CancelTests: XCTestCase {
66

77
func testEncodeCancel() {
8-
let encoder = PSQLFrontendMessage.Encoder.forTests
8+
let encoder = PSQLFrontendMessageEncoder.forTests
99
var byteBuffer = ByteBuffer()
1010
let cancel = PSQLFrontendMessage.Cancel(processID: 1234, secretKey: 4567)
1111
let message = PSQLFrontendMessage.cancel(cancel)

Tests/PostgresNIOTests/New/Messages/CloseTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import NIOCore
55
class CloseTests: XCTestCase {
66

77
func testEncodeClosePortal() {
8-
let encoder = PSQLFrontendMessage.Encoder.forTests
8+
let encoder = PSQLFrontendMessageEncoder.forTests
99
var byteBuffer = ByteBuffer()
1010
let message = PSQLFrontendMessage.close(.portal("Hello"))
1111
XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer))
@@ -19,7 +19,7 @@ class CloseTests: XCTestCase {
1919
}
2020

2121
func testEncodeCloseUnnamedStatement() {
22-
let encoder = PSQLFrontendMessage.Encoder.forTests
22+
let encoder = PSQLFrontendMessageEncoder.forTests
2323
var byteBuffer = ByteBuffer()
2424
let message = PSQLFrontendMessage.close(.preparedStatement(""))
2525
XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer))

Tests/PostgresNIOTests/New/Messages/DescribeTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import NIOCore
55
class DescribeTests: XCTestCase {
66

77
func testEncodeDescribePortal() {
8-
let encoder = PSQLFrontendMessage.Encoder.forTests
8+
let encoder = PSQLFrontendMessageEncoder.forTests
99
var byteBuffer = ByteBuffer()
1010
let message = PSQLFrontendMessage.describe(.portal("Hello"))
1111
XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer))
@@ -19,7 +19,7 @@ class DescribeTests: XCTestCase {
1919
}
2020

2121
func testEncodeDescribeUnnamedStatement() {
22-
let encoder = PSQLFrontendMessage.Encoder.forTests
22+
let encoder = PSQLFrontendMessageEncoder.forTests
2323
var byteBuffer = ByteBuffer()
2424
let message = PSQLFrontendMessage.describe(.preparedStatement(""))
2525
XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer))

Tests/PostgresNIOTests/New/Messages/ExecuteTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import NIOCore
55
class ExecuteTests: XCTestCase {
66

77
func testEncodeExecute() {
8-
let encoder = PSQLFrontendMessage.Encoder.forTests
8+
let encoder = PSQLFrontendMessageEncoder.forTests
99
var byteBuffer = ByteBuffer()
1010
let message = PSQLFrontendMessage.execute(.init(portalName: "", maxNumberOfRows: 0))
1111
XCTAssertNoThrow(try encoder.encode(data: message, out: &byteBuffer))

Tests/PostgresNIOTests/New/Messages/ParseTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import NIOCore
55
class ParseTests: XCTestCase {
66

77
func testEncode() {
8-
let encoder = PSQLFrontendMessage.Encoder.forTests
8+
let encoder = PSQLFrontendMessageEncoder.forTests
99
var byteBuffer = ByteBuffer()
1010
let parse = PSQLFrontendMessage.Parse(
1111
preparedStatementName: "test",

Tests/PostgresNIOTests/New/Messages/PasswordTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import NIOCore
55
class PasswordTests: XCTestCase {
66

77
func testEncodePassword() {
8-
let encoder = PSQLFrontendMessage.Encoder.forTests
8+
let encoder = PSQLFrontendMessageEncoder.forTests
99
var byteBuffer = ByteBuffer()
1010
// md522d085ed8dc3377968dc1c1a40519a2a = "abc123" with salt 1, 2, 3, 4
1111
let message = PSQLFrontendMessage.password(.init(value: "md522d085ed8dc3377968dc1c1a40519a2a"))

Tests/PostgresNIOTests/New/Messages/SASLInitialResponseTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import NIOCore
55
class SASLInitialResponseTests: XCTestCase {
66

77
func testEncodeWithData() {
8-
let encoder = PSQLFrontendMessage.Encoder.forTests
8+
let encoder = PSQLFrontendMessageEncoder.forTests
99
var byteBuffer = ByteBuffer()
1010
let sasl = PSQLFrontendMessage.SASLInitialResponse(
1111
saslMechanism: "hello", initialData: [0, 1, 2, 3, 4, 5, 6, 7])
@@ -30,7 +30,7 @@ class SASLInitialResponseTests: XCTestCase {
3030
}
3131

3232
func testEncodeWithoutData() {
33-
let encoder = PSQLFrontendMessage.Encoder.forTests
33+
let encoder = PSQLFrontendMessageEncoder.forTests
3434
var byteBuffer = ByteBuffer()
3535
let sasl = PSQLFrontendMessage.SASLInitialResponse(
3636
saslMechanism: "hello", initialData: [])

Tests/PostgresNIOTests/New/Messages/SASLResponseTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import NIOCore
55
class SASLResponseTests: XCTestCase {
66

77
func testEncodeWithData() {
8-
let encoder = PSQLFrontendMessage.Encoder.forTests
8+
let encoder = PSQLFrontendMessageEncoder.forTests
99
var byteBuffer = ByteBuffer()
1010
let sasl = PSQLFrontendMessage.SASLResponse(data: [0, 1, 2, 3, 4, 5, 6, 7])
1111
let message = PSQLFrontendMessage.saslResponse(sasl)
@@ -21,7 +21,7 @@ class SASLResponseTests: XCTestCase {
2121
}
2222

2323
func testEncodeWithoutData() {
24-
let encoder = PSQLFrontendMessage.Encoder.forTests
24+
let encoder = PSQLFrontendMessageEncoder.forTests
2525
var byteBuffer = ByteBuffer()
2626
let sasl = PSQLFrontendMessage.SASLResponse(data: [])
2727
let message = PSQLFrontendMessage.saslResponse(sasl)

Tests/PostgresNIOTests/New/Messages/SSLRequestTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import NIOCore
55
class SSLRequestTests: XCTestCase {
66

77
func testSSLRequest() {
8-
let encoder = PSQLFrontendMessage.Encoder.forTests
8+
let encoder = PSQLFrontendMessageEncoder.forTests
99
var byteBuffer = ByteBuffer()
1010
let request = PSQLFrontendMessage.SSLRequest()
1111
let message = PSQLFrontendMessage.sslRequest(request)

Tests/PostgresNIOTests/New/Messages/StartupTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import NIOCore
55
class StartupTests: XCTestCase {
66

77
func testStartupMessage() {
8-
let encoder = PSQLFrontendMessage.Encoder.forTests
8+
let encoder = PSQLFrontendMessageEncoder.forTests
99
var byteBuffer = ByteBuffer()
1010

1111
let replicationValues: [PSQLFrontendMessage.Startup.Parameters.Replication] = [

Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class PSQLFrontendMessageTests: XCTestCase {
2323
// MARK: Encoder
2424

2525
func testEncodeFlush() {
26-
let encoder = PSQLFrontendMessage.Encoder.forTests
26+
let encoder = PSQLFrontendMessageEncoder.forTests
2727
var byteBuffer = ByteBuffer()
2828
XCTAssertNoThrow(try encoder.encode(data: .flush, out: &byteBuffer))
2929

@@ -33,7 +33,7 @@ class PSQLFrontendMessageTests: XCTestCase {
3333
}
3434

3535
func testEncodeSync() {
36-
let encoder = PSQLFrontendMessage.Encoder.forTests
36+
let encoder = PSQLFrontendMessageEncoder.forTests
3737
var byteBuffer = ByteBuffer()
3838
XCTAssertNoThrow(try encoder.encode(data: .sync, out: &byteBuffer))
3939

@@ -43,7 +43,7 @@ class PSQLFrontendMessageTests: XCTestCase {
4343
}
4444

4545
func testEncodeTerminate() {
46-
let encoder = PSQLFrontendMessage.Encoder.forTests
46+
let encoder = PSQLFrontendMessageEncoder.forTests
4747
var byteBuffer = ByteBuffer()
4848
XCTAssertNoThrow(try encoder.encode(data: .terminate, out: &byteBuffer))
4949

0 commit comments

Comments
 (0)