C++20協程解糖 - 動手實現協程1 - Future和Promise


std::future和promise在C++20里面沒法直接用的唯一原因就是不支持then,雖然MSVC有一個弱智版開線程阻塞實現的future.then,能then了但不保序,而且libstdc++也用不了。folly之類的庫有靠譜的實現,但是功能太齊全太復雜,不適合新手學習。因此我們先從弱智版future promise schedular開始,從源頭講解如何實現協程相關設施。

如果你看到這行文字,說明這篇文章被無恥的盜用了(或者你正在選中文字),請前往 cnblogs.com/pointer-smq 支持原作者,謝謝

基本結構

我們要實現的功能很簡單:

  1. 單線程模型
  2. promise是入口,future是出口
  3. promise支持set_result, get_future
  4. future支持add_finish_callback,在promise set_result之后按序調用
  5. callback在下次調度時調用而不是立即調用

結構大概是這樣

image

狀態都存在shared_state里面,future和promise實際上只是個空殼


先搭框架

如果你看到這行文字,說明這篇文章被無恥的盜用了(或者你正在選中文字),請前往 cnblogs.com/pointer-smq 支持原作者,謝謝

首先是shared_state,最直觀的,shared_state需要存儲最終設置的結果T,以及記錄結果有沒有設置。這里要求T必須支持默認構造,省事。


template<class T>
class SharedState {
    friend class Future<T>;
    friend class Promise<T>;
public:
    SharedState()
    {}
    SharedState(const SharedState&) = delete;
    SharedState(SharedState&&) = delete;
    SharedState& operator=(const SharedState&) = delete;
    SharedState& operator=(SharedState&&) = delete;

private:
    template<class U>
    void set(U&& v) {
        if (settled) {
            return;
        }
        settled = true;
        value = std::forward<U>(v);
    }

    bool settled = false;
    T value;
};


然后是引用了shared_state的promise和future

template<class T>
class Promise {
public:
    Promise()
        : _state(std::make_shared<SharedState<T>>())
    {}

    Future<T> get_future();

    template<class U>
    void set_result(U&& value) {
        if (_state->settled) {
            throw std::invalid_argument("already set result");
        }
        _state->set(std::forward<U>(value));
    }
private:
    std::shared_ptr<SharedState<T>> _state;
};

template<class T>
class Future {
    friend class Promise<T>;
private:
    Future(std::shared_ptr<SharedState<T>> state)
        : _state(std::move(state))
    {
    }
private:
    std::shared_ptr<SharedState<T>> _state;
};

template<class T>
Future<T> Promise<T>::get_future() {
    return Future<T>(_state);
}


先把調度器的殼寫上,以后會用到

class Schedular {
    template<class T>
    friend class SharedState;
public:
    Schedular() = default;
    Schedular(Schedular&&) = delete;
    Schedular(const Schedular&) = delete;
    Schedular& operator=(Schedular&&) = delete;
    Schedular& operator=(const Schedular&) = delete;

    void poll() {
        // TODO
    }
};



再補功能

如果你看到這行文字,說明這篇文章被無恥的盜用了(或者你正在選中文字),請前往 cnblogs.com/pointer-smq 支持原作者,謝謝

首先future要支持add_finish_callback


template<class T>
class Future {
    // ...
    // public
    void add_finish_callback(std::function<void(T&)> callback) {
        _state->add_finish_callback(std::move(callback));
    }
};

為什么要把callback實際加到_state里面去呢,因為之后需要post所有callback給schedular,而schedular要接受各種Future<T>的callback,要做類型擦除太麻煩了,所以索性,把callback存入_state,_state正好是堆對象,把他的類型擦除了,丟給schedular去post,還是因為簡單

既然callback實際加給了shared_state,那SharedState也得補充對應的功能


template<class T>
class SharedState {
    // ...
    // private
    void add_finish_callback(std::function<void(T&)> callback) {
        finish_callbacks.push_back(std::move(callback));
        // TODO
    }

    std::vector<std::function<void(T&)>> finish_callbacks;
};


然后,就是在shared_state set的時候,或者shared_state已有結果,但剛剛新增了callback的時候,把shared_state自己發送給Schedular,等待下一幀被調度到時調用callback

為此,SharedState本身要存儲一個Schedular的指針,那么Promise就得接受Schedular作為構造函數參數,SharedState還要記錄自己是否已經被post給Schedular,不需要重復post


template<class T>
class SharedState {
    // ...
    // public
    // 構造函數增加一個參數
    SharedState(Schedular& schedular)
        : schedular(&schedular)
    {}

    // ...
    // private
    // set增加內容
    template<class U>
    void set(U&& v) {
        if (settled) {
            return;
        }
        settled = true;
        value = std::forward<U>(v);
        post_all_callbacks();
    }

    void add_finish_callback(std::function<void(T&)> callback) {
        finish_callbacks.push_back(std::move(callback));
        post_all_callbacks();
    }

    void post_all_callbacks();

    bool settled = false;
    bool callback_posted = false;
    Schedular* schedular = nullptr;
    T value;
    std::vector<std::function<void(T&)>> finish_callbacks;
};

template<class T>
class Promise {
    // ...
	// public
	// 構造函數增加參數
    Promise(Schedular& schedular)
        : _schedular(&schedular)
        , _state(std::make_shared<SharedState<T>>(*_schedular))
    {}
};

// ...
// 在Schedular定義的后面
template<class T>
void SharedState<T>::post_all_callbacks() {
    if (callback_posted) {
        return;
    }
    callback_posted = true;
    schedular->post_call_state(shared_from_this());
}

可以發現,在post_all_callback時,SharedState把自己shared_from_this()后發送給了schedular,顯然這里既要enable_shared_from_this,又要類型擦除,於是

class SharedStateBase : public std::enable_shared_from_this<SharedStateBase> {
    friend class Schedular;
public:
    virtual ~SharedStateBase() = default;
};

template<class T>
class SharedState : public SharedStateBase {
    // ...
};


Schedular里面也得補充存儲shared_ptr<SharedStateBase>的東西

class Schedular {
    // ...
    // private
    void post_call_state(std::shared_ptr<SharedStateBase> state) {
        pending_states.push_back(std::move(state));
    }

    std::vector<std::shared_ptr<SharedStateBase>> pending_states;
};


下面,就輪到Schedular的poll函數在每幀調用被post過來的SharedStateBase了

class Schedular {
    // ...
    // public
    void poll() {
        size_t sz = pending_states.size();
        for (size_t i = 0; i != sz; i++) {
            auto state = std::move(pending_states[i]);
            state->invoke_all_callback();
        }
        pending_states.erase(pending_states.begin(), pending_states.begin()+sz);
    }
};


  • 之所以這里使用下標循環,是因為迭代過程中還可能有callback繼續往pending_states里面新增元素
  • 之所以不用while !empty()而是預先獲取size,是為了避免調度是callback內無限post callback導致無限循環
  • 之所以調用前先move出來,是為了避免調用callback期間callback繼續往pending_states里面新增元素導致容器擴容,內容物失效

這里對state調用了invoke_all_callback,顯然這是一個虛函數,需要給SharedState補上


class SharedStateBase : public std::enable_shared_from_this<SharedStateBase> {
    // ...
    // private
    virtual void invoke_all_callback() = 0;
};

template<class T>
class SharedState : public SharedStateBase {
    // ...
    // private
    virtual void invoke_all_callback() override {
        callback_posted = false;
        size_t sz = finish_callbacks.size();
        for (size_t i = 0; i != sz; i++) {
            auto v = std::move(finish_callbacks[i]);
            v(value);
        }
        finish_callbacks.erase(finish_callbacks.begin(), finish_callbacks.begin()+sz);
    }
};


這里在invoke_all_callbacks里面,使用了和上面schedular poll里面類似的代碼結構,可以讓本幀callback里新增的callback post到下一幀調用

好了,全部功能都齊了,下面可以測試了,完整的代碼在文章最后貼出


測試

如果你看到這行文字,說明這篇文章被無恥的盜用了(或者你正在選中文字),請前往 cnblogs.com/pointer-smq 支持原作者,謝謝


int main() {
    Schedular schedular;
    Promise<int> promise(schedular);
    Future<int> future = promise.get_future();
    std::cout << "future get\n";
    promise.set_result(10);
    std::cout << "promise result set\n";
    future.add_finish_callback([](int v) {
        std::cout << "callback 1 got result " << v << "\n";
    });
    std::cout << "future callback add\n";
    std::cout << "tick 1\n";
    schedular.poll();
    std::cout << "tick 2\n";
    future.add_finish_callback([](int v) {
        std::cout << "callback 2 got result " << v << "\n";
    });
    std::cout << "future callback 2 add\n";
    schedular.poll();

    std::cout << "\n";

    Promise<double> promise2(schedular);
    promise2.set_result(12.34);
    std::cout << "promise result2 set\n";
    Future<double> future2 = promise2.get_future();
    std::cout << "future2 get\n";
    future2.add_finish_callback([&](double v) {
        std::cout << "future2 callback 1 got result" << v << "\n";
        future2.add_finish_callback([](double v) {
            std::cout << "future2 callback 2 got result" << v << "\n";
        });
        std::cout << "future2 callback 2 add inside callback\n";
    });
    std::cout << "future2 callback add\n";
    std::cout << "tick 3\n";
    schedular.poll();
    std::cout << "tick 4\n";
    schedular.poll();
}


輸出

future get
promise result set
future callback add
tick 1
callback 1 got result 10
tick 2
future callback 2 add
callback 2 got result 10

promise result2 set
future2 get
future2 callback add
tick 3
future2 callback 1 got result12.34
future2 callback 2 add inside callback
tick 4
future2 callback 2 got result12.34

怎么樣,是不是很簡單呢,趕緊自己回家造一個吧!


附錄

如果你看到這行文字,說明這篇文章被無恥的盜用了(或者你正在選中文字),請前往 cnblogs.com/pointer-smq 支持原作者,謝謝

