Why do these c++ coroutines correctly optimize away in one case but not the other?

89 Views Asked by At

I have this a set of coroutines (generators) that I call from func1 and func2. The question is: While the compiler correctly optimizes away the coroutines in func1, producing a constant, why does it fail to do so with func2?

https://godbolt.org/z/GEndxb66M

#include <coroutine>
#include <exception>
#include <numeric>

template <typename T> struct generator 
{
    struct promise_type 
    {
        generator           get_return_object()         { return generator{this}; };
        std::suspend_always yield_value(T value)        { current_value_ = value; return {}; }
        std::suspend_always initial_suspend()           { return {}; }
        std::suspend_always final_suspend() noexcept    { return {}; }
        void                return_void()               {}
        void                unhandled_exception()       { std::terminate(); }

        T const&            get_value() const           { return current_value_; }

    private:
        T                   current_value_;
    };

    struct iterator 
    {
        iterator(std::coroutine_handle<promise_type> coro, bool done) : 
            coro_{coro}, 
            done_{done} 
        {}

        iterator& operator++() 
        {
            coro_.resume();
            done_ = coro_.done();
            return *this;
        }

        bool        operator==(iterator const& rhs) const { return done_ == rhs.done_; }
        bool        operator!=(iterator const& rhs) const { return !(*this == rhs); }
        T const&    operator*() const { return coro_.promise().get_value(); }
        T const*    operator->() const { return &(operator*()); }

    private:
        std::coroutine_handle<promise_type> coro_;
        bool                                done_;
    };

    iterator begin() const { promise_.resume(); return {promise_, promise_.done()}; }
    iterator end() const { return {promise_, true}; }

    generator(generator const&) = delete;
    generator(generator&& rhs) : promise_{rhs.promise_} { rhs.promise_ = nullptr; }
    ~generator() { if (promise_) promise_.destroy(); }

private:
    std::coroutine_handle<promise_type> promise_;

    explicit generator(promise_type *promise) : 
        promise_{std::coroutine_handle<promise_type>::from_promise(*promise)}
    {}
};

template <typename T>
static generator<T> seq() noexcept
{
    for (T i = {};; ++i)
        co_yield i;
}

template <typename T>
static generator<T> take_until(generator<T> const& g, T limit) noexcept
{
    for (auto&& v: g)
        if (v < limit) 
            co_yield v;
        else
            break;
}

template <typename T>
static generator<T> multiply(generator<T> const& g, T factor) noexcept
{
    for (auto&& v: g)
        co_yield v * factor;
}

template <typename T>
static generator<T> add(generator<T> const& g, T adder) noexcept
{
    for (auto&& v: g)
        co_yield v + adder;
}

template <typename T>
static generator<T> all(generator<T> const& g, T limit, T factor, T adder) noexcept
{
    for (auto&& v: add(multiply(take_until(g, limit), factor), adder))
        co_yield v;
}

int func1() noexcept
{
    auto s = seq<int>();
    auto t = take_until(s, 10);
    auto m = multiply(t, 2);
    auto a = add(m, 110);
    return std::accumulate(a.begin(), a.end(), 0);
}

int func2() noexcept
{
    auto a = all(seq<int>(), 10, 2, 110);
    return std::accumulate(a.begin(), a.end(), 0);
}

and the assembly showing the difference between the two: (clang 17.0.1 -std=c++23 -O3 -stdlib=libc++)

func1():                              # @func1()
        mov     eax, 1190
        ret
func2():                              # @func2()
        push    r14
        push    rbx
        sub     rsp, 184
        lea     rax, [rip + generator<int> seq<int>() [clone .resume]]
        mov     qword ptr [rsp + 8], rax
        lea     rax, [rip + generator<int> seq<int>() [clone .cleanup]]
        mov     qword ptr [rsp + 16], rax
        mov     byte ptr [rsp + 32], 0
        lea     rax, [rip + generator<int> all<int>(generator<int> const&, int, int, int) [clone .resume]]
        mov     qword ptr [rsp + 40], rax
        lea     rax, [rip + generator<int> all<int>(generator<int> const&, int, int, int) [clone .cleanup]]
        mov     qword ptr [rsp + 48], rax
        movabs  rax, 472446402562
        mov     qword ptr [rsp + 168], rax
        mov     dword ptr [rsp + 60], 10
        mov     rax, rsp
        mov     qword ptr [rsp + 160], rax
        mov     byte ptr [rsp + 176], 0
        lea     rdi, [rsp + 40]
        call    generator<int> all<int>(generator<int> const&, int, int, int) [clone .resume]
        cmp     qword ptr [rsp + 40], 0
        je      .LBB1_1
        xor     ebx, ebx
        lea     r14, [rsp + 40]
.LBB1_3:                                # =>This Inner Loop Header: Depth=1
        add     ebx, dword ptr [rsp + 56]
        mov     rdi, r14
        call    generator<int> all<int>(generator<int> const&, int, int, int) [clone .resume]
        cmp     qword ptr [rsp + 40], 0
        jne     .LBB1_3
        jmp     .LBB1_4
