Skip to content

Fix policy softmax accuracy if masking is enabled. #912

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 3, 2019
Merged

Fix policy softmax accuracy if masking is enabled. #912

merged 11 commits into from
Aug 3, 2019

Conversation

ddobbelaere
Copy link
Contributor

@ddobbelaere ddobbelaere commented Jul 30, 2019

Perform policy softmax outside backends on set of legal moves.
This should fix the limited accuracy issues observed with CUDA backend FP16 in combination with policy masking enabled in training.

Blas backend has been tested.
Testing of CUDA and OpenCL backend much appreciated!

@ddobbelaere ddobbelaere changed the title Fix softmax accuracy Fix policy softmax accuracy if masking is enabled. Jul 30, 2019
@jkormu
Copy link

jkormu commented Jul 30, 2019

CUDA fp16 looks now much more sane. Tested with net 60020 that gives completely flat policy for version 0.23.1 on startpos with fp16. This pr seems to fix it and policy for fp32 and fp16 are almost identical as can seen below. Although 1000 node search from startpos does not give same best move. Maybe this is just expected side effect of very flat policies of 60020.

Used args: ./lc0 --smart-pruning-factor=0.0 --threads=1 --minibatch-size=1 --max-prefetch=32 --cpuct-base=19652.00 --cpuct=2.5 --fpu-strategy=absolute --fpu-value=-1.0 --fpu-strategy-at-root=absolute --fpu-value-at-root=1.0 --cpuct-factor=0 --no-out-of-order-eval --max-collision-events=1 --max-collision-visits=1 --verbose-move-stats

search results for go nodes 1000:

#pr912 fp16 go nodes 100
info string b1a3  (34  ) N:      22 (+ 0) (P:  5.09%) (Q: -0.05063) (D:  0.033) (U: 0.17483) (Q+U:  0.12420) (V: -0.0149) 
info string b2b3  (230 ) N:      22 (+ 0) (P:  5.05%) (Q: -0.03713) (D:  0.039) (U: 0.17362) (Q+U:  0.13649) (V: -0.0154) 
info string g1f3  (159 ) N:      24 (+ 0) (P:  4.94%) (Q: -0.01873) (D:  0.034) (U: 0.15607) (Q+U:  0.13734) (V: -0.0149) 
info string g1h3  (161 ) N:      26 (+ 0) (P:  4.94%) (Q: -0.01153) (D:  0.039) (U: 0.14464) (Q+U:  0.13311) (V: -0.0151) 
info string e2e3  (317 ) N:      26 (+ 0) (P:  5.01%) (Q: -0.01270) (D:  0.042) (U: 0.14665) (Q+U:  0.13395) (V: -0.0151) 
info string g2g3  (374 ) N:      26 (+ 0) (P:  5.05%) (Q: -0.00938) (D:  0.041) (U: 0.14777) (Q+U:  0.13839) (V: -0.0146) 
info string d2d3  (288 ) N:      26 (+ 0) (P:  4.72%) (Q:  0.00162) (D:  0.036) (U: 0.13812) (Q+U:  0.13975) (V: -0.0151) 
info string a2a4  (207 ) N:      26 (+ 0) (P:  5.00%) (Q: -0.00628) (D:  0.038) (U: 0.14634) (Q+U:  0.14006) (V: -0.0151) 
info string e2e4  (322 ) N:      28 (+ 0) (P:  4.91%) (Q:  0.00632) (D:  0.031) (U: 0.13383) (Q+U:  0.14016) (V: -0.0154) 
info string c2c4  (264 ) N:      30 (+ 0) (P:  5.01%) (Q:  0.00880) (D:  0.033) (U: 0.12773) (Q+U:  0.13652) (V: -0.0156) 
info string a2a3  (204 ) N:      31 (+ 0) (P:  5.34%) (Q:  0.00516) (D:  0.035) (U: 0.13184) (Q+U:  0.13700) (V: -0.0149) 
info string g2g4  (378 ) N:      51 (+ 0) (P:  4.84%) (Q:  0.06247) (D:  0.031) (U: 0.07348) (Q+U:  0.13595) (V: -0.0151) 
info string h2h3  (400 ) N:      52 (+ 0) (P:  5.00%) (Q:  0.06147) (D:  0.033) (U: 0.07455) (Q+U:  0.13602) (V: -0.0151) 
info string f2f3  (346 ) N:      70 (+ 0) (P:  4.86%) (Q:  0.08452) (D:  0.039) (U: 0.05409) (Q+U:  0.13861) (V: -0.0154) 
info string c2c3  (259 ) N:      72 (+ 0) (P:  5.21%) (Q:  0.08131) (D:  0.039) (U: 0.05644) (Q+U:  0.13774) (V: -0.0154) 
info string b2b4  (234 ) N:      72 (+ 0) (P:  5.15%) (Q:  0.08237) (D:  0.039) (U: 0.05576) (Q+U:  0.13813) (V: -0.0156) 
info string f2f4  (351 ) N:      78 (+ 0) (P:  4.87%) (Q:  0.09339) (D:  0.036) (U: 0.04867) (Q+U:  0.14207) (V: -0.0151) 
info string b1c3  (36  ) N:      93 (+ 0) (P:  5.08%) (Q:  0.09620) (D:  0.036) (U: 0.04271) (Q+U:  0.13891) (V: -0.0146) 
info string h2h4  (403 ) N:     108 (+ 0) (P:  5.17%) (Q:  0.10161) (D:  0.035) (U: 0.03748) (Q+U:  0.13909) (V: -0.0149) 
info string d2d4  (293 ) N:     116 (+ 0) (P:  4.76%) (Q:  0.10713) (D:  0.038) (U: 0.03213) (Q+U:  0.13926) (V: -0.0154)


#pr912 fp32 go nodes 100
info string b1a3  (34  ) N:      22 (+ 0) (P:  5.09%) (Q: -0.05045) (D:  0.033) (U: 0.17483) (Q+U:  0.12438) (V: -0.0150) 
info string b2b3  (230 ) N:      23 (+ 0) (P:  5.05%) (Q: -0.03849) (D:  0.039) (U: 0.16634) (Q+U:  0.12785) (V: -0.0156) 
info string g1f3  (159 ) N:      26 (+ 0) (P:  4.94%) (Q: -0.02137) (D:  0.033) (U: 0.14446) (Q+U:  0.12309) (V: -0.0152) 
info string e2e3  (317 ) N:      27 (+ 0) (P:  5.01%) (Q: -0.02343) (D:  0.042) (U: 0.14141) (Q+U:  0.11798) (V: -0.0153) 
info string g1h3  (161 ) N:      27 (+ 0) (P:  4.94%) (Q: -0.01351) (D:  0.039) (U: 0.13947) (Q+U:  0.12596) (V: -0.0150) 
info string d2d3  (288 ) N:      27 (+ 0) (P:  4.72%) (Q: -0.00580) (D:  0.036) (U: 0.13319) (Q+U:  0.12739) (V: -0.0152) 
info string a2a4  (207 ) N:      28 (+ 0) (P:  5.00%) (Q: -0.00998) (D:  0.039) (U: 0.13629) (Q+U:  0.12631) (V: -0.0156) 
info string e2e4  (322 ) N:      30 (+ 0) (P:  4.91%) (Q:  0.00105) (D:  0.031) (U: 0.12520) (Q+U:  0.12625) (V: -0.0153) 
info string c2c4  (264 ) N:      31 (+ 0) (P:  5.01%) (Q:  0.00115) (D:  0.032) (U: 0.12374) (Q+U:  0.12489) (V: -0.0156) 
info string g2g3  (374 ) N:      32 (+ 0) (P:  5.05%) (Q: -0.01942) (D:  0.044) (U: 0.12090) (Q+U:  0.10148) (V: -0.0148) 
info string a2a3  (204 ) N:      33 (+ 0) (P:  5.34%) (Q:  0.00308) (D:  0.035) (U: 0.12405) (Q+U:  0.12713) (V: -0.0151) 
info string h2h3  (400 ) N:      34 (+ 0) (P:  5.00%) (Q:  0.00854) (D:  0.032) (U: 0.11292) (Q+U:  0.12147) (V: -0.0155) 
info string g2g4  (378 ) N:      54 (+ 0) (P:  4.83%) (Q:  0.05652) (D:  0.031) (U: 0.06945) (Q+U:  0.12597) (V: -0.0153) 
info string d2d4  (293 ) N:      67 (+ 0) (P:  4.76%) (Q:  0.06972) (D:  0.036) (U: 0.05530) (Q+U:  0.12502) (V: -0.0155) 
info string f2f3  (346 ) N:      76 (+ 0) (P:  4.86%) (Q:  0.07781) (D:  0.039) (U: 0.04989) (Q+U:  0.12769) (V: -0.0155) 
info string c2c3  (259 ) N:      78 (+ 0) (P:  5.21%) (Q:  0.07531) (D:  0.039) (U: 0.05214) (Q+U:  0.12745) (V: -0.0155) 
info string b2b4  (234 ) N:      78 (+ 0) (P:  5.15%) (Q:  0.07649) (D:  0.039) (U: 0.05149) (Q+U:  0.12799) (V: -0.0154) 
info string f2f4  (351 ) N:      87 (+ 0) (P:  4.87%) (Q:  0.08479) (D:  0.036) (U: 0.04371) (Q+U:  0.12850) (V: -0.0151) 
info string b1c3  (36  ) N:     101 (+ 0) (P:  5.08%) (Q:  0.08910) (D:  0.037) (U: 0.03937) (Q+U:  0.12848) (V: -0.0148) 
info string h2h4  (403 ) N:     118 (+ 0) (P:  5.17%) (Q:  0.09425) (D:  0.035) (U: 0.03434) (Q+U:  0.12859) (V: -0.0150)


