]>
Commit | Line | Data |
---|---|---|
7c673cae FG |
1 | // Copyright Jim Bosch & Ankit Daftery 2010-2012. |
2 | // Copyright Stefan Seefeld 2016. | |
3 | // Distributed under the Boost Software License, Version 1.0. | |
4 | // (See accompanying file LICENSE_1_0.txt or copy at | |
5 | // http://www.boost.org/LICENSE_1_0.txt) | |
6 | ||
7 | #include <boost/python/numpy.hpp> | |
8 | #include <boost/mpl/vector.hpp> | |
9 | #include <boost/mpl/vector_c.hpp> | |
10 | ||
11 | namespace p = boost::python; | |
12 | namespace np = boost::python::numpy; | |
13 | ||
14 | struct ArrayFiller | |
15 | { | |
16 | ||
17 | typedef boost::mpl::vector< short, int, float, std::complex<double> > TypeSequence; | |
18 | typedef boost::mpl::vector_c< int, 1, 2 > DimSequence; | |
19 | ||
20 | explicit ArrayFiller(np::ndarray const & arg) : argument(arg) {} | |
21 | ||
22 | template <typename T, int N> | |
23 | void apply() const | |
24 | { | |
25 | if (N == 1) | |
26 | { | |
27 | char * p = argument.get_data(); | |
28 | int stride = argument.strides(0); | |
29 | int size = argument.shape(0); | |
30 | for (int n = 0; n != size; ++n, p += stride) | |
31 | *reinterpret_cast<T*>(p) = static_cast<T>(n); | |
32 | } | |
33 | else | |
34 | { | |
35 | char * row_p = argument.get_data(); | |
36 | int row_stride = argument.strides(0); | |
37 | int col_stride = argument.strides(1); | |
38 | int rows = argument.shape(0); | |
39 | int cols = argument.shape(1); | |
40 | int i = 0; | |
41 | for (int n = 0; n != rows; ++n, row_p += row_stride) | |
42 | { | |
43 | char * col_p = row_p; | |
44 | for (int m = 0; m != cols; ++i, ++m, col_p += col_stride) | |
45 | *reinterpret_cast<T*>(col_p) = static_cast<T>(i); | |
46 | } | |
47 | } | |
48 | } | |
49 | ||
50 | np::ndarray argument; | |
51 | }; | |
52 | ||
53 | void fill(np::ndarray const & arg) | |
54 | { | |
55 | ArrayFiller filler(arg); | |
56 | np::invoke_matching_array<ArrayFiller::TypeSequence, ArrayFiller::DimSequence >(arg, filler); | |
57 | } | |
58 | ||
59 | BOOST_PYTHON_MODULE(templates_ext) | |
60 | { | |
61 | np::initialize(); | |
62 | p::def("fill", fill); | |
63 | } |