Skip to content

Commit 9265429

Browse files
committed
musig-spec: Add Session Context to reference implementation
1 parent 01f62b2 commit 9265429

File tree

1 file changed

+37
-23
lines changed

1 file changed

+37
-23
lines changed

doc/musig-reference.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import namedtuple
12
from typing import Any, List, Optional, Tuple
23
import hashlib
34
import secrets
@@ -101,6 +102,23 @@ def pointc(x: bytes) -> Point:
101102
return point_negate(P)
102103
assert False
103104

105+
SessionContext = namedtuple('SessionContext', ['aggnonce', 'pubkeys', 'msg'])
106+
107+
def get_session_values(session_ctx: SessionContext) -> tuple[bytes, List[bytes], bytes]:
108+
(aggnonce, pubkeys, msg) = session_ctx
109+
Q = key_agg_internal(pubkeys)
110+
b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + bytes_from_point(Q) + msg)) % n
111+
R_1 = pointc(aggnonce[0:33])
112+
R_2 = pointc(aggnonce[33:66])
113+
R = point_add(R_1, point_mul(R_2, b))
114+
assert not is_infinite(R)
115+
e = int_from_bytes(tagged_hash('BIP0340/challenge', bytes_from_point(R) + bytes_from_point(Q) + msg)) % n
116+
return (Q, b, R, e)
117+
118+
def get_session_key_agg_coeff(session_ctx: SessionContext, P: Point) -> int:
119+
(_, pubkeys, _) = session_ctx
120+
return key_agg_coeff(pubkeys, bytes_from_point(P))
121+
104122
def key_agg(pubkeys: List[bytes]) -> bytes:
105123
Q = key_agg_internal(pubkeys)
106124
return bytes_from_point(Q)
@@ -152,13 +170,8 @@ def nonce_agg(pubnonces: List[bytes]) -> bytes:
152170
aggnonce += cbytes(R_i)
153171
return aggnonce
154172

155-
def sign(secnonce: bytes, sk: bytes, aggnonce: bytes, pubkeys: List[bytes], msg: bytes) -> bytes:
156-
R_1 = pointc(aggnonce[0:33])
157-
R_2 = pointc(aggnonce[33:66])
158-
Q = key_agg_internal(pubkeys)
159-
b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + bytes_from_point(Q) + msg)) % n
160-
R = point_add(R_1, point_mul(R_2, b))
161-
assert not is_infinite(R)
173+
def sign(secnonce: bytes, sk: bytes, session_ctx: SessionContext) -> bytes:
174+
(Q, b, R, e) = get_session_values(session_ctx)
162175
k_1_ = int_from_bytes(secnonce[0:32])
163176
k_2_ = int_from_bytes(secnonce[32:64])
164177
assert 0 < k_1_ < n
@@ -168,35 +181,30 @@ def sign(secnonce: bytes, sk: bytes, aggnonce: bytes, pubkeys: List[bytes], msg:
168181
d_ = int_from_bytes(sk)
169182
assert 0 < d_ < n
170183
P = point_mul(G, d_)
184+
mu = get_session_key_agg_coeff(session_ctx, P)
171185
d = n - d_ if has_even_y(P) != has_even_y(Q) else d_
172-
e = int_from_bytes(tagged_hash('BIP0340/challenge', bytes_from_point(R) + bytes_from_point(Q) + msg)) % n
173-
mu = key_agg_coeff(pubkeys, bytes_from_point(P))
174186
s = (k_1 + b * k_2 + e * mu * d) % n
175187
psig = bytes_from_int(s)
176188
pubnonce = cbytes(point_mul(G, k_1_)) + cbytes(point_mul(G, k_2_))
177-
assert partial_sig_verify_internal(psig, pubnonce, aggnonce, pubkeys, bytes_from_point(P), msg)
189+
assert partial_sig_verify_internal(psig, pubnonce, bytes_from_point(P), session_ctx)
178190
return psig
179191

180192
def partial_sig_verify(psig: bytes, pubnonces: List[bytes], pubkeys: List[bytes], msg: bytes, i: int) -> bool:
181193
aggnonce = nonce_agg(pubnonces)
182-
return partial_sig_verify_internal(psig, pubnonces[i], aggnonce, pubkeys, pubkeys[i], msg)
194+
session_ctx = SessionContext(aggnonce, pubkeys, msg)
195+
return partial_sig_verify_internal(psig, pubnonces[i], pubkeys[i], session_ctx)
183196

184-
def partial_sig_verify_internal(psig: bytes, pubnonce: bytes, aggnonce: bytes, pubkeys: List[bytes], pk: bytes, msg: bytes) -> bool:
197+
def partial_sig_verify_internal(psig: bytes, pubnonce: bytes, pk: bytes, session_ctx: SessionContext) -> bool:
198+
(Q, b, R, e) = get_session_values(session_ctx)
185199
s = int_from_bytes(psig)
186200
assert s < n
187-
R_1 = pointc(aggnonce[0:33])
188-
R_2 = pointc(aggnonce[33:66])
189-
Q = key_agg_internal(pubkeys)
190-
b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + bytes_from_point(Q) + msg)) % n
191-
R = point_add(R_1, point_mul(R_2, b))
192201
R_1_ = pointc(pubnonce[0:33])
193202
R_2_ = pointc(pubnonce[33:66])
194203
R__ = point_add(R_1_, point_mul(R_2_, b))
195204
R_ = R__ if has_even_y(R) else point_negate(R__)
196-
e = int_from_bytes(tagged_hash('BIP0340/challenge', bytes_from_point(R) + bytes_from_point(Q) + msg)) % n
197-
mu = key_agg_coeff(pubkeys, pk)
198205
P_ = lift_x(pk)
199206
P = P_ if has_even_y(Q) else point_negate(P_)
207+
mu = get_session_key_agg_coeff(session_ctx, P)
200208
return point_mul(G, s) == point_add(R_, point_mul(P, e * mu % n))
201209

202210
#
@@ -252,9 +260,14 @@ def test_sign_vectors():
252260

253261
pk = bytes_from_point(point_mul(G, int_from_bytes(sk)))
254262

255-
assert sign(secnonce, sk, aggnonce, [pk, X[0], X[1]], msg) == expected[0]
256-
assert sign(secnonce, sk, aggnonce, [X[0], pk, X[1]], msg) == expected[1]
257-
assert sign(secnonce, sk, aggnonce, [X[0], X[1], pk], msg) == expected[2]
263+
session_ctx = SessionContext(aggnonce, [pk, X[0], X[1]], msg)
264+
assert sign(secnonce, sk, session_ctx) == expected[0]
265+
266+
session_ctx = SessionContext(aggnonce, [X[0], pk, X[1]], msg)
267+
assert sign(secnonce, sk, session_ctx) == expected[1]
268+
269+
session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], msg)
270+
assert sign(secnonce, sk, session_ctx) == expected[2]
258271

259272
def test_sign_and_verify_random(iters):
260273
for i in range(iters):
@@ -271,7 +284,8 @@ def test_sign_and_verify_random(iters):
271284

272285
msg = secrets.token_bytes(32)
273286

274-
psig = sign(secnonce_1, sk_1, aggnonce, pubkeys, msg)
287+
session_ctx = SessionContext(aggnonce, pubkeys, msg)
288+
psig = sign(secnonce_1, sk_1, session_ctx)
275289
assert partial_sig_verify(psig, pubnonces, pubkeys, msg, 0)
276290

277291
# Wrong signer index

0 commit comments

Comments
 (0)