Skip to content

Commit 9bbb415

Browse files
authored
zstd: translate fseDecoder.buildDtable into asm (#598)
* zstd: translate fseDecoder.buildDtable into asm
1 parent 4bc73d3 commit 9bbb415

File tree

6 files changed

+626
-63
lines changed

6 files changed

+626
-63
lines changed

zstd/_generate/gen.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ func main() {
8080
decodeSync.generateProcedure("sequenceDecs_decodeSync_safe_amd64")
8181
decodeSync.setBMI2(true)
8282
decodeSync.generateProcedure("sequenceDecs_decodeSync_safe_bmi2")
83+
8384
Generate()
8485
}
8586

zstd/_generate/gen_fse.go

Lines changed: 362 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,362 @@
1+
package main
2+
3+
//go:generate go run gen_fse.go -out ../fse_decoder_amd64.s -pkg=zstd
4+
5+
import (
6+
"flag"
7+
8+
_ "github.com/klauspost/compress"
9+
10+
. "github.com/mmcloughlin/avo/build"
11+
"github.com/mmcloughlin/avo/buildtags"
12+
. "github.com/mmcloughlin/avo/operand"
13+
"github.com/mmcloughlin/avo/reg"
14+
)
15+
16+
func main() {
17+
flag.Parse()
18+
19+
Constraint(buildtags.Not("appengine").ToConstraint())
20+
Constraint(buildtags.Not("noasm").ToConstraint())
21+
Constraint(buildtags.Term("gc").ToConstraint())
22+
Constraint(buildtags.Not("noasm").ToConstraint())
23+
24+
buildDtable := buildDtable{}
25+
buildDtable.generateProcedure("buildDtable_asm")
26+
Generate()
27+
}
28+
29+
const (
30+
errorCorruptedNormalizedCounter = 1
31+
errorNewStateTooBig = 2
32+
errorNewStateNoBits = 3
33+
)
34+
35+
type buildDtable struct {
36+
bmi2 bool
37+
38+
// values used across all methods
39+
actualTableLog reg.GPVirtual
40+
tableSize reg.GPVirtual
41+
highThreshold reg.GPVirtual
42+
symbolNext reg.GPVirtual // array []uint16
43+
dt reg.GPVirtual // array []uint64
44+
}
45+
46+
func (b *buildDtable) generateProcedure(name string) {
47+
Package("github.com/klauspost/compress/zstd")
48+
TEXT(name, 0, "func (s *fseDecoder, ctx *buildDtableAsmContext ) int")
49+
Doc(name+" implements fseDecoder.buildDtable in asm", "")
50+
Pragma("noescape")
51+
52+
ctx := Dereference(Param("ctx"))
53+
s := Dereference(Param("s"))
54+
55+
Comment("Load values")
56+
{
57+
// tableSize = (1 << s.actualTableLog)
58+
b.tableSize = GP64()
59+
b.actualTableLog = GP64()
60+
Load(s.Field("actualTableLog"), b.actualTableLog)
61+
XORQ(b.tableSize, b.tableSize)
62+
BTSQ(b.actualTableLog, b.tableSize)
63+
64+
// symbolNext = &s.stateTable[0]
65+
b.symbolNext = GP64()
66+
Load(ctx.Field("stateTable"), b.symbolNext)
67+
68+
// dt = &s.dt[0]
69+
b.dt = GP64()
70+
Load(ctx.Field("dt"), b.dt)
71+
72+
// highThreshold = tableSize - 1
73+
b.highThreshold = GP64()
74+
LEAQ(Mem{Base: b.tableSize, Disp: -1}, b.highThreshold)
75+
}
76+
77+
norm := GP64()
78+
Load(ctx.Field("norm"), norm)
79+
80+
symbolLen := GP64()
81+
Load(s.Field("symbolLen"), symbolLen)
82+
Comment("End load values")
83+
84+
b.init(norm, symbolLen)
85+
b.spread(norm, symbolLen)
86+
b.buildTable()
87+
88+
b.returnCode(0)
89+
}
90+
91+
func (b *buildDtable) init(norm, symbolLen reg.GPVirtual) {
92+
Comment("Init, lay down lowprob symbols")
93+
/*
94+
for i, v := range s.norm[:s.symbolLen] {
95+
if v == -1 {
96+
s.dt[highThreshold].setAddBits(uint8(i))
97+
highThreshold--
98+
symbolNext[i] = 1
99+
} else {
100+
symbolNext[i] = uint16(v)
101+
}
102+
}
103+
*/
104+
105+
i := New64()
106+
JMP(LabelRef("init_main_loop_condition"))
107+
Label("init_main_loop")
108+
109+
v := GP64()
110+
MOVWQSX(Mem{Base: norm, Index: i, Scale: 2}, v)
111+
112+
CMPW(v.As16(), I16(-1))
113+
JNE(LabelRef("do_not_update_high_threshold"))
114+
115+
{
116+
// s.dt[highThreshold].setAddBits(uint8(i))
117+
MOVB(i.As8(), Mem{Base: b.dt, Index: b.highThreshold, Scale: 8, Disp: 1}) // set highThreshold*8 + 1 byte
118+
// highThreshold--
119+
DECQ(b.highThreshold)
120+
121+
// symbolNext[i] = 1
122+
MOVQ(U64(1), v)
123+
}
124+
125+
Label("do_not_update_high_threshold")
126+
{
127+
// symbolNext[i] = uint16(v)
128+
MOVW(v.As16(), Mem{Base: b.symbolNext, Index: i, Scale: 2})
129+
130+
INCQ(i)
131+
Label("init_main_loop_condition")
132+
CMPQ(i, symbolLen)
133+
JL(LabelRef("init_main_loop"))
134+
}
135+
136+
Label("init_end")
137+
}
138+
139+
func (b *buildDtable) spread(norm, symbolLen reg.GPVirtual) {
140+
Comment("Spread symbols")
141+
/*
142+
tableMask := tableSize - 1
143+
step := tableStep(tableSize)
144+
position := uint32(0)
145+
for ss, v := range s.norm[:s.symbolLen] {
146+
for i := 0; i < int(v); i++ {
147+
s.dt[position].setAddBits(uint8(ss))
148+
position = (position + step) & tableMask
149+
for position > highThreshold {
150+
// lowprob area
151+
position = (position + step) & tableMask
152+
}
153+
}
154+
}
155+
*/
156+
step := GP64()
157+
Comment("Calculate table step")
158+
{
159+
// tmp1 = tableSize >> 1
160+
tmp1 := Copy64(b.tableSize)
161+
SHRQ(U8(1), tmp1)
162+
163+
// tmp3 = tableSize >> 3
164+
tmp3 := Copy64(b.tableSize)
165+
SHRQ(U8(3), tmp3)
166+
167+
// step = tmp1 + tmp3 + 3
168+
LEAQ(Mem{Base: tmp1, Index: tmp3, Scale: 1, Disp: 3}, step)
169+
}
170+
171+
Comment("Fill add bits values")
172+
173+
// tableMask = tableSize - 1 (tableSize is a pow of 2)
174+
tableMask := GP64()
175+
LEAQ(Mem{Base: b.tableSize, Disp: -1}, tableMask)
176+
177+
// position := 0
178+
position := New64()
179+
180+
// ss := 0
181+
ss := New64()
182+
JMP(LabelRef("spread_main_loop_condition"))
183+
Label("spread_main_loop")
184+
{
185+
i := New64()
186+
v := GP64()
187+
MOVWQSX(Mem{Base: norm, Index: ss, Scale: 2}, v)
188+
JMP(LabelRef("spread_inner_loop_condition"))
189+
Label("spread_inner_loop")
190+
191+
{
192+
// s.dt[position].setAddBits(uint8(ss))
193+
MOVB(ss.As8(), Mem{Base: b.dt, Index: position, Scale: 8, Disp: 1})
194+
195+
Label("adjust_position")
196+
// position = (position + step) & tableMask
197+
ADDQ(step, position)
198+
ANDQ(tableMask, position)
199+
200+
// for position > highThreshold {
201+
// // lowprob area
202+
// position = (position + step) & tableMask
203+
// }
204+
CMPQ(position, b.highThreshold)
205+
JG(LabelRef("adjust_position"))
206+
}
207+
INCQ(i)
208+
Label("spread_inner_loop_condition")
209+
CMPQ(i, v)
210+
JL(LabelRef("spread_inner_loop"))
211+
}
212+
213+
INCQ(ss)
214+
Label("spread_main_loop_condition")
215+
CMPQ(ss, symbolLen)
216+
JL(LabelRef("spread_main_loop"))
217+
218+
/*
219+
if position != 0 {
220+
// position must reach all cells once, otherwise normalizedCounter is incorrect
221+
return errors.New("corrupted input (position != 0)")
222+
}
223+
*/
224+
TESTQ(position, position)
225+
{
226+
JZ(LabelRef("spread_check_ok"))
227+
b.returnError(errorCorruptedNormalizedCounter, position)
228+
}
229+
Label("spread_check_ok")
230+
}
231+
232+
func (b *buildDtable) buildTable() {
233+
Comment("Build Decoding table")
234+
/*
235+
tableSize := uint16(1 << s.actualTableLog)
236+
for u, v := range s.dt[:tableSize] {
237+
symbol := v.addBits()
238+
nextState := symbolNext[symbol]
239+
symbolNext[symbol] = nextState + 1
240+
nBits := s.actualTableLog - byte(highBits(uint32(nextState)))
241+
s.dt[u&maxTableMask].setNBits(nBits)
242+
newState := (nextState << nBits) - tableSize
243+
if newState > tableSize {
244+
return fmt.Errorf("newState (%d) outside table size (%d)", newState, tableSize)
245+
}
246+
if newState == uint16(u) && nBits == 0 {
247+
// Seems weird that this is possible with nbits > 0.
248+
return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, u)
249+
}
250+
s.dt[u&maxTableMask].setNewState(newState)
251+
}
252+
*/
253+
u := New64()
254+
Label("build_table_main_table")
255+
{
256+
// v := s.dt[u]
257+
// symbol := v.addBits()
258+
symbol := GP64()
259+
MOVBQZX(Mem{Base: b.dt, Index: u, Scale: 8, Disp: 1}, symbol)
260+
261+
// nextState := symbolNext[symbol]
262+
nextState := GP64()
263+
ptr := Mem{Base: b.symbolNext, Index: symbol, Scale: 2}
264+
MOVWQZX(ptr, nextState)
265+
266+
// symbolNext[symbol] = nextState + 1
267+
{
268+
tmp := GP64()
269+
LEAQ(Mem{Base: nextState, Disp: 1}, tmp)
270+
MOVW(tmp.As16(), ptr)
271+
}
272+
273+
// nBits := s.actualTableLog - byte(highBits(uint32(nextState)))
274+
nBits := reg.RCX // As we use nBits to shift
275+
{
276+
highBits := Copy64(nextState)
277+
BSRQ(highBits, highBits)
278+
279+
MOVQ(b.actualTableLog, nBits)
280+
SUBQ(highBits, nBits)
281+
}
282+
283+
// newState := (nextState << nBits) - tableSize
284+
newState := Copy64(nextState)
285+
SHLQ(reg.CL, newState)
286+
SUBQ(b.tableSize, newState)
287+
288+
// s.dt[u&maxTableMask].setNBits(nBits) // sets byte #0
289+
// s.dt[u&maxTableMask].setNewState(newState) // sets word #1 (bytes #2 & #3)
290+
{
291+
MOVB(nBits.As8(), Mem{Base: b.dt, Index: u, Scale: 8})
292+
MOVW(newState.As16(), Mem{Base: b.dt, Index: u, Scale: 8, Disp: 2})
293+
}
294+
295+
// if newState > tableSize {
296+
// return fmt.Errorf("newState (%d) outside table size (%d)", newState, tableSize)
297+
// }
298+
{
299+
CMPQ(newState, b.tableSize)
300+
JLE(LabelRef("build_table_check1_ok"))
301+
302+
b.returnError(errorNewStateTooBig, newState, b.tableSize)
303+
Label("build_table_check1_ok")
304+
}
305+
306+
// if newState == uint16(u) && nBits == 0 {
307+
// // Seems weird that this is possible with nbits > 0.
308+
// return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, u)
309+
// }
310+
{
311+
TESTB(nBits.As8(), nBits.As8())
312+
JNZ(LabelRef("build_table_check2_ok"))
313+
CMPW(newState.As16(), u.As16())
314+
JNE(LabelRef("build_table_check2_ok"))
315+
316+
b.returnError(errorNewStateNoBits, newState, u)
317+
Label("build_table_check2_ok")
318+
}
319+
}
320+
INCQ(u)
321+
CMPQ(u, b.tableSize)
322+
JL(LabelRef("build_table_main_table"))
323+
}
324+
325+
// returnCode sets function result and terminates the function.
326+
func (b *buildDtable) returnCode(code int) {
327+
a, err := ReturnIndex(0).Resolve()
328+
if err != nil {
329+
panic(err)
330+
}
331+
MOVQ(I32(code), a.Addr)
332+
RET()
333+
}
334+
335+
// returnError sets error params and terminates function with given exit code.
336+
func (b *buildDtable) returnError(code int, args ...reg.GPVirtual) {
337+
ctx := Dereference(Param("ctx"))
338+
339+
if len(args) > 0 {
340+
Store(args[0], ctx.Field("errParam1"))
341+
}
342+
343+
if len(args) > 1 {
344+
Store(args[1], ctx.Field("errParam2"))
345+
}
346+
347+
b.returnCode(code)
348+
}
349+
350+
func New64() reg.GPVirtual {
351+
cnt := GP64()
352+
XORQ(cnt, cnt)
353+
354+
return cnt
355+
}
356+
357+
func Copy64(val reg.GPVirtual) reg.GPVirtual {
358+
tmp := GP64()
359+
MOVQ(val, tmp)
360+
361+
return tmp
362+
}

0 commit comments

Comments
 (0)