Skip to content

Commit 29d61f1

Browse files
authored
Make Room.connect cancellable (livekit#273)
* engine connect * connect flow * cancellable completer * cancellable WebSocket * completer cancel test * comment * check cancel for queue actor
1 parent 3a07312 commit 29d61f1

File tree

12 files changed

+121
-61
lines changed

12 files changed

+121
-61
lines changed

Package.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ let package = Package(
1010
.macOS(.v10_15),
1111
],
1212
products: [
13-
// Products define the executables and libraries a package produces, and make them visible to other packages.
1413
.library(
1514
name: "LiveKit",
1615
targets: ["LiveKit"]

Sources/LiveKit/Broadcast/BroadcastScreenCapturer.swift

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414
* limitations under the License.
1515
*/
1616

17-
import Foundation
17+
#if os(iOS)
1818

19-
#if canImport(UIKit)
20-
import UIKit
21-
#endif
19+
import Foundation
2220

23-
@_implementationOnly import WebRTC
21+
#if canImport(UIKit)
22+
import UIKit
23+
#endif
2424

25-
#if os(iOS)
25+
@_implementationOnly import WebRTC
2626

2727
class BroadcastScreenCapturer: BufferCapturer {
2828
static let kRTCScreensharingSocketFD = "rtc_SSFD"

Sources/LiveKit/Core/Engine.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ class Engine: MulticastDelegate<EngineDelegate> {
145145
}
146146

147147
try await cleanUp()
148+
try Task.checkCancellation()
148149

149150
_state.mutate { $0.connectionState = .connecting }
150151

@@ -154,13 +155,19 @@ class Engine: MulticastDelegate<EngineDelegate> {
154155
// Connect sequence successful
155156
log("Connect sequence completed")
156157

158+
// Final check if cancelled, don't fire connected events
159+
try Task.checkCancellation()
160+
157161
// update internal vars (only if connect succeeded)
158162
_state.mutate {
159163
$0.url = url
160164
$0.token = token
161165
$0.connectionState = .connected
162166
}
163167

168+
} catch is CancellationError {
169+
// Cancelled by .user
170+
try await cleanUp(reason: .user)
164171
} catch {
165172
try await cleanUp(reason: .networkError(error))
166173
}
@@ -344,10 +351,18 @@ extension Engine {
344351
connectOptions: _state.connectOptions,
345352
reconnectMode: _state.reconnectMode,
346353
adaptiveStream: room._state.options.adaptiveStream)
354+
// Check cancellation after WebSocket connected
355+
try Task.checkCancellation()
347356

348357
let jr = try await signalClient.joinResponseCompleter.wait()
358+
// Check cancellation after received join response
359+
try Task.checkCancellation()
360+
349361
_state.mutate { $0.connectStopwatch.split(label: "signal") }
350362
try await configureTransports(joinResponse: jr)
363+
// Check cancellation after configuring transports
364+
try Task.checkCancellation()
365+
351366
try await signalClient.resumeResponseQueue()
352367
try await primaryTransportConnectedCompleter.wait()
353368
_state.mutate { $0.connectStopwatch.split(label: "engine") }

Sources/LiveKit/Core/SignalClient.swift

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,8 @@ class SignalClient: MulticastDelegate<SignalClientDelegate> {
103103
$0.connectionState = .connecting
104104
}
105105

106-
let socket = WebSocket(url: url)
107-
108106
do {
109-
try await socket.connect()
107+
let socket = try await WebSocket(url: url)
110108
_webSocket = socket
111109
_state.mutate { $0.connectionState = .connected }
112110

@@ -156,10 +154,8 @@ class SignalClient: MulticastDelegate<SignalClientDelegate> {
156154
pingIntervalTimer = nil
157155
pingTimeoutTimer = nil
158156

159-
if let socket = _webSocket {
160-
socket.reset()
161-
_webSocket = nil
162-
}
157+
_webSocket?.close()
158+
_webSocket = nil
163159

164160
latestJoinResponse = nil
165161

@@ -311,7 +307,7 @@ private extension SignalClient {
311307

312308
extension SignalClient {
313309
func resumeResponseQueue() async throws {
314-
await _responseQueue.resume { response in
310+
try await _responseQueue.resume { response in
315311
await processSignalResponse(response)
316312
}
317313
}
@@ -321,7 +317,7 @@ extension SignalClient {
321317

322318
extension SignalClient {
323319
func sendQueuedRequests() async throws {
324-
await _requestQueue.resume { element in
320+
try await _requestQueue.resume { element in
325321
do {
326322
try await sendRequest(element, enqueueIfReconnecting: false)
327323
} catch {

Sources/LiveKit/Core/Transport.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ class Transport: MulticastDelegate<TransportDelegate> {
112112
func set(remoteDescription sd: LKRTCSessionDescription) async throws {
113113
try await _pc.setRemoteDescription(sd)
114114

115-
await _pendingCandidatesQueue.resume { candidate in
115+
try await _pendingCandidatesQueue.resume { candidate in
116116
do {
117117
try await add(iceCandidate: candidate)
118118
} catch {

Sources/LiveKit/Errors.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ public enum TrackError: LiveKitError {
9696
}
9797

9898
public enum SignalClientError: LiveKitError {
99+
case cancelled
99100
case state(message: String? = nil)
100101
case socketError(rawError: Error?)
101102
case close(message: String? = nil)
@@ -105,6 +106,7 @@ public enum SignalClientError: LiveKitError {
105106

106107
public var description: String {
107108
switch self {
109+
case .cancelled: return buildDescription("cancelled")
108110
case let .state(message): return buildDescription("state", message)
109111
case let .socketError(rawError): return buildDescription("socketError", rawError: rawError)
110112
case let .close(message): return buildDescription("close", message)

Sources/LiveKit/Support/AsyncCompleter.swift

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ class AsyncCompleter<T>: Loggable {
9696

9797
public func cancel() {
9898
_cancelTimer()
99+
if _continuation != nil {
100+
log("\(label) cancelled")
101+
}
99102
_continuation?.resume(throwing: AsyncCompleterError.cancelled)
100103
_continuation = nil
101104
_returningValue = nil
@@ -140,24 +143,29 @@ class AsyncCompleter<T>: Loggable {
140143
// Cancel any previous waits
141144
cancel()
142145

143-
// Create a timed continuation
144-
return try await withCheckedThrowingContinuation { continuation in
145-
// Store reference to continuation
146-
_continuation = continuation
147-
148-
// Create time-out block
149-
let timeOutBlock = DispatchWorkItem { [weak self] in
150-
guard let self else { return }
151-
self.log("\(self.label) timedOut")
152-
self._continuation?.resume(throwing: AsyncCompleterError.timedOut)
153-
self._continuation = nil
154-
self.cancel()
146+
// Create a cancel-aware timed continuation
147+
return try await withTaskCancellationHandler {
148+
try await withCheckedThrowingContinuation { continuation in
149+
// Store reference to continuation
150+
_continuation = continuation
151+
152+
// Create time-out block
153+
let timeOutBlock = DispatchWorkItem { [weak self] in
154+
guard let self else { return }
155+
self.log("\(self.label) timedOut")
156+
self._continuation?.resume(throwing: AsyncCompleterError.timedOut)
157+
self._continuation = nil
158+
self.cancel()
159+
}
160+
161+
// Schedule time-out block
162+
_queue.asyncAfter(deadline: .now() + _timeOut, execute: timeOutBlock)
163+
// Store reference to time-out block
164+
_timeOutBlock = timeOutBlock
155165
}
156-
157-
// Schedule time-out block
158-
_queue.asyncAfter(deadline: .now() + _timeOut, execute: timeOutBlock)
159-
// Store reference to time-out block
160-
_timeOutBlock = timeOutBlock
166+
} onCancel: {
167+
// Cancel completer when Task gets cancelled
168+
cancel()
161169
}
162170
}
163171
}

Sources/LiveKit/Support/AsyncQueueActor.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,13 @@ actor AsyncQueueActor<T> {
4949
}
5050

5151
/// Mark as `.resumed` and process each element with an async `block`.
52-
func resume(_ block: (T) async -> Void) async {
52+
func resume(_ block: (T) async throws -> Void) async throws {
5353
state = .resumed
5454
if queue.isEmpty { return }
5555
for element in queue {
56-
await block(element)
56+
// Check cancellation before processing next block...
57+
try Task.checkCancellation()
58+
try await block(element)
5759
}
5860
queue.removeAll()
5961
}

Sources/LiveKit/Support/WebSocket.swift

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,24 +42,27 @@ class WebSocket: NSObject, Loggable, AsyncSequence, URLSessionWebSocketDelegate
4242
waitForNextValue()
4343
}
4444

45-
init(url: URL) {
45+
init(url: URL) async throws {
4646
request = URLRequest(url: url,
4747
cachePolicy: .useProtocolCachePolicy,
4848
timeoutInterval: .defaultSocketConnect)
49+
super.init()
50+
try await withTaskCancellationHandler {
51+
try await withCheckedThrowingContinuation { continuation in
52+
connectContinuation = continuation
53+
task.resume()
54+
}
55+
} onCancel: {
56+
// Cancel(reset) when Task gets cancelled
57+
close()
58+
}
4959
}
5060

5161
deinit {
52-
reset()
53-
}
54-
55-
public func connect() async throws {
56-
try await withCheckedThrowingContinuation { continuation in
57-
connectContinuation = continuation
58-
task.resume()
59-
}
62+
close()
6063
}
6164

62-
func reset() {
65+
func close() {
6366
task.cancel(with: .goingAway, reason: nil)
6467
connectContinuation?.resume(throwing: SignalClientError.socketError(rawError: nil))
6568
connectContinuation = nil

Tests/LiveKitTests/CompleterTests.swift

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,48 @@ class CompleterTests: XCTestCase {
2222

2323
override func tearDown() async throws {}
2424

25-
func testCompleter() async throws {}
25+
func testCompleterReuse() async throws {
26+
let completer = AsyncCompleter<Void>(label: "Test01", timeOut: .seconds(1))
27+
do {
28+
try await completer.wait()
29+
} catch AsyncCompleterError.timedOut {
30+
print("Timed out 1")
31+
}
32+
// Re-use
33+
do {
34+
try await completer.wait()
35+
} catch AsyncCompleterError.timedOut {
36+
print("Timed out 2")
37+
}
38+
}
39+
40+
func testCompleterCancel() async throws {
41+
let completer = AsyncCompleter<Void>(label: "cancel-test", timeOut: .never)
42+
do {
43+
// Run Tasks in parallel
44+
try await withThrowingTaskGroup(of: Void.self) { group in
45+
46+
group.addTask {
47+
print("Task 1: Waiting...")
48+
try await completer.wait()
49+
}
50+
51+
group.addTask {
52+
print("Task 2: Started...")
53+
// Cancel after 1 second
54+
try await Task.sleep(until: .now + .seconds(1), clock: .continuous)
55+
print("Task 2: Cancelling completer...")
56+
completer.cancel()
57+
}
58+
59+
try await group.waitForAll()
60+
}
61+
} catch let error as AsyncCompleterError where error == .timedOut {
62+
print("Completer timed out")
63+
} catch let error as AsyncCompleterError where error == .cancelled {
64+
print("Completer cancelled")
65+
} catch {
66+
print("Unknown error: \(error)")
67+
}
68+
}
2669
}

Tests/LiveKitTests/TimerTests.swift

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
*/
1616

1717
@testable import LiveKit
18-
import Promises
1918
import XCTest
2019

2120
class TimerTests: XCTestCase {
@@ -35,10 +34,10 @@ class TimerTests: XCTestCase {
3534
if self.counter == 3 {
3635
print("suspending timer for 3s...")
3736
self.timer.suspend()
38-
Promise(()).delay(3).then {
39-
print("restarting timer...")
40-
self.timer.restart()
41-
}
37+
// Promise(()).delay(3).then {
38+
// print("restarting timer...")
39+
// self.timer.restart()
40+
// }
4241
}
4342

4443
if self.counter == 5 {

Tests/LiveKitTests/WebSocketTests.swift

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,13 @@
1818
import XCTest
1919

2020
class WebSocketTests: XCTestCase {
21-
lazy var socket: WebSocket = {
22-
let url = URL(string: "wss://socketsbay.com/wss/v2/1/demo/")!
23-
return WebSocket(url: url)
24-
}()
25-
2621
override func setUpWithError() throws {}
2722

2823
override func tearDown() async throws {}
2924

30-
func testCompleter1() async throws {
31-
// Read messages
32-
25+
func testWebSocket01() async throws {
3326
print("Connecting...")
34-
try await socket.connect()
27+
let socket = try await WebSocket(url: URL(string: "wss://socketsbay.com/wss/v2/1/demo/")!)
3528

3629
print("Connected. Waiting for messages...")
3730
do {

0 commit comments

Comments
 (0)