.LBB1_1:
        xor     ebx, ebx
.LBB1_4:
        mov     eax, ebx
        add     rsp, 184
        pop     rbx
        pop     r14
        ret
generator<int> seq<int>() [clone .resume]:       # @generator<int> seq<int>() [clone .resume]
        mov     eax, dword ptr [rdi + 20]
        inc     eax
        xor     ecx, ecx
        cmp     byte ptr [rdi + 24], 0
        cmovne  ecx, eax
        mov     dword ptr [rdi + 20], ecx
        mov     dword ptr [rdi + 16], ecx
        mov     byte ptr [rdi + 24], 1
        ret
generator<int> seq<int>() [clone .cleanup]:      # @generator<int> seq<int>() [clone .cleanup]
        ret
generator<int> take_until<int>(generator<int> const&, int) [clone .resume]: # @generator<int> take_until<int>(generator<int> const&, int) [clone .resume]
        push    rbx
        mov     rbx, rdi
        cmp     byte ptr [rdi + 40], 0
        je      .LBB4_1
        mov     rdi, qword ptr [rbx + 32]
        call    qword ptr [rdi]
        mov     rax, qword ptr [rbx + 32]
        cmp     qword ptr [rax], 0
        jne     .LBB4_4
        jmp     .LBB4_10
.LBB4_1:
        mov     rax, qword ptr [rbx + 24]
        mov     rdi, qword ptr [rax]
        call    qword ptr [rdi]
        mov     rax, qword ptr [rbx + 24]
        mov     rax, qword ptr [rax]
        mov     qword ptr [rbx + 32], rax
        cmp     qword ptr [rax], 0
        je      .LBB4_10
.LBB4_4:
        mov     eax, dword ptr [rax + 16]
        cmp     eax, dword ptr [rbx + 20]
        jge     .LBB4_10
        mov     dword ptr [rbx + 16], eax
        mov     byte ptr [rbx + 40], 1
        pop     rbx
        ret
.LBB4_10:
        mov     qword ptr [rbx], 0
        pop     rbx
        ret
        mov     rdi, rax
        call    __cxa_begin_catch@PLT
        call    std::terminate()@PLT
        mov     rdi, rax
        call    __cxa_begin_catch@PLT
        call    std::terminate()@PLT
generator<int> take_until<int>(generator<int> const&, int) [clone .cleanup]: # @generator<int> take_until<int>(generator<int> const&, int) [clone .cleanup]
        ret
generator<int> multiply<int>(generator<int> const&, int) [clone .resume]: # @generator<int> multiply<int>(generator<int> const&, int) [clone .resume]
        push    rbx
        mov     rbx, rdi
        cmp     byte ptr [rdi + 40], 0
        je      .LBB6_1
        mov     rdi, qword ptr [rbx + 32]
        call    qword ptr [rdi]
        mov     rax, qword ptr [rbx + 32]
        cmp     qword ptr [rax], 0
        je      .LBB6_9
.LBB6_4:
        mov     eax, dword ptr [rax + 16]
        imul    eax, dword ptr [rbx + 20]
        mov     dword ptr [rbx + 16], eax
        mov     byte ptr [rbx + 40], 1
        pop     rbx
        ret
.LBB6_1:
        mov     rax, qword ptr [rbx + 24]
        mov     rdi, qword ptr [rax]
        call    qword ptr [rdi]
        mov     rax, qword ptr [rbx + 24]
        mov     rax, qword ptr [rax]
        mov     qword ptr [rbx + 32], rax
        cmp     qword ptr [rax], 0
        jne     .LBB6_4
.LBB6_9:
        mov     qword ptr [rbx], 0
        pop     rbx
        ret
        mov     rdi, rax
        call    __cxa_begin_catch@PLT
        call    std::terminate()@PLT
        mov     rdi, rax
        call    __cxa_begin_catch@PLT
        call    std::terminate()@PLT
generator<int> multiply<int>(generator<int> const&, int) [clone .cleanup]: # @generator<int> multiply<int>(generator<int> const&, int) [clone .cleanup]
        ret
generator<int> add<int>(generator<int> const&, int) [clone .resume]: # @generator<int> add<int>(generator<int> const&, int) [clone .resume]
        push    rbx
        mov     rbx, rdi
        cmp     byte ptr [rdi + 40], 0
        je      .LBB8_1
        mov     rdi, qword ptr [rbx + 32]
        call    qword ptr [rdi]
        mov     rax, qword ptr [rbx + 32]
        cmp     qword ptr [rax], 0
        je      .LBB8_9
.LBB8_4:
        mov     eax, dword ptr [rax + 16]
        add     eax, dword ptr [rbx + 20]
        mov     dword ptr [rbx + 16], eax
        mov     byte ptr [rbx + 40], 1
        pop     rbx
        ret
.LBB8_1:
        mov     rax, qword ptr [rbx + 24]
        mov     rdi, qword ptr [rax]
        call    qword ptr [rdi]
        mov     rax, qword ptr [rbx + 24]
        mov     rax, qword ptr [rax]
        mov     qword ptr [rbx + 32], rax
        cmp     qword ptr [rax], 0
        jne     .LBB8_4
