Why does adding two xtensor expressions together in template function broadcast incorrectly?

162 Views Asked by At

Considering the following program:

#include <iostream>
#include "xtensor/xarray.hpp"
#include "xtensor/xio.hpp"
#include "xtensor/xview.hpp"
xt::xarray<double> arr1
  {1.0, 2.0, 3.0};

xt::xarray<double> arr2
  {5.0, 6.0, 7.0};

template <typename T, typename U>
struct container{
    container(const T& t, const U& u) : a(t), b(u) {}
    T a;
    U b;
};

template <typename T, typename U>
container<T, U> make_container(const T& t, const U& u){
    return container<T,U>(t, u);
}

auto c = make_container(arr1, arr1);
std::cout << (arr1 * arr1) + arr2;

template <typename A, typename B, typename R>
auto operator+(const container<A, B>& e1, const R& e2){
    return (e1.a * e1.b) + e2;
}

std::cout << (c + arr2);

If we look at the code:

std::cout << (arr1 * arr1) + arr2;

It will output:

{  6.,  10.,  16.}

However, running the last line:

std::cout << (c + arr2);

Yields the following:

{{  6.,   9.,  14.}, {  7.,  10.,  15.}, {  8.,  11.,  16.}}

Why is this the case? I changed the function definition of operator+ to the following:

template <typename A, typename B, typename R>
auto operator+(const container<A, B>& e1, const R& e2){
    std::cout << __PRETTY_FUNCTION__ << std::endl;
    return (e1.b * e1.alpha) + e2;
}

And the output was a bit surprising:

auto operator+(const container<A, B> &, const R &) [A = xt::xarray_container<xt::uvector<double, std::allocator<double> >, xt::layout_type::row_major, xt::svector<unsigned long, 4, std::allocator<unsigned long>, true>, xt::xtensor_expression_tag>, B = xt::xarray_container<xt::uvector<double, std::allocator<double> >, xt::layout_type::row_major, xt::svector<unsigned long, 4, std::allocator<unsigned long>, true>, xt::xtensor_expression_tag>, R = double]
auto operator+(const container<A, B> &, const R &) [A = xt::xarray_container<xt::uvector<double, std::allocator<double> >, xt::layout_type::row_major, xt::svector<unsigned long, 4, std::allocator<unsigned long>, true>, xt::xtensor_expression_tag>, B = xt::xarray_container<xt::uvector<double, std::allocator<double> >, xt::layout_type::row_major, xt::svector<unsigned long, 4, std::allocator<unsigned long>, true>, xt::xtensor_expression_tag>, R = double]
auto operator+(const container<A, B> &, const R &) [A = xt::xarray_container<xt::uvector<double, std::allocator<double> >, xt::layout_type::row_major, xt::svector<unsigned long, 4, std::allocator<unsigned long>, true>, xt::xtensor_expression_tag>, B = xt::xarray_container<xt::uvector<double, std::allocator<double> >, xt::layout_type::row_major, xt::svector<unsigned long, 4, std::allocator<unsigned long>, true>, xt::xtensor_expression_tag>, R = double]
{{  6.,   9.,  14.}, {  7.,  10.,  15.}, {  8.,  11.,  16.}}

Why are there 3 + operations called in a single operation? Is there a macro being defined somewhere that's causing this behavior? The R type in the operator+ gives us double, which should actually be xt::xarray<double>.

Any insights would be appreciated, thanks.

1

There are 1 best solutions below

2
On BEST ANSWER

The operator+ defined in namespace xt takes universal references, and thus is preferred to your overload when you write c + arr2.

Thus this last line will return an xfunction whose first operand is your container, and the second one an xarray.

Now, since container is not an xexpression, inside the xfunction it is handled as ... an xscalar<container>!

Thus when you try to access the i-th element of this xfunction, the following operation is performed: xscalar<container> + arr2[i] (the xscalar is broadcasted). Since xscalar<container> is convertible to container, your operator+ overload is called with R resolved as the value_type of arr2, that is, double.

The following loop illustrates this behavior:

auto f = c + arr2;
for(auto iter = f.begin(); iter != f.end(); ++iter)
{
    std::cout << *iter << std::endl;
}

It generates the following calls of your operator+:

operator+(c, arr[0]);
operator+(c, arr[1]);
operator+(c, arr[2]);

This is why you see 3 calls of your operator+.