#0.23.1 fp16 go nodes 1000
info string e2e4  (322 ) N:      22 (+ 0) (P:  5.00%) (Q: -0.06230) (D:  0.041) (U: 0.17179) (Q+U:  0.10949) (V: -0.0154) 
info string b2b3  (230 ) N:      23 (+ 0) (P:  5.00%) (Q: -0.04130) (D:  0.038) (U: 0.16463) (Q+U:  0.12333) (V: -0.0154) 
info string f2f3  (346 ) N:      23 (+ 0) (P:  5.00%) (Q: -0.04068) (D:  0.037) (U: 0.16463) (Q+U:  0.12395) (V: -0.0154) 
info string c2c3  (259 ) N:      23 (+ 0) (P:  5.00%) (Q: -0.04032) (D:  0.037) (U: 0.16463) (Q+U:  0.12431) (V: -0.0154) 
info string a2a4  (207 ) N:      26 (+ 0) (P:  5.00%) (Q: -0.02268) (D:  0.038) (U: 0.14634) (Q+U:  0.12366) (V: -0.0151) 
info string g1h3  (161 ) N:      27 (+ 0) (P:  5.00%) (Q: -0.02153) (D:  0.040) (U: 0.14111) (Q+U:  0.11958) (V: -0.0151) 
info string b1c3  (36  ) N:      28 (+ 0) (P:  5.00%) (Q: -0.01823) (D:  0.033) (U: 0.13625) (Q+U:  0.11801) (V: -0.0146) 
info string b1a3  (34  ) N:      28 (+ 0) (P:  5.00%) (Q: -0.01730) (D:  0.038) (U: 0.13625) (Q+U:  0.11895) (V: -0.0149) 
info string e2e3  (317 ) N:      30 (+ 0) (P:  5.00%) (Q: -0.01437) (D:  0.035) (U: 0.12746) (Q+U:  0.11308) (V: -0.0151) 
info string g1f3  (159 ) N:      31 (+ 0) (P:  5.00%) (Q: -0.00395) (D:  0.040) (U: 0.12347) (Q+U:  0.11952) (V: -0.0149) 
info string g2g4  (378 ) N:      38 (+ 0) (P:  5.00%) (Q:  0.01774) (D:  0.040) (U: 0.10131) (Q+U:  0.11905) (V: -0.0151) 
info string c2c4  (264 ) N:      51 (+ 0) (P:  5.00%) (Q:  0.04886) (D:  0.033) (U: 0.07598) (Q+U:  0.12484) (V: -0.0156) 
info string b2b4  (234 ) N:      65 (+ 0) (P:  5.00%) (Q:  0.06484) (D:  0.038) (U: 0.05987) (Q+U:  0.12471) (V: -0.0156) 
info string d2d4  (293 ) N:      66 (+ 0) (P:  5.00%) (Q:  0.06121) (D:  0.045) (U: 0.05897) (Q+U:  0.12018) (V: -0.0154) 
info string h2h3  (400 ) N:      66 (+ 0) (P:  5.00%) (Q:  0.06347) (D:  0.035) (U: 0.05897) (Q+U:  0.12244) (V: -0.0151) 
info string f2f4  (351 ) N:      66 (+ 0) (P:  5.00%) (Q:  0.06466) (D:  0.038) (U: 0.05897) (Q+U:  0.12363) (V: -0.0151) 
info string a2a3  (204 ) N:      67 (+ 0) (P:  5.00%) (Q:  0.06612) (D:  0.038) (U: 0.05810) (Q+U:  0.12422) (V: -0.0149) 
info string d2d3  (288 ) N:      89 (+ 0) (P:  5.00%) (Q:  0.08015) (D:  0.047) (U: 0.04390) (Q+U:  0.12405) (V: -0.0151) 
info string g2g3  (374 ) N:      98 (+ 0) (P:  5.00%) (Q:  0.08423) (D:  0.036) (U: 0.03991) (Q+U:  0.12414) (V: -0.0146) 
info string h2h4  (403 ) N:     132 (+ 0) (P:  5.00%) (Q:  0.09431) (D:  0.035) (U: 0.02971) (Q+U:  0.12402) (V: -0.0149)


