2 // Copyright (c) 2018-2019, Cem Bassoy, cem.bassoy@gmail.com
4 // Distributed under the Boost Software License, Version 1.0. (See
5 // accompanying file LICENSE_1_0.txt or copy at
6 // http://www.boost.org/LICENSE_1_0.txt)
8 // The authors gratefully acknowledge the support of
9 // Fraunhofer IOSB, Ettlingen, Germany
13 #ifndef BOOST_UBLAS_TENSOR_MULTIPLICATION
14 #define BOOST_UBLAS_TENSOR_MULTIPLICATION
25 /** @brief Computes the tensor-times-tensor product for q contraction modes
27 * Implements C[i1,...,ir,j1,...,js] = sum( A[i1,...,ir+q] * B[j1,...,js+q] )
29 * nc[x] = na[phia[x] ] for 1 <= x <= r
30 * nc[r+x] = nb[phib[x] ] for 1 <= x <= s
31 * na[phia[r+x]] = nb[phib[s+x]] for 1 <= x <= q
33 * @note is used in function ttt
35 * @param k zero-based recursion level starting with 0
36 * @param r number of non-contraction indices of A
37 * @param s number of non-contraction indices of B
38 * @param q number of contraction indices with q > 0
39 * @param phia pointer to the permutation tuple of length q+r for A
40 * @param phib pointer to the permutation tuple of length q+s for B
41 * @param c pointer to the output tensor C with rank(A)=r+s
42 * @param nc pointer to the extents of tensor C
43 * @param wc pointer to the strides of tensor C
44 * @param a pointer to the first input tensor with rank(A)=r+q
45 * @param na pointer to the extents of the first input tensor A
46 * @param wa pointer to the strides of the first input tensor A
47 * @param b pointer to the second input tensor B with rank(B)=s+q
48 * @param nb pointer to the extents of the second input tensor B
49 * @param wb pointer to the strides of the second input tensor B
52 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
53 void ttt(SizeType const k,
54 SizeType const r, SizeType const s, SizeType const q,
55 SizeType const*const phia, SizeType const*const phib,
56 PointerOut c, SizeType const*const nc, SizeType const*const wc,
57 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
58 PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
62 assert(nc[k] == na[phia[k]-1]);
63 for(size_t ic = 0u; ic < nc[k]; a += wa[phia[k]-1], c += wc[k], ++ic)
64 ttt(k+1, r, s, q, phia,phib, c, nc, wc, a, na, wa, b, nb, wb);
68 assert(nc[k] == nb[phib[k-r]-1]);
69 for(size_t ic = 0u; ic < nc[k]; b += wb[phib[k-r]-1], c += wc[k], ++ic)
70 ttt(k+1, r, s, q, phia, phib, c, nc, wc, a, na, wa, b, nb, wb);
74 assert(na[phia[k-s]-1] == nb[phib[k-r]-1]);
75 for(size_t ia = 0u; ia < na[phia[k-s]-1]; a += wa[phia[k-s]-1], b += wb[phib[k-r]-1], ++ia)
76 ttt(k+1, r, s, q, phia, phib, c, nc, wc, a, na, wa, b, nb, wb);
80 assert(na[phia[k-s]-1] == nb[phib[k-r]-1]);
81 for(size_t ia = 0u; ia < na[phia[k-s]-1]; a += wa[phia[k-s]-1], b += wb[phib[k-r]-1], ++ia)
89 /** @brief Computes the tensor-times-tensor product for q contraction modes
91 * Implements C[i1,...,ir,j1,...,js] = sum( A[i1,...,ir+q] * B[j1,...,js+q] )
93 * @note no permutation tuple is used
95 * nc[x] = na[x ] for 1 <= x <= r
96 * nc[r+x] = nb[x ] for 1 <= x <= s
97 * na[r+x] = nb[s+x] for 1 <= x <= q
99 * @note is used in function ttt
101 * @param k zero-based recursion level starting with 0
102 * @param r number of non-contraction indices of A
103 * @param s number of non-contraction indices of B
104 * @param q number of contraction indices with q > 0
105 * @param c pointer to the output tensor C with rank(A)=r+s
106 * @param nc pointer to the extents of tensor C
107 * @param wc pointer to the strides of tensor C
108 * @param a pointer to the first input tensor with rank(A)=r+q
109 * @param na pointer to the extents of the first input tensor A
110 * @param wa pointer to the strides of the first input tensor A
111 * @param b pointer to the second input tensor B with rank(B)=s+q
112 * @param nb pointer to the extents of the second input tensor B
113 * @param wb pointer to the strides of the second input tensor B
116 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
117 void ttt(SizeType const k,
118 SizeType const r, SizeType const s, SizeType const q,
119 PointerOut c, SizeType const*const nc, SizeType const*const wc,
120 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
121 PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
125 assert(nc[k] == na[k]);
126 for(size_t ic = 0u; ic < nc[k]; a += wa[k], c += wc[k], ++ic)
127 ttt(k+1, r, s, q, c, nc, wc, a, na, wa, b, nb, wb);
131 assert(nc[k] == nb[k-r]);
132 for(size_t ic = 0u; ic < nc[k]; b += wb[k-r], c += wc[k], ++ic)
133 ttt(k+1, r, s, q, c, nc, wc, a, na, wa, b, nb, wb);
137 assert(na[k-s] == nb[k-r]);
138 for(size_t ia = 0u; ia < na[k-s]; a += wa[k-s], b += wb[k-r], ++ia)
139 ttt(k+1, r, s, q, c, nc, wc, a, na, wa, b, nb, wb);
143 assert(na[k-s] == nb[k-r]);
144 for(size_t ia = 0u; ia < na[k-s]; a += wa[k-s], b += wb[k-r], ++ia)
150 /** @brief Computes the tensor-times-matrix product for the contraction mode m > 0
152 * Implements C[i1,i2,...,im-1,j,im+1,...,ip] = sum(A[i1,i2,...,im,...,ip] * B[j,im])
154 * @note is used in function ttm
156 * @param m zero-based contraction mode with 0<m<p
157 * @param r zero-based recursion level starting with p-1
158 * @param c pointer to the output tensor
159 * @param nc pointer to the extents of tensor c
160 * @param wc pointer to the strides of tensor c
161 * @param a pointer to the first input tensor
162 * @param na pointer to the extents of input tensor a
163 * @param wa pointer to the strides of input tensor a
164 * @param b pointer to the second input tensor
165 * @param nb pointer to the extents of input tensor b
166 * @param wb pointer to the strides of input tensor b
169 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
170 void ttm(SizeType const m, SizeType const r,
171 PointerOut c, SizeType const*const nc, SizeType const*const wc,
172 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
173 PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
177 ttm(m, r-1, c, nc, wc, a, na, wa, b, nb, wb);
180 for(auto i0 = 0ul; i0 < nc[0]; c += wc[0], a += wa[0], ++i0) {
183 for(auto i0 = 0ul; i0 < nc[m]; cm += wc[m], b0 += wb[0], ++i0){
186 for(auto i1 = 0ul; i1 < nb[1]; am += wa[m], b1 += wb[1], ++i1)
193 for(auto i = 0ul; i < na[r]; c += wc[r], a += wa[r], ++i)
194 ttm(m, r-1, c, nc, wc, a, na, wa, b, nb, wb);
198 /** @brief Computes the tensor-times-matrix product for the contraction mode m = 0
200 * Implements C[j,i2,...,ip] = sum(A[i1,i2,...,ip] * B[j,i1])
202 * @note is used in function ttm
204 * @param m zero-based contraction mode with 0<m<p
205 * @param r zero-based recursion level starting with p-1
206 * @param c pointer to the output tensor
207 * @param nc pointer to the extents of tensor c
208 * @param wc pointer to the strides of tensor c
209 * @param a pointer to the first input tensor
210 * @param na pointer to the extents of input tensor a
211 * @param wa pointer to the strides of input tensor a
212 * @param b pointer to the second input tensor
213 * @param nb pointer to the extents of input tensor b
214 * @param wb pointer to the strides of input tensor b
216 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
217 void ttm0( SizeType const r,
218 PointerOut c, SizeType const*const nc, SizeType const*const wc,
219 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
220 PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
224 for(auto i = 0ul; i < na[r]; c += wc[r], a += wa[r], ++i)
225 ttm0(r-1, c, nc, wc, a, na, wa, b, nb, wb);
228 for(auto i1 = 0ul; i1 < nc[1]; c += wc[1], a += wa[1], ++i1) {
232 for(auto i0 = 0ul; i0 < nc[0]; cm += wc[0], b0 += wb[0], ++i0){
236 for(auto i1 = 0u; i1 < nb[1]; am += wa[0], b1 += wb[1], ++i1){
246 //////////////////////////////////////////////////////////////////////////////////////////
247 //////////////////////////////////////////////////////////////////////////////////////////
248 //////////////////////////////////////////////////////////////////////////////////////////
249 //////////////////////////////////////////////////////////////////////////////////////////
252 /** @brief Computes the tensor-times-vector product for the contraction mode m > 0
254 * Implements C[i1,i2,...,im-1,im+1,...,ip] = sum(A[i1,i2,...,im,...,ip] * b[im])
256 * @note is used in function ttv
258 * @param m zero-based contraction mode with 0<m<p
259 * @param r zero-based recursion level starting with p-1 for tensor A
260 * @param q zero-based recursion level starting with p-1 for tensor C
261 * @param c pointer to the output tensor
262 * @param nc pointer to the extents of tensor c
263 * @param wc pointer to the strides of tensor c
264 * @param a pointer to the first input tensor
265 * @param na pointer to the extents of input tensor a
266 * @param wa pointer to the strides of input tensor a
267 * @param b pointer to the second input tensor
270 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
271 void ttv( SizeType const m, SizeType const r, SizeType const q,
272 PointerOut c, SizeType const*const nc, SizeType const*const wc,
273 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
278 ttv(m, r-1, q, c, nc, wc, a, na, wa, b);
281 for(auto i0 = 0u; i0 < na[0]; c += wc[0], a += wa[0], ++i0) {
282 auto c1 = c; auto a1 = a; auto b1 = b;
283 for(auto im = 0u; im < na[m]; a1 += wa[m], ++b1, ++im)
288 for(auto i = 0u; i < na[r]; c += wc[q], a += wa[r], ++i)
289 ttv(m, r-1, q-1, c, nc, wc, a, na, wa, b);
294 /** @brief Computes the tensor-times-vector product for the contraction mode m = 0
296 * Implements C[i2,...,ip] = sum(A[i1,...,ip] * b[i1])
298 * @note is used in function ttv
300 * @param m zero-based contraction mode with m=0
301 * @param r zero-based recursion level starting with p-1
302 * @param c pointer to the output tensor
303 * @param nc pointer to the extents of tensor c
304 * @param wc pointer to the strides of tensor c
305 * @param a pointer to the first input tensor
306 * @param na pointer to the extents of input tensor a
307 * @param wa pointer to the strides of input tensor a
308 * @param b pointer to the second input tensor
310 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
311 void ttv0(SizeType const r,
312 PointerOut c, SizeType const*const nc, SizeType const*const wc,
313 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
318 for(auto i = 0u; i < na[r]; c += wc[r-1], a += wa[r], ++i)
319 ttv0(r-1, c, nc, wc, a, na, wa, b);
322 for(auto i1 = 0u; i1 < na[1]; c += wc[0], a += wa[1], ++i1)
324 auto c1 = c; auto a1 = a; auto b1 = b;
325 for(auto i0 = 0u; i0 < na[0]; a1 += wa[0], ++b1, ++i0)
332 /** @brief Computes the matrix-times-vector product
334 * Implements C[i1] = sum(A[i1,i2] * b[i2]) or C[i2] = sum(A[i1,i2] * b[i1])
336 * @note is used in function ttv
338 * @param[in] m zero-based contraction mode with m=0 or m=1
339 * @param[out] c pointer to the output tensor C
340 * @param[in] nc pointer to the extents of tensor C
341 * @param[in] wc pointer to the strides of tensor C
342 * @param[in] a pointer to the first input tensor A
343 * @param[in] na pointer to the extents of input tensor A
344 * @param[in] wa pointer to the strides of input tensor A
345 * @param[in] b pointer to the second input tensor B
347 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
348 void mtv(SizeType const m,
349 PointerOut c, SizeType const*const , SizeType const*const wc,
350 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
353 // decides whether matrix multiplied with vector or vector multiplied with matrix
354 const auto o = (m == 0) ? 1 : 0;
356 for(auto io = 0u; io < na[o]; c += wc[o], a += wa[o], ++io) {
357 auto c1 = c; auto a1 = a; auto b1 = b;
358 for(auto im = 0u; im < na[m]; a1 += wa[m], ++b1, ++im)
364 /** @brief Computes the matrix-times-matrix product
366 * Implements C[i1,i3] = sum(A[i1,i2] * B[i2,i3])
368 * @note is used in function ttm
370 * @param[out] c pointer to the output tensor C
371 * @param[in] nc pointer to the extents of tensor C
372 * @param[in] wc pointer to the strides of tensor C
373 * @param[in] a pointer to the first input tensor A
374 * @param[in] na pointer to the extents of input tensor A
375 * @param[in] wa pointer to the strides of input tensor A
376 * @param[in] b pointer to the second input tensor B
377 * @param[in] nb pointer to the extents of input tensor B
378 * @param[in] wb pointer to the strides of input tensor B
380 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
381 void mtm(PointerOut c, SizeType const*const nc, SizeType const*const wc,
382 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
383 PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
386 // C(i,j) = A(i,k) * B(k,j)
388 assert(nc[0] == na[0]);
389 assert(nc[1] == nb[1]);
390 assert(na[1] == nb[0]);
392 auto cj = c; auto bj = b;
393 for(auto j = 0u; j < nc[1]; cj += wc[1], bj += wb[1], ++j) {
395 auto bk = bj; auto ak = a;
396 for(auto k = 0u; k < na[1]; ak += wa[1], bk += wb[0], ++k) {
398 auto ci = cj; auto ai = ak;
399 for(auto i = 0u; i < na[0]; ai += wa[0], ci += wc[0], ++i){
410 /** @brief Computes the inner product of two tensors
412 * Implements c = sum(A[i1,i2,...,ip] * B[i1,i2,...,ip])
414 * @note is used in function inner
416 * @param r zero-based recursion level starting with p-1
417 * @param n pointer to the extents of input or output tensor
418 * @param a pointer to the first input tensor
419 * @param wa pointer to the strides of input tensor a
420 * @param b pointer to the second input tensor
421 * @param wb pointer to the strides of tensor b
422 * @param v previously computed value (start with v = 0).
423 * @return inner product of two tensors.
425 template <class PointerIn1, class PointerIn2, class value_t, class SizeType>
426 value_t inner(SizeType const r, SizeType const*const n,
427 PointerIn1 a, SizeType const*const wa,
428 PointerIn2 b, SizeType const*const wb,
432 for(auto i0 = 0u; i0 < n[0]; a += wa[0], b += wb[0], ++i0)
435 for(auto ir = 0u; ir < n[r]; a += wa[r], b += wb[r], ++ir)
436 v = inner(r-1, n, a, wa, b, wb, v);
441 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
442 void outer_2x2(SizeType const pa,
443 PointerOut c, SizeType const*const , SizeType const*const wc,
444 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
445 PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
451 for(auto ib1 = 0u; ib1 < nb[1]; b += wb[1], c += wc[pa+1], ++ib1) {
454 for(auto ib0 = 0u; ib0 < nb[0]; b0 += wb[0], c2 += wc[pa], ++ib0) {
458 for(auto ia1 = 0u; ia1 < na[1]; a1 += wa[1], c1 += wc[1], ++ia1) {
461 for(SizeType ia0 = 0u; ia0 < na[0]; a0 += wa[0], c0 += wc[0], ++ia0)
468 /** @brief Computes the outer product of two tensors
470 * Implements C[i1,...,ip,j1,...,jq] = A[i1,i2,...,ip] * B[j1,j2,...,jq]
472 * @note called by outer
475 * @param[in] pa number of dimensions (rank) of the first input tensor A with pa > 0
477 * @param[in] rc recursion level for C that starts with pc-1
478 * @param[out] c pointer to the output tensor
479 * @param[in] nc pointer to the extents of output tensor c
480 * @param[in] wc pointer to the strides of output tensor c
482 * @param[in] ra recursion level for A that starts with pa-1
483 * @param[in] a pointer to the first input tensor
484 * @param[in] na pointer to the extents of the first input tensor a
485 * @param[in] wa pointer to the strides of the first input tensor a
487 * @param[in] rb recursion level for B that starts with pb-1
488 * @param[in] b pointer to the second input tensor
489 * @param[in] nb pointer to the extents of the second input tensor b
490 * @param[in] wb pointer to the strides of the second input tensor b
492 template<class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
493 void outer(SizeType const pa,
494 SizeType const rc, PointerOut c, SizeType const*const nc, SizeType const*const wc,
495 SizeType const ra, PointerIn1 a, SizeType const*const na, SizeType const*const wa,
496 SizeType const rb, PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
499 for(auto ib = 0u; ib < nb[rb]; b += wb[rb], c += wc[rc], ++ib)
500 outer(pa, rc-1, c, nc, wc, ra, a, na, wa, rb-1, b, nb, wb);
502 for(auto ia = 0u; ia < na[ra]; a += wa[ra], c += wc[ra], ++ia)
503 outer(pa, rc-1, c, nc, wc, ra-1, a, na, wa, rb, b, nb, wb);
505 outer_2x2(pa, c, nc, wc, a, na, wa, b, nb, wb); //assert(ra==1 && rb==1 && rc==3);
511 /** @brief Computes the outer product with permutation tuples
513 * Implements C[i1,...,ir,j1,...,js] = sum( A[i1,...,ir] * B[j1,...,js] )
515 * nc[x] = na[phia[x]] for 1 <= x <= r
516 * nc[r+x] = nb[phib[x]] for 1 <= x <= s
518 * @note maybe called by ttt function
520 * @param k zero-based recursion level starting with 0
521 * @param r number of non-contraction indices of A
522 * @param s number of non-contraction indices of B
523 * @param phia pointer to the permutation tuple of length r for A
524 * @param phib pointer to the permutation tuple of length s for B
525 * @param c pointer to the output tensor C with rank(A)=r+s
526 * @param nc pointer to the extents of tensor C
527 * @param wc pointer to the strides of tensor C
528 * @param a pointer to the first input tensor with rank(A)=r
529 * @param na pointer to the extents of the first input tensor A
530 * @param wa pointer to the strides of the first input tensor A
531 * @param b pointer to the second input tensor B with rank(B)=s
532 * @param nb pointer to the extents of the second input tensor B
533 * @param wb pointer to the strides of the second input tensor B
536 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
537 void outer(SizeType const k,
538 SizeType const r, SizeType const s,
539 SizeType const*const phia, SizeType const*const phib,
540 PointerOut c, SizeType const*const nc, SizeType const*const wc,
541 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
542 PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
546 assert(nc[k] == na[phia[k]-1]);
547 for(size_t ic = 0u; ic < nc[k]; a += wa[phia[k]-1], c += wc[k], ++ic)
548 outer(k+1, r, s, phia,phib, c, nc, wc, a, na, wa, b, nb, wb);
552 assert(nc[k] == nb[phib[k-r]-1]);
553 for(size_t ic = 0u; ic < nc[k]; b += wb[phib[k-r]-1], c += wc[k], ++ic)
554 outer(k+1, r, s, phia, phib, c, nc, wc, a, na, wa, b, nb, wb);
558 assert(nc[k] == nb[phib[k-r]-1]);
559 for(size_t ic = 0u; ic < nc[k]; b += wb[phib[k-r]-1], c += wc[k], ++ic)
565 } // namespace recursive
566 } // namespace detail
568 } // namespace numeric
574 //////////////////////////////////////////////////////////////////////////////////////////
575 //////////////////////////////////////////////////////////////////////////////////////////
576 //////////////////////////////////////////////////////////////////////////////////////////
577 //////////////////////////////////////////////////////////////////////////////////////////
579 //////////////////////////////////////////////////////////////////////////////////////////
580 //////////////////////////////////////////////////////////////////////////////////////////
581 //////////////////////////////////////////////////////////////////////////////////////////
582 //////////////////////////////////////////////////////////////////////////////////////////
591 /** @brief Computes the tensor-times-vector product
594 * C[i1,i2,...,im-1,im+1,...,ip] = sum(A[i1,i2,...,im,...,ip] * b[im]) for m>1 and
595 * C[i2,...,ip] = sum(A[i1,...,ip] * b[i1]) for m=1
597 * @note calls detail::ttv, detail::ttv0 or detail::mtv
599 * @param[in] m contraction mode with 0 < m <= p
600 * @param[in] p number of dimensions (rank) of the first input tensor with p > 0
601 * @param[out] c pointer to the output tensor with rank p-1
602 * @param[in] nc pointer to the extents of tensor c
603 * @param[in] wc pointer to the strides of tensor c
604 * @param[in] a pointer to the first input tensor
605 * @param[in] na pointer to the extents of input tensor a
606 * @param[in] wa pointer to the strides of input tensor a
607 * @param[in] b pointer to the second input tensor
608 * @param[in] nb pointer to the extents of input tensor b
609 * @param[in] wb pointer to the strides of input tensor b
611 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
612 void ttv(SizeType const m, SizeType const p,
613 PointerOut c, SizeType const*const nc, SizeType const*const wc,
614 const PointerIn1 a, SizeType const*const na, SizeType const*const wa,
615 const PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
617 static_assert( std::is_pointer<PointerOut>::value & std::is_pointer<PointerIn1>::value & std::is_pointer<PointerIn2>::value,
618 "Static error in boost::numeric::ublas::ttv: Argument types for pointers are not pointer types.");
621 throw std::length_error("Error in boost::numeric::ublas::ttv: Contraction mode must be greater than zero.");
624 throw std::length_error("Error in boost::numeric::ublas::ttv: Rank must be greater equal the modus.");
627 throw std::length_error("Error in boost::numeric::ublas::ttv: Rank must be greater than zero.");
629 if(c == nullptr || a == nullptr || b == nullptr)
630 throw std::length_error("Error in boost::numeric::ublas::ttv: Pointers shall not be null pointers.");
632 for(auto i = 0u; i < m-1; ++i)
634 throw std::length_error("Error in boost::numeric::ublas::ttv: Extents (except of dimension mode) of A and C must be equal.");
636 for(auto i = m; i < p; ++i)
638 throw std::length_error("Error in boost::numeric::ublas::ttv: Extents (except of dimension mode) of A and C must be equal.");
640 const auto max = std::max(nb[0], nb[1]);
642 throw std::length_error("Error in boost::numeric::ublas::ttv: Extent of dimension mode of A and b must be equal.");
645 if((m != 1) && (p > 2))
646 detail::recursive::ttv(m-1, p-1, p-2, c, nc, wc, a, na, wa, b);
647 else if ((m == 1) && (p > 2))
648 detail::recursive::ttv0(p-1, c, nc, wc, a, na, wa, b);
650 detail::recursive::mtv(m-1, c, nc, wc, a, na, wa, b);
651 else /*if( p == 1 )*/{
652 auto v = std::remove_pointer_t<std::remove_cv_t<PointerOut>>{};
653 *c = detail::recursive::inner(SizeType(0), na, a, wa, b, wb, v);
660 /** @brief Computes the tensor-times-matrix product
663 * C[i1,i2,...,im-1,j,im+1,...,ip] = sum(A[i1,i2,...,im,...,ip] * B[j,im]) for m>1 and
664 * C[j,i2,...,ip] = sum(A[i1,i2,...,ip] * B[j,i1]) for m=1
666 * @note calls detail::ttm or detail::ttm0
668 * @param[in] m contraction mode with 0 < m <= p
669 * @param[in] p number of dimensions (rank) of the first input tensor with p > 0
670 * @param[out] c pointer to the output tensor with rank p-1
671 * @param[in] nc pointer to the extents of tensor c
672 * @param[in] wc pointer to the strides of tensor c
673 * @param[in] a pointer to the first input tensor
674 * @param[in] na pointer to the extents of input tensor a
675 * @param[in] wa pointer to the strides of input tensor a
676 * @param[in] b pointer to the second input tensor
677 * @param[in] nb pointer to the extents of input tensor b
678 * @param[in] wb pointer to the strides of input tensor b
681 template <class PointerIn1, class PointerIn2, class PointerOut, class SizeType>
682 void ttm(SizeType const m, SizeType const p,
683 PointerOut c, SizeType const*const nc, SizeType const*const wc,
684 const PointerIn1 a, SizeType const*const na, SizeType const*const wa,
685 const PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
688 static_assert( std::is_pointer<PointerOut>::value & std::is_pointer<PointerIn1>::value & std::is_pointer<PointerIn2>::value,
689 "Static error in boost::numeric::ublas::ttm: Argument types for pointers are not pointer types.");
692 throw std::length_error("Error in boost::numeric::ublas::ttm: Contraction mode must be greater than zero.");
695 throw std::length_error("Error in boost::numeric::ublas::ttm: Rank must be greater equal than the specified mode.");
698 throw std::length_error("Error in boost::numeric::ublas::ttm:Rank must be greater than zero.");
700 if(c == nullptr || a == nullptr || b == nullptr)
701 throw std::length_error("Error in boost::numeric::ublas::ttm: Pointers shall not be null pointers.");
703 for(auto i = 0u; i < m-1; ++i)
705 throw std::length_error("Error in boost::numeric::ublas::ttm: Extents (except of dimension mode) of A and C must be equal.");
707 for(auto i = m; i < p; ++i)
709 throw std::length_error("Error in boost::numeric::ublas::ttm: Extents (except of dimension mode) of A and C must be equal.");
712 throw std::length_error("Error in boost::numeric::ublas::ttm: 2nd Extent of B and M-th Extent of A must be the equal.");
715 throw std::length_error("Error in boost::numeric::ublas::ttm: 1nd Extent of B and M-th Extent of C must be the equal.");
718 detail::recursive::ttm (m-1, p-1, c, nc, wc, a, na, wa, b, nb, wb);
719 else /*if (m == 1 && p > 2)*/
720 detail::recursive::ttm0( p-1, c, nc, wc, a, na, wa, b, nb, wb);
725 /** @brief Computes the tensor-times-tensor product
727 * Implements C[i1,...,ir,j1,...,js] = sum( A[i1,...,ir+q] * B[j1,...,js+q] )
729 * @note calls detail::recursive::ttt or ttm or ttv or inner or outer
731 * nc[x] = na[phia[x] ] for 1 <= x <= r
732 * nc[r+x] = nb[phib[x] ] for 1 <= x <= s
733 * na[phia[r+x]] = nb[phib[s+x]] for 1 <= x <= q
735 * @param[in] pa number of dimensions (rank) of the first input tensor a with pa > 0
736 * @param[in] pb number of dimensions (rank) of the second input tensor b with pb > 0
737 * @param[in] q number of contraction dimensions with pa >= q and pb >= q and q >= 0
738 * @param[in] phia pointer to a permutation tuple for the first input tensor a
739 * @param[in] phib pointer to a permutation tuple for the second input tensor b
740 * @param[out] c pointer to the output tensor with rank p-1
741 * @param[in] nc pointer to the extents of tensor c
742 * @param[in] wc pointer to the strides of tensor c
743 * @param[in] a pointer to the first input tensor
744 * @param[in] na pointer to the extents of input tensor a
745 * @param[in] wa pointer to the strides of input tensor a
746 * @param[in] b pointer to the second input tensor
747 * @param[in] nb pointer to the extents of input tensor b
748 * @param[in] wb pointer to the strides of input tensor b
751 template <class PointerIn1, class PointerIn2, class PointerOut, class SizeType>
752 void ttt(SizeType const pa, SizeType const pb, SizeType const q,
753 SizeType const*const phia, SizeType const*const phib,
754 PointerOut c, SizeType const*const nc, SizeType const*const wc,
755 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
756 PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
758 static_assert( std::is_pointer<PointerOut>::value & std::is_pointer<PointerIn1>::value & std::is_pointer<PointerIn2>::value,
759 "Static error in boost::numeric::ublas::ttm: Argument types for pointers are not pointer types.");
761 if( pa == 0 || pb == 0)
762 throw std::length_error("Error in boost::numeric::ublas::ttt: tensor order must be greater zero.");
764 if( q > pa && q > pb)
765 throw std::length_error("Error in boost::numeric::ublas::ttt: number of contraction must be smaller than or equal to the tensor order.");
768 SizeType const r = pa - q;
769 SizeType const s = pb - q;
771 if(c == nullptr || a == nullptr || b == nullptr)
772 throw std::length_error("Error in boost::numeric::ublas::ttm: Pointers shall not be null pointers.");
774 for(auto i = 0ul; i < r; ++i)
775 if( na[phia[i]-1] != nc[i] )
776 throw std::length_error("Error in boost::numeric::ublas::ttt: dimensions of lhs and res tensor not correct.");
778 for(auto i = 0ul; i < s; ++i)
779 if( nb[phib[i]-1] != nc[r+i] )
780 throw std::length_error("Error in boost::numeric::ublas::ttt: dimensions of rhs and res not correct.");
782 for(auto i = 0ul; i < q; ++i)
783 if( nb[phib[s+i]-1] != na[phia[r+i]-1] )
784 throw std::length_error("Error in boost::numeric::ublas::ttt: dimensions of lhs and rhs not correct.");
788 detail::recursive::outer(SizeType{0},r,s, phia,phib, c,nc,wc, a,na,wa, b,nb,wb);
790 detail::recursive::ttt(SizeType{0},r,s,q, phia,phib, c,nc,wc, a,na,wa, b,nb,wb);
795 /** @brief Computes the tensor-times-tensor product
797 * Implements C[i1,...,ir,j1,...,js] = sum( A[i1,...,ir+q] * B[j1,...,js+q] )
799 * @note calls detail::recursive::ttt or ttm or ttv or inner or outer
801 * nc[x] = na[x ] for 1 <= x <= r
802 * nc[r+x] = nb[x ] for 1 <= x <= s
803 * na[r+x] = nb[s+x] for 1 <= x <= q
805 * @param[in] pa number of dimensions (rank) of the first input tensor a with pa > 0
806 * @param[in] pb number of dimensions (rank) of the second input tensor b with pb > 0
807 * @param[in] q number of contraction dimensions with pa >= q and pb >= q and q >= 0
808 * @param[out] c pointer to the output tensor with rank p-1
809 * @param[in] nc pointer to the extents of tensor c
810 * @param[in] wc pointer to the strides of tensor c
811 * @param[in] a pointer to the first input tensor
812 * @param[in] na pointer to the extents of input tensor a
813 * @param[in] wa pointer to the strides of input tensor a
814 * @param[in] b pointer to the second input tensor
815 * @param[in] nb pointer to the extents of input tensor b
816 * @param[in] wb pointer to the strides of input tensor b
819 template <class PointerIn1, class PointerIn2, class PointerOut, class SizeType>
820 void ttt(SizeType const pa, SizeType const pb, SizeType const q,
821 PointerOut c, SizeType const*const nc, SizeType const*const wc,
822 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
823 PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
825 static_assert( std::is_pointer<PointerOut>::value & std::is_pointer<PointerIn1>::value & std::is_pointer<PointerIn2>::value,
826 "Static error in boost::numeric::ublas::ttm: Argument types for pointers are not pointer types.");
828 if( pa == 0 || pb == 0)
829 throw std::length_error("Error in boost::numeric::ublas::ttt: tensor order must be greater zero.");
831 if( q > pa && q > pb)
832 throw std::length_error("Error in boost::numeric::ublas::ttt: number of contraction must be smaller than or equal to the tensor order.");
835 SizeType const r = pa - q;
836 SizeType const s = pb - q;
837 SizeType const pc = r+s;
839 if(c == nullptr || a == nullptr || b == nullptr)
840 throw std::length_error("Error in boost::numeric::ublas::ttm: Pointers shall not be null pointers.");
842 for(auto i = 0ul; i < r; ++i)
844 throw std::length_error("Error in boost::numeric::ublas::ttt: dimensions of lhs and res tensor not correct.");
846 for(auto i = 0ul; i < s; ++i)
847 if( nb[i] != nc[r+i] )
848 throw std::length_error("Error in boost::numeric::ublas::ttt: dimensions of rhs and res not correct.");
850 for(auto i = 0ul; i < q; ++i)
851 if( nb[s+i] != na[r+i] )
852 throw std::length_error("Error in boost::numeric::ublas::ttt: dimensions of lhs and rhs not correct.");
854 using value_type = std::decay_t<decltype(*c)>;
859 detail::recursive::outer(pa, pc-1, c,nc,wc, pa-1, a,na,wa, pb-1, b,nb,wb);
860 else if(r == 0ul && s == 0ul)
861 *c = detail::recursive::inner(q-1, na, a,wa, b,wb, value_type(0) );
863 detail::recursive::ttt(SizeType{0},r,s,q, c,nc,wc, a,na,wa, b,nb,wb);
867 /** @brief Computes the inner product of two tensors
869 * Implements c = sum(A[i1,i2,...,ip] * B[i1,i2,...,ip])
871 * @note calls detail::inner
873 * @param[in] p number of dimensions (rank) of the first input tensor with p > 0
874 * @param[in] n pointer to the extents of input or output tensor
875 * @param[in] a pointer to the first input tensor
876 * @param[in] wa pointer to the strides of input tensor a
877 * @param[in] b pointer to the second input tensor
878 * @param[in] wb pointer to the strides of input tensor b
879 * @param[in] v inital value
881 * @return inner product of two tensors.
883 template <class PointerIn1, class PointerIn2, class value_t, class SizeType>
884 auto inner(const SizeType p, SizeType const*const n,
885 const PointerIn1 a, SizeType const*const wa,
886 const PointerIn2 b, SizeType const*const wb,
889 static_assert( std::is_pointer<PointerIn1>::value && std::is_pointer<PointerIn2>::value,
890 "Static error in boost::numeric::ublas::inner: Argument types for pointers must be pointer types.");
892 throw std::length_error("Error in boost::numeric::ublas::inner: Rank must be greater than zero.");
893 if(a == nullptr || b == nullptr)
894 throw std::length_error("Error in boost::numeric::ublas::inner: Pointers shall not be null pointers.");
896 return detail::recursive::inner(p-1, n, a, wa, b, wb, v);
901 /** @brief Computes the outer product of two tensors
903 * Implements C[i1,...,ip,j1,...,jq] = A[i1,i2,...,ip] * B[j1,j2,...,jq]
905 * @note calls detail::outer
907 * @param[out] c pointer to the output tensor
908 * @param[in] pc number of dimensions (rank) of the output tensor c with pc > 0
909 * @param[in] nc pointer to the extents of output tensor c
910 * @param[in] wc pointer to the strides of output tensor c
911 * @param[in] a pointer to the first input tensor
912 * @param[in] pa number of dimensions (rank) of the first input tensor a with pa > 0
913 * @param[in] na pointer to the extents of the first input tensor a
914 * @param[in] wa pointer to the strides of the first input tensor a
915 * @param[in] b pointer to the second input tensor
916 * @param[in] pb number of dimensions (rank) of the second input tensor b with pb > 0
917 * @param[in] nb pointer to the extents of the second input tensor b
918 * @param[in] wb pointer to the strides of the second input tensor b
920 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
921 void outer(PointerOut c, SizeType const pc, SizeType const*const nc, SizeType const*const wc,
922 const PointerIn1 a, SizeType const pa, SizeType const*const na, SizeType const*const wa,
923 const PointerIn2 b, SizeType const pb, SizeType const*const nb, SizeType const*const wb)
925 static_assert( std::is_pointer<PointerIn1>::value & std::is_pointer<PointerIn2>::value & std::is_pointer<PointerOut>::value,
926 "Static error in boost::numeric::ublas::outer: argument types for pointers must be pointer types.");
927 if(pa < 2u || pb < 2u)
928 throw std::length_error("Error in boost::numeric::ublas::outer: number of extents of lhs and rhs tensor must be equal or greater than two.");
930 throw std::length_error("Error in boost::numeric::ublas::outer: number of extents of lhs plus rhs tensor must be equal to the number of extents of C.");
931 if(a == nullptr || b == nullptr || c == nullptr)
932 throw std::length_error("Error in boost::numeric::ublas::outer: pointers shall not be null pointers.");
934 detail::recursive::outer(pa, pc-1, c, nc, wc, pa-1, a, na, wa, pb-1, b, nb, wb);