1
+ from collections import namedtuple
1
2
from typing import Any , List , Optional , Tuple
2
3
import hashlib
3
4
import secrets
@@ -101,6 +102,23 @@ def pointc(x: bytes) -> Point:
101
102
return point_negate (P )
102
103
assert False
103
104
105
+ SessionContext = namedtuple ('SessionContext' , ['aggnonce' , 'pubkeys' , 'msg' ])
106
+
107
+ def get_session_values (session_ctx : SessionContext ) -> tuple [bytes , List [bytes ], bytes ]:
108
+ (aggnonce , pubkeys , msg ) = session_ctx
109
+ Q = key_agg_internal (pubkeys )
110
+ b = int_from_bytes (tagged_hash ('MuSig/noncecoef' , aggnonce + bytes_from_point (Q ) + msg )) % n
111
+ R_1 = pointc (aggnonce [0 :33 ])
112
+ R_2 = pointc (aggnonce [33 :66 ])
113
+ R = point_add (R_1 , point_mul (R_2 , b ))
114
+ assert not is_infinite (R )
115
+ e = int_from_bytes (tagged_hash ('BIP0340/challenge' , bytes_from_point (R ) + bytes_from_point (Q ) + msg )) % n
116
+ return (Q , b , R , e )
117
+
118
+ def get_session_key_agg_coeff (session_ctx : SessionContext , P : Point ) -> int :
119
+ (_ , pubkeys , _ ) = session_ctx
120
+ return key_agg_coeff (pubkeys , bytes_from_point (P ))
121
+
104
122
def key_agg (pubkeys : List [bytes ]) -> bytes :
105
123
Q = key_agg_internal (pubkeys )
106
124
return bytes_from_point (Q )
@@ -152,13 +170,8 @@ def nonce_agg(pubnonces: List[bytes]) -> bytes:
152
170
aggnonce += cbytes (R_i )
153
171
return aggnonce
154
172
155
- def sign (secnonce : bytes , sk : bytes , aggnonce : bytes , pubkeys : List [bytes ], msg : bytes ) -> bytes :
156
- R_1 = pointc (aggnonce [0 :33 ])
157
- R_2 = pointc (aggnonce [33 :66 ])
158
- Q = key_agg_internal (pubkeys )
159
- b = int_from_bytes (tagged_hash ('MuSig/noncecoef' , aggnonce + bytes_from_point (Q ) + msg )) % n
160
- R = point_add (R_1 , point_mul (R_2 , b ))
161
- assert not is_infinite (R )
173
+ def sign (secnonce : bytes , sk : bytes , session_ctx : SessionContext ) -> bytes :
174
+ (Q , b , R , e ) = get_session_values (session_ctx )
162
175
k_1_ = int_from_bytes (secnonce [0 :32 ])
163
176
k_2_ = int_from_bytes (secnonce [32 :64 ])
164
177
assert 0 < k_1_ < n
@@ -168,35 +181,30 @@ def sign(secnonce: bytes, sk: bytes, aggnonce: bytes, pubkeys: List[bytes], msg:
168
181
d_ = int_from_bytes (sk )
169
182
assert 0 < d_ < n
170
183
P = point_mul (G , d_ )
184
+ mu = get_session_key_agg_coeff (session_ctx , P )
171
185
d = n - d_ if has_even_y (P ) != has_even_y (Q ) else d_
172
- e = int_from_bytes (tagged_hash ('BIP0340/challenge' , bytes_from_point (R ) + bytes_from_point (Q ) + msg )) % n
173
- mu = key_agg_coeff (pubkeys , bytes_from_point (P ))
174
186
s = (k_1 + b * k_2 + e * mu * d ) % n
175
187
psig = bytes_from_int (s )
176
188
pubnonce = cbytes (point_mul (G , k_1_ )) + cbytes (point_mul (G , k_2_ ))
177
- assert partial_sig_verify_internal (psig , pubnonce , aggnonce , pubkeys , bytes_from_point (P ), msg )
189
+ assert partial_sig_verify_internal (psig , pubnonce , bytes_from_point (P ), session_ctx )
178
190
return psig
179
191
180
192
def partial_sig_verify (psig : bytes , pubnonces : List [bytes ], pubkeys : List [bytes ], msg : bytes , i : int ) -> bool :
181
193
aggnonce = nonce_agg (pubnonces )
182
- return partial_sig_verify_internal (psig , pubnonces [i ], aggnonce , pubkeys , pubkeys [i ], msg )
194
+ session_ctx = SessionContext (aggnonce , pubkeys , msg )
195
+ return partial_sig_verify_internal (psig , pubnonces [i ], pubkeys [i ], session_ctx )
183
196
184
- def partial_sig_verify_internal (psig : bytes , pubnonce : bytes , aggnonce : bytes , pubkeys : List [bytes ], pk : bytes , msg : bytes ) -> bool :
197
+ def partial_sig_verify_internal (psig : bytes , pubnonce : bytes , pk : bytes , session_ctx : SessionContext ) -> bool :
198
+ (Q , b , R , e ) = get_session_values (session_ctx )
185
199
s = int_from_bytes (psig )
186
200
assert s < n
187
- R_1 = pointc (aggnonce [0 :33 ])
188
- R_2 = pointc (aggnonce [33 :66 ])
189
- Q = key_agg_internal (pubkeys )
190
- b = int_from_bytes (tagged_hash ('MuSig/noncecoef' , aggnonce + bytes_from_point (Q ) + msg )) % n
191
- R = point_add (R_1 , point_mul (R_2 , b ))
192
201
R_1_ = pointc (pubnonce [0 :33 ])
193
202
R_2_ = pointc (pubnonce [33 :66 ])
194
203
R__ = point_add (R_1_ , point_mul (R_2_ , b ))
195
204
R_ = R__ if has_even_y (R ) else point_negate (R__ )
196
- e = int_from_bytes (tagged_hash ('BIP0340/challenge' , bytes_from_point (R ) + bytes_from_point (Q ) + msg )) % n
197
- mu = key_agg_coeff (pubkeys , pk )
198
205
P_ = lift_x (pk )
199
206
P = P_ if has_even_y (Q ) else point_negate (P_ )
207
+ mu = get_session_key_agg_coeff (session_ctx , P )
200
208
return point_mul (G , s ) == point_add (R_ , point_mul (P , e * mu % n ))
201
209
202
210
#
@@ -252,9 +260,14 @@ def test_sign_vectors():
252
260
253
261
pk = bytes_from_point (point_mul (G , int_from_bytes (sk )))
254
262
255
- assert sign (secnonce , sk , aggnonce , [pk , X [0 ], X [1 ]], msg ) == expected [0 ]
256
- assert sign (secnonce , sk , aggnonce , [X [0 ], pk , X [1 ]], msg ) == expected [1 ]
257
- assert sign (secnonce , sk , aggnonce , [X [0 ], X [1 ], pk ], msg ) == expected [2 ]
263
+ session_ctx = SessionContext (aggnonce , [pk , X [0 ], X [1 ]], msg )
264
+ assert sign (secnonce , sk , session_ctx ) == expected [0 ]
265
+
266
+ session_ctx = SessionContext (aggnonce , [X [0 ], pk , X [1 ]], msg )
267
+ assert sign (secnonce , sk , session_ctx ) == expected [1 ]
268
+
269
+ session_ctx = SessionContext (aggnonce , [X [0 ], X [1 ], pk ], msg )
270
+ assert sign (secnonce , sk , session_ctx ) == expected [2 ]
258
271
259
272
def test_sign_and_verify_random (iters ):
260
273
for i in range (iters ):
@@ -271,7 +284,8 @@ def test_sign_and_verify_random(iters):
271
284
272
285
msg = secrets .token_bytes (32 )
273
286
274
- psig = sign (secnonce_1 , sk_1 , aggnonce , pubkeys , msg )
287
+ session_ctx = SessionContext (aggnonce , pubkeys , msg )
288
+ psig = sign (secnonce_1 , sk_1 , session_ctx )
275
289
assert partial_sig_verify (psig , pubnonces , pubkeys , msg , 0 )
276
290
277
291
# Wrong signer index
0 commit comments