完整代碼

#include <vector>
#include <memory>
#include <iostream>
#include <functional>

template<class T>
class Future;

template<class T>
class Promise;

class Schedular;

class SharedStateBase : public std::enable_shared_from_this<SharedStateBase> {
    friend class Schedular;
public:
    virtual ~SharedStateBase() = default;
private:
    virtual void invoke_all_callback() = 0;
};

template<class T>
class SharedState : public SharedStateBase {
    friend class Future<T>;
    friend class Promise<T>;
public:
    SharedState(Schedular& schedular)
        : schedular(&schedular)
    {}
    SharedState(const SharedState&) = delete;
    SharedState(SharedState&&) = delete;
    SharedState& operator=(const SharedState&) = delete;
    SharedState& operator=(SharedState&&) = delete;

private:
    template<class U>
    void set(U&& v) {
        if (settled) {
            return;
        }
        settled = true;
        value = std::forward<U>(v);
        post_all_callbacks();
    }

    T& get() { return value; }

    void add_finish_callback(std::function<void(T&)> callback) {
        finish_callbacks.push_back(std::move(callback));
        post_all_callbacks();
    }

    void post_all_callbacks();

    virtual void invoke_all_callback() override {
        callback_posted = false;
        size_t sz = finish_callbacks.size();
        for (size_t i = 0; i != sz; i++) {
            auto v = std::move(finish_callbacks[i]);
            v(value);
        }
        finish_callbacks.erase(finish_callbacks.begin(), finish_callbacks.begin()+sz);
    }

    bool has_owner = false;
    bool settled = false;
    bool callback_posted = false;
    Schedular* schedular = nullptr;
    T value;
    std::vector<std::function<void(T&)>> finish_callbacks;
};

template<class T>
class Promise {
public:
    Promise(Schedular& schedular)
        : _schedular(&schedular)
        , _state(std::make_shared<SharedState<T>>(*_schedular))
    {}

    Future<T> get_future();

    template<class U>
    void set_result(U&& value) {
        if (_state->settled) {
            throw std::invalid_argument("already set result");
        }
        _state->set(std::forward<U>(value));
    }
private:
    Schedular* _schedular;
    std::shared_ptr<SharedState<T>> _state;
};

template<class T>
class Future {
    friend class Promise<T>;
private:
    Future(std::shared_ptr<SharedState<T>> state)
        : _state(std::move(state))
    {
    }
public:

    void add_finish_callback(std::function<void(T&)> callback) {
        _state->add_finish_callback(std::move(callback));
    }
private:
    std::shared_ptr<SharedState<T>> _state;
};

template<class T>
Future<T> Promise<T>::get_future() {
    return Future<T>(_state);
}

class Schedular {
    template<class T>
    friend class SharedState;
public:
    Schedular() = default;
    Schedular(Schedular&&) = delete;
    Schedular(const Schedular&) = delete;
    Schedular& operator=(Schedular&&) = delete;
    Schedular& operator=(const Schedular&) = delete;

    void poll() {
        size_t sz = pending_states.size();
        for (size_t i = 0; i != sz; i++) {
            auto state = std::move(pending_states[i]);
            state->invoke_all_callback();
        }
        pending_states.erase(pending_states.begin(), pending_states.begin()+sz);
    }
private:
    void post_call_state(std::shared_ptr<SharedStateBase> state) {
        pending_states.push_back(std::move(state));
    }

    std::vector<std::shared_ptr<SharedStateBase>> pending_states;
};

template<class T>
void SharedState<T>::post_all_callbacks() {
    if (callback_posted) {
        return;
    }
    callback_posted = true;
    schedular->post_call_state(shared_from_this());
}

int main() {
    Schedular schedular;
    Promise<int> promise(schedular);
    Future<int> future = promise.get_future();
    std::cout << "future get\n";
    promise.set_result(10);
    std::cout << "promise result set\n";
    future.add_finish_callback([](int v) {
        std::cout << "callback 1 got result " << v << "\n";
    });
    std::cout << "future callback add\n";
    std::cout << "tick 1\n";
    schedular.poll();
    std::cout << "tick 2\n";
    future.add_finish_callback([](int v) {
        std::cout << "callback 2 got result " << v << "\n";
    });
    std::cout << "future callback 2 add\n";
    schedular.poll();

    std::cout << "\n";

    Promise<double> promise2(schedular);
    promise2.set_result(12.34);
    std::cout << "promise result2 set\n";
    Future<double> future2 = promise2.get_future();
    std::cout << "future2 get\n";
    future2.add_finish_callback([&](double v) {
        std::cout << "future2 callback 1 got result" << v << "\n";
        future2.add_finish_callback([](double v) {
            std::cout << "future2 callback 2 got result" << v << "\n";
        });
        std::cout << "future2 callback 2 add inside callback\n";
    });
    std::cout << "future2 callback add\n";
    std::cout << "tick 3\n";
    schedular.poll();
    std::cout << "tick 4\n";
    schedular.poll();
}


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM