|
| 1 | +from typing import Any, List, Optional, Tuple |
| 2 | +import hashlib |
| 3 | +import secrets |
| 4 | + |
| 5 | +# |
| 6 | +# The following helper functions were copied from the BIP-340 reference implementation: |
| 7 | +# https://github.com/bitcoin/bips/blob/master/bip-0340/reference.py |
| 8 | +# |
| 9 | + |
| 10 | +p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F |
| 11 | +n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 |
| 12 | + |
| 13 | +# Points are tuples of X and Y coordinates and the point at infinity is |
| 14 | +# represented by the None keyword. |
| 15 | +G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8) |
| 16 | + |
| 17 | +Point = Tuple[int, int] |
| 18 | + |
| 19 | +# This implementation can be sped up by storing the midstate after hashing |
| 20 | +# tag_hash instead of rehashing it all the time. |
| 21 | +def tagged_hash(tag: str, msg: bytes) -> bytes: |
| 22 | + tag_hash = hashlib.sha256(tag.encode()).digest() |
| 23 | + return hashlib.sha256(tag_hash + tag_hash + msg).digest() |
| 24 | + |
| 25 | +def is_infinite(P: Optional[Point]) -> bool: |
| 26 | + return P is None |
| 27 | + |
| 28 | +def x(P: Point) -> int: |
| 29 | + assert not is_infinite(P) |
| 30 | + return P[0] |
| 31 | + |
| 32 | +def y(P: Point) -> int: |
| 33 | + assert not is_infinite(P) |
| 34 | + return P[1] |
| 35 | + |
| 36 | +def point_add(P1: Optional[Point], P2: Optional[Point]) -> Optional[Point]: |
| 37 | + if P1 is None: |
| 38 | + return P2 |
| 39 | + if P2 is None: |
| 40 | + return P1 |
| 41 | + if (x(P1) == x(P2)) and (y(P1) != y(P2)): |
| 42 | + return None |
| 43 | + if P1 == P2: |
| 44 | + lam = (3 * x(P1) * x(P1) * pow(2 * y(P1), p - 2, p)) % p |
| 45 | + else: |
| 46 | + lam = ((y(P2) - y(P1)) * pow(x(P2) - x(P1), p - 2, p)) % p |
| 47 | + x3 = (lam * lam - x(P1) - x(P2)) % p |
| 48 | + return (x3, (lam * (x(P1) - x3) - y(P1)) % p) |
| 49 | + |
| 50 | +def point_mul(P: Optional[Point], n: int) -> Optional[Point]: |
| 51 | + R = None |
| 52 | + for i in range(256): |
| 53 | + if (n >> i) & 1: |
| 54 | + R = point_add(R, P) |
| 55 | + P = point_add(P, P) |
| 56 | + return R |
| 57 | + |
| 58 | +def bytes_from_int(x: int) -> bytes: |
| 59 | + return x.to_bytes(32, byteorder="big") |
| 60 | + |
| 61 | +def bytes_from_point(P: Point) -> bytes: |
| 62 | + return bytes_from_int(x(P)) |
| 63 | + |
| 64 | +def lift_x(b: bytes) -> Optional[Point]: |
| 65 | + x = int_from_bytes(b) |
| 66 | + if x >= p: |
| 67 | + return None |
| 68 | + y_sq = (pow(x, 3, p) + 7) % p |
| 69 | + y = pow(y_sq, (p + 1) // 4, p) |
| 70 | + if pow(y, 2, p) != y_sq: |
| 71 | + return None |
| 72 | + return (x, y if y & 1 == 0 else p-y) |
| 73 | + |
| 74 | +def int_from_bytes(b: bytes) -> int: |
| 75 | + return int.from_bytes(b, byteorder="big") |
| 76 | + |
| 77 | +def has_even_y(P: Point) -> bool: |
| 78 | + assert not is_infinite(P) |
| 79 | + return y(P) % 2 == 0 |
| 80 | + |
| 81 | +# |
| 82 | +# End of helper functions copied from BIP-340 reference implementation. |
| 83 | +# |
| 84 | + |
| 85 | +def cbytes(P: Point) -> bytes: |
| 86 | + a = b'\x02' if has_even_y(P) else b'\x03' |
| 87 | + return a + bytes_from_point(P) |
| 88 | + |
| 89 | +def point_negate(P: Point) -> Point: |
| 90 | + if is_infinite(P): |
| 91 | + return P |
| 92 | + return (x(P), p - y(P)) |
| 93 | + |
| 94 | +def pointc(x: bytes) -> Point: |
| 95 | + P = lift_x(x[1:33]) |
| 96 | + if x[0] == 2: |
| 97 | + return P |
| 98 | + elif x[0] == 3: |
| 99 | + return point_negate(P) |
| 100 | + assert False |
| 101 | + |
| 102 | +def key_agg(pubkeys: List[bytes]) -> bytes: |
| 103 | + Q = key_agg_internal(pubkeys) |
| 104 | + return bytes_from_point(Q) |
| 105 | + |
| 106 | +def key_agg_internal(pubkeys: List[bytes]) -> Point: |
| 107 | + u = len(pubkeys) |
| 108 | + Q = None |
| 109 | + for i in range(u): |
| 110 | + a_i = key_agg_coeff(pubkeys, pubkeys[i]) |
| 111 | + P_i = lift_x(pubkeys[i]) |
| 112 | + Q = point_add(Q, point_mul(P_i, a_i)) |
| 113 | + assert not is_infinite(Q) |
| 114 | + return Q |
| 115 | + |
| 116 | +def hash_keys(pubkeys: List[bytes]) -> bytes: |
| 117 | + return tagged_hash('KeyAgg list', b''.join(pubkeys)) |
| 118 | + |
| 119 | +def is_second(pubkeys: List[bytes], pk: bytes) -> bool: |
| 120 | + u = len(pubkeys) |
| 121 | + for j in range(u): |
| 122 | + if pubkeys[j] != pubkeys[0]: |
| 123 | + return pubkeys[j] == pk |
| 124 | + return False |
| 125 | + |
| 126 | +def key_agg_coeff(pubkeys: List[bytes], pk: bytes) -> int: |
| 127 | + if is_second(pubkeys, pk): |
| 128 | + return 1 |
| 129 | + else: |
| 130 | + L = hash_keys(pubkeys) |
| 131 | + return int_from_bytes(tagged_hash('KeyAgg coefficient', L + pk)) % n |
| 132 | + |
| 133 | +def nonce_gen() -> Tuple[bytes, bytes]: |
| 134 | + k_1 = 1 + secrets.randbelow(n - 2) |
| 135 | + k_2 = 1 + secrets.randbelow(n - 2) |
| 136 | + R_1 = point_mul(G, k_1) |
| 137 | + R_2 = point_mul(G, k_2) |
| 138 | + pubnonce = cbytes(R_1) + cbytes(R_2) |
| 139 | + secnonce = bytes_from_int(k_1) + bytes_from_int(k_2) |
| 140 | + return secnonce, pubnonce |
| 141 | + |
| 142 | +def nonce_agg(pubnonces: List[bytes]) -> bytes: |
| 143 | + u = len(pubnonces) |
| 144 | + aggnonce = b'' |
| 145 | + for i in (1, 2): |
| 146 | + R_i_ = None |
| 147 | + for j in range(u): |
| 148 | + R_i_ = point_add(R_i_, pointc(pubnonces[j][(i-1)*33:i*33])) |
| 149 | + R_i = R_i_ if not is_infinite(R_i_) else G |
| 150 | + aggnonce += cbytes(R_i) |
| 151 | + return aggnonce |
| 152 | + |
| 153 | +def sign(secnonce: bytes, sk: bytes, aggnonce: bytes, pubkeys: List[bytes], msg: bytes) -> bytes: |
| 154 | + R_1 = pointc(aggnonce[0:33]) |
| 155 | + R_2 = pointc(aggnonce[33:66]) |
| 156 | + Q = key_agg_internal(pubkeys) |
| 157 | + b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + bytes_from_point(Q) + msg)) % n |
| 158 | + R = point_add(R_1, point_mul(R_2, b)) |
| 159 | + assert not is_infinite(R) |
| 160 | + k_1_ = int_from_bytes(secnonce[0:32]) |
| 161 | + k_2_ = int_from_bytes(secnonce[32:64]) |
| 162 | + assert 0 < k_1_ < n |
| 163 | + assert 0 < k_2_ < n |
| 164 | + k_1 = k_1_ if has_even_y(R) else n - k_1_ |
| 165 | + k_2 = k_2_ if has_even_y(R) else n - k_2_ |
| 166 | + d_ = int_from_bytes(sk) |
| 167 | + assert 0 < d_ < n |
| 168 | + P = point_mul(G, d_) |
| 169 | + d = n - d_ if has_even_y(P) != has_even_y(Q) else d_ |
| 170 | + e = int_from_bytes(tagged_hash('BIP0340/challenge', bytes_from_point(R) + bytes_from_point(Q) + msg)) % n |
| 171 | + mu = key_agg_coeff(pubkeys, bytes_from_point(P)) |
| 172 | + s = (k_1 + b * k_2 + e * mu * d) % n |
| 173 | + psig = bytes_from_int(s) |
| 174 | + pubnonce = cbytes(point_mul(G, k_1_)) + cbytes(point_mul(G, k_2_)) |
| 175 | + assert partial_sig_verify_internal(psig, pubnonce, aggnonce, pubkeys, bytes_from_point(P), msg) |
| 176 | + return psig |
| 177 | + |
| 178 | +def partial_sig_verify(psig: bytes, pubnonces: List[bytes], pubkeys: List[bytes], msg: bytes, i: int) -> bool: |
| 179 | + aggnonce = nonce_agg(pubnonces) |
| 180 | + return partial_sig_verify_internal(psig, pubnonces[i], aggnonce, pubkeys, pubkeys[i], msg) |
| 181 | + |
| 182 | +def partial_sig_verify_internal(psig: bytes, pubnonce: bytes, aggnonce: bytes, pubkeys: List[bytes], pk: bytes, msg: bytes) -> bool: |
| 183 | + s = int_from_bytes(psig) |
| 184 | + assert s < n |
| 185 | + R_1 = pointc(aggnonce[0:33]) |
| 186 | + R_2 = pointc(aggnonce[33:66]) |
| 187 | + Q = key_agg_internal(pubkeys) |
| 188 | + b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + bytes_from_point(Q) + msg)) % n |
| 189 | + R = point_add(R_1, point_mul(R_2, b)) |
| 190 | + R_1_ = pointc(pubnonce[0:33]) |
| 191 | + R_2_ = pointc(pubnonce[33:66]) |
| 192 | + R__ = point_add(R_1_, point_mul(R_2_, b)) |
| 193 | + R_ = R__ if has_even_y(R) else point_negate(R__) |
| 194 | + e = int_from_bytes(tagged_hash('BIP0340/challenge', bytes_from_point(R) + bytes_from_point(Q) + msg)) % n |
| 195 | + mu = key_agg_coeff(pubkeys, pk) |
| 196 | + P_ = lift_x(pk) |
| 197 | + P = P_ if has_even_y(Q) else point_negate(P_) |
| 198 | + return point_mul(G, s) == point_add(R_, point_mul(P, e * mu % n)) |
| 199 | + |
| 200 | +# |
| 201 | +# The following code is only used for testing. |
| 202 | +# Test vectors were copied from libsecp256k1-zkp's MuSig test file. |
| 203 | +# See `musig_test_vectors_keyagg` and `musig_test_vectors_sign` in |
| 204 | +# https://github.com/ElementsProject/secp256k1-zkp/blob/master/src/modules/musig/tests_impl.h |
| 205 | +# |
| 206 | +def fromhex_all(l): |
| 207 | + return [bytes.fromhex(l_i) for l_i in l] |
| 208 | + |
| 209 | +def test_key_agg_vectors(): |
| 210 | + X = fromhex_all([ |
| 211 | + 'F9308A019258C31049344F85F89D5229B531C845836F99B08601F113BCE036F9', |
| 212 | + 'DFF1D77F2A671C5F36183726DB2341BE58FEAE1DA2DECED843240F7B502BA659', |
| 213 | + '3590A94E768F8E1815C2F24B4D80A8E3149316C3518CE7B7AD338368D038CA66', |
| 214 | + ]) |
| 215 | + |
| 216 | + expected = fromhex_all([ |
| 217 | + 'E5830140512195D74C8307E39637CBE5FB730EBEAB80EC514CF88A877CEEEE0B', |
| 218 | + 'D70CD69A2647F7390973DF48CBFA2CCC407B8B2D60B08C5F1641185C7998A290', |
| 219 | + '81A8B093912C9E481408D09776CEFB48AEB8B65481B6BAAFB3C5810106717BEB', |
| 220 | + '2EB18851887E7BDC5E830E89B19DDBC28078F1FA88AAD0AD01CA06FE4F80210B', |
| 221 | + ]) |
| 222 | + |
| 223 | + assert key_agg([X[0], X[1], X[2]]) == expected[0] |
| 224 | + assert key_agg([X[2], X[1], X[0]]) == expected[1] |
| 225 | + assert key_agg([X[0], X[0], X[0]]) == expected[2] |
| 226 | + assert key_agg([X[0], X[0], X[1], X[1]]) == expected[3] |
| 227 | + |
| 228 | +def test_sign_vectors(): |
| 229 | + X = fromhex_all([ |
| 230 | + 'F9308A019258C31049344F85F89D5229B531C845836F99B08601F113BCE036F9', |
| 231 | + 'DFF1D77F2A671C5F36183726DB2341BE58FEAE1DA2DECED843240F7B502BA659', |
| 232 | + ]) |
| 233 | + |
| 234 | + secnonce = bytes.fromhex( |
| 235 | + '508B81A611F100A6B2B6B29656590898AF488BCF2E1F55CF22E5CFB84421FE61' + |
| 236 | + 'FA27FD49B1D50085B481285E1CA205D55C82CC1B31FF5CD54A489829355901F7') |
| 237 | + |
| 238 | + aggnonce = bytes.fromhex( |
| 239 | + '028465FCF0BBDBCF443AABCCE533D42B4B5A10966AC09A49655E8C42DAAB8FCD61' + |
| 240 | + '037496A3CC86926D452CAFCFD55D25972CA1675D549310DE296BFF42F72EEEA8C9') |
| 241 | + |
| 242 | + sk = bytes.fromhex('7FB9E0E687ADA1EEBF7ECFE2F21E73EBDB51A7D450948DFE8D76D7F2D1007671') |
| 243 | + msg = bytes.fromhex('F95466D086770E689964664219266FE5ED215C92AE20BAB5C9D79ADDDDF3C0CF') |
| 244 | + |
| 245 | + expected = fromhex_all([ |
| 246 | + '68537CC5234E505BD14061F8DA9E90C220A181855FD8BDB7F127BB12403B4D3B', |
| 247 | + '2DF67BFFF18E3DE797E13C6475C963048138DAEC5CB20A357CECA7C8424295EA', |
| 248 | + '0D5B651E6DE34A29A12DE7A8B4183B4AE6A7F7FBE15CDCAFA4A3D1BCAABC7517', |
| 249 | + ]) |
| 250 | + |
| 251 | + pk = bytes_from_point(point_mul(G, int_from_bytes(sk))) |
| 252 | + |
| 253 | + assert sign(secnonce, sk, aggnonce, [pk, X[0], X[1]], msg) == expected[0] |
| 254 | + assert sign(secnonce, sk, aggnonce, [X[0], pk, X[1]], msg) == expected[1] |
| 255 | + assert sign(secnonce, sk, aggnonce, [X[0], X[1], pk], msg) == expected[2] |
| 256 | + |
| 257 | +def test_sign_and_verify_random(iters): |
| 258 | + for i in range(iters): |
| 259 | + sk_1 = secrets.token_bytes(32) |
| 260 | + sk_2 = secrets.token_bytes(32) |
| 261 | + pk_1 = bytes_from_point(point_mul(G, int_from_bytes(sk_1))) |
| 262 | + pk_2 = bytes_from_point(point_mul(G, int_from_bytes(sk_2))) |
| 263 | + pubkeys = [pk_1, pk_2] |
| 264 | + |
| 265 | + secnonce_1, pubnonce_1 = nonce_gen() |
| 266 | + secnonce_2, pubnonce_2 = nonce_gen() |
| 267 | + pubnonces = [pubnonce_1, pubnonce_2] |
| 268 | + aggnonce = nonce_agg(pubnonces) |
| 269 | + |
| 270 | + msg = secrets.token_bytes(32) |
| 271 | + |
| 272 | + psig = sign(secnonce_1, sk_1, aggnonce, pubkeys, msg) |
| 273 | + assert partial_sig_verify(psig, pubnonces, pubkeys, msg, 0) |
| 274 | + |
| 275 | + # Wrong signer index |
| 276 | + assert not partial_sig_verify(psig, pubnonces, pubkeys, msg, 1) |
| 277 | + |
| 278 | + # Wrong message |
| 279 | + assert not partial_sig_verify(psig, pubnonces, pubkeys, secrets.token_bytes(32), 0) |
| 280 | + |
| 281 | +if __name__ == '__main__': |
| 282 | + test_key_agg_vectors() |
| 283 | + test_sign_vectors() |
| 284 | + test_sign_and_verify_random(4) |
0 commit comments