1
1
# load op_wrapper
2
+ load ("@org_tensorflow//tensorflow:tensorflow.bzl" , "tf_gpu_kernel_library" , "tf_gen_op_wrapper_py" )
3
+ load ("@local_config_cuda//cuda:build_defs.bzl" , "if_cuda_is_configured" , "if_cuda" )
2
4
3
5
package (default_visibility = ["//visibility:public" ])
4
6
@@ -12,6 +14,23 @@ config_setting(
12
14
constraint_values = ["@bazel_tools//platforms:windows" ],
13
15
)
14
16
17
+ cc_library (
18
+ name = "cuda" ,
19
+ data = [
20
+ "@local_config_cuda//cuda:cudart" ,
21
+ ],
22
+ linkopts = select ({
23
+ ":windows" : [],
24
+ "//conditions:default" : [
25
+ "-Wl,-rpath,../local_config_cuda/cuda/lib64" ,
26
+ "-Wl,-rpath,../local_config_cuda/cuda/extras/CUPTI/lib64" ,
27
+ ],
28
+ }),
29
+ deps = [
30
+ "@local_config_cuda//cuda:cudart" ,
31
+ ],
32
+ )
33
+
15
34
py_library (
16
35
name = "ops" ,
17
36
srcs = ["__init__.py" ],
@@ -24,6 +43,7 @@ py_library(
24
43
":tfq_adj_grad_op_py" ,
25
44
":tfq_ps_util_ops_py" ,
26
45
":tfq_simulate_ops_py" ,
46
+ ":tfq_simulate_ops_cuda_py" ,
27
47
":tfq_unitary_op_py" ,
28
48
":tfq_utility_ops_py" ,
29
49
# test addons
@@ -619,6 +639,103 @@ py_test(
619
639
],
620
640
)
621
641
642
+ py_library (
643
+ name = "tfq_simulate_ops_cuda_py" ,
644
+ srcs = ["tfq_simulate_ops_cuda.py" ],
645
+ data = [
646
+ ":_tfq_simulate_ops_cuda.so" ,
647
+ ],
648
+ srcs_version = "PY3" ,
649
+ deps = [
650
+ # tensorflow framework for wrappers
651
+ ":load_module" ,
652
+ ],
653
+ )
654
+
655
+ py_test (
656
+ name = "tfq_simulate_ops_cuda_test" ,
657
+ srcs = ["tfq_simulate_ops_cuda_test.py" ],
658
+ deps = [
659
+ ":tfq_simulate_ops_cuda_py" ,
660
+ ":tfq_simulate_ops_py" ,
661
+ "//tensorflow_quantum/python:util" ,
662
+ ],
663
+ srcs_version = "PY3" ,
664
+ )
665
+
666
+ cc_binary (
667
+ name = "_tfq_simulate_ops_cuda.so" ,
668
+ srcs = [
669
+ "tfq_simulate_expectation_op_cuda.cu.cc" ,
670
+ ],
671
+ linkshared = 1 ,
672
+ features = select ({
673
+ ":windows" : ["windows_export_all_symbols" ],
674
+ "//conditions:default" : [],
675
+ }),
676
+ copts = select ({
677
+ ":windows" : [
678
+ "/D__CLANG_SUPPORT_DYN_ANNOTATION__" ,
679
+ "/D_USE_MATH_DEFINES" ,
680
+ "/DEIGEN_MPL2_ONLY" ,
681
+ "/DEIGEN_MAX_ALIGN_BYTES=64" ,
682
+ "/DEIGEN_HAS_TYPE_TRAITS=0" ,
683
+ "/DTF_USE_SNAPPY" ,
684
+ "/showIncludes" ,
685
+ "/MD" ,
686
+ "/O2" ,
687
+ "/DNDEBUG" ,
688
+ "/w" ,
689
+ "-DWIN32_LEAN_AND_MEAN" ,
690
+ "-DNOGDI" ,
691
+ "/d2ReducedOptimizeHugeFunctions" ,
692
+ "/arch:AVX" ,
693
+ "/std:c++17" ,
694
+ "-DTENSORFLOW_MONOLITHIC_BUILD" ,
695
+ "/DPLATFORM_WINDOWS" ,
696
+ "/DEIGEN_HAS_C99_MATH" ,
697
+ "/DTENSORFLOW_USE_EIGEN_THREADPOOL" ,
698
+ "/DEIGEN_AVOID_STL_ARRAY" ,
699
+ "/Iexternal/gemmlowp" ,
700
+ "/wd4018" ,
701
+ "/wd4577" ,
702
+ "/DNOGDI" ,
703
+ "/UTF_COMPILE_LIBRARY" ,
704
+ ],
705
+ "//conditions:default" : [
706
+ "-Iexternal/local_cuda/cuda/include" ,
707
+ # "--cuda-gpu-arch=sm_86",
708
+ # "-L/usr/local/cuda/lib64",
709
+ # "-lcudart_static",
710
+ # "-ldl",
711
+ # "-lrt",
712
+ "-pthread" ,
713
+ "-std=c++17" ,
714
+ "-D_GLIBCXX_USE_CXX11_ABI=1" ,
715
+ "-O3" ,
716
+ "-Iexternal/cuda_headers" ,
717
+ "-DNV_CUDNN_DISABLE_EXCEPTION" ,
718
+ # "-fpermissive",
719
+ ],
720
+ }) + if_cuda_is_configured (["-DTENSORFLOW_USE_NVCC=1" , "-DGOOGLE_CUDA=1" , "-x cuda" , "-nvcc_options=relaxed-constexpr" , "-nvcc_options=ftz=true" ]),
721
+ deps = [
722
+ # cirq cc proto
723
+ "//tensorflow_quantum/core/ops:parse_context" ,
724
+ "//tensorflow_quantum/core/ops:tfq_simulate_utils" ,
725
+ "//tensorflow_quantum/core/proto:pauli_sum_cc_proto" ,
726
+ "//tensorflow_quantum/core/proto:program_cc_proto" ,
727
+ "//tensorflow_quantum/core/src:circuit_parser_qsim" ,
728
+ "//tensorflow_quantum/core/src:util_qsim" ,
729
+ "@qsim//lib:qsim_cuda_lib" ,
730
+ "@eigen//:eigen3" ,
731
+ # "@local_cuda//:cuda_headers"
732
+ # tensorflow core framework
733
+ # tensorflow core lib
734
+ # tensorflow core protos
735
+ ] + if_cuda_is_configured ([":cuda" , "@local_config_cuda//cuda:cuda_headers" ]),
736
+ # alwayslink=1,
737
+ )
738
+
622
739
py_library (
623
740
name = "load_module" ,
624
741
srcs = ["load_module.py" ],
0 commit comments