Skip to content

Commit 526542c

Browse files
committed
musig-spec: Add naive Python reference implementation
1 parent 73f0cbd commit 526542c

File tree

2 files changed

+287
-0
lines changed

2 files changed

+287
-0
lines changed

doc/musig-reference.py

Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
from typing import Any, List, Optional, Tuple
2+
import hashlib
3+
import secrets
4+
5+
#
6+
# The following helper functions were copied from the BIP-340 reference implementation:
7+
# https://github.com/bitcoin/bips/blob/master/bip-0340/reference.py
8+
#
9+
10+
p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
11+
n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
12+
13+
# Points are tuples of X and Y coordinates and the point at infinity is
14+
# represented by the None keyword.
15+
G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8)
16+
17+
Point = Tuple[int, int]
18+
19+
# This implementation can be sped up by storing the midstate after hashing
20+
# tag_hash instead of rehashing it all the time.
21+
def tagged_hash(tag: str, msg: bytes) -> bytes:
22+
tag_hash = hashlib.sha256(tag.encode()).digest()
23+
return hashlib.sha256(tag_hash + tag_hash + msg).digest()
24+
25+
def is_infinite(P: Optional[Point]) -> bool:
26+
return P is None
27+
28+
def x(P: Point) -> int:
29+
assert not is_infinite(P)
30+
return P[0]
31+
32+
def y(P: Point) -> int:
33+
assert not is_infinite(P)
34+
return P[1]
35+
36+
def point_add(P1: Optional[Point], P2: Optional[Point]) -> Optional[Point]:
37+
if P1 is None:
38+
return P2
39+
if P2 is None:
40+
return P1
41+
if (x(P1) == x(P2)) and (y(P1) != y(P2)):
42+
return None
43+
if P1 == P2:
44+
lam = (3 * x(P1) * x(P1) * pow(2 * y(P1), p - 2, p)) % p
45+
else:
46+
lam = ((y(P2) - y(P1)) * pow(x(P2) - x(P1), p - 2, p)) % p
47+
x3 = (lam * lam - x(P1) - x(P2)) % p
48+
return (x3, (lam * (x(P1) - x3) - y(P1)) % p)
49+
50+
def point_mul(P: Optional[Point], n: int) -> Optional[Point]:
51+
R = None
52+
for i in range(256):
53+
if (n >> i) & 1:
54+
R = point_add(R, P)
55+
P = point_add(P, P)
56+
return R
57+
58+
def bytes_from_int(x: int) -> bytes:
59+
return x.to_bytes(32, byteorder="big")
60+
61+
def bytes_from_point(P: Point) -> bytes:
62+
return bytes_from_int(x(P))
63+
64+
def lift_x(b: bytes) -> Optional[Point]:
65+
x = int_from_bytes(b)
66+
if x >= p:
67+
return None
68+
y_sq = (pow(x, 3, p) + 7) % p
69+
y = pow(y_sq, (p + 1) // 4, p)
70+
if pow(y, 2, p) != y_sq:
71+
return None
72+
return (x, y if y & 1 == 0 else p-y)
73+
74+
def int_from_bytes(b: bytes) -> int:
75+
return int.from_bytes(b, byteorder="big")
76+
77+
def has_even_y(P: Point) -> bool:
78+
assert not is_infinite(P)
79+
return y(P) % 2 == 0
80+
81+
#
82+
# End of helper functions copied from BIP-340 reference implementation.
83+
#
84+
85+
def cbytes(P: Point) -> bytes:
86+
a = b'\x02' if has_even_y(P) else b'\x03'
87+
return a + bytes_from_point(P)
88+
89+
def point_negate(P: Point) -> Point:
90+
if is_infinite(P):
91+
return P
92+
return (x(P), p - y(P))
93+
94+
def pointc(x: bytes) -> Point:
95+
P = lift_x(x[1:33])
96+
if x[0] == 2:
97+
return P
98+
elif x[0] == 3:
99+
return point_negate(P)
100+
assert False
101+
102+
def key_agg(pubkeys: List[bytes]) -> bytes:
103+
Q = key_agg_internal(pubkeys)
104+
return bytes_from_point(Q)
105+
106+
def key_agg_internal(pubkeys: List[bytes]) -> Point:
107+
u = len(pubkeys)
108+
Q = None
109+
for i in range(u):
110+
a_i = key_agg_coeff(pubkeys, pubkeys[i])
111+
P_i = lift_x(pubkeys[i])
112+
Q = point_add(Q, point_mul(P_i, a_i))
113+
assert not is_infinite(Q)
114+
return Q
115+
116+
def hash_keys(pubkeys: List[bytes]) -> bytes:
117+
return tagged_hash('KeyAgg list', b''.join(pubkeys))
118+
119+
def is_second(pubkeys: List[bytes], pk: bytes) -> bool:
120+
u = len(pubkeys)
121+
for j in range(u):
122+
if pubkeys[j] != pubkeys[0]:
123+
return pubkeys[j] == pk
124+
return False
125+
126+
def key_agg_coeff(pubkeys: List[bytes], pk: bytes) -> int:
127+
if is_second(pubkeys, pk):
128+
return 1
129+
else:
130+
L = hash_keys(pubkeys)
131+
return int_from_bytes(tagged_hash('KeyAgg coefficient', L + pk)) % n
132+
133+
def nonce_gen() -> Tuple[bytes, bytes]:
134+
k_1 = 1 + secrets.randbelow(n - 2)
135+
k_2 = 1 + secrets.randbelow(n - 2)
136+
R_1 = point_mul(G, k_1)
137+
R_2 = point_mul(G, k_2)
138+
pubnonce = cbytes(R_1) + cbytes(R_2)
139+
secnonce = bytes_from_int(k_1) + bytes_from_int(k_2)
140+
return secnonce, pubnonce
141+
142+
def nonce_agg(pubnonces: List[bytes]) -> bytes:
143+
u = len(pubnonces)
144+
aggnonce = b''
145+
for i in (1, 2):
146+
R_i_ = None
147+
for j in range(u):
148+
R_i_ = point_add(R_i_, pointc(pubnonces[j][(i-1)*33:i*33]))
149+
R_i = R_i_ if not is_infinite(R_i_) else G
150+
aggnonce += cbytes(R_i)
151+
return aggnonce
152+
153+
def sign(secnonce: bytes, sk: bytes, aggnonce: bytes, pubkeys: List[bytes], msg: bytes) -> bytes:
154+
R_1 = pointc(aggnonce[0:33])
155+
R_2 = pointc(aggnonce[33:66])
156+
Q = key_agg_internal(pubkeys)
157+
b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + bytes_from_point(Q) + msg)) % n
158+
R = point_add(R_1, point_mul(R_2, b))
159+
assert not is_infinite(R)
160+
k_1_ = int_from_bytes(secnonce[0:32])
161+
k_2_ = int_from_bytes(secnonce[32:64])
162+
assert 0 < k_1_ < n
163+
assert 0 < k_2_ < n
164+
k_1 = k_1_ if has_even_y(R) else n - k_1_
165+
k_2 = k_2_ if has_even_y(R) else n - k_2_
166+
d_ = int_from_bytes(sk)
167+
assert 0 < d_ < n
168+
P = point_mul(G, d_)
169+
d = n - d_ if has_even_y(P) != has_even_y(Q) else d_
170+
e = int_from_bytes(tagged_hash('BIP0340/challenge', bytes_from_point(R) + bytes_from_point(Q) + msg)) % n
171+
mu = key_agg_coeff(pubkeys, bytes_from_point(P))
172+
s = (k_1 + b * k_2 + e * mu * d) % n
173+
psig = bytes_from_int(s)
174+
pubnonce = cbytes(point_mul(G, k_1_)) + cbytes(point_mul(G, k_2_))
175+
assert partial_sig_verify_internal(psig, pubnonce, aggnonce, pubkeys, bytes_from_point(P), msg)
176+
return psig
177+
178+
def partial_sig_verify(psig: bytes, pubnonces: List[bytes], pubkeys: List[bytes], msg: bytes, i: int) -> bool:
179+
aggnonce = nonce_agg(pubnonces)
180+
return partial_sig_verify_internal(psig, pubnonces[i], aggnonce, pubkeys, pubkeys[i], msg)
181+
182+
def partial_sig_verify_internal(psig: bytes, pubnonce: bytes, aggnonce: bytes, pubkeys: List[bytes], pk: bytes, msg: bytes) -> bool:
183+
s = int_from_bytes(psig)
184+
assert s < n
185+
R_1 = pointc(aggnonce[0:33])
186+
R_2 = pointc(aggnonce[33:66])
187+
Q = key_agg_internal(pubkeys)
188+
b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + bytes_from_point(Q) + msg)) % n
189+
R = point_add(R_1, point_mul(R_2, b))
190+
R_1_ = pointc(pubnonce[0:33])
191+
R_2_ = pointc(pubnonce[33:66])
192+
R__ = point_add(R_1_, point_mul(R_2_, b))
193+
R_ = R__ if has_even_y(R) else point_negate(R__)
194+
e = int_from_bytes(tagged_hash('BIP0340/challenge', bytes_from_point(R) + bytes_from_point(Q) + msg)) % n
195+
mu = key_agg_coeff(pubkeys, pk)
196+
P_ = lift_x(pk)
197+
P = P_ if has_even_y(Q) else point_negate(P_)
198+
return point_mul(G, s) == point_add(R_, point_mul(P, e * mu % n))
199+
200+
#
201+
# The following code is only used for testing.
202+
# Test vectors were copied from libsecp256k1-zkp's MuSig test file.
203+
# See `musig_test_vectors_keyagg` and `musig_test_vectors_sign` in
204+
# https://github.com/ElementsProject/secp256k1-zkp/blob/master/src/modules/musig/tests_impl.h
205+
#
206+
def fromhex_all(l):
207+
return [bytes.fromhex(l_i) for l_i in l]
208+
209+
def test_key_agg_vectors():
210+
X = fromhex_all([
211+
'F9308A019258C31049344F85F89D5229B531C845836F99B08601F113BCE036F9',
212+
'DFF1D77F2A671C5F36183726DB2341BE58FEAE1DA2DECED843240F7B502BA659',
213+
'3590A94E768F8E1815C2F24B4D80A8E3149316C3518CE7B7AD338368D038CA66',
214+
])
215+
216+
expected = fromhex_all([
217+
'E5830140512195D74C8307E39637CBE5FB730EBEAB80EC514CF88A877CEEEE0B',
218+
'D70CD69A2647F7390973DF48CBFA2CCC407B8B2D60B08C5F1641185C7998A290',
219+
'81A8B093912C9E481408D09776CEFB48AEB8B65481B6BAAFB3C5810106717BEB',
220+
'2EB18851887E7BDC5E830E89B19DDBC28078F1FA88AAD0AD01CA06FE4F80210B',
221+
])
222+
223+
assert key_agg([X[0], X[1], X[2]]) == expected[0]
224+
assert key_agg([X[2], X[1], X[0]]) == expected[1]
225+
assert key_agg([X[0], X[0], X[0]]) == expected[2]
226+
assert key_agg([X[0], X[0], X[1], X[1]]) == expected[3]
227+
228+
def test_sign_vectors():
229+
X = fromhex_all([
230+
'F9308A019258C31049344F85F89D5229B531C845836F99B08601F113BCE036F9',
231+
'DFF1D77F2A671C5F36183726DB2341BE58FEAE1DA2DECED843240F7B502BA659',
232+
])
233+
234+
secnonce = bytes.fromhex(
235+
'508B81A611F100A6B2B6B29656590898AF488BCF2E1F55CF22E5CFB84421FE61' +
236+
'FA27FD49B1D50085B481285E1CA205D55C82CC1B31FF5CD54A489829355901F7')
237+
238+
aggnonce = bytes.fromhex(
239+
'028465FCF0BBDBCF443AABCCE533D42B4B5A10966AC09A49655E8C42DAAB8FCD61' +
240+
'037496A3CC86926D452CAFCFD55D25972CA1675D549310DE296BFF42F72EEEA8C9')
241+
242+
sk = bytes.fromhex('7FB9E0E687ADA1EEBF7ECFE2F21E73EBDB51A7D450948DFE8D76D7F2D1007671')
243+
msg = bytes.fromhex('F95466D086770E689964664219266FE5ED215C92AE20BAB5C9D79ADDDDF3C0CF')
244+
245+
expected = fromhex_all([
246+
'68537CC5234E505BD14061F8DA9E90C220A181855FD8BDB7F127BB12403B4D3B',
247+
'2DF67BFFF18E3DE797E13C6475C963048138DAEC5CB20A357CECA7C8424295EA',
248+
'0D5B651E6DE34A29A12DE7A8B4183B4AE6A7F7FBE15CDCAFA4A3D1BCAABC7517',
249+
])
250+
251+
pk = bytes_from_point(point_mul(G, int_from_bytes(sk)))
252+
253+
assert sign(secnonce, sk, aggnonce, [pk, X[0], X[1]], msg) == expected[0]
254+
assert sign(secnonce, sk, aggnonce, [X[0], pk, X[1]], msg) == expected[1]
255+
assert sign(secnonce, sk, aggnonce, [X[0], X[1], pk], msg) == expected[2]
256+
257+
def test_sign_and_verify_random(iters):
258+
for i in range(iters):
259+
sk_1 = secrets.token_bytes(32)
260+
sk_2 = secrets.token_bytes(32)
261+
pk_1 = bytes_from_point(point_mul(G, int_from_bytes(sk_1)))
262+
pk_2 = bytes_from_point(point_mul(G, int_from_bytes(sk_2)))
263+
pubkeys = [pk_1, pk_2]
264+
265+
secnonce_1, pubnonce_1 = nonce_gen()
266+
secnonce_2, pubnonce_2 = nonce_gen()
267+
pubnonces = [pubnonce_1, pubnonce_2]
268+
aggnonce = nonce_agg(pubnonces)
269+
270+
msg = secrets.token_bytes(32)
271+
272+
psig = sign(secnonce_1, sk_1, aggnonce, pubkeys, msg)
273+
assert partial_sig_verify(psig, pubnonces, pubkeys, msg, 0)
274+
275+
# Wrong signer index
276+
assert not partial_sig_verify(psig, pubnonces, pubkeys, msg, 1)
277+
278+
# Wrong message
279+
assert not partial_sig_verify(psig, pubnonces, pubkeys, secrets.token_bytes(32), 0)
280+
281+
if __name__ == '__main__':
282+
test_key_agg_vectors()
283+
test_sign_vectors()
284+
test_sign_and_verify_random(4)

doc/musig-spec.mediawiki

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,9 @@ Avoiding reuse also implies that the ''NonceGen'' algorithm must compute unbiase
248248
There are some vectors in libsecp256k1's [https://github.com/ElementsProject/secp256k1-zkp/blob/master/src/modules/musig/tests_impl.h MuSig test file].
249249
Search for the ''musig_test_vectors_keyagg'' and ''musig_test_vectors_sign'' functions.
250250

251+
We provide a naive, highly inefficient, and non-constant time [[musig-reference.py|pure Python 3.7 reference implementation of the key aggregation, partial signing, and partial signature verification algorithms]].
252+
The reference implementation is for demonstration purposes only and not to be used in production environments.
253+
251254
== Footnotes ==
252255

253256
<references />

0 commit comments

Comments
 (0)