#0.23.1 fp32 go nodes 1000
info string b1a3  (34  ) N:      22 (+ 0) (P:  5.09%) (Q: -0.05045) (D:  0.033) (U: 0.17493) (Q+U:  0.12448) (V: -0.0150) 
info string b2b3  (230 ) N:      23 (+ 0) (P:  5.05%) (Q: -0.03849) (D:  0.039) (U: 0.16634) (Q+U:  0.12785) (V: -0.0156) 
info string g1f3  (159 ) N:      26 (+ 0) (P:  4.93%) (Q: -0.02137) (D:  0.033) (U: 0.14424) (Q+U:  0.12287) (V: -0.0152) 
info string e2e3  (317 ) N:      27 (+ 0) (P:  5.01%) (Q: -0.02343) (D:  0.042) (U: 0.14133) (Q+U:  0.11790) (V: -0.0153) 
info string g1h3  (161 ) N:      27 (+ 0) (P:  4.94%) (Q: -0.01351) (D:  0.039) (U: 0.13935) (Q+U:  0.12584) (V: -0.0150) 
info string d2d3  (288 ) N:      27 (+ 0) (P:  4.71%) (Q: -0.00580) (D:  0.036) (U: 0.13297) (Q+U:  0.12717) (V: -0.0152) 
info string a2a4  (207 ) N:      28 (+ 0) (P:  5.00%) (Q: -0.00998) (D:  0.039) (U: 0.13616) (Q+U:  0.12618) (V: -0.0156) 
info string e2e4  (322 ) N:      30 (+ 0) (P:  4.90%) (Q:  0.00105) (D:  0.031) (U: 0.12500) (Q+U:  0.12606) (V: -0.0153) 
info string c2c4  (264 ) N:      31 (+ 0) (P:  5.01%) (Q:  0.00115) (D:  0.032) (U: 0.12370) (Q+U:  0.12485) (V: -0.0156) 
info string g2g3  (374 ) N:      32 (+ 0) (P:  5.05%) (Q: -0.01942) (D:  0.044) (U: 0.12090) (Q+U:  0.10148) (V: -0.0148) 
info string h2h3  (400 ) N:      34 (+ 0) (P:  5.00%) (Q:  0.00854) (D:  0.032) (U: 0.11285) (Q+U:  0.12140) (V: -0.0155) 
info string a2a3  (204 ) N:      34 (+ 0) (P:  5.37%) (Q:  0.00322) (D:  0.035) (U: 0.12126) (Q+U:  0.12448) (V: -0.0151) 
info string g2g4  (378 ) N:      53 (+ 0) (P:  4.83%) (Q:  0.05807) (D:  0.031) (U: 0.07062) (Q+U:  0.12870) (V: -0.0153) 
info string d2d4  (293 ) N:      67 (+ 0) (P:  4.75%) (Q:  0.06972) (D:  0.036) (U: 0.05520) (Q+U:  0.12492) (V: -0.0155) 
info string f2f3  (346 ) N:      76 (+ 0) (P:  4.86%) (Q:  0.07781) (D:  0.039) (U: 0.04983) (Q+U:  0.12763) (V: -0.0155) 
info string c2c3  (259 ) N:      78 (+ 0) (P:  5.23%) (Q:  0.07529) (D:  0.039) (U: 0.05229) (Q+U:  0.12758) (V: -0.0155) 
info string b2b4  (234 ) N:      78 (+ 0) (P:  5.16%) (Q:  0.07646) (D:  0.039) (U: 0.05159) (Q+U:  0.12805) (V: -0.0154) 
info string f2f4  (351 ) N:      87 (+ 0) (P:  4.86%) (Q:  0.08479) (D:  0.036) (U: 0.04364) (Q+U:  0.12843) (V: -0.0151) 
info string b1c3  (36  ) N:     101 (+ 0) (P:  5.08%) (Q:  0.08910) (D:  0.037) (U: 0.03937) (Q+U:  0.12848) (V: -0.0148) 
info string h2h4  (403 ) N:     118 (+ 0) (P:  5.18%) (Q:  0.09425) (D:  0.035) (U: 0.03440) (Q+U:  0.12865) (V: -0.0150)

@oscardssmith
Copy link
Contributor

If it works for cuda 16, and regular, I'd say merge now and deal with potential opencl problems later. This will let us train with masking which has given us nice gains in sl

@mooskagh
Copy link
Member

Could you take some small network (e.g. 10b) and check on fast RTX GPU that it didn't become slower? Probably it didn't but worth checking.

@ddobbelaere
Copy link
Contributor Author

ddobbelaere commented Jul 31, 2019

FWIW, a quick benchmark (best of three) with the random backend gives me 240562 nps for master and 230368 nps for this PR.

Note that for practical backends, slightly less time needs to be spent on NN eval. So far, I had no big success in using approximate exp(.) (from an accuracy standpoint of view, it does go faster).

@ddobbelaere
Copy link
Contributor Author

ddobbelaere commented Jul 31, 2019

Ok, now it works better.

