代码 
#include <algorithm>
#include <cassert>
#include <functional>
#include <map>
#include <memory>
#include <queue>
#include <set>
#include <unordered_set>
#include <vector>

namespace jc {

    template <typename K, typename V>
    struct DAGNode {
        K k;
        V v;
        std::set<DAGNode<K, V>*> in;
        std::set<DAGNode<K, V>*> out;
    };

    template <typename K, typename V>
    class DAGGraph {
    public:
        bool AddEdge(const K& from, const K& to);

        V& operator[](const K& key);

        bool Exist(const K& key) const;

        void Clear();

        std::size_t Size() const;

        void Walk(std::function<void(const K& k, const V& v)> f,
            bool start_from_head = true);

        void WalkHeads(std::function<void(const K& k, const V& v)> f);

        void WalkTails(std::function<void(const K& k, const V& v)> f);

        std::unordered_set<K> NextKeys();

        std::unordered_set<K> NextKeys(const K& key);

    private:
        bool IsCyclic(const DAGNode<K, V>& from, const DAGNode<K, V>& to) const;

        void RefreshWalkSequences();

        std::vector<std::set<K>> ConnectedComponents() const;

        void DFS(const K& k, std::unordered_set<K>* visited,
            std::set<K>* connected_components) const;

        std::vector<K> TopologicalSequence(const std::set<K>& connected_components,
            bool start_from_head) const;

    private:
        std::map<K, DAGNode<K, V>> bucket_;
        std::unordered_set<K> heads_;
        std::unordered_set<K> tails_;
        std::vector<std::vector<K>> sequences_start_from_head_;
        std::vector<std::vector<K>> sequences_start_from_tail_;

    private:
        bool allow_modify_ = true;
        std::vector<std::vector<K>> sequences_start_from_head_for_next_;
        std::unordered_set<K> current_heads_for_next_;
    };

    template <typename K, typename V>
    inline bool DAGGraph<K, V>::AddEdge(const K& from, const K& to) {
        assert(allow_modify_);
        if (from == to || !bucket_.count(from) || !bucket_.count(to) ||
            IsCyclic(bucket_.at(from), bucket_.at(to))) {
            return false;
        }
        bucket_.at(from).out.emplace(&bucket_.at(to));
        bucket_.at(to).in.emplace(&bucket_.at(from));
        heads_.erase(to);
        tails_.erase(from);
        sequences_start_from_head_.clear();
        sequences_start_from_tail_.clear();
        return true;
    }

    template <typename K, typename V>
    inline V& DAGGraph<K, V>::operator[](const K& key) {
        if (!bucket_.count(key)) {
            assert(allow_modify_);
            bucket_[key].k = key;
            heads_.emplace(key);
            tails_.emplace(key);
            sequences_start_from_head_.clear();
            sequences_start_from_tail_.clear();
        }
        return bucket_.at(key).v;
    }

    template <typename K, typename V>
    inline bool DAGGraph<K, V>::Exist(const K& key) const {
        return bucket_.count(key);
    }

    template <typename K, typename V>
    inline void DAGGraph<K, V>::Clear() {
        allow_modify_ = true;
        bucket_.clear();
        heads_.clear();
        tails_.clear();
        sequences_start_from_head_.clear();
        sequences_start_from_tail_.clear();
    }

    template <typename K, typename V>
    inline std::size_t DAGGraph<K, V>::Size() const {
        return bucket_.size();
    }

    template <typename K, typename V>
    inline void DAGGraph<K, V>::Walk(std::function<void(const K& k, const V& v)> f,
        bool start_from_head) {
        if (sequences_start_from_head_.empty()) {
            RefreshWalkSequences();
        }
        const std::vector<std::vector<K>>& seqs_to_walk =
            start_from_head ? sequences_start_from_head_ : sequences_start_from_tail_;
        for (const std::vector<K>& seq : seqs_to_walk) {
            std::for_each(std::begin(seq), std::end(seq), [&](const K& key) {
                const DAGNode<K, V>& node = bucket_.at(key);
                f(node.k, node.v);
                });
        }
    }

    template <typename K, typename V>
    inline void DAGGraph<K, V>::WalkHeads(
        std::function<void(const K& k, const V& v)> f) {
        if (sequences_start_from_head_.empty()) {
            RefreshWalkSequences();
        }
        for (const std::vector<K>& seq : sequences_start_from_head_) {
            std::for_each(std::begin(seq), std::end(seq), [&](const K& key) {
                if (heads_.count(key)) {
                    const DAGNode<K, V>& node = bucket_.at(key);
                    f(node.k, node.v);
                }
                });
        }
    }

    template <typename K, typename V>
    inline void DAGGraph<K, V>::WalkTails(
        std::function<void(const K& k, const V& v)> f) {
        if (sequences_start_from_head_.empty()) {
            RefreshWalkSequences();
        }
        for (const std::vector<K>& seq : sequences_start_from_tail_) {
            std::for_each(std::begin(seq), std::end(seq), [&](const K& key) {
                if (tails_.count(key)) {
                    const DAGNode<K, V>& node = bucket_.at(key);
                    f(node.k, node.v);
                }
                });
        }
    }

    template <typename K, typename V>
    inline std::unordered_set<K> DAGGraph<K, V>::NextKeys() {
        assert(allow_modify_);  // allowed call once unless Clear()
        allow_modify_ = false;
        current_heads_for_next_ = heads_;
        if (sequences_start_from_head_.empty()) {
            RefreshWalkSequences();
        }
        return heads_;
    }

    template <typename K, typename V>
    inline std::unordered_set<K> DAGGraph<K, V>::NextKeys(const K& key) {
        assert(!allow_modify_);  // must call NextKeys() before
        assert(current_heads_for_next_.count(key));
        current_heads_for_next_.erase(key);

        std::unordered_set<K> res;
        for (std::vector<K>& seq : sequences_start_from_head_for_next_) {
            auto it = std::find(begin(seq), std::end(seq), key);
            if (it == std::end(seq)) {
                continue;
            }
            seq.erase(it);
            const std::set<DAGNode<K, V>*>& nodes = bucket_.at(key).out;
            for (DAGNode<K, V>* v : nodes) {
                const std::set<DAGNode<K, V>*>& prev_nodes = v->in;
                bool no_prev_node_in_seq =
                    std::all_of(std::begin(prev_nodes), std::end(prev_nodes),
                        [&](DAGNode<K, V>* in_node) {
                            return std::find(std::begin(seq), std::end(seq),
                                in_node->k) == std::end(seq);
                        });
                if (no_prev_node_in_seq) {
                    current_heads_for_next_.emplace(v->k);
                    res.emplace(v->k);
                }
            }
            break;
        }
        return res;
    }

    template <typename K, typename V>
    inline bool DAGGraph<K, V>::IsCyclic(const DAGNode<K, V>& from,
        const DAGNode<K, V>& to) const {
        std::queue<DAGNode<K, V>*> q;
        for (DAGNode<K, V>* v : from.in) {
            q.emplace(v);
        }

        std::unordered_set<DAGNode<K, V>*> visited;
        while (!q.empty()) {
            DAGNode<K, V>* node = q.front();
            q.pop();
            if (visited.count(node)) {
                continue;
            }
            if (node == &to) {
                return true;
            }
            visited.emplace(node);
            for (DAGNode<K, V>* v : node->in) {
                q.emplace(v);
            }
        }

        return false;
    }

