Skip to content
Snippets Groups Projects
Commit fcf88bc2 authored by cmaffeo2's avatar cmaffeo2
Browse files

Bindings to create numpy view of Array<Vector_t<T>> with obj.as_array()

parent f46a7112
No related branches found
No related tags found
No related merge requests found
......@@ -333,6 +333,8 @@ struct Array {
}
#endif
HOST DEVICE size_t size() const { return num; }
HOST T* get_pointer() const { return values; }
private:
HOST void host_allocate() {
......
......@@ -3,6 +3,8 @@
#include <pybind11/numpy.h>
#include <pybind11/operators.h>
#include <iostream>
namespace py = pybind11;
/// Convert a NumPy array to a Vector3_t object.
......@@ -27,13 +29,14 @@ Vector3_t<T> array_to_vector(py::array_t<T> a) {
return Vector3_t<T>(ptr[0],ptr[1],ptr[2],ptr[3]);
}
/// Convert a NumPy array to a Vector3_t object.
/// Convert a NumPy array to an Array<Vector3_t> object.
///
/// This function converts a 1D NumPy array of size 3 or 4 into a Vector3_t object.
/// This function converts a 2D NumPy array of shape [N,M] with M in (3,4) into an Array<Vector3_t<T> object.
///
/// \tparam T - The data type of the elements in the NumPy array.
/// \param a - The NumPy array to convert.
/// \return A Array<Vector3_t<T>> object created from the array.
/// \return An Array<Vector3_t<T>> object created from the array.
template<typename T>
Array<Vector3_t<T>> array_to_vector_arr(py::array_t<T> a) {
py::buffer_info info = a.request();
......@@ -45,11 +48,13 @@ Array<Vector3_t<T>> array_to_vector_arr(py::array_t<T> a) {
T *ptr = static_cast<T *>(info.ptr);
Array<Vector3_t<T>> arr(static_cast<size_t>(info.shape[0]));
assert( info.strides[0] == info.shape[1] );
// std::cerr << "Shape : " << info.shape[0] << " " << info.shape[1] << std::endl;
// std::cerr << "Stride : " << info.strides[0] << " " << info.strides[1] << std::endl;
if (info.shape[1] == 3) {
for (size_t i = 0; i < info.shape[0]; ++i) {
size_t j = i*info.strides[0];
size_t j = i*info.shape[1];
arr[i] = (Vector3_t<T>( ptr[j], ptr[j+1], ptr[j+2] ));
// std::cerr << arr[i].to_string() << std::endl;
}
} else {
for (size_t i = 0; i < info.shape[0]; ++i) {
......@@ -57,9 +62,81 @@ Array<Vector3_t<T>> array_to_vector_arr(py::array_t<T> a) {
arr[i] = (Vector3_t<T>( ptr[j], ptr[j+1], ptr[j+2], ptr[j+3] ));
}
}
// std::cerr << "Created Array<Vector3_t<T>> @" << &arr <<
// " with data @" << arr.get_pointer() << std::endl;
return arr;
}
/// Convert am Array<Vector3_t> object to a NumPy array.
///
/// This function converts an Array<Vector3_t<T> object into a 2D NumPy array of shape [N,M] with M in (3,4).
///
/// \tparam T - The data type of the elements in the arrays.
/// \param a - The Array<Vector3_t<T>> object to convert.
/// \return A NumPy object created from the array.
template<typename T>
auto vector_arr_to_numpy_array(Array<Vector3_t<T>>& inp) {
// NOTE: this may not work on all architectures; reinterpret_cast is dangerous!
// Create a Python object that will free the allocated
// memory when destroyed:
T* ptr = reinterpret_cast<T*>(inp.get_pointer());
// unsigned char* cptr = reinterpret_cast<unsigned char*>(inp.get_pointer());
// Vector3_t<T> tmp(1);
// assert( reinterpret_cast<unsigned char*>(&(tmp.x)) - reinterpret_cast<unsigned char*>(&(tmp)) == 0 );
// cptr += reinterpret_cast<unsigned char*>(&(tmp.x)) - reinterpret_cast<unsigned char*>(&(tmp));
// T* ptr = reinterpret_cast<T*>(cptr);
// assert(&(inp[0]) != inp.get_pointer());
// std::cerr << "addr: " << inp.get_pointer() << " (" <<
// ptr << ")" << std::endl;
py::capsule free_when_done(ptr, [](void *f) {
/* Don't do anything
T *ptr = reinterpret_cast<T *>(f);
std::cerr << "Element [0] = " << ptr[0] << "\n";
std::cerr << "not freeing memory @ " << f << "\n";
delete[] ptr;
*/
});
// std::cerr << "Printing array" << std::endl;
// for (size_t i = 0; i < inp.size(); ++i) {
// std::cerr << inp[i].to_string() << std::endl;
// }
// std::cerr << " done printing" << std::endl;
// std::cerr << "Printing raw data" << std::endl;
// for (size_t i = 0; i < inp.size(); ++i) {
// for (size_t j = i*4; j < 4*i+4; ++j) {
// std::cerr << ptr[j] << " ";
// }
// std::cerr << std::endl;
// }
// std::cerr << " done printing" << std::endl;
// std::vector<size_t> shape;
// shape.push_back(inp.size());
// shape.push_back(4);
py::array::ShapeContainer shape = {inp.size(), std::size_t{4}};
py::array::StridesContainer strides = {sizeof(Vector3_t<T>), sizeof(T)};
// std::cerr << "sizeof(T) " << sizeof(T) << std::endl;
assert( sizeof(Vector3_t<T>) == 4 * sizeof(T) );
auto a = py::template array_t<T>(
shape, //shape
strides, // strides
ptr, // data pointer
free_when_done); // numpy array references this parent
return a;
}
/// Declare Python bindings for Vector3_t<some_type>
///
......@@ -120,6 +197,7 @@ void declare_vector_array(py::module &m, const std::string &typestr) {
// .def(-py::self)
// Conversions
.def("as_array", [](Array<Vector3_t<T>>& a) { return vector_arr_to_numpy_array<T>(a); })
// .def("__repr__", &Class::to_string)
;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment