Skip to content

asyncnet ssl overhaul #24896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 86 additions & 96 deletions lib/pure/asyncnet.nim
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,6 @@ type
when defineSsl:
sslHandle: SslPtr
sslContext: SslContext
bioIn: BIO
bioOut: BIO
sslNoShutdown: bool
domain: Domain
sockType: SockType
Expand Down Expand Up @@ -210,7 +208,7 @@ when defineSsl:
proc raiseSslHandleError =
raiseSSLError("The SSL Handle is closed/unset")

proc getSslError(socket: AsyncSocket, err: cint): cint =
proc getSslError(socket: AsyncSocket, flags: set[SocketFlag], err: cint): cint =
assert socket.isSsl
assert err < 0
var ret = SSL_get_error(socket.sslHandle, err.cint)
Expand All @@ -223,47 +221,49 @@ when defineSsl:
return ret
of SSL_ERROR_WANT_X509_LOOKUP:
raiseSSLError("Function for x509 lookup has been called.")
of SSL_ERROR_SYSCALL, SSL_ERROR_SSL:
of SSL_ERROR_SYSCALL:
socket.sslNoShutdown = true
let osErr = osLastError()
if not flags.isDisconnectionError(osErr):
var errStr = "IO error has occurred"
let sslErr = ERR_peek_last_error()
if sslErr == 0 and err == 0:
errStr.add ' '
errStr.add "because an EOF was observed that violates the protocol"
elif sslErr == 0 and err == -1:
errStr.add ' '
errStr.add "in the BIO layer"
else:
let errStr = $ERR_error_string(sslErr, nil)
raiseSSLError(errStr & ": " & errStr)
raiseOSError(osErr, errStr)
else:
return ret
of SSL_ERROR_SSL:
socket.sslNoShutdown = true
raiseSSLError()
else: raiseSSLError("Unknown Error")

proc sendPendingSslData(socket: AsyncSocket,
flags: set[SocketFlag]) {.async.} =
if socket.sslHandle == nil:
raiseSslHandleError()
let len = bioCtrlPending(socket.bioOut)
if len > 0:
var data = newString(len)
let read = bioRead(socket.bioOut, cast[cstring](addr data[0]), len)
assert read != 0
if read < 0:
raiseSSLError()
data.setLen(read)
await socket.fd.AsyncFD.send(data, flags)

proc appeaseSsl(socket: AsyncSocket, flags: set[SocketFlag],
sslError: cint): owned(Future[bool]) {.async.} =
proc handleSslFailure(socket: AsyncSocket, flags: set[SocketFlag], sslError: cint): Future[bool] =
## Returns `true` if `socket` is still connected, otherwise `false`.
result = true
let retFut = newFuture[bool]("asyncnet.handleSslFailure")
case sslError
of SSL_ERROR_WANT_WRITE:
await sendPendingSslData(socket, flags)
of SSL_ERROR_WANT_WRITE, SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_ACCEPT:
addWrite(socket.fd.AsyncFD, proc (sock: AsyncFD): bool =
retFut.complete(true)
return true
)
of SSL_ERROR_WANT_READ:
var data = await recv(socket.fd.AsyncFD, BufferSize, flags)
if socket.sslHandle == nil:
raiseSslHandleError()
let length = len(data)
if length > 0:
let ret = bioWrite(socket.bioIn, cast[cstring](addr data[0]), length.cint)
if ret < 0:
raiseSSLError()
elif length == 0:
# connection not properly closed by remote side or connection dropped
SSL_set_shutdown(socket.sslHandle, SSL_RECEIVED_SHUTDOWN)
result = false
addRead(socket.fd.AsyncFD, proc (sock: AsyncFD): bool =
retFut.complete(true)
return true
)
of SSL_ERROR_SYSCALL:
assert flags.isDisconnectionError(osLastError())
retFut.complete(false)
else:
raiseSSLError("Cannot appease SSL.")
raiseSSLError("Cannot handle SSL failure.")
return retFut

template sslLoop(socket: AsyncSocket, flags: set[SocketFlag],
op: untyped) =
Expand All @@ -274,20 +274,12 @@ when defineSsl:
ErrClearError()
# Call the desired operation.
opResult = op
let err =
if opResult < 0:
getSslError(socket, opResult.cint)
else:
SSL_ERROR_NONE
# Send any remaining pending SSL data.
await sendPendingSslData(socket, flags)

# If the operation failed, try to see if SSL has some data to read
# or write.
if opResult < 0:
let fut = appeaseSsl(socket, flags, err.cint)
yield fut
if not fut.read():
let err = getSslError(socket, flags, opResult.cint)
let connected = await handleSslFailure(socket, flags, err.cint)
if not connected:
# Socket disconnected.
if SocketFlag.SafeDisconn in flags:
opResult = 0.cint
Expand Down Expand Up @@ -323,8 +315,7 @@ proc connect*(socket: AsyncSocket, address: string, port: Port) {.async.} =
discard SSL_set_tlsext_host_name(socket.sslHandle, address)

let flags = {SocketFlag.SafeDisconn}
sslSetConnectState(socket.sslHandle)
sslLoop(socket, flags, sslDoHandshake(socket.sslHandle))
sslLoop(socket, flags, SSL_connect(socket.sslHandle))

template readInto(buf: pointer, size: int, socket: AsyncSocket,
flags: set[SocketFlag]): int =
Expand Down Expand Up @@ -461,7 +452,6 @@ proc send*(socket: AsyncSocket, buf: pointer, size: int,
when defineSsl:
sslLoop(socket, flags,
sslWrite(socket.sslHandle, cast[cstring](buf), size.cint))
await sendPendingSslData(socket, flags)
else:
await send(socket.fd.AsyncFD, buf, size, flags)

Expand All @@ -475,52 +465,9 @@ proc send*(socket: AsyncSocket, data: string,
var copy = data
sslLoop(socket, flags,
sslWrite(socket.sslHandle, cast[cstring](addr copy[0]), copy.len.cint))
await sendPendingSslData(socket, flags)
else:
await send(socket.fd.AsyncFD, data, flags)

proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn},
inheritable = defined(nimInheritHandles)):
owned(Future[tuple[address: string, client: AsyncSocket]]) =
## Accepts a new connection. Returns a future containing the client socket
## corresponding to that connection and the remote address of the client.
##
## If `inheritable` is false (the default), the resulting client socket will
## not be inheritable by child processes.
##
## The future will complete when the connection is successfully accepted.
var retFuture = newFuture[tuple[address: string, client: AsyncSocket]]("asyncnet.acceptAddr")
var fut = acceptAddr(socket.fd.AsyncFD, flags, inheritable)
fut.callback =
proc (future: Future[tuple[address: string, client: AsyncFD]]) =
assert future.finished
if future.failed:
retFuture.fail(future.readError)
else:
let resultTup = (future.read.address,
newAsyncSocket(future.read.client, socket.domain,
socket.sockType, socket.protocol, socket.isBuffered, inheritable))
retFuture.complete(resultTup)
return retFuture

proc accept*(socket: AsyncSocket,
flags = {SocketFlag.SafeDisconn}): owned(Future[AsyncSocket]) =
## Accepts a new connection. Returns a future containing the client socket
## corresponding to that connection.
## If `inheritable` is false (the default), the resulting client socket will
## not be inheritable by child processes.
## The future will complete when the connection is successfully accepted.
var retFut = newFuture[AsyncSocket]("asyncnet.accept")
var fut = acceptAddr(socket, flags)
fut.callback =
proc (future: Future[tuple[address: string, client: AsyncSocket]]) =
assert future.finished
if future.failed:
retFut.fail(future.readError)
else:
retFut.complete(future.read.client)
return retFut

proc recvLineInto*(socket: AsyncSocket, resString: FutureVar[string],
flags = {SocketFlag.SafeDisconn}, maxLength = MaxLineLength) {.async.} =
## Reads a line of data from `socket` into `resString`.
Expand Down Expand Up @@ -776,9 +723,8 @@ when defineSsl:
if socket.sslHandle == nil:
raiseSSLError()

socket.bioIn = bioNew(bioSMem())
socket.bioOut = bioNew(bioSMem())
sslSetBio(socket.sslHandle, socket.bioIn, socket.bioOut)
if SSL_set_fd(socket.sslHandle, socket.fd) != 1:
raiseSSLError()

socket.sslNoShutdown = true

Expand All @@ -795,6 +741,8 @@ when defineSsl:
##
## **Disclaimer**: This code is not well tested, may be very unsafe and
## prone to security vulnerabilities.
if socket.isSsl:
return
wrapSocket(ctx, socket)

case handshake
Expand All @@ -818,6 +766,48 @@ when defineSsl:
else:
result = getPeerCertificates(socket.sslHandle)

proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn},
inheritable = defined(nimInheritHandles)):
owned(Future[tuple[address: string, client: AsyncSocket]]) {.async.} =
## Accepts a new connection. Returns a future containing the client socket
## corresponding to that connection and the remote address of the client.
##
## If `inheritable` is false (the default), the resulting client socket will
## not be inheritable by child processes.
##
## The future will complete when the connection is successfully accepted.
let (address, fd) = await acceptAddr(socket.fd.AsyncFD, flags, inheritable)
let client = newAsyncSocket(fd, socket.domain, socket.sockType,
socket.protocol, socket.isBuffered, inheritable)
result = (address, client)
if socket.isSsl:
when defineSsl:
if socket.sslContext == nil:
raiseSSLError("The SSL Context is closed/unset")
wrapSocket(socket.sslContext, result.client)
if result.client.sslHandle == nil:
raiseSslHandleError()
let flags = {SocketFlag.SafeDisconn}
sslLoop(result.client, flags, SSL_accept(result.client.sslHandle))

proc accept*(socket: AsyncSocket,
flags = {SocketFlag.SafeDisconn}): owned(Future[AsyncSocket]) =
## Accepts a new connection. Returns a future containing the client socket
## corresponding to that connection.
## If `inheritable` is false (the default), the resulting client socket will
## not be inheritable by child processes.
## The future will complete when the connection is successfully accepted.
var retFut = newFuture[AsyncSocket]("asyncnet.accept")
var fut = acceptAddr(socket, flags)
fut.callback =
proc (future: Future[tuple[address: string, client: AsyncSocket]]) =
assert future.finished
if future.failed:
retFut.fail(future.readError)
else:
retFut.complete(future.read.client)
return retFut

proc getSockOpt*(socket: AsyncSocket, opt: SOBool, level = SOL_SOCKET): bool {.
tags: [ReadIOEffect].} =
## Retrieves option `opt` as a boolean value.
Expand Down
79 changes: 79 additions & 0 deletions tests/async/t24895.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
discard """
cmd: "nim $target --hints:on --define:ssl $options $file"
"""

{.define: ssl.}

import std/[asyncdispatch, asyncnet, net, openssl]

var port0: Port
var checked = 0

proc server {.async.} =
let sock = newAsyncSocket(AF_INET, SOCK_STREAM, IPPROTO_TCP, buffered = true)
doAssert sock != nil
defer: sock.close()
let sslCtx = newContext(
protSSLv23,
verifyMode = CVerifyNone,
certFile = "tests/testdata/mycert.pem",
keyFile = "tests/testdata/mycert.pem"
)
doAssert sslCtx != nil
defer: sslCtx.destroyContext()
wrapSocket(sslCtx, sock)
#sock.bindAddr(Port 8181)
sock.bindAddr()
port0 = getLocalAddr(sock)[1]
sock.listen()
echo "accept"
let clientSocket = await sock.accept()
defer: clientSocket.close()
wrapConnectedSocket(
sslCtx, clientSocket, handshakeAsServer, "localhost"
)
let sdata = "x" & newString(41)
let sfut = clientSocket.send(sdata)
let rdata = newString(42)
let rfut = clientSocket.recvInto(addr rdata[0], rdata.len)
echo "send"
await sfut
echo "recv"
let rLen = await rfut # it hang here until the client closes the connection or sends more data
doAssert rLen == 42, $rLen
doAssert rdata[0] == 'x', $rdata[0]
echo "ok"
inc checked

proc client {.async.} =
let sock = newAsyncSocket(AF_INET, SOCK_STREAM, IPPROTO_TCP, buffered = true)
doAssert sock != nil
defer: sock.close()
let sslCtx = newContext(
protSSLv23,
verifyMode = CVerifyNone
)
doAssert sslCtx != nil
defer: sslCtx.destroyContext()
wrapSocket(sslCtx, sock)
#await sock.connect("127.0.0.1", Port 8181)
await sock.connect("localhost", port0)
let sdata = "x" & newString(41)
echo "send"
await sock.send(sdata)
let rdata = newString(42)
echo "recv"
let rLen = await sock.recvInto(addr rdata[0], rdata.len)
doAssert rLen == 42, $rLen
doAssert rdata[0] == 'x', $rdata[0]
#await sleepAsync(10_000)
#await sock.send("x")
echo "ok"
inc checked

discard getGlobalDispatcher()
let serverFut = server()
waitFor client()
waitFor serverFut
doAssert checked == 2
doAssert not hasPendingOperations()