Skip to content

Commit 581531a

Browse files
Support constructing MessageId from results of send() and receive() (#254)
Currently, `Producer.send` returns a `_pulsar.MessageId` instance, `Consumer.receive` returns a `MessageId` whose `message_id()` method returns a `_pulsar.MessageId` instance. This forces users to access the type from the C extension module (`_pulsar`). This patch adds a `MessageId.wrap` class method to convert the type from the C extension to the type in the `pulsar` module. It also exposes the comparison methods for `MessageId`.
1 parent 6df05a1 commit 581531a

File tree

2 files changed

+67
-6
lines changed

2 files changed

+67
-6
lines changed

pulsar/__init__.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class MessageId:
8181
"""
8282

8383
def __init__(self, partition=-1, ledger_id=-1, entry_id=-1, batch_index=-1):
84-
self._msg_id = _pulsar.MessageId(partition, ledger_id, entry_id, batch_index)
84+
self._msg_id: _pulsar.MessageId = _pulsar.MessageId(partition, ledger_id, entry_id, batch_index)
8585

8686
earliest = _pulsar.MessageId.earliest
8787
latest = _pulsar.MessageId.latest
@@ -111,6 +111,24 @@ def __str__(self) -> str:
111111
"""
112112
return str(self._msg_id)
113113

114+
def __eq__(self, other) -> bool:
115+
return self._msg_id == other._msg_id
116+
117+
def __ne__(self, other) -> bool:
118+
return self._msg_id != other._msg_id
119+
120+
def __le__(self, other) -> bool:
121+
return self._msg_id <= other._msg_id
122+
123+
def __lt__(self, other) -> bool:
124+
return self._msg_id < other._msg_id
125+
126+
def __ge__(self, other) -> bool:
127+
return self._msg_id >= other._msg_id
128+
129+
def __gt__(self, other) -> bool:
130+
return self._msg_id > other._msg_id
131+
114132
@staticmethod
115133
def deserialize(message_id_bytes):
116134
"""
@@ -119,6 +137,14 @@ def deserialize(message_id_bytes):
119137
"""
120138
return _pulsar.MessageId.deserialize(message_id_bytes)
121139

140+
@classmethod
141+
def wrap(cls, msg_id: _pulsar.MessageId):
142+
"""
143+
Wrap the underlying MessageId type from the C extension to the Python type.
144+
"""
145+
self = cls()
146+
self._msg_id = msg_id
147+
return self
122148

123149
class Message:
124150
"""
@@ -170,9 +196,13 @@ def event_timestamp(self):
170196
"""
171197
return self._message.event_timestamp()
172198

173-
def message_id(self):
199+
def message_id(self) -> _pulsar.MessageId:
174200
"""
175201
The message ID that can be used to refer to this particular message.
202+
203+
Returns
204+
----------
205+
A `_pulsar.MessageId` object that represents where the message is persisted.
176206
"""
177207
return self._message.message_id()
178208

@@ -1231,7 +1261,7 @@ def send(self, content,
12311261
event_timestamp=None,
12321262
deliver_at=None,
12331263
deliver_after=None,
1234-
):
1264+
) -> _pulsar.MessageId:
12351265
"""
12361266
Publish a message on the topic. Blocks until the message is acknowledged
12371267
@@ -1264,6 +1294,10 @@ def send(self, content,
12641294
The timestamp is milliseconds and based on UTC
12651295
deliver_after: optional
12661296
Specify a delay in timedelta for the delivery of the messages.
1297+
1298+
Returns
1299+
----------
1300+
A `_pulsar.MessageId` object that represents where the message is persisted.
12671301
"""
12681302
msg = self._build_msg(content, properties, partition_key, ordering_key, sequence_id,
12691303
replication_clusters, disable_replication, event_timestamp,
@@ -1502,7 +1536,7 @@ def batch_receive(self):
15021536
messages.append(m)
15031537
return messages
15041538

1505-
def acknowledge(self, message):
1539+
def acknowledge(self, message: Union[Message, MessageId, _pulsar.Message, _pulsar.MessageId]):
15061540
"""
15071541
Acknowledge the reception of a single message.
15081542
@@ -1511,7 +1545,7 @@ def acknowledge(self, message):
15111545
15121546
Parameters
15131547
----------
1514-
message : Message, _pulsar.Message, _pulsar.MessageId
1548+
message : Message, MessageId, _pulsar.Message, _pulsar.MessageId
15151549
The received message or message id.
15161550
15171551
Raises
@@ -1521,10 +1555,12 @@ def acknowledge(self, message):
15211555
"""
15221556
if isinstance(message, Message):
15231557
self._consumer.acknowledge(message._message)
1558+
elif isinstance(message, MessageId):
1559+
self._consumer.acknowledge(message._msg_id)
15241560
else:
15251561
self._consumer.acknowledge(message)
15261562

1527-
def acknowledge_cumulative(self, message):
1563+
def acknowledge_cumulative(self, message: Union[Message, MessageId, _pulsar.Message, _pulsar.MessageId]):
15281564
"""
15291565
Acknowledge the reception of all the messages in the stream up to (and
15301566
including) the provided message.
@@ -1545,6 +1581,8 @@ def acknowledge_cumulative(self, message):
15451581
"""
15461582
if isinstance(message, Message):
15471583
self._consumer.acknowledge_cumulative(message._message)
1584+
elif isinstance(message, MessageId):
1585+
self._consumer.acknowledge_cumulative(message._msg_id)
15481586
else:
15491587
self._consumer.acknowledge_cumulative(message)
15501588

tests/pulsar_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,6 +1253,29 @@ def test_message_id(self):
12531253
s = MessageId.latest.serialize()
12541254
self.assertEqual(MessageId.deserialize(s), MessageId.latest)
12551255

1256+
client = Client(self.serviceUrl)
1257+
topic = f'test-message-id-compare-{str(time.time())}'
1258+
producer = client.create_producer(topic)
1259+
consumer = client.subscribe(topic, 'sub')
1260+
1261+
sent_ids = []
1262+
received_ids = []
1263+
for i in range(5):
1264+
sent_ids.append(MessageId.wrap(producer.send(b'msg-%d' % i)))
1265+
msg = consumer.receive(TM)
1266+
received_ids.append(MessageId.wrap(msg.message_id()))
1267+
self.assertEqual(sent_ids[i], received_ids[i])
1268+
consumer.acknowledge(received_ids[i])
1269+
consumer.acknowledge_cumulative(received_ids[4])
1270+
1271+
for i in range(4):
1272+
self.assertLess(sent_ids[i], sent_ids[i + 1])
1273+
self.assertLessEqual(sent_ids[i], sent_ids[i + 1])
1274+
self.assertGreater(sent_ids[i + 1], sent_ids[i])
1275+
self.assertGreaterEqual(sent_ids[i + 1], sent_ids[i])
1276+
self.assertNotEqual(sent_ids[i], sent_ids[i + 1])
1277+
client.close()
1278+
12561279
def test_get_topics_partitions(self):
12571280
client = Client(self.serviceUrl)
12581281
topic_partitioned = "persistent://public/default/test_get_topics_partitions"

0 commit comments

Comments
 (0)