With the FastExp calculations (which are in line with @borg323's fast approximation for softmax policy temperature corrections), results are better now, best of three (under current laptop load conditions) with random backend gives 238537 nps (master), 231485 nps (with approx.), 228647 (without approx.).

The slight loss in accuracy is marginal:

# master
go nodes 1
Loading weights file from: networks/36092_70be354d727332e6fc992a4f01eaf8831a85e7ab4ad6fa695980a860722db2b6
Creating backend [blas]...
BLAS vendor: OpenBLAS.
OpenBLAS [OpenBLAS 0.3.6 DYNAMIC_ARCH NO_AFFINITY Haswell SINGLE_THREADED].
OpenBLAS found 1 Haswell core(s).
OpenBLAS using 1 core(s) for this backend.
BLAS max batch size is 256.
info depth 1 seldepth 1 time 13 nodes 1 score cp 31 hashfull 0 nps 76 tbhits 0 pv e2e4
info string g2g4  (378 ) N:       0 (+ 0) (P:  1.80%) (Q:  0.10555) (D:  0.000) (U: 0.05388) (Q+U:  0.15943) (V:  -.----) 
info string f2f3  (346 ) N:       0 (+ 0) (P:  1.88%) (Q:  0.10555) (D:  0.000) (U: 0.05651) (Q+U:  0.16207) (V:  -.----) 
info string g1h3  (161 ) N:       0 (+ 0) (P:  2.20%) (Q:  0.10555) (D:  0.000) (U: 0.06599) (Q+U:  0.17154) (V:  -.----) 
info string f2f4  (351 ) N:       0 (+ 0) (P:  2.41%) (Q:  0.10555) (D:  0.000) (U: 0.07219) (Q+U:  0.17774) (V:  -.----) 
info string b1a3  (34  ) N:       0 (+ 0) (P:  2.42%) (Q:  0.10555) (D:  0.000) (U: 0.07258) (Q+U:  0.17813) (V:  -.----) 
info string h2h4  (403 ) N:       0 (+ 0) (P:  2.50%) (Q:  0.10555) (D:  0.000) (U: 0.07487) (Q+U:  0.18042) (V:  -.----) 
info string b2b4  (234 ) N:       0 (+ 0) (P:  2.57%) (Q:  0.10555) (D:  0.000) (U: 0.07711) (Q+U:  0.18267) (V:  -.----) 
info string a2a4  (207 ) N:       0 (+ 0) (P:  2.78%) (Q:  0.10555) (D:  0.000) (U: 0.08336) (Q+U:  0.18891) (V:  -.----) 
info string d2d3  (288 ) N:       0 (+ 0) (P:  3.23%) (Q:  0.10555) (D:  0.000) (U: 0.09687) (Q+U:  0.20242) (V:  -.----) 
info string h2h3  (400 ) N:       0 (+ 0) (P:  3.23%) (Q:  0.10555) (D:  0.000) (U: 0.09696) (Q+U:  0.20251) (V:  -.----) 
info string b2b3  (230 ) N:       0 (+ 0) (P:  3.46%) (Q:  0.10555) (D:  0.000) (U: 0.10387) (Q+U:  0.20942) (V:  -.----) 
info string b1c3  (36  ) N:       0 (+ 0) (P:  3.51%) (Q:  0.10555) (D:  0.000) (U: 0.10529) (Q+U:  0.21084) (V:  -.----) 
info string a2a3  (204 ) N:       0 (+ 0) (P:  3.55%) (Q:  0.10555) (D:  0.000) (U: 0.10653) (Q+U:  0.21208) (V:  -.----) 
info string c2c3  (259 ) N:       0 (+ 0) (P:  3.88%) (Q:  0.10555) (D:  0.000) (U: 0.11628) (Q+U:  0.22183) (V:  -.----) 
info string e2e3  (317 ) N:       0 (+ 0) (P:  4.20%) (Q:  0.10555) (D:  0.000) (U: 0.12603) (Q+U:  0.23158) (V:  -.----) 
info string g2g3  (374 ) N:       0 (+ 0) (P:  4.43%) (Q:  0.10555) (D:  0.000) (U: 0.13294) (Q+U:  0.23849) (V:  -.----) 
info string c2c4  (264 ) N:       0 (+ 0) (P:  4.45%) (Q:  0.10555) (D:  0.000) (U: 0.13358) (Q+U:  0.23913) (V:  -.----) 
info string g1f3  (159 ) N:       0 (+ 0) (P:  7.75%) (Q:  0.10555) (D:  0.000) (U: 0.23264) (Q+U:  0.33820) (V:  -.----) 
info string d2d4  (293 ) N:       0 (+ 0) (P:  8.93%) (Q:  0.10555) (D:  0.000) (U: 0.26798) (Q+U:  0.37354) (V:  -.----) 
info string e2e4  (322 ) N:       0 (+ 0) (P: 30.82%) (Q:  0.10555) (D:  0.000) (U: 0.92472) (Q+U:  1.03027) (V:  -.----) 
bestmove e2e4

# this PR with FastExp
go nodes 1
Loading weights file from: networks/36092_70be354d727332e6fc992a4f01eaf8831a85e7ab4ad6fa695980a860722db2b6
Creating backend [blas]...
BLAS vendor: OpenBLAS.
OpenBLAS [OpenBLAS 0.3.6 DYNAMIC_ARCH NO_AFFINITY Haswell SINGLE_THREADED].
OpenBLAS found 1 Haswell core(s).
OpenBLAS using 1 core(s) for this backend.
BLAS max batch size is 256.
info depth 1 seldepth 1 time 13 nodes 1 score cp 31 hashfull 0 nps 76 tbhits 0 pv e2e4
info string g2g4  (378 ) N:       0 (+ 0) (P:  1.80%) (Q:  0.10555) (D:  0.000) (U: 0.05393) (Q+U:  0.15948) (V:  -.----) 
info string f2f3  (346 ) N:       0 (+ 0) (P:  1.87%) (Q:  0.10555) (D:  0.000) (U: 0.05622) (Q+U:  0.16177) (V:  -.----) 
info string g1h3  (161 ) N:       0 (+ 0) (P:  2.21%) (Q:  0.10555) (D:  0.000) (U: 0.06629) (Q+U:  0.17184) (V:  -.----) 
info string f2f4  (351 ) N:       0 (+ 0) (P:  2.42%) (Q:  0.10555) (D:  0.000) (U: 0.07247) (Q+U:  0.17802) (V:  -.----) 
info string b1a3  (34  ) N:       0 (+ 0) (P:  2.43%) (Q:  0.10555) (D:  0.000) (U: 0.07281) (Q+U:  0.17836) (V:  -.----) 
info string h2h4  (403 ) N:       0 (+ 0) (P:  2.49%) (Q:  0.10555) (D:  0.000) (U: 0.07485) (Q+U:  0.18040) (V:  -.----) 
info string b2b4  (234 ) N:       0 (+ 0) (P:  2.56%) (Q:  0.10555) (D:  0.000) (U: 0.07693) (Q+U:  0.18248) (V:  -.----) 
info string a2a4  (207 ) N:       0 (+ 0) (P:  2.78%) (Q:  0.10555) (D:  0.000) (U: 0.08354) (Q+U:  0.18910) (V:  -.----) 
info string d2d3  (288 ) N:       0 (+ 0) (P:  3.25%) (Q:  0.10555) (D:  0.000) (U: 0.09751) (Q+U:  0.20306) (V:  -.----) 
info string h2h3  (400 ) N:       0 (+ 0) (P:  3.25%) (Q:  0.10555) (D:  0.000) (U: 0.09751) (Q+U:  0.20306) (V:  -.----) 
info string b2b3  (230 ) N:       0 (+ 0) (P:  3.45%) (Q:  0.10555) (D:  0.000) (U: 0.10364) (Q+U:  0.20919) (V:  -.----) 
info string b1c3  (36  ) N:       0 (+ 0) (P:  3.50%) (Q:  0.10555) (D:  0.000) (U: 0.10488) (Q+U:  0.21043) (V:  -.----) 
info string a2a3  (204 ) N:       0 (+ 0) (P:  3.54%) (Q:  0.10555) (D:  0.000) (U: 0.10607) (Q+U:  0.21162) (V:  -.----) 
info string c2c3  (259 ) N:       0 (+ 0) (P:  3.87%) (Q:  0.10555) (D:  0.000) (U: 0.11623) (Q+U:  0.22178) (V:  -.----) 
info string e2e3  (317 ) N:       0 (+ 0) (P:  4.21%) (Q:  0.10555) (D:  0.000) (U: 0.12639) (Q+U:  0.23195) (V:  -.----) 
info string g2g3  (374 ) N:       0 (+ 0) (P:  4.45%) (Q:  0.10555) (D:  0.000) (U: 0.13344) (Q+U:  0.23900) (V:  -.----) 
info string c2c4  (264 ) N:       0 (+ 0) (P:  4.47%) (Q:  0.10555) (D:  0.000) (U: 0.13404) (Q+U:  0.23959) (V:  -.----) 
info string g1f3  (159 ) N:       0 (+ 0) (P:  7.78%) (Q:  0.10555) (D:  0.000) (U: 0.23338) (Q+U:  0.33893) (V:  -.----) 
info string d2d4  (293 ) N:       0 (+ 0) (P:  8.90%) (Q:  0.10555) (D:  0.000) (U: 0.26707) (Q+U:  0.37262) (V:  -.----) 
info string e2e4  (322 ) N:       0 (+ 0) (P: 30.76%) (Q:  0.10555) (D:  0.000) (U: 0.92288) (Q+U:  1.02844) (V:  -.----) 
bestmove e2e4

@Tilps
Copy link
Contributor

Tilps commented Jul 31, 2019

I wonder if the random backend should change the values it outputs for policy in random mode to have potentially greater range with this change.

@Sopel97
Copy link

Sopel97 commented Jul 31, 2019

p is always in range [0, 1] right? maybe a simple taylor series approximation would do better
https://stackoverflow.com/a/10552567/3763139
FastPow2 is actually pretty complex for what is needed here

@Tilps
Copy link
Contributor

Tilps commented Jul 31, 2019

The result of GetPValue as of this change can be anything - any positive or negative float value. Use of max move the range to be non-positive, with at least one value being 0.

@jjoshua2
Copy link
Contributor

jjoshua2 commented Jul 31, 2019 via email

@ddobbelaere
Copy link
Contributor Author

I've fused the softmax and softmax temperature steps, such that now the random backend has just about the same performance: best of three yields 240698 nps (master) and 239228 nps (this PR), even if this PR has to do more work (namely softmax)!

The trick is that w.r.t. master an extra FastLog2 call could be dropped (note that FastExp uses FastPow2 internally).

Accuracy doesn't seem to be an issue:

# master
go nodes 1
Loading weights file from: networks/36092_70be354d727332e6fc992a4f01eaf8831a85e7ab4ad6fa695980a860722db2b6
Creating backend [blas]...
BLAS vendor: OpenBLAS.
OpenBLAS [OpenBLAS 0.3.6 DYNAMIC_ARCH NO_AFFINITY Haswell SINGLE_THREADED].
OpenBLAS found 1 Haswell core(s).
OpenBLAS using 1 core(s) for this backend.
BLAS max batch size is 256.
info depth 1 seldepth 1 time 15 nodes 1 score cp 31 hashfull 0 nps 66 tbhits 0 pv e2e4
info string g2g4  (378 ) N:       0 (+ 0) (P:  1.80%) (Q:  0.10555) (D:  0.000) (U: 0.05388) (Q+U:  0.15943) (V:  -.----) 
info string f2f3  (346 ) N:       0 (+ 0) (P:  1.88%) (Q:  0.10555) (D:  0.000) (U: 0.05651) (Q+U:  0.16207) (V:  -.----) 
info string g1h3  (161 ) N:       0 (+ 0) (P:  2.20%) (Q:  0.10555) (D:  0.000) (U: 0.06599) (Q+U:  0.17154) (V:  -.----) 
info string f2f4  (351 ) N:       0 (+ 0) (P:  2.41%) (Q:  0.10555) (D:  0.000) (U: 0.07219) (Q+U:  0.17774) (V:  -.----) 
info string b1a3  (34  ) N:       0 (+ 0) (P:  2.42%) (Q:  0.10555) (D:  0.000) (U: 0.07258) (Q+U:  0.17813) (V:  -.----) 
info string h2h4  (403 ) N:       0 (+ 0) (P:  2.50%) (Q:  0.10555) (D:  0.000) (U: 0.07487) (Q+U:  0.18042) (V:  -.----) 
info string b2b4  (234 ) N:       0 (+ 0) (P:  2.57%) (Q:  0.10555) (D:  0.000) (U: 0.07711) (Q+U:  0.18267) (V:  -.----) 
info string a2a4  (207 ) N:       0 (+ 0) (P:  2.78%) (Q:  0.10555) (D:  0.000) (U: 0.08336) (Q+U:  0.18891) (V:  -.----) 
info string d2d3  (288 ) N:       0 (+ 0) (P:  3.23%) (Q:  0.10555) (D:  0.000) (U: 0.09687) (Q+U:  0.20242) (V:  -.----) 
info string h2h3  (400 ) N:       0 (+ 0) (P:  3.23%) (Q:  0.10555) (D:  0.000) (U: 0.09696) (Q+U:  0.20251) (V:  -.----) 
info string b2b3  (230 ) N:       0 (+ 0) (P:  3.46%) (Q:  0.10555) (D:  0.000) (U: 0.10387) (Q+U:  0.20942) (V:  -.----) 
info string b1c3  (36  ) N:       0 (+ 0) (P:  3.51%) (Q:  0.10555) (D:  0.000) (U: 0.10529) (Q+U:  0.21084) (V:  -.----) 
info string a2a3  (204 ) N:       0 (+ 0) (P:  3.55%) (Q:  0.10555) (D:  0.000) (U: 0.10653) (Q+U:  0.21208) (V:  -.----) 
info string c2c3  (259 ) N:       0 (+ 0) (P:  3.88%) (Q:  0.10555) (D:  0.000) (U: 0.11628) (Q+U:  0.22183) (V:  -.----) 
info string e2e3  (317 ) N:       0 (+ 0) (P:  4.20%) (Q:  0.10555) (D:  0.000) (U: 0.12603) (Q+U:  0.23158) (V:  -.----) 
info string g2g3  (374 ) N:       0 (+ 0) (P:  4.43%) (Q:  0.10555) (D:  0.000) (U: 0.13294) (Q+U:  0.23849) (V:  -.----) 
info string c2c4  (264 ) N:       0 (+ 0) (P:  4.45%) (Q:  0.10555) (D:  0.000) (U: 0.13358) (Q+U:  0.23913) (V:  -.----) 
info string g1f3  (159 ) N:       0 (+ 0) (P:  7.75%) (Q:  0.10555) (D:  0.000) (U: 0.23264) (Q+U:  0.33820) (V:  -.----) 
info string d2d4  (293 ) N:       0 (+ 0) (P:  8.93%) (Q:  0.10555) (D:  0.000) (U: 0.26798) (Q+U:  0.37354) (V:  -.----) 
info string e2e4  (322 ) N:       0 (+ 0) (P: 30.82%) (Q:  0.10555) (D:  0.000) (U: 0.92472) (Q+U:  1.03027) (V:  -.----) 
bestmove e2e4

# this PR
go nodes 1
Loading weights file from: networks/36092_70be354d727332e6fc992a4f01eaf8831a85e7ab4ad6fa695980a860722db2b6
Creating backend [blas]...
BLAS vendor: OpenBLAS.
OpenBLAS [OpenBLAS 0.3.6 DYNAMIC_ARCH NO_AFFINITY Haswell SINGLE_THREADED].
OpenBLAS found 1 Haswell core(s).
OpenBLAS using 1 core(s) for this backend.
BLAS max batch size is 256.
info depth 1 seldepth 1 time 12 nodes 1 score cp 31 hashfull 0 nps 83 tbhits 0 pv e2e4
info string g2g4  (378 ) N:       0 (+ 0) (P:  1.80%) (Q:  0.10555) (D:  0.000) (U: 0.05388) (Q+U:  0.15943) (V:  -.----) 
info string f2f3  (346 ) N:       0 (+ 0) (P:  1.88%) (Q:  0.10555) (D:  0.000) (U: 0.05640) (Q+U:  0.16195) (V:  -.----) 
info string g1h3  (161 ) N:       0 (+ 0) (P:  2.21%) (Q:  0.10555) (D:  0.000) (U: 0.06617) (Q+U:  0.17172) (V:  -.----) 
info string f2f4  (351 ) N:       0 (+ 0) (P:  2.41%) (Q:  0.10555) (D:  0.000) (U: 0.07233) (Q+U:  0.17788) (V:  -.----) 
info string b1a3  (34  ) N:       0 (+ 0) (P:  2.42%) (Q:  0.10555) (D:  0.000) (U: 0.07267) (Q+U:  0.17823) (V:  -.----) 
info string h2h4  (403 ) N:       0 (+ 0) (P:  2.50%) (Q:  0.10555) (D:  0.000) (U: 0.07496) (Q+U:  0.18051) (V:  -.----) 
info string b2b4  (234 ) N:       0 (+ 0) (P:  2.57%) (Q:  0.10555) (D:  0.000) (U: 0.07720) (Q+U:  0.18276) (V:  -.----) 
info string a2a4  (207 ) N:       0 (+ 0) (P:  2.79%) (Q:  0.10555) (D:  0.000) (U: 0.08380) (Q+U:  0.18935) (V:  -.----) 
info string d2d3  (288 ) N:       0 (+ 0) (P:  3.24%) (Q:  0.10555) (D:  0.000) (U: 0.09719) (Q+U:  0.20274) (V:  -.----) 
info string h2h3  (400 ) N:       0 (+ 0) (P:  3.24%) (Q:  0.10555) (D:  0.000) (U: 0.09723) (Q+U:  0.20279) (V:  -.----) 
info string b2b3  (230 ) N:       0 (+ 0) (P:  3.46%) (Q:  0.10555) (D:  0.000) (U: 0.10387) (Q+U:  0.20942) (V:  -.----) 
info string b1c3  (36  ) N:       0 (+ 0) (P:  3.51%) (Q:  0.10555) (D:  0.000) (U: 0.10520) (Q+U:  0.21075) (V:  -.----) 
info string a2a3  (204 ) N:       0 (+ 0) (P:  3.55%) (Q:  0.10555) (D:  0.000) (U: 0.10643) (Q+U:  0.21199) (V:  -.----) 
info string c2c3  (259 ) N:       0 (+ 0) (P:  3.88%) (Q:  0.10555) (D:  0.000) (U: 0.11646) (Q+U:  0.22201) (V:  -.----) 
info string e2e3  (317 ) N:       0 (+ 0) (P:  4.20%) (Q:  0.10555) (D:  0.000) (U: 0.12607) (Q+U:  0.23163) (V:  -.----) 
info string g2g3  (374 ) N:       0 (+ 0) (P:  4.43%) (Q:  0.10555) (D:  0.000) (U: 0.13299) (Q+U:  0.23854) (V:  -.----) 
info string c2c4  (264 ) N:       0 (+ 0) (P:  4.46%) (Q:  0.10555) (D:  0.000) (U: 0.13367) (Q+U:  0.23922) (V:  -.----) 
info string g1f3  (159 ) N:       0 (+ 0) (P:  7.77%) (Q:  0.10555) (D:  0.000) (U: 0.23301) (Q+U:  0.33856) (V:  -.----) 
info string d2d4  (293 ) N:       0 (+ 0) (P:  8.92%) (Q:  0.10555) (D:  0.000) (U: 0.26771) (Q+U:  0.37326) (V:  -.----) 
info string e2e4  (322 ) N:       0 (+ 0) (P: 30.76%) (Q:  0.10555) (D:  0.000) (U: 0.92288) (Q+U:  1.02844) (V:  -.----) 
bestmove e2e4

@ddobbelaere
Copy link
Contributor Author

ddobbelaere commented Aug 1, 2019

@Tilps I have modified the distribution of the random backend policy value.

Note that it is difficult to compare the two cases (master and this PR) with the random backend directly, as the policy distributions are different, as you mentioned on Discord. Therefore, a speed comparison with a fast GPU like @mooskagh proposed seems highly advisable and interesting at this stage.

@ddobbelaere
Copy link
Contributor Author

ddobbelaere commented Aug 1, 2019

Actually a fair comparison is possible with the uniform random backend (--backend=random --backend-opts="uniform=true"). Here are the results for lc0 benchmark over 10 runs:

nps master this PR
min 121692 121258
max 126924 126891
avg 124489 124287 (-0.16%)

It is safe to say that there is no significant regression in terms of nps.

NN evals are expected to be faster (because of dropped softmax layer), so the balance might even be positive, although speedup was not the goal of this PR in any case.

@jkormu
Copy link

jkormu commented Aug 1, 2019

No speed difference with real nets using RTX 2070.

Ten samples of goodgyal-5 (48x5), cudnn-fp16 go nodes 1000000 from startpos:

0.23.1 goodgyal-5 (48x5):
mean     12.215547 s
std       0.308873 s
min      11.480607 s
max      12.616816 s

pr912 goodgyal-5 (48x5):
mean     12.234400 s
std       0.480077 s
min      10.932873 s
max      12.564241 s

Same for net 42850 but with 100000 nodes:

0.23.1 net 42850
mean      5.508349 s
std       0.015322 s
min       5.482581 s
max       5.532932 s

pr912 net 42850
mean      5.492028 s
std       0.026364 s
min       5.421928 s
max       5.509242 s

@Tilps Tilps merged commit 7bb95ca into LeelaChessZero:master Aug 3, 2019
Tilps pushed a commit that referenced this pull request Aug 3, 2019
* Do softmax outside backend on set of legal moves.

* Remove policy softmax from blas backend.

* Remove policy softmax from CUDA backend.

* Remove policy softmax from OpenCL backend.

* Remove policy softmax from TensorFlow backend.

* Use FastExp for policy softmax calculations.

* Fix for negative exponentials.

* Revert "Fix for negative exponentials."

This reverts commit 9fb73d0.

* Fuse softmax with softmax temperature.

* Modify random backend policy value distribution.

* Comment improvements.
borg323 pushed a commit to borg323/lc0 that referenced this pull request Aug 4, 2019
* Do softmax outside backend on set of legal moves.

* Remove policy softmax from blas backend.

* Remove policy softmax from CUDA backend.

* Remove policy softmax from OpenCL backend.

* Remove policy softmax from TensorFlow backend.

* Use FastExp for policy softmax calculations.

* Fix for negative exponentials.

* Revert "Fix for negative exponentials."

This reverts commit 9fb73d0.

* Fuse softmax with softmax temperature.

* Modify random backend policy value distribution.

* Comment improvements.
@ddobbelaere ddobbelaere deleted the fix-softmax-accuracy branch August 9, 2019 12:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants