From cec942e3bf09a2373ea237b65defd2225cb32b48 Mon Sep 17 00:00:00 2001
From: Chris Maffeo <cmaffeo2@illinois.edu>
Date: Fri, 4 Aug 2023 17:35:47 -0500
Subject: [PATCH] Add a few bindings for Vector3_t<T>

---
 .gitmodules     |  4 +++
 CMakeLists.txt  | 11 +++++++-
 extern/pybind11 |  1 +
 src/pyarbd.cpp  | 70 +++++++++++++++++++++++++++++++++++++++++++++++++
 4 files changed, 85 insertions(+), 1 deletion(-)
 create mode 160000 extern/pybind11
 create mode 100644 src/pyarbd.cpp

diff --git a/.gitmodules b/.gitmodules
index eae98cd..3f20580 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -1,3 +1,7 @@
 [submodule "extern/spdlog"]
 	path = extern/spdlog
 	url = https://github.com/gabime/spdlog.git
+[submodule "extern/pybind11"]
+	path = extern/pybind11
+	url = ../pybind11
+	branch = stable
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 2d24803..16e6d81 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,5 +1,6 @@
 ## Specify the project
-cmake_minimum_required(VERSION 3.9 FATAL_ERROR)
+# cmake_minimum_required(VERSION 3.9 FATAL_ERROR)
+cmake_minimum_required(VERSION 3.12 FATAL_ERROR) # FOR PYBIND, should be 3.15+?
 
 # option(USE_CUDA "Use CUDA" ON)
 set(USE_CUDA ON)
@@ -164,6 +165,14 @@ src/SimManager.cu
 src/useful.cu
 )
 
+if(USE_PYBIND)
+  ## https://pybind11.readthedocs.io/en/latest/compiling.html
+  add_subdirectory(extern/pybind11)
+  find_package(Python 3.6 COMPONENTS Interpreter Development REQUIRED)
+  pybind11_add_module("py${PROJECT_NAME}" MODULE src/pyarbd.cpp)
+  target_link_libraries("py${PROJECT_NAME}" PUBLIC "lib${PROJECT_NAME}")
+endif()
+
 target_link_libraries("${PROJECT_NAME}" PUBLIC "lib${PROJECT_NAME}")
 
 ## Add optional libraries
diff --git a/extern/pybind11 b/extern/pybind11
new file mode 160000
index 0000000..914c06f
--- /dev/null
+++ b/extern/pybind11
@@ -0,0 +1 @@
+Subproject commit 914c06fb252b6cc3727d0eedab6736e88a3fcb01
diff --git a/src/pyarbd.cpp b/src/pyarbd.cpp
new file mode 100644
index 0000000..3e49f05
--- /dev/null
+++ b/src/pyarbd.cpp
@@ -0,0 +1,70 @@
+#include "Types.h"
+#include "ParticlePatch.h"
+#include <pybind11/pybind11.h>
+#include <pybind11/numpy.h>
+#include <pybind11/operators.h>
+
+
+namespace py = pybind11;
+
+template<typename T>
+Vector3_t<T> array_to_vector(py::array_t<T> a) {
+    py::buffer_info buf1 = a.request();
+    if (buf1.ndim != 1)
+        throw std::runtime_error("Number of dimensions must be one");
+    if (buf1.size < 3 || buf1.size > 4)
+        throw std::runtime_error("Size of array must be 3 or 4");
+   
+    T *ptr = static_cast<T *>(buf1.ptr);
+    if (buf1.size == 3)
+	return Vector3_t<T>(ptr[0],ptr[1],ptr[2]);
+    else
+	return Vector3_t<T>(ptr[0],ptr[1],ptr[2],ptr[3]);
+}
+
+// basic types
+template<typename T>
+void declare_vector(py::module &m, const std::string &typestr) {
+    using Class = Vector3_t<T>;
+    std::string pyclass_name = std::string("Vector3_t_") + typestr;
+    py::class_<Class>(m, pyclass_name.c_str(), py::buffer_protocol(), py::dynamic_attr())
+	.def(py::init<>())
+	.def(py::init<T>())
+	.def(py::init<T, T, T>())
+	.def(py::init([](py::array_t<T> a) { return array_to_vector<T>(a); }))
+	// .def(py::init<T, T, T, T>(),[] lambda )
+	// .def(py::init<T*>())
+	.def(py::self + py::self)
+	.def(py::self * py::self)
+        // .def(py::self += py::self)
+        .def(py::self *= float())
+        .def(float() * py::self)
+        .def(py::self * float())
+        .def(-py::self)
+        .def("__repr__", &Class::to_string);
+}
+
+PYBIND11_MODULE(pyarbd, m) {
+    declare_vector<int>(m, "int");
+    declare_vector<float>(m, "float");
+    declare_vector<double>(m, "double");
+    m.attr("Vector3") = m.attr("Vector3_t_float");
+
+/*
+    py::class_<Vector_t<T>>(m, "Vector3")
+        .def(py::init<float, float, float>())
+        .def(py::init<float, float, float>())
+
+        py::class_<Patch>(m, "Patch")
+        .def(py::init<float, float, float>())
+	.def(py::self + py::self)
+	.def(py::self * py::self)
+        // .def(py::self += py::self)
+        .def(py::self *= float())
+        .def(float() * py::self)
+        .def(py::self * float())
+        .def(-py::self)
+        .def("__repr__", &Vector3::to_string);
+*/
+}
+
-- 
GitLab