Skip to content

feat(BA-686): Implement Raftify KVS #3838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: feat/create-raftify-client
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 203 additions & 95 deletions src/ai/backend/manager/raft/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from typing import AsyncIterator, Dict, Final, Optional
from typing import AsyncIterator, Dict, Final, Optional, Self

import aiohttp
from attr import dataclass
from raftify import RaftNode

Expand All @@ -16,26 +17,35 @@ class RaftKVSLockOptions:

class ConnectOptions:
def __init__(self) -> None: ...
def with_user(self, user: str, password: str) -> "ConnectOptions":
def with_user(self, user: str, password: str) -> Self:
return self

def with_keep_alive(self, interval: float, timeout: float) -> "ConnectOptions":
def with_keep_alive(self, interval: float, timeout: float) -> Self:
return self

def with_keep_alive_while_idle(self, enabled: bool) -> "ConnectOptions":
def with_keep_alive_while_idle(self, enabled: bool) -> Self:
return self

def with_connect_timeout(self, connect_timeout: float) -> "ConnectOptions":
def with_connect_timeout(self, connect_timeout: float) -> Self:
return self

def with_timeout(self, timeout: float) -> "ConnectOptions":
def with_timeout(self, timeout: float) -> Self:
return self

def with_tcp_keepalive(self, tcp_keepalive: float) -> "ConnectOptions":
def with_tcp_keepalive(self, tcp_keepalive: float) -> Self:
return self


class RaftKVSClient:
raft_node: RaftNode
endpoints: list[str]
connect_options: Optional["ConnectOptions"]
_state_machine: HashStore
_communicator: "RaftKVSCommunicator"
_watchers: Dict[bytes, list[asyncio.Queue]]
_leases: Dict[bytes, float]
_lease_task: Optional[asyncio.Task]