    template <typename K, typename V>
    inline void DAGGraph<K, V>::RefreshWalkSequences() {
        sequences_start_from_head_.clear();
        sequences_start_from_tail_.clear();

        const std::vector<std::set<K>> connected_components = ConnectedComponents();
        for (const std::set<K>& x : connected_components) {
            const std::vector<K> seq_from_head = TopologicalSequence(x, true);
            const std::vector<K> seq_from_tail = TopologicalSequence(x, false);
            assert(!seq_from_head.empty());
            assert(!seq_from_tail.empty());
            sequences_start_from_head_.emplace_back(seq_from_head);
            sequences_start_from_tail_.emplace_back(seq_from_tail);
        }

        sequences_start_from_head_for_next_ = sequences_start_from_head_;
    }

    template <typename K, typename V>
    inline std::vector<std::set<K>> DAGGraph<K, V>::ConnectedComponents() const {
        std::vector<std::set<K>> res;
        std::unordered_set<K> visited;
        for (auto& x : bucket_) {
            std::set<K> tmp;
            DFS(x.second.k, &visited, &tmp);
            if (!tmp.empty()) {
                res.emplace_back(tmp);
            }
        }
        std::sort(std::begin(res), std::end(res),
            [&](const std::set<K>& lhs, const std::set<K>& rhs) {
                return lhs.size() < rhs.size();
            });
        return res;
    }

    template <typename K, typename V>
    inline void DAGGraph<K, V>::DFS(const K& k, std::unordered_set<K>* visited,
        std::set<K>* connected_components) const {
        if (visited->count(k)) {
            return;
        }
        visited->emplace(k);
        connected_components->emplace(k);
        if (!bucket_.at(k).in.empty()) {
            for (DAGNode<K, V>* v : bucket_.at(k).in) {
                DFS(v->k, visited, connected_components);
            }
        }
        if (!bucket_.at(k).out.empty()) {
            for (DAGNode<K, V>* v : bucket_.at(k).out) {
                DFS(v->k, visited, connected_components);
            }
        }
    }

    template <typename K, typename V>
    inline std::vector<K> DAGGraph<K, V>::TopologicalSequence(
        const std::set<K>& connected_components, bool start_from_head) const {
        std::map<K, std::vector<K>> adjacency_list;
        std::map<K, int32_t> in_degree;

        for (const K& key : connected_components) {
            if (!in_degree.count(key)) {
                in_degree.emplace(key, 0);
            }
            const std::set<DAGNode<K, V>*>& nodes =
                start_from_head ? bucket_.at(key).out : bucket_.at(key).in;
            for (DAGNode<K, V>* v : nodes) {
                adjacency_list[key].emplace_back(v->k);
                ++in_degree[v->k];
            }
        }

        std::queue<K> q;
        for (auto& x : in_degree) {
            if (x.second == 0) {
                q.emplace(x.first);
            }
        }

        std::vector<K> res;
        while (!q.empty()) {
            const K key = q.front();
            q.pop();
            res.emplace_back(key);
            for (const K& k : adjacency_list[key]) {
                if (--in_degree.at(k) == 0) {
                    q.emplace(k);
                }
            }
        }

        assert(res.size() == connected_components.size());  // graph is DAG
        return res;
    }

}  // namespace jc

namespace jc::test {

    class MockPipelineEngine {
    public:
        void Start() {}
        void Stop() {}
        void Destroy() {}
    };