.LBB8_9:
        mov     qword ptr [rbx], 0
        pop     rbx
        ret
        mov     rdi, rax
        call    __cxa_begin_catch@PLT
        call    std::terminate()@PLT
        mov     rdi, rax
        call    __cxa_begin_catch@PLT
        call    std::terminate()@PLT
generator<int> add<int>(generator<int> const&, int) [clone .cleanup]: # @generator<int> add<int>(generator<int> const&, int) [clone .cleanup]
        ret
generator<int> all<int>(generator<int> const&, int, int, int) [clone .resume]: # @generator<int> all<int>(generator<int> const&, int, int, int) [clone .resume]
        push    rbx
        sub     rsp, 64
        mov     rbx, rdi
        cmp     byte ptr [rdi + 136], 0
        je      .LBB10_1
        cmp     byte ptr [rbx + 112], 0
        je      .LBB10_6
        mov     rdi, qword ptr [rbx + 104]
        call    qword ptr [rdi]
        mov     rax, qword ptr [rbx + 104]
        cmp     qword ptr [rax], 0
        jne     .LBB10_2
.LBB10_9:
        mov     qword ptr [rbx + 72], 0
        jmp     .LBB10_10
.LBB10_1:
        lea     rax, [rbx + 24]
        mov     ecx, dword ptr [rbx + 132]
        mov     edx, dword ptr [rbx + 128]
        mov     esi, dword ptr [rbx + 20]
        mov     rdi, qword ptr [rbx + 120]
        lea     r8, [rip + generator<int> take_until<int>(generator<int> const&, int) [clone .resume]]
        mov     qword ptr [rsp + 16], r8
        lea     r8, [rip + generator<int> take_until<int>(generator<int> const&, int) [clone .cleanup]]
        mov     qword ptr [rsp + 24], r8
        mov     dword ptr [rsp + 36], esi
        mov     qword ptr [rsp + 40], rdi
        lea     rsi, [rsp + 16]
        mov     qword ptr [rsp + 8], rsi
        mov     byte ptr [rsp + 56], 0
        lea     rsi, [rip + generator<int> multiply<int>(generator<int> const&, int) [clone .resume]]
        mov     qword ptr [rbx + 24], rsi
        lea     rsi, [rip + generator<int> multiply<int>(generator<int> const&, int) [clone .cleanup]]
        mov     qword ptr [rbx + 32], rsi
        mov     dword ptr [rbx + 44], edx
        lea     rdx, [rsp + 8]
        mov     qword ptr [rbx + 48], rdx
        mov     qword ptr [rsp], rax
        mov     byte ptr [rbx + 64], 0
        lea     rax, [rip + generator<int> add<int>(generator<int> const&, int) [clone .resume]]
        mov     qword ptr [rbx + 72], rax
        lea     rax, [rip + generator<int> add<int>(generator<int> const&, int) [clone .cleanup]]
        mov     qword ptr [rbx + 80], rax
        mov     dword ptr [rbx + 92], ecx
        mov     rax, rsp
        mov     qword ptr [rbx + 96], rax
        mov     byte ptr [rbx + 112], 0
        mov     rdi, qword ptr [rsp]
        call    qword ptr [rdi]
        jmp     .LBB10_7
.LBB10_6:
        mov     rax, qword ptr [rbx + 96]
        mov     rdi, qword ptr [rax]
        call    qword ptr [rdi]
.LBB10_7:
        mov     rax, qword ptr [rbx + 96]
        mov     rax, qword ptr [rax]
        mov     qword ptr [rbx + 104], rax
        cmp     qword ptr [rax], 0
        je      .LBB10_9
.LBB10_2:
        mov     eax, dword ptr [rax + 16]
        add     eax, dword ptr [rbx + 92]
        mov     dword ptr [rbx + 88], eax
        mov     byte ptr [rbx + 112], 1
        cmp     qword ptr [rbx + 72], 0
        je      .LBB10_10
        mov     eax, dword ptr [rbx + 88]
        mov     dword ptr [rbx + 16], eax
        mov     byte ptr [rbx + 136], 1
        add     rsp, 64
        pop     rbx
        ret
.LBB10_10:
        mov     qword ptr [rbx], 0
        add     rsp, 64
        pop     rbx
        ret
        mov     rdi, rax
        call    __cxa_begin_catch@PLT
        call    std::terminate()@PLT
        mov     rdi, rax
        call    __cxa_begin_catch@PLT
        call    std::terminate()@PLT
        mov     rdi, rax
        call    __cxa_begin_catch@PLT
        call    std::terminate()@PLT
generator<int> all<int>(generator<int> const&, int, int, int) [clone .cleanup]: # @generator<int> all<int>(generator<int> const&, int, int, int) [clone .cleanup]
        ret
DW.ref.__gxx_personality_v0:
        .quad   __gxx_personality_v0
0

There are 0 best solutions below