N-ary tree template metaprogramming

109 Views Asked by At

For learning purposes, I'm trying to implement a similar class as the Value object in micrograd. To be more precise, I'm trying to implement it in template metaprogramming.

So far, I'm able to do addition and multiplication with the following code:

#include <iostream>
#include <array>

template<double N>
struct Value {
    static constexpr double data = N;

    friend std::ostream& operator<<(std::ostream& os, const Value &v) {
        os << "Value(data=" << v.data << ")";
        return os;
    }
};

template<double N, double R>
constexpr auto operator+(const Value<N> lhs, const Value<R> rhs) {
    return Value<lhs.data + rhs.data>();
}

template<double N, double R>
constexpr auto operator*(const Value<N> lhs, const Value<R> rhs) {
    return Value<lhs.data * rhs.data>();
}

int main() {
    Value<3.5> v;
    Value<2.0> w;

    std::cout << v << std::endl;
    std::cout << v + w << std::endl;
    std::cout << v * w << std::endl;
    return 0;
}

Now, the problem arises when I try to keep reference of the children for each Value object. For example, I would like to be able to do v + w, this would give me a new Value object with data equal to 5.5 and a list of children {v, w}. Furthermore, I would like to keep the children field generic and not assume any number of children. In other words, I'm trying to build a N-ary tree, not a binary tree. So right now, I have:

#include <iostream>
#include <array>

template<double N, auto... Children>
struct Value {
    static constexpr double data = N;
    static constexpr Value children[sizeof...(Children)]{Children...};

    friend std::ostream& operator<<(std::ostream& os, const Value &v) {
        os << "Value(data=" << v.data << ", children=[";
        for (auto &&child : children) {
            os << child << ",";
        }
        os << "])";
        return os;
    }
};

template<double N, auto... Ns, double R, auto... Rs>
constexpr auto operator+(const Value<N, Ns...> lhs, const Value<R, Rs...> rhs) {
    return Value<lhs.data + rhs.data, lhs, rhs>();
}

template<double N, auto... Ns, double R, auto... Rs>
constexpr auto operator*(const Value<N, Ns...> lhs, const Value<R, Rs...> rhs) {
    return Value<lhs.data * rhs.data, lhs, rhs>();
}

int main() {
    Value<3.5> v;
    Value<2.0> w;

    std::cout << v << std::endl;
    std::cout << v + w << std::endl;
    std::cout << v * w << std::endl;
    return 0;
}

and this gives me the following errors:

main.cc: In instantiation of 'constexpr const Value<5.5e+0, Value<3.5e+0>(), Value<2.0e+0>()> Value<5.5e+0, Value<3.5e+0>(), Value<2.0e+0>()>::children [2]':
main.cc:11:23:   required from 'std::ostream& operator<<(std::ostream&, const Value<5.5e+0, Value<3.5e+0>(), Value<2.0e+0>()>&)'
main.cc:34:19:   required from here
main.cc:7:32: error: initializer for 'const Value<5.5e+0, Value<3.5e+0>(), Value<2.0e+0>()>' must be brace-enclosed
    7 |         static constexpr Value children[sizeof...(Children)]{Children...};
      |                                ^~~~~~~~
main.cc: In instantiation of 'constexpr const Value<7.0e+0, Value<3.5e+0>(), Value<2.0e+0>()> Value<7.0e+0, Value<3.5e+0>(), Value<2.0e+0>()>::children [2]':
main.cc:11:23:   required from 'std::ostream& operator<<(std::ostream&, const Value<7.0e+0, Value<3.5e+0>(), Value<2.0e+0>()>&)'
main.cc:35:19:   required from here
main.cc:7:32: error: initializer for 'const Value<7.0e+0, Value<3.5e+0>(), Value<2.0e+0>()>' must be brace-enclosed

Unfortunately, I'm not understanding the error. It seems to say that I need to put braces around Children... but I already have some. If someone could give me a feedback on the error and maybe on how to improve the code, that would be highly appreciated.

Edit

This is seems to be caused by the operator<< function when accessing the children property, not sure why, but this might give a lead to someone more experienced.

1

There are 1 best solutions below

0
On BEST ANSWER

Value<3.5> and Value<2.0> are different types. You cannot have an array of elements of different types, i.e., you can't put Value<3.5> and Value<2.0> in the same array. You would either need to have an array of polymorphic base or a tuple. The latter comes with less (none) runtime overhead:

#include <iostream>
#include <tuple>

template<double N, auto... Children>
struct Value {
    static constexpr double data = N;
    static constexpr auto children = std::make_tuple(Children...);

    friend std::ostream& operator<<(std::ostream& os, const Value &v) {
        os << "Value(data=" << v.data << ", children=[";
        std::apply([&os](auto const&... child) {
            ((os << child << ","), ...);
        }, children);
        os << "])";
        return os;
    }
};

template<double N, auto... Ns, double R, auto... Rs>
constexpr auto operator+(const Value<N, Ns...> lhs, const Value<R, Rs...> rhs) {
    return Value<lhs.data + rhs.data, lhs, rhs>();
}

template<double N, auto... Ns, double R, auto... Rs>
constexpr auto operator*(const Value<N, Ns...> lhs, const Value<R, Rs...> rhs) {
    return Value<lhs.data * rhs.data, lhs, rhs>();
}

int main() {
    Value<3.5> v;
    Value<2.0> w;

    std::cout << v << std::endl;
    std::cout << v + w << std::endl;
    std::cout << v * w << std::endl;
    return 0;
}

This prints:

Value(data=3.5, children=[])
Value(data=5.5, children=[Value(data=3.5, children=[]),Value(data=2, children=[]),])
Value(data=7, children=[Value(data=3.5, children=[]),Value(data=2, children=[]),])