Skip to content

Commit 5474465

Browse files
committed
Add PSQLFrontendMessageDecoder
1 parent 8c32013 commit 5474465

File tree

3 files changed

+211
-3
lines changed

3 files changed

+211
-3
lines changed

Sources/PostgresNIO/New/PSQLBackendMessageDecoder.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ struct PSQLBackendMessageDecoder: NIOSingleStepByteToMessageDecoder {
6969

7070
return try PSQLBackendMessage.decode(from: &slice, for: messageID)
7171
} catch let error as PSQLPartialDecodingError {
72-
throw PSQLDecodingError.withPartialError(error, messageID: messageID, messageBytes: completeMessageBuffer)
72+
throw PSQLDecodingError.withPartialError(error, messageID: messageID.rawValue, messageBytes: completeMessageBuffer)
7373
} catch {
7474
preconditionFailure("Expected to only see `PartialDecodingError`s here.")
7575
}
@@ -106,14 +106,14 @@ struct PSQLDecodingError: Error {
106106

107107
static func withPartialError(
108108
_ partialError: PSQLPartialDecodingError,
109-
messageID: PSQLBackendMessage.ID,
109+
messageID: UInt8,
110110
messageBytes: ByteBuffer) -> Self
111111
{
112112
var byteBuffer = messageBytes
113113
let data = byteBuffer.readData(length: byteBuffer.readableBytes)!
114114

115115
return PSQLDecodingError(
116-
messageID: messageID.rawValue,
116+
messageID: messageID,
117117
payload: data.base64EncodedString(),
118118
description: partialError.description,
119119
file: partialError.file,
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
@testable import PostgresNIO
2+
3+
struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder {
4+
typealias InboundOut = PSQLFrontendMessage
5+
6+
private(set) var isInStartup: Bool
7+
8+
init() {
9+
self.isInStartup = true
10+
}
11+
12+
mutating func decode(buffer: inout ByteBuffer) throws -> PSQLFrontendMessage? {
13+
// make sure we have at least one byte to read
14+
guard buffer.readableBytes > 0 else {
15+
return nil
16+
}
17+
18+
if self.isInStartup {
19+
guard let length = buffer.getInteger(at: buffer.readerIndex, as: UInt32.self) else {
20+
return nil
21+
}
22+
23+
guard var messageSlice = buffer.getSlice(at: buffer.readerIndex &+ 4, length: Int(length)) else {
24+
return nil
25+
}
26+
buffer.moveReaderIndex(forwardBy: 4 &+ Int(length))
27+
let finalIndex = buffer.readerIndex
28+
29+
guard let code = buffer.readInteger(as: UInt32.self) else {
30+
throw PSQLPartialDecodingError.fieldNotDecodable(type: UInt32.self)
31+
}
32+
33+
switch code {
34+
case 80877103:
35+
self.isInStartup = true
36+
return .sslRequest(.init())
37+
38+
case 196608:
39+
var user: String?
40+
var database: String?
41+
var options: String?
42+
43+
while let name = messageSlice.readNullTerminatedString(), messageSlice.readerIndex < finalIndex {
44+
let value = messageSlice.readNullTerminatedString()
45+
46+
switch name {
47+
case "user":
48+
user = value
49+
50+
case "database":
51+
database = value
52+
53+
case "options":
54+
options = value
55+
56+
default:
57+
break
58+
}
59+
}
60+
61+
let parameters = PSQLFrontendMessage.Startup.Parameters(
62+
user: user!,
63+
database: database,
64+
options: options,
65+
replication: .false
66+
)
67+
68+
let startup = PSQLFrontendMessage.Startup(
69+
protocolVersion: 0x00_03_00_00,
70+
parameters: parameters
71+
)
72+
73+
precondition(buffer.readerIndex == finalIndex)
74+
self.isInStartup = false
75+
76+
return .startup(startup)
77+
78+
default:
79+
throw PSQLDecodingError.unknownStartupCodeReceived(code: code, messageBytes: messageSlice)
80+
}
81+
}
82+
83+
// all other packages have an Int32 after the identifier that determines their length.
84+
// do we have enough bytes for that?
85+
guard let idByte = buffer.getInteger(at: buffer.readerIndex, as: UInt8.self),
86+
let length = buffer.getInteger(at: buffer.readerIndex + 1, as: Int32.self) else {
87+
return nil
88+
}
89+
90+
// At this point we are sure, that we have enough bytes to decode the next message.
91+
// 1. Create a byteBuffer that represents exactly the next message. This can be force
92+
// unwrapped, since it was verified that enough bytes are available.
93+
guard let completeMessageBuffer = buffer.readSlice(length: 1 + Int(length)) else {
94+
return nil
95+
}
96+
97+
// 2. make sure we have a known message identifier
98+
guard let messageID = PSQLFrontendMessage.ID(rawValue: idByte) else {
99+
throw PSQLDecodingError.unknownMessageIDReceived(messageID: idByte, messageBytes: completeMessageBuffer)
100+
}
101+
102+
// 3. decode the message
103+
do {
104+
// get a mutable byteBuffer copy
105+
var slice = completeMessageBuffer
106+
// move reader index forward by five bytes
107+
slice.moveReaderIndex(forwardBy: 5)
108+
109+
return try PSQLFrontendMessage.decode(from: &slice, for: messageID)
110+
} catch let error as PSQLPartialDecodingError {
111+
throw PSQLDecodingError.withPartialError(error, messageID: messageID.rawValue, messageBytes: completeMessageBuffer)
112+
} catch {
113+
preconditionFailure("Expected to only see `PartialDecodingError`s here.")
114+
}
115+
}
116+
117+
mutating func decodeLast(buffer: inout ByteBuffer, seenEOF: Bool) throws -> PSQLFrontendMessage? {
118+
try self.decode(buffer: &buffer)
119+
}
120+
}
121+
122+
extension PSQLFrontendMessage {
123+
124+
static func decode(from buffer: inout ByteBuffer, for messageID: ID) throws -> PSQLFrontendMessage {
125+
switch messageID {
126+
case .bind:
127+
preconditionFailure("TODO: Unimplemented")
128+
case .close:
129+
preconditionFailure("TODO: Unimplemented")
130+
case .describe:
131+
preconditionFailure("TODO: Unimplemented")
132+
case .execute:
133+
preconditionFailure("TODO: Unimplemented")
134+
case .flush:
135+
return .flush
136+
case .parse:
137+
preconditionFailure("TODO: Unimplemented")
138+
case .password:
139+
guard let password = buffer.readNullTerminatedString() else {
140+
throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self)
141+
}
142+
return .password(.init(value: password))
143+
case .saslInitialResponse:
144+
preconditionFailure("TODO: Unimplemented")
145+
case .saslResponse:
146+
preconditionFailure("TODO: Unimplemented")
147+
case .sync:
148+
return .sync
149+
case .terminate:
150+
return .terminate
151+
}
152+
}
153+
}
154+
155+
extension PSQLDecodingError {
156+
static func unknownStartupCodeReceived(
157+
code: UInt32,
158+
messageBytes: ByteBuffer,
159+
file: String = #file,
160+
line: Int = #line) -> Self
161+
{
162+
var byteBuffer = messageBytes
163+
let data = byteBuffer.readData(length: byteBuffer.readableBytes)!
164+
165+
return PSQLDecodingError(
166+
messageID: 0,
167+
payload: data.base64EncodedString(),
168+
description: "Received a startup code '\(code)'. There is no message associated with this code.",
169+
file: file,
170+
line: line)
171+
}
172+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import NIOCore
2+
3+
/// This is a reverse ``NIOCore/ByteToMessageHandler``. Instead of creating messages from incoming bytes
4+
/// as the normal `ByteToMessageHandler` does, this `ReverseByteToMessageHandler` creates messages
5+
/// from outgoing bytes. This is only important for testing in `EmbeddedChannel`s.
6+
class ReverseByteToMessageHandler<Decoder: NIOSingleStepByteToMessageDecoder>: ChannelOutboundHandler {
7+
typealias OutboundIn = ByteBuffer
8+
typealias OutboundOut = Decoder.InboundOut
9+
10+
let processor: NIOSingleStepByteToMessageProcessor<Decoder>
11+
12+
init(_ decoder: Decoder) {
13+
self.processor = .init(decoder, maximumBufferSize: nil)
14+
}
15+
16+
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
17+
let buffer = self.unwrapOutboundIn(data)
18+
19+
do {
20+
var messages = [Decoder.InboundOut]()
21+
try self.processor.process(buffer: buffer) { message in
22+
messages.append(message)
23+
}
24+
25+
for (index, message) in messages.enumerated() {
26+
if index == messages.index(before: messages.endIndex) {
27+
context.write(self.wrapOutboundOut(message), promise: promise)
28+
} else {
29+
context.write(self.wrapOutboundOut(message), promise: nil)
30+
}
31+
}
32+
} catch {
33+
context.fireErrorCaught(error)
34+
}
35+
}
36+
}

0 commit comments

Comments
 (0)