Skip to content

Commit 4f6290a

Browse files
committed
Fix ComputeSmall and enable it.
1 parent 6644248 commit 4f6290a

File tree

1 file changed

+5
-23
lines changed

1 file changed

+5
-23
lines changed

tensorflow_quantum/core/ops/tfq_simulate_expectation_op_cuda.cu.cc

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ class TfqSimulateExpectationOpCuda : public tensorflow::OpKernel {
111111
for (const int num : num_qubits) {
112112
max_num_qubits = std::max(max_num_qubits, num);
113113
}
114-
if (max_num_qubits >= 26 || programs.size() == 1 || true) {
114+
if (max_num_qubits >= 26 || programs.size() == 1) {
115115
ComputeLarge(num_qubits, fused_circuits, pauli_sums, context,
116116
&output_tensor);
117117
} else {
@@ -122,7 +122,6 @@ class TfqSimulateExpectationOpCuda : public tensorflow::OpKernel {
122122

123123
private:
124124
int num_threads_in_sim_;
125-
int thread_per_block_;
126125
int block_count_;
127126

128127
// Define the GPU implementation that launches the CUDA kernel.
@@ -135,24 +134,10 @@ class TfqSimulateExpectationOpCuda : public tensorflow::OpKernel {
135134
// Instantiate qsim objects.
136135
using Simulator = qsim::SimulatorCUDA<float>;
137136
using StateSpace = Simulator::StateSpace;
138-
// Launch the cuda kernel. These parameters came from:
139-
// 1. min/max num_threads in //third_party/qsim/tests/simulator_cuda_test.cu
140-
// 2. min/max num_threads & dblocks in
141-
// //third_party/qsim/tests/statespace_cuda_test.cu
142-
// //third_party/qsim/lib/statespace_cuda.h:55-64
143-
// num_dblocks has no explanation. just follow test code 2 or 16.
144-
int block_count = 2; // 2 or 16;
145-
// num_threads = 2**q where q in [5..10]
146-
int thread_per_block = 128; // 32, 64, 128, 256, 512, 1024;
147-
// TFQ GPU
148-
StateSpace::Parameter param_ss;
149-
param_ss.num_threads = thread_per_block;
150-
param_ss.num_dblocks = block_count;
151-
152-
// Begin simulation.
137+
// Begin simulation with default parameters.
153138
int largest_nq = 1;
154139
Simulator sim = Simulator();
155-
StateSpace ss = StateSpace(param_ss);
140+
StateSpace ss = StateSpace(StateSpace::Parameter());
156141
auto sv = ss.Create(largest_nq);
157142
auto scratch = ss.Create(largest_nq);
158143

@@ -199,10 +184,7 @@ class TfqSimulateExpectationOpCuda : public tensorflow::OpKernel {
199184
using Simulator = qsim::SimulatorCUDA<float>;
200185
using StateSpace = Simulator::StateSpace;
201186

202-
StateSpace::Parameter param_ss;
203-
param_ss.num_threads = thread_per_block_;
204-
param_ss.num_dblocks = block_count_;
205-
187+
StateSpace::Parameter param_default;
206188
const int output_dim_op_size = output_tensor->dimension(1);
207189

208190
Status compute_status = Status();
@@ -215,7 +197,7 @@ class TfqSimulateExpectationOpCuda : public tensorflow::OpKernel {
215197

216198
// Begin simulation.
217199
auto sim = Simulator();
218-
auto ss = StateSpace(param_ss);
200+
auto ss = StateSpace(param_default);
219201
auto sv = ss.Create(largest_nq);
220202
auto scratch = ss.Create(largest_nq);
221203
for (int i = start; i < end; i++) {

0 commit comments

Comments
 (0)