Skip to content

Commit 7bb95ca

Browse files
ddobbelaereTilps
authored andcommitted
Fix policy softmax accuracy if masking is enabled. (#912)
* Do softmax outside backend on set of legal moves. * Remove policy softmax from blas backend. * Remove policy softmax from CUDA backend. * Remove policy softmax from OpenCL backend. * Remove policy softmax from TensorFlow backend. * Use FastExp for policy softmax calculations. * Fix for negative exponentials. * Revert "Fix for negative exponentials." This reverts commit 9fb73d0. * Fuse softmax with softmax temperature. * Modify random backend policy value distribution. * Comment improvements.
1 parent 78baefe commit 7bb95ca

File tree

7 files changed

+49
-50
lines changed

7 files changed

+49
-50
lines changed

src/mcts/search.cc

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,16 +1235,23 @@ void SearchWorker::FetchSingleNodeResult(NodeToProcess* node_to_process,
12351235
node_to_process->v = -computation_->GetQVal(idx_in_computation);
12361236
node_to_process->d = computation_->GetDVal(idx_in_computation);
12371237
// ...and secondly, the policy data.
1238+
// Calculate maximum first.
1239+
float max_p = -std::numeric_limits<float>::infinity();
1240+
for (auto edge : node->Edges()) {
1241+
max_p =
1242+
std::max(max_p, computation_->GetPVal(idx_in_computation,
1243+
edge.GetMove().as_nn_index()));
1244+
}
12381245
float total = 0.0;
12391246
for (auto edge : node->Edges()) {
12401247
float p =
12411248
computation_->GetPVal(idx_in_computation, edge.GetMove().as_nn_index());
1242-
if (params_.GetPolicySoftmaxTemp() != 1.0f) {
1243-
// Flush denormals to zero.
1244-
p = p < 1.17549435E-38
1245-
? 0.0
1246-
: FastPow2(FastLog2(p) / params_.GetPolicySoftmaxTemp());
1247-
}
1249+
// Perform softmax and take into account policy softmax temperature T.
1250+
// Note that we want to calculate (exp(p-max_p))^(1/T) = exp((p-max_p)/T).
1251+
p = FastExp((p - max_p) / params_.GetPolicySoftmaxTemp());
1252+
1253+
// Note that p now lies in [0, 1], so it is safe to store it in compressed
1254+
// format. Normalization happens later.
12481255
edge.edge()->SetP(p);
12491256
// Edge::SetP does some rounding, so only add to the total after rounding.
12501257
total += edge.edge()->GetP();

src/neural/blas/network_blas.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -303,9 +303,8 @@ void BlasComputation::ComputeBlocking() {
303303
std::vector<float> policy(num_output_policy);
304304

305305
// Get the moves
306-
SoftmaxActivation(num_output_policy, &output_pol[j * num_output_policy],
307-
policy.data());
308-
306+
policy.assign(output_pol.begin() + j * num_output_policy,
307+
output_pol.begin() + (j + 1) * num_output_policy);
309308
policies_.emplace_back(std::move(policy));
310309
}
311310

@@ -418,8 +417,8 @@ BlasNetwork::BlasNetwork(const WeightsFile& file, const OptionsDict& options)
418417
CERR << "MKL " << versionbuf << ".";
419418
MKLVersion version;
420419
mkl_get_version(&version);
421-
CERR << "MKL platform: " << version.Platform << ", processor: "
422-
<< version.Processor << ".";
420+
CERR << "MKL platform: " << version.Platform
421+
<< ", processor: " << version.Processor << ".";
423422
CERR << "MKL can use up to " << max_procs << " thread(s).";
424423
CERR << "MKL using " << blas_cores << " thread(s) for this backend.";
425424
#endif

src/neural/cuda/network_cudnn.cc

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ struct InputsOutputs {
9999
ReportCUDAErrors(cudaMalloc(
100100
&op_policy_mem_gpu_, maxBatchSize * kNumOutputPolicy * sizeof(float)));
101101

102-
ReportCUDAErrors(
103-
cudaHostAlloc(&op_value_mem_, maxBatchSize * (wdl ? 3 : 1) * sizeof(float),
102+
ReportCUDAErrors(cudaHostAlloc(&op_value_mem_,
103+
maxBatchSize * (wdl ? 3 : 1) * sizeof(float),
104104
cudaHostAllocMapped));
105105
ReportCUDAErrors(
106106
cudaHostGetDevicePointer(&op_value_mem_gpu_, op_value_mem_, 0));
@@ -239,8 +239,7 @@ class CudnnNetwork : public Network {
239239
}
240240

241241
// Override if forced from backend option
242-
if (!options.IsDefault<bool>("nhwc"))
243-
nhwc_ = options.Get<bool>("nhwc");
242+
if (!options.IsDefault<bool>("nhwc")) nhwc_ = options.Get<bool>("nhwc");
244243

245244
if (nhwc_)
246245
ReportCUBLASErrors(cublasSetMathMode(cublas_, CUBLAS_TENSOR_OP_MATH));
@@ -377,10 +376,6 @@ class CudnnNetwork : public Network {
377376
policymap->LoadWeights(kConvPolicyMap, scratch_mem_);
378377

379378
network_.emplace_back(std::move(policymap));
380-
381-
auto softmaxPol =
382-
std::make_unique<SoftMaxLayer<DataType>>(getLastLayer());
383-
network_.emplace_back(std::move(softmaxPol));
384379
} else {
385380
auto convPol = std::make_unique<ConvLayer<DataType>>(
386381
resi_last_, weights.policy.biases.size(), 8, 8, 1, kNumFilters, true,
@@ -394,10 +389,6 @@ class CudnnNetwork : public Network {
394389
FCPol->LoadWeights(&weights.ip_pol_w[0], &weights.ip_pol_b[0],
395390
scratch_mem_);
396391
network_.emplace_back(std::move(FCPol));
397-
398-
auto softmaxPol =
399-
std::make_unique<SoftMaxLayer<DataType>>(getLastLayer());
400-
network_.emplace_back(std::move(softmaxPol));
401392
}
402393
policy_out_ = getLastLayer();
403394

@@ -533,39 +524,32 @@ class CudnnNetwork : public Network {
533524
scratch_mem_, scratch_size_, cudnn_,
534525
cublas_); // conv1
535526

536-
network_[l++]->Eval(batchSize, tensor_mem_[0], tensor_mem_[1], nullptr,
537-
scratch_mem_, scratch_size_, cudnn_,
538-
cublas_); // pol FC
539527
if (fp16) {
540-
// TODO: consider softmax layer that writes directly to fp32
541-
network_[l++]->Eval(batchSize, tensor_mem_[1], tensor_mem_[0], nullptr,
528+
network_[l++]->Eval(batchSize, tensor_mem_[0], tensor_mem_[1], nullptr,
542529
scratch_mem_, scratch_size_, cudnn_,
543-
cublas_); // pol softmax
544-
copyTypeConverted(opPol, (half*)(tensor_mem_[1]),
530+
cublas_); // pol FC
531+
copyTypeConverted(opPol, (half*)(tensor_mem_[0]),
545532
batchSize * kNumOutputPolicy); // POLICY
546533
} else {
547-
network_[l++]->Eval(batchSize, (DataType*)opPol, tensor_mem_[0],
534+
network_[l++]->Eval(batchSize, (DataType*)opPol, tensor_mem_[1],
548535
nullptr, scratch_mem_, scratch_size_, cudnn_,
549-
cublas_); // pol softmax // POLICY
536+
cublas_); // pol FC // POLICY
550537
}
551538
} else {
552539
network_[l++]->Eval(batchSize, tensor_mem_[0], tensor_mem_[2], nullptr,
553540
scratch_mem_, scratch_size_, cudnn_,
554541
cublas_); // pol conv
555-
network_[l++]->Eval(batchSize, tensor_mem_[1], tensor_mem_[0], nullptr,
556-
scratch_mem_, scratch_size_, cudnn_,
557-
cublas_); // pol FC
542+
558543
if (fp16) {
559-
// TODO: consider softmax layer that writes directly to fp32.
560-
network_[l++]->Eval(batchSize, tensor_mem_[0], tensor_mem_[1], nullptr,
544+
network_[l++]->Eval(batchSize, tensor_mem_[1], tensor_mem_[0], nullptr,
561545
scratch_mem_, scratch_size_, cudnn_,
562-
cublas_); // pol softmax
563-
copyTypeConverted(opPol, (half*)(tensor_mem_[0]),
546+
cublas_); // pol FC
547+
copyTypeConverted(opPol, (half*)(tensor_mem_[1]),
564548
batchSize * kNumOutputPolicy); // POLICY
565549
} else {
566-
network_[l++]->Eval(batchSize, (DataType*)opPol, tensor_mem_[1],
550+
network_[l++]->Eval(batchSize, (DataType*)opPol, tensor_mem_[0],
567551
nullptr, scratch_mem_, scratch_size_, cudnn_,
568-
cublas_); // pol softmax // POLICY
552+
cublas_); // pol FC // POLICY
569553
}
570554
}
571555

src/neural/network_random.cc

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,16 @@ class RandomNetworkComputation : public NetworkComputation {
7878

7979
float GetPVal(int sample, int move_id) const override {
8080
if (uniform_mode_) return 1.0f;
81+
82+
// Note that this function returns the policy value *before* softmax.
83+
// We choose a uniform distribution over [0, a], implying that the
84+
// proportion between the smallest and largest policy value *after* softmax
85+
// exponentiation (but before normalization) is equal to S = exp(-a).
86+
// Choosing a = 3.0 leads to S = 0.05.
87+
const float a = 3.0f;
8188
return (HashCat({inputs_[sample], static_cast<unsigned long>(move_id)}) %
82-
10000) /
83-
10000.0;
89+
10000) *
90+
(a / 10000.0f);
8491
}
8592

8693
private:
@@ -97,7 +104,8 @@ class RandomNetwork : public Network {
97104
seed_(options.GetOrDefault<int>("seed", 0)),
98105
uniform_mode_(options.GetOrDefault<bool>("uniform", false)) {}
99106
std::unique_ptr<NetworkComputation> NewComputation() override {
100-
return std::make_unique<RandomNetworkComputation>(delay_ms_, seed_, uniform_mode_);
107+
return std::make_unique<RandomNetworkComputation>(delay_ms_, seed_,
108+
uniform_mode_);
101109
}
102110

103111
private:

src/neural/network_tf.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ std::pair<Output, Output> MakeNetwork(const Scope& scope, Input input,
144144
ip_pol_w = Reshape(scope, ip_pol_w, Const(scope, {32 * 8 * 8, 1858}));
145145
auto ip_pol_b = MakeConst(scope, {1858}, weights.ip_pol_b);
146146
auto policy_fc = Add(scope, MatMul(scope, conv_pol, ip_pol_w), ip_pol_b);
147-
auto policy_head = Softmax(scope, policy_fc);
148147

149148
// Value head
150149
auto conv_val =
@@ -163,7 +162,7 @@ std::pair<Output, Output> MakeNetwork(const Scope& scope, Input input,
163162
auto value_head =
164163
Tanh(scope, Add(scope, MatMul(scope, value_flow, ip2_val_w), ip2_val_b));
165164

166-
return {policy_head, value_head};
165+
return {policy_fc, value_head};
167166
}
168167

169168
template <bool CPU>

src/neural/opencl/network_opencl.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,11 @@ class OpenCLComputation : public NetworkComputation {
102102
buffers_->forward(input_data, output_pol, output_val, batch_size);
103103

104104
for (size_t j = 0; j < batch_size; j++) {
105-
std::vector<float> policy(weights_.num_output_policies);
105+
std::vector<float> policy(num_output_policies);
106106

107107
// Get the moves.
108-
SoftmaxActivation(num_output_policies,
109-
&output_pol[j * num_output_policies], policy.data());
110-
108+
policy.assign(output_pol.begin() + j * num_output_policies,
109+
output_pol.begin() + (j + 1) * num_output_policies);
111110
policies_.emplace_back(std::move(policy));
112111

113112
// Now get the score.

src/utils/fastmath.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,7 @@ inline float FastLog(const float a) {
6565
return 0.6931471805599453f * FastLog2(a);
6666
}
6767

68+
// Fast approximate exp(x). Does only limited range checking.
69+
inline float FastExp(const float a) { return FastPow2(1.442695040f * a); }
70+
6871
} // namespace lczero

0 commit comments

Comments
 (0)