    void test() {
        DAGGraph<int, std::unique_ptr<MockPipelineEngine>> d;
        // Make Direct Acyclic Graph:
        //    0    6      11  13
        //   / \   |      |
        //  1   3  7  8   12
        //  | x |      \ /
        //  2   4       9
        //   \ /        |
        //    5         10
        // Traverse each child graph in order whose size smaller

        // Start Order:
        // 13
        // 6 -> 7
        // 8 -> 11 -> 12 -> 9 -> 10
        // 0 -> 1 -> 3 -> 2 -> 4 -> 5
        // Stop Order:
        // 13
        // 7 -> 6
        // 10 -> 9 -> 8 -> 12 -> 11
        // 5 -> 2 -> 4 -> 1 -> 3 -> 0

        constexpr int nodes_count = 14;
        for (int i = 0; i < nodes_count; ++i) {
            d[i].reset(new MockPipelineEngine);
        }
        assert(d.AddEdge(0, 1));
        assert(d.AddEdge(0, 3));
        assert(d.AddEdge(1, 2));
        assert(d.AddEdge(3, 4));
        assert(d.AddEdge(1, 4));
        assert(d.AddEdge(3, 2));
        assert(d.AddEdge(2, 5));
        assert(d.AddEdge(4, 5));
        assert(d.AddEdge(6, 7));
        assert(d.AddEdge(8, 9));
        assert(d.AddEdge(9, 10));
        assert(d.AddEdge(11, 12));
        assert(d.AddEdge(12, 9));

        assert(d.Size() == nodes_count);

        for (int i = 0; i < nodes_count; ++i) {
            assert(d.Exist(i));
        }

        assert(!d.AddEdge(1, 0));
        assert(!d.AddEdge(2, 0));
        assert(!d.AddEdge(4, 0));
        assert(!d.AddEdge(7, 6));
        assert(!d.AddEdge(10, 11));
        assert(!d.AddEdge(13, 13));
        assert(!d.AddEdge(13, 14));

        constexpr bool start_from_head = true;
        {
            std::vector<int> v;
            std::vector<int> start_order{ 13, 6, 7, 8, 11, 12, 9, 10, 0, 1, 3, 2, 4, 5 };
            std::vector<int> start_order2{ 13, 6, 7, 8, 11, 12, 9, 10, 0, 3, 1, 4, 2, 5 };
            d.Walk(
                [&](int key, const std::unique_ptr<MockPipelineEngine>& pipeline) {
                    pipeline->Start();
                    v.emplace_back(key);
                },
                start_from_head);
            assert(v == start_order || v == start_order2);
        }

        {
            std::vector<int> v;
            std::vector<int> stop_order{ 13, 7, 6, 10, 9, 8, 12, 11, 5, 2, 4, 1, 3, 0 };
            std::vector<int> stop_order2{ 13, 7, 6, 10, 9, 12, 8, 11, 5, 2, 4, 1, 3, 0 };
            d.Walk(
                [&](int key, const std::unique_ptr<MockPipelineEngine>& pipeline) {
                    pipeline->Stop();
                    v.emplace_back(key);
                },
                !start_from_head);
            assert(v == stop_order || v == stop_order2
            );
        }

        {
            std::vector<int> v;
            std::vector<int> heads_order{ 13, 6, 8, 11, 0 };
            d.WalkHeads(
                [&](int key, const std::unique_ptr<MockPipelineEngine>& pipeline) {
                    pipeline->Destroy();
                    v.emplace_back(key);
                });
            assert(v == heads_order);
        }

        {
            std::vector<int> v;
            std::vector<int> tails_order{ 13, 7, 10, 5 };
            d.WalkTails(
                [&](int key, const std::unique_ptr<MockPipelineEngine>& pipeline) {
                    pipeline->Destroy();
                    v.emplace_back(key);
                });
            assert(v == tails_order);
        }

        {
            std::vector<int> test_sequence{ 13, 6, 7, 0,  1,  3, 4,
                                           2,  5, 8, 11, 12, 9, 10 };

            std::unordered_set<int> heads{ 0, 6, 8, 11, 13 };
            assert(d.NextKeys() == heads);

            std::vector<std::unordered_set<int>> next_keys{
                {}, {7}, {}, {1, 3}, {}, {2, 4}, {}, {5}, {}, {}, {12}, {9}, {10}, {},
            };

            assert(test_sequence.size() == nodes_count);
            assert(next_keys.size() == nodes_count);
            for (int i = 0; i < nodes_count; ++i) {
                assert(d.NextKeys(test_sequence[i]) == next_keys[i]);
            }
        }

        d.Clear();
        assert(d.Size() == 0);
        for (int i = 0; i < nodes_count; ++i) {
            assert(!d.Exist(i));
        }
    }

}  // namespace jc::test

//int main() { jc::test::test(); }
效果

 

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