Skip to content

Commit cc9f91c

Browse files
committed
XMSS/XMSS^MT implementation according to draft-irtf-cfrg-xmss-hash-based-signatures-09 including BDS algorithm for efficient auth path computation
1 parent b94ff84 commit cc9f91c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+8090
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
package org.bouncycastle.pqc.crypto.xmss;
2+
3+
import java.io.Serializable;
4+
import java.util.ArrayDeque;
5+
import java.util.ArrayList;
6+
import java.util.List;
7+
import java.util.Map;
8+
import java.util.Stack;
9+
import java.util.TreeMap;
10+
11+
/**
12+
* BDS.
13+
*
14+
* @author Sebastian Roland <[email protected]>
15+
*/
16+
public class BDS implements Serializable {
17+
18+
private static final long serialVersionUID = 1L;
19+
20+
private class TreeHash implements Serializable {
21+
22+
private static final long serialVersionUID = 1L;
23+
24+
private XMSSNode tailNode;
25+
private final int initialHeight;
26+
private int height;
27+
private int nextIndex;
28+
private boolean initialized;
29+
private boolean finished;
30+
31+
private TreeHash(int initialHeight) {
32+
super();
33+
this.initialHeight = initialHeight;
34+
initialized = false;
35+
finished = false;
36+
}
37+
38+
private void initialize(int nextIndex) {
39+
tailNode = null;
40+
height = initialHeight;
41+
this.nextIndex = nextIndex;
42+
initialized = true;
43+
finished = false;
44+
}
45+
46+
private void update(OTSHashAddress otsHashAddress) {
47+
if (otsHashAddress == null) {
48+
throw new NullPointerException("otsHashAddress == null");
49+
}
50+
if (finished || !initialized) {
51+
throw new IllegalStateException("finished or not initialized");
52+
}
53+
/* prepare addresses */
54+
otsHashAddress.setOTSAddress(nextIndex);
55+
LTreeAddress lTreeAddress = new LTreeAddress();
56+
lTreeAddress.setLayerAddress(otsHashAddress.getLayerAddress());
57+
lTreeAddress.setTreeAddress(otsHashAddress.getTreeAddress());
58+
lTreeAddress.setLTreeAddress(nextIndex);
59+
HashTreeAddress hashTreeAddress = new HashTreeAddress();
60+
hashTreeAddress.setLayerAddress(otsHashAddress.getLayerAddress());
61+
hashTreeAddress.setTreeAddress(otsHashAddress.getTreeAddress());
62+
hashTreeAddress.setTreeHeight(0);
63+
hashTreeAddress.setTreeIndex(nextIndex);
64+
65+
/* calculate leaf node */
66+
wotsPlus.importKeys(xmss.getWOTSPlusSecretKey(otsHashAddress), xmss.getPublicSeed());
67+
WOTSPlusPublicKey wotsPlusPublicKey = wotsPlus.getPublicKey(otsHashAddress);
68+
XMSSNode node = xmss.lTree(wotsPlusPublicKey, lTreeAddress);
69+
70+
while (!stack.isEmpty() && stack.peek().getHeight() == node.getHeight() && stack.peek().getHeight() != initialHeight) {
71+
hashTreeAddress.setTreeIndex((hashTreeAddress.getTreeIndex() - 1) / 2);
72+
node = xmss.randomizeHash(stack.pop(), node, hashTreeAddress);
73+
node.setHeight(node.getHeight() + 1);
74+
hashTreeAddress.setTreeHeight(hashTreeAddress.getTreeHeight() + 1);
75+
}
76+
77+
if (tailNode == null) {
78+
tailNode = node;
79+
} else {
80+
if (tailNode.getHeight() == node.getHeight()) {
81+
hashTreeAddress.setTreeIndex((hashTreeAddress.getTreeIndex() - 1) / 2);
82+
node = xmss.randomizeHash(tailNode, node, hashTreeAddress);
83+
node.setHeight(tailNode.getHeight() + 1);
84+
tailNode = node;
85+
hashTreeAddress.setTreeHeight(hashTreeAddress.getTreeHeight() + 1);
86+
} else {
87+
stack.push(node);
88+
}
89+
}
90+
91+
if (tailNode.getHeight() == initialHeight) {
92+
finished = true;
93+
} else {
94+
height = node.getHeight();
95+
nextIndex++;
96+
}
97+
}
98+
99+
private int getHeight() {
100+
if (!initialized || finished) {
101+
return Integer.MAX_VALUE;
102+
}
103+
return height;
104+
}
105+
106+
private int getIndexLeaf() {
107+
return nextIndex;
108+
}
109+
110+
private void setNode(XMSSNode node) {
111+
tailNode = node;
112+
height = node.getHeight();
113+
if (height == initialHeight) {
114+
finished = true;
115+
}
116+
}
117+
118+
private boolean isFinished() {
119+
return finished;
120+
}
121+
122+
private boolean isInitialized() {
123+
return initialized;
124+
}
125+
}
126+
127+
private transient XMSS xmss;
128+
private transient WOTSPlus wotsPlus;
129+
private final int treeHeight;
130+
private int k;
131+
private XMSSNode root;
132+
private List<XMSSNode> authenticationPath;
133+
private Map<Integer, ArrayDeque<XMSSNode>> retain;
134+
private Stack<XMSSNode> stack;
135+
private List<TreeHash> treeHashInstances;
136+
private Map<Integer, XMSSNode> keep;
137+
private int index;
138+
139+
protected BDS(XMSS xmss) {
140+
super();
141+
if (xmss == null) {
142+
throw new NullPointerException("xmss == null");
143+
}
144+
this.xmss = xmss;
145+
wotsPlus = xmss.getWOTSPlus();
146+
treeHeight = xmss.getParams().getHeight();
147+
k = xmss.getParams().getK();
148+
if (k > treeHeight || k < 2 || ((treeHeight - k) % 2) != 0) {
149+
throw new IllegalArgumentException("illegal value for BDS parameter k");
150+
}
151+
authenticationPath = new ArrayList<XMSSNode>();
152+
retain = new TreeMap<Integer, ArrayDeque<XMSSNode>>();
153+
stack = new Stack<XMSSNode>();
154+
initializeTreeHashInstances();
155+
keep = new TreeMap<Integer, XMSSNode>();
156+
index = 0;
157+
}
158+
159+
private void initializeTreeHashInstances() {
160+
treeHashInstances = new ArrayList<TreeHash>();
161+
for (int height = 0; height < (treeHeight - k); height++) {
162+
treeHashInstances.add(new TreeHash(height));
163+
}
164+
}
165+
166+
protected XMSSNode initialize(OTSHashAddress otsHashAddress) {
167+
if (otsHashAddress == null) {
168+
throw new NullPointerException("otsHashAddress == null");
169+
}
170+
/* prepare addresses */
171+
LTreeAddress lTreeAddress = new LTreeAddress();
172+
lTreeAddress.setLayerAddress(otsHashAddress.getLayerAddress());
173+
lTreeAddress.setTreeAddress(otsHashAddress.getTreeAddress());
174+
HashTreeAddress hashTreeAddress = new HashTreeAddress();
175+
hashTreeAddress.setLayerAddress(otsHashAddress.getLayerAddress());
176+
hashTreeAddress.setTreeAddress(otsHashAddress.getTreeAddress());
177+
178+
/* iterate indexes */
179+
for (int indexLeaf = 0; indexLeaf < (1 << treeHeight); indexLeaf++) {
180+
/* generate leaf */
181+
otsHashAddress.setOTSAddress(indexLeaf);
182+
/* import WOTSPlusSecretKey as its needed to calculate the public key on the fly */
183+
wotsPlus.importKeys(xmss.getWOTSPlusSecretKey(otsHashAddress), xmss.getPublicSeed());
184+
WOTSPlusPublicKey wotsPlusPublicKey = wotsPlus.getPublicKey(otsHashAddress);
185+
lTreeAddress.setLTreeAddress(indexLeaf);
186+
XMSSNode node = xmss.lTree(wotsPlusPublicKey, lTreeAddress);
187+
188+
hashTreeAddress.setTreeHeight(0);
189+
hashTreeAddress.setTreeIndex(indexLeaf);
190+
while (!stack.isEmpty() && stack.peek().getHeight() == node.getHeight()) {
191+
/* add to authenticationPath if leafIndex == 1 */
192+
int indexOnHeight = ((int)Math.floor(indexLeaf / (1 << node.getHeight())));
193+
if (indexOnHeight == 1) {
194+
authenticationPath.add(node.clone());
195+
}
196+
/* store next right authentication node */
197+
if (indexOnHeight == 3 && node.getHeight() < (treeHeight - k)) {
198+
treeHashInstances.get(node.getHeight()).setNode(node.clone());
199+
}
200+
if (indexOnHeight >= 3 && (indexOnHeight & 1) == 1 && node.getHeight() >= (treeHeight - k) && node.getHeight() <= (treeHeight - 2)) {
201+
if (retain.get(node.getHeight()) == null) {
202+
ArrayDeque<XMSSNode> queue = new ArrayDeque<XMSSNode>();
203+
queue.add(node.clone());
204+
retain.put(node.getHeight(), queue);
205+
} else {
206+
retain.get(node.getHeight()).add(node.clone());
207+
}
208+
}
209+
hashTreeAddress.setTreeIndex((hashTreeAddress.getTreeIndex() - 1) / 2);
210+
node = xmss.randomizeHash(stack.pop(), node, hashTreeAddress);
211+
node.setHeight(node.getHeight() + 1);
212+
hashTreeAddress.setTreeHeight(hashTreeAddress.getTreeHeight() + 1);
213+
}
214+
/* push to stack */
215+
stack.push(node);
216+
}
217+
root = stack.pop();
218+
return root.clone();
219+
}
220+
221+
protected void nextAuthenticationPath(OTSHashAddress otsHashAddress) {
222+
if (otsHashAddress == null) {
223+
throw new NullPointerException("otsHashAddress == null");
224+
}
225+
if (index > ((1 << treeHeight) - 2)) {
226+
throw new IllegalStateException("index out of bounds");
227+
}
228+
/* prepare addresses */
229+
LTreeAddress lTreeAddress = new LTreeAddress();
230+
lTreeAddress.setLayerAddress(otsHashAddress.getLayerAddress());
231+
lTreeAddress.setTreeAddress(otsHashAddress.getTreeAddress());
232+
HashTreeAddress hashTreeAddress = new HashTreeAddress();
233+
hashTreeAddress.setLayerAddress(otsHashAddress.getLayerAddress());
234+
hashTreeAddress.setTreeAddress(otsHashAddress.getTreeAddress());
235+
236+
/* determine tau */
237+
int tau = XMSSUtil.calculateTau(index, treeHeight);
238+
239+
/* parent of leaf on height tau+1 is a left node */
240+
if (((index >> (tau + 1)) & 1) == 0 && (tau < (treeHeight - 1))) {
241+
keep.put(tau, authenticationPath.get(tau).clone());
242+
}
243+
/* leaf is a left node */
244+
if (tau == 0) {
245+
otsHashAddress.setOTSAddress(index);
246+
/* import WOTSPlusSecretKey as its needed to calculate the public key on the fly */
247+
wotsPlus.importKeys(xmss.getWOTSPlusSecretKey(otsHashAddress), xmss.getPublicSeed());
248+
WOTSPlusPublicKey wotsPlusPublicKey = wotsPlus.getPublicKey(otsHashAddress);
249+
lTreeAddress.setLTreeAddress(index);
250+
XMSSNode node = xmss.lTree(wotsPlusPublicKey, lTreeAddress);
251+
authenticationPath.set(0, node);
252+
} else {
253+
/* add new left node on height tau to authentication path */
254+
hashTreeAddress.setTreeHeight(tau - 1);
255+
hashTreeAddress.setTreeIndex(index >> tau);
256+
XMSSNode node = xmss.randomizeHash(authenticationPath.get(tau - 1), keep.get(tau - 1), hashTreeAddress);
257+
node.setHeight(node.getHeight() + 1);
258+
authenticationPath.set(tau, node);
259+
keep.remove(tau - 1);
260+
261+
/* add new right nodes to authentication path */
262+
for (int height = 0; height < tau; height++) {
263+
if (height < (treeHeight - k)) {
264+
authenticationPath.set(height, treeHashInstances.get(height).tailNode.clone());
265+
} else {
266+
authenticationPath.set(height, retain.get(height).pop());
267+
}
268+
}
269+
270+
/* reinitialize treehash instances */
271+
int minHeight = Math.min(tau, treeHeight - k);
272+
for (int height = 0; height < minHeight; height++) {
273+
int startIndex = index + 1 + (3 * (1 << height));
274+
if (startIndex < (1 << treeHeight)) {
275+
treeHashInstances.get(height).initialize(startIndex);
276+
}
277+
}
278+
}
279+
280+
/* update treehash instances */
281+
for (int i = 0; i < (treeHeight - k) >> 1; i++) {
282+
TreeHash treeHash = getTreeHashInstanceForUpdate();
283+
if (treeHash != null) {
284+
treeHash.update(otsHashAddress);
285+
}
286+
}
287+
index++;
288+
}
289+
290+
private TreeHash getTreeHashInstanceForUpdate() {
291+
TreeHash ret = null;
292+
for (TreeHash treeHash : treeHashInstances) {
293+
if (treeHash.isFinished() || !treeHash.isInitialized()) {
294+
continue;
295+
}
296+
if (ret == null) {
297+
ret = treeHash;
298+
continue;
299+
}
300+
if (treeHash.getHeight() < ret.getHeight()) {
301+
ret = treeHash;
302+
continue;
303+
}
304+
if (treeHash.getHeight() == ret.getHeight()) {
305+
if (treeHash.getIndexLeaf() < ret.getIndexLeaf()) {
306+
ret = treeHash;
307+
}
308+
}
309+
}
310+
return ret;
311+
}
312+
313+
protected void validate(boolean isStateForRootTree) {
314+
if (treeHeight != xmss.getParams().getHeight()) {
315+
throw new IllegalStateException("wrong height");
316+
}
317+
if (isStateForRootTree) {
318+
if (!XMSSUtil.compareByteArray(root.getValue(), xmss.getRoot())) {
319+
throw new IllegalStateException("root in BDS state does not match root of public / private key");
320+
}
321+
}
322+
if (authenticationPath == null) {
323+
throw new IllegalStateException("authenticationPath == null");
324+
}
325+
if (retain == null) {
326+
throw new IllegalStateException("retain == null");
327+
}
328+
if (stack == null) {
329+
throw new IllegalStateException("stack == null");
330+
}
331+
if (treeHashInstances == null) {
332+
throw new IllegalStateException("treeHashInstances == null");
333+
}
334+
if (keep == null) {
335+
throw new IllegalStateException("keep == null");
336+
}
337+
if (!XMSSUtil.isIndexValid(treeHeight, index)) {
338+
throw new IllegalStateException("index in BDS state out of bounds");
339+
}
340+
}
341+
342+
protected int getTreeHeight() {
343+
return treeHeight;
344+
}
345+
346+
protected XMSSNode getRoot() {
347+
return root.clone();
348+
}
349+
350+
protected List<XMSSNode> getAuthenticationPath() {
351+
List<XMSSNode> authenticationPath = new ArrayList<XMSSNode>();
352+
for (XMSSNode node : this.authenticationPath) {
353+
authenticationPath.add(node.clone());
354+
}
355+
return authenticationPath;
356+
}
357+
358+
protected void setXMSS(XMSS xmss) {
359+
this.xmss = xmss;
360+
this.wotsPlus = xmss.getWOTSPlus();
361+
}
362+
363+
protected int getIndex() {
364+
return index;
365+
}
366+
}

0 commit comments

Comments
 (0)