def __init__(
self,
raft_node: RaftNode,
Expand All @@ -47,119 +57,213 @@ def __init__(
self.connect_options = connect_options

self._state_machine = HashStore()
self.communicator: RaftKVSCommunicator = RaftKVSCommunicator(self._state_machine)
self._lock_store: Dict[bytes, asyncio.Lock] = {}
self._watchers: Dict[bytes, asyncio.Queue] = {}
self._communicator: RaftKVSCommunicator = RaftKVSCommunicator(self._state_machine)
self._watchers = {}
self._leases: Dict[bytes, float] = {} # Stores TTL expiration timestamps
self._data_store: Dict[bytes, bytes] = {} # KVS storage
self._lease_task: Optional[asyncio.Task] = None

async def connect(self, connect_options: Optional["ConnectOptions"] = None) -> Self:
if connect_options:
self.connect_options = connect_options

if not self._lease_task:
self._lease_task = asyncio.create_task(self._cleanup_expired_leases())

return self

async def close(self) -> None:
if self._lease_task:
self._lease_task.cancel()
self._lease_task = None

self._watchers.clear()
self._leases.clear()

"""
===================== Leadership and Redirection =====================
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i just added it for my personal readability. i'll remove it before the merge


async def is_leader(self) -> bool:
return await self.raft_node.is_leader()

async def get_leader_id(self) -> Optional[int]:
return await self.raft_node.get_leader_id()

async def _get_leader_addr(self) -> Optional[str]:
leader_id = await self.get_leader_id()
if leader_id is None:
return None

for endpoint in self.endpoints:
if endpoint.startswith(f"{leader_id}:"):
return endpoint.split(":", 1)[1]
return None

async def _redirect_write_to_leader(self, key: bytes, value: bytes, method: str) -> None:
leader_addr = await self._get_leader_addr()
if leader_addr is None:
raise RuntimeError("No leader found in the cluster. Request cannot be redirected.")

if method == "PUT":
url = f"http://{leader_addr}/put/{key.decode()}/{value.decode()}"
elif method == "DELETE":
url = f"http://{leader_addr}/delete/{key.decode()}"
else:
raise RuntimeError(f"Unsupported method: {method}")

try:
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
if resp.status != 200:
raise RuntimeError(f"Failed to redirect request to leader: {resp.status}")
except Exception as e:
raise RuntimeError(f"Failed to redirect request to leader: {e}")

"""
===================== Watch Management =====================
"""

async def watch(self, key: bytes, start_revision: Optional[int] = None) -> "Watch":
if key not in self._watchers:
self._watchers[key] = []
queue: asyncio.Queue[WatchEvent] = asyncio.Queue()
self._watchers[key].append(queue)
return Watch(queue)

async def notify_watchers(
self,
key: bytes,
event_type: str,
new_value: Optional[bytes],
revision: int,
prev_value: Optional[bytes] = None,
) -> None:
if key not in self._watchers:
return
event = WatchEvent(key, new_value, event_type, revision, prev_value)
for queue in self._watchers[key]:
await queue.put(event)

"""
===================== KVS Methods =====================
"""

async def put(self, key: bytes, value: bytes) -> None:
if not await self.is_leader():
raise RuntimeError("Writes can only be performed on the Raft leader.")
await self._redirect_write_to_leader(key, value, method="PUT")
return

message = SetCommand(key.decode(), value.decode()).encode()
await self.raft_node.propose(message)
await self.notify_watchers(key, value)

revision = self._state_machine.current_revision()
old_val = self._state_machine.get(key.decode())

await self.notify_watchers(
key,
WatchEventType.PUT,
value,
revision,
prev_value=old_val.encode() if old_val else None,
)

async def get(self, key: bytes) -> Optional[bytes]:
state_machine = await self.raft_node.state_machine()
assert isinstance(state_machine, HashStore)
result = state_machine.get(key.decode())
return result.encode() if result is not None else None
# todo: check if data reads might be stale
val = self._state_machine.get(key.decode())
return val.encode() if val is not None else None

async def delete(self, key: bytes) -> None:
if not await self.is_leader():
raise RuntimeError("Deletes can only be performed on the Raft leader.")
message = SetCommand(key.decode(), "").encode()
await self._redirect_write_to_leader(key, b"", method="DELETE")
return

old_val = self._state_machine.get(key.decode())
message = SetCommand(key.decode(), "").encode() # send empty value to delete
await self.raft_node.propose(message)
await self.notify_watchers(key, None)

revision = self._state_machine.current_revision()
if old_val is not None:
await self.notify_watchers(
key, WatchEventType.DELETE, None, revision, prev_value=old_val.encode()
)

async def get_cluster_size(self) -> int:
return await self.raft_node.get_cluster_size()

async def watch(self, key: bytes) -> asyncio.Queue:
if key not in self._watchers:
self._watchers[key] = asyncio.Queue()
return self._watchers[key]

async def notify_watchers(self, key: bytes, value: Optional[bytes]) -> None:
if key in self._watchers:
await self._watchers[key].put(value)
"""
===================== Lease Management =====================
"""

async def lease_grant(self, key: bytes, ttl: int) -> None:
if not await self.is_leader():
await self._redirect_write_to_leader(key, str(ttl).encode(), method="PUT")
return

self._leases[key] = asyncio.get_event_loop().time() + ttl

async def lease_revoke(self, key: bytes) -> None:
if key in self._leases:
del self._leases[key]
await self.delete(key)
if key not in self._leases:
return

if not await self.is_leader():
await self._redirect_write_to_leader(key, b"", method="DELETE")
return

del self._leases[key]
await self.delete(key)

async def _cleanup_expired_leases(self) -> None:
while True:
now = asyncio.get_event_loop().time()
expired_keys = [key for key, expiry in self._leases.items() if expiry < now]
for key in expired_keys:
await self.delete(key)
del self._leases[key]
await asyncio.sleep(1) # Runs every second

def connect(self, connect_options: Optional["ConnectOptions"] = None) -> "RaftKVSClient":
return self
try:
while True:
await asyncio.sleep(1)
if not await self.is_leader():
continue

now = asyncio.get_event_loop().time()
expired = [k for k, exp in self._leases.items() if exp < now]
for key in expired:
await self.lease_revoke(key)
except asyncio.CancelledError:
pass

"""
===================== Lock =====================
"""

async def with_lock(
self, lock_options: RaftKVSLockOptions, connect_options: Optional["ConnectOptions"] = None
) -> "RaftKVSClient":
lock = RaftKVSLock(self, lock_options, connect_options)
await lock.__aenter__()
self._current_lock = lock
) -> Self:
if not await self.is_leader():
raise RuntimeError("Locks can only be acquired by the Raft leader.")

message = SetCommand(lock_options.lock_name.decode(), "LOCKED").encode()
await self.raft_node.propose(message)

return self

"""
===================== Membership / Cluster =====================
"""

async def add_peer(self, id: int, addr: str) -> None:
if not await self.is_leader():
raise RuntimeError("Only the leader can add peers to the cluster.")

await self.raft_node.add_peer(id, addr)

async def join_cluster(self, tickets: list) -> None:
await self.raft_node.join_cluster(tickets)

async def __aenter__(self) -> "RaftKVSCommunicator":
asyncio.create_task(self._cleanup_expired_leases())
return self.communicator

async def __aexit__(self, *args) -> None:
self._lock_store.clear()
self._watchers.clear()
self._leases.clear()


class RaftKVSLock:
def __init__(
self,
client: RaftKVSClient,
lock_options: RaftKVSLockOptions,
connect_options: Optional["ConnectOptions"] = None,
) -> None:
self.client = client
self.lock_options = lock_options
self.lock_acquired = False

async def __aenter__(self) -> None:
if not await self.client.is_leader():
raise RuntimeError("Locks can only be acquired by the Raft leader.")
"""
===================== Context Manager =====================
"""

lock_command = SetCommand(self.lock_options.lock_name.decode(), "LOCKED").encode()

await self.client.raft_node.propose(lock_command)
self.lock_acquired = True
async def __aenter__(self) -> "RaftKVSCommunicator":
await self.connect()
return self._communicator

async def __aexit__(self, *args) -> None:
if not self.lock_acquired:
return

unlock_command = SetCommand(self.lock_options.lock_name.decode(), "UNLOCKED").encode()

await self.client.raft_node.propose(unlock_command)
self.lock_acquired = False
await self.close()


class RaftKVSCommunicator:
Expand All @@ -179,37 +283,41 @@ async def delete(self, key: bytes) -> None:
await self.state_machine.apply(cmd)


class Watch:
def __init__(self, queue: asyncio.Queue):
self.queue = queue

async def __aiter__(self) -> AsyncIterator[Optional[bytes]]:
while True:
yield await self.queue.get()
class WatchEventType:
PUT: Final[str] = "PUT"
DELETE: Final[str] = "DELETE"


class WatchEvent:
key: bytes
value: bytes
event: "WatchEventType"
value: Optional[bytes]
event_type: str
prev_value: Optional[bytes]

def __init__(
self,
key: bytes,
value: bytes,
event: "WatchEventType",
value: Optional[bytes],
event_type: str,
revision: int,
prev_value: Optional[bytes] = None,
) -> None:
self.key = key
self.value = value
self.event = event
self.event_type = event_type
self.revision = revision
self.prev_value = prev_value


class WatchEventType:
PUT: Final[str] = "PUT"
DELETE: Final[str] = "DELETE"
class Watch:
def __init__(self, queue: asyncio.Queue[WatchEvent]) -> None:
self.queue = queue

def __aiter__(self) -> AsyncIterator[WatchEvent]:
return self

async def __anext__(self) -> WatchEvent:
return await self.queue.get()


class CondVar:
Expand Down
Loading