From f025198f195516da8654bece64bf7a1fb85be3cd Mon Sep 17 00:00:00 2001 From: Lee Clagett Date: Mon, 21 Nov 2016 14:48:42 -0500 Subject: [PATCH] Added task_region - a fork/join task implementation --- src/common/CMakeLists.txt | 2 + src/common/task_region.cpp | 94 +++++++++++++ src/common/task_region.h | 223 ++++++++++++++++++++++++++++++ src/common/thread_group.cpp | 92 ++++++------ src/common/thread_group.h | 66 +++++---- src/ringct/rctSigs.cpp | 105 ++++++-------- tests/unit_tests/CMakeLists.txt | 1 + tests/unit_tests/thread_group.cpp | 177 ++++++++++++++++++++++++ 8 files changed, 620 insertions(+), 140 deletions(-) create mode 100644 src/common/task_region.cpp create mode 100644 src/common/task_region.h create mode 100644 tests/unit_tests/thread_group.cpp diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index d5d22bca6..dd17f6d64 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -33,6 +33,7 @@ set(common_sources util.cpp i18n.cpp perf_timer.cpp + task_region.cpp thread_group.cpp) if (STACK_TRACE) @@ -57,6 +58,7 @@ set(common_private_headers i18n.h perf_timer.h stack_trace.h + task_region.h thread_group.h) monero_private_headers(common diff --git a/src/common/task_region.cpp b/src/common/task_region.cpp new file mode 100644 index 000000000..b53a8376a --- /dev/null +++ b/src/common/task_region.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2014-2016, The Monero Project +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, are +// permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of +// conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list +// of conditions and the following disclaimer in the documentation and/or other +// materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be +// used to endorse or promote products derived from this software without specific +// prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL +// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF +// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#include "common/task_region.h" + +#include +#include + +/* `mark_completed` and `wait` can throw in the lock call, but its difficult to +recover from either. An exception in `wait` means the post condition of joining +all threads cannot be achieved, and an exception in `mark_completed` means +certain deadlock. `noexcept` qualifier will force a call to `std::terminate` if +locking throws an exception, which should only happen if a recursive lock +attempt is made (which is not possible since no external function is called +while holding the lock). */ + +namespace tools +{ +void task_region_handle::state::mark_completed(id task_id) noexcept { + assert(task_id != 0 && (task_id & (task_id - 1)) == 0); // power of 2 check + if (pending.fetch_and(~task_id) == task_id) { + // synchronize with wait call, but do not need to hold + boost::unique_lock{sync_on_complete}; + all_complete.notify_all(); + } +} + +void task_region_handle::state::abort() noexcept { + state* current = this; + while (current) { + current->ready = 0; + current = current->next.get(); + } +} + +void task_region_handle::state::wait() noexcept { + state* current = this; + while (current) { + { + boost::unique_lock lock{current->sync_on_complete}; + current->all_complete.wait(lock, [current] { return current->pending == 0; }); + } + current = current->next.get(); + } +} + +void task_region_handle::state::wait(thread_group& threads) noexcept { + state* current = this; + while (current) { + while (current->pending != 0) { + if (!threads.try_run_one()) { + current->wait(); + return; + } + } + current = current->next.get(); + } +} + +void task_region_handle::create_state() { + st = std::make_shared(std::move(st)); + next_id = 1; +} + +void task_region_handle::do_wait() noexcept { + assert(st); + const std::shared_ptr temp = std::move(st); + temp->wait(threads); +} +} diff --git a/src/common/task_region.h b/src/common/task_region.h new file mode 100644 index 000000000..e4d210661 --- /dev/null +++ b/src/common/task_region.h @@ -0,0 +1,223 @@ +// Copyright (c) 2014-2016, The Monero Project +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, are +// permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of +// conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list +// of conditions and the following disclaimer in the documentation and/or other +// materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be +// used to endorse or promote products derived from this software without specific +// prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL +// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF +// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "common/thread_group.h" + +namespace tools +{ + +/*! A model of the fork-join concept. `run(...)` "forks" (i.e. spawns new +tasks), and `~task_region_handle()` or `wait()` "joins" the spawned tasks. +`wait` will block until all tasks have completed, while `~task_region_handle()` +blocks until all tasks have completed or aborted. + +Do _NOT_ give this object to separate thread of execution (which includes +`task_region_handle::run(...)`) because joining on a different thread is +undesireable (potential deadlock). + +This class cannot be constructed directly, use the function +`task_region(...)` instead. +*/ +class task_region_handle +{ + struct state + { + using id = unsigned; + + explicit state(std::shared_ptr next_src) noexcept + : next(std::move(next_src)) + , ready(0) + , pending(0) + , sync_on_complete() + , all_complete() { + } + + state(const state&) = default; + state(state&&) = default; + ~state() = default; + state& operator=(const state&) = default; + state& operator=(state&&) = default; + + void track_id(id task_id) noexcept { + pending |= task_id; + ready |= task_id; + } + + //! \return True only once whether a given id can execute + bool can_run(id task_id) noexcept { + return (ready.fetch_and(~task_id) & task_id); + } + + //! Mark id as completed, and synchronize with waiting threads + void mark_completed(id task_id) noexcept; + + //! Tell all unstarted functions in region to return immediately + void abort() noexcept; + + //! Blocks until all functions in region have aborted or completed. + void wait() noexcept; + + //! Same as `wait()`, except `this_thread` runs tasks while waiting. + void wait(thread_group& threads) noexcept; + + private: + /* This implementation is a bit pessimistic, it ensures that all copies + of a wrapped task can only be executed once. `thread_group` should never + do this, but some variable needs to track whether an abort should be done + anyway... */ + std::shared_ptr next; + std::atomic ready; //!< Tracks whether a task has been invoked + std::atomic pending; //!< Tracks when a task has completed or aborted + boost::mutex sync_on_complete; + boost::condition_variable all_complete; + }; + + template + struct wrapper + { + wrapper(state::id id_src, std::shared_ptr st_src, F f_src) + : task_id(id_src), st(std::move(st_src)), f(std::move(f_src)) { + } + + wrapper(const wrapper&) = default; + wrapper(wrapper&&) = default; + wrapper& operator=(const wrapper&) = default; + wrapper& operator=(wrapper&&) = default; + + void operator()() { + if (st) { + if (st->can_run(task_id)) { + f(); + } + st->mark_completed(task_id); + } + } + + private: + const state::id task_id; + std::shared_ptr st; + F f; + }; + +public: + friend struct task_region_; + + task_region_handle() = delete; + task_region_handle(const task_region_handle&) = delete; + task_region_handle(task_region_handle&&) = delete; + + //! Cancels unstarted pending tasks, and waits for them to respond. + ~task_region_handle() noexcept { + if (st) { + st->abort(); + st->wait(threads); + } + } + + task_region_handle& operator=(const task_region_handle&) = delete; + task_region_handle& operator=(task_region_handle&&) = delete; + + /*! If the group has no threads, `f` is immediately run before returning. + Otherwise, `f` is dispatched to the thread_group associated with `this` + region. If `f` is dispatched to another thread, and it throws, the process + will immediately terminate. See std::packaged_task for getting exceptions on + functions executed on other threads. */ + template + void run(F&& f) { + if (threads.count() == 0) { + f(); + } else { + if (!st || next_id == 0) { + create_state(); + } + const state::id this_id = next_id; + next_id <<= 1; + + st->track_id(this_id); + threads.dispatch(wrapper{this_id, st, std::move(f)}); + } + } + + //! Wait until all functions provided to `run` have completed. + void wait() noexcept { + if (st) { + do_wait(); + } + } + +private: + explicit task_region_handle(thread_group& threads_src) + : st(nullptr), threads(threads_src), next_id(0) { + } + + void create_state(); + void do_wait() noexcept; + + std::shared_ptr st; + thread_group& threads; + state::id next_id; +}; + +/*! Function for creating a `task_region_handle`, which automatically calls +`task_region_handle::wait()` before returning. If a `thread_group` is not +provided, one is created with an optimal number of threads. The callback `f` +must have the signature `void(task_region_handle&)`. */ +struct task_region_ { + template + void operator()(thread_group& threads, F&& f) const { + static_assert( + std::is_same::type>::value, + "f cannot have a return value" + ); + task_region_handle region{threads}; + f(region); + region.wait(); + } + + template + void operator()(thread_group&& threads, F&& f) const { + (*this)(threads, std::forward(f)); + } + + template + void operator()(F&& f) const { + thread_group threads; + (*this)(threads, std::forward(f)); + } +}; + +constexpr const task_region_ task_region{}; +} diff --git a/src/common/thread_group.cpp b/src/common/thread_group.cpp index aa1b64f2e..4e1cc8964 100644 --- a/src/common/thread_group.cpp +++ b/src/common/thread_group.cpp @@ -27,6 +27,7 @@ // THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "common/thread_group.h" +#include #include #include #include @@ -35,14 +36,20 @@ namespace tools { -thread_group::thread_group(std::size_t count) : internal() { +std::size_t thread_group::optimal() { static_assert( std::numeric_limits::max() <= std::numeric_limits::max(), "unexpected truncation" ); - count = std::min(count, get_max_concurrency()); - count = count ? count - 1 : 0; + const std::size_t hardware = get_max_concurrency(); + return hardware ? (hardware - 1) : 0; +} +std::size_t thread_group::optimal_with_max(std::size_t count) { + return count ? std::min(count - 1, optimal()) : 0; +} + +thread_group::thread_group(std::size_t count) : internal() { if (count) { internal.emplace(count); } @@ -52,24 +59,21 @@ thread_group::data::data(std::size_t count) : threads() , head{nullptr} , last(std::addressof(head)) - , pending(count) , mutex() , has_work() - , finished_work() , stop(false) { threads.reserve(count); while (count--) { - threads.push_back(std::thread(&thread_group::data::run, this)); + threads.push_back(boost::thread(&thread_group::data::run, this)); } } thread_group::data::~data() noexcept { { - const std::unique_lock lock(mutex); + const boost::unique_lock lock(mutex); stop = true; } has_work.notify_all(); - finished_work.notify_all(); for (auto& worker : threads) { try { worker.join(); @@ -78,42 +82,6 @@ thread_group::data::~data() noexcept { } } - -void thread_group::data::sync() noexcept { - /* This function and `run()` can both throw when acquiring the lock, or in - the dispatched function. It is tough to recover from either, particularly the - lock case. These functions are marked as noexcept so that if either call - throws, the entire process is terminated. Users of the `dispatch` call are - expected to make their functions noexcept, or use std::packaged_task to copy - exceptions so that the process will continue in all but the most pessimistic - cases (std::bad_alloc). This was the existing behavior; - `asio::io_service::run` propogates errors from dispatched calls, and uncaught - exceptions on threads result in process termination. */ - assert(!threads.empty()); - bool not_first = false; - while (true) { - std::unique_ptr next = nullptr; - { - std::unique_lock lock(mutex); - pending -= std::size_t(not_first); - not_first = true; - finished_work.notify_all(); - - if (stop) { - return; - } - - next = get_next(); - if (next == nullptr) { - finished_work.wait(lock, [this] { return pending == 0 || stop; }); - return; - } - } - assert(next->f); - next->f(); - } -} - std::unique_ptr thread_group::data::get_next() noexcept { std::unique_ptr rc = std::move(head.ptr); if (rc != nullptr) { @@ -125,14 +93,35 @@ std::unique_ptr thread_group::data::get_next() noexcep return rc; } +bool thread_group::data::try_run_one() noexcept { + /* This function and `run()` can both throw when acquiring the lock, or in + dispatched function. It is tough to recover from either, particularly the + lock case. These functions are marked as noexcept so that if either call + throws, the entire process is terminated. Users of the `dispatch` call are + expected to make their functions noexcept, or use std::packaged_task to copy + exceptions so that the process will continue in all but the most pessimistic + cases (std::bad_alloc). This was the existing behavior; + `asio::io_service::run` propogates errors from dispatched calls, and uncaught + exceptions on threads result in process termination. */ + std::unique_ptr next = nullptr; + { + const boost::unique_lock lock(mutex); + next = get_next(); + } + if (next) { + assert(next->f); + next->f(); + return true; + } + return false; +} + void thread_group::data::run() noexcept { - // see `sync()` source for additional information + // see `try_run_one()` source for additional information while (true) { std::unique_ptr next = nullptr; { - std::unique_lock lock(mutex); - --pending; - finished_work.notify_all(); + boost::unique_lock lock(mutex); has_work.wait(lock, [this] { return head.ptr != nullptr || stop; }); if (stop) { return; @@ -149,15 +138,12 @@ void thread_group::data::dispatch(std::function f) { std::unique_ptr latest(new work{std::move(f), node{nullptr}}); node* const latest_node = std::addressof(latest->next); { - const std::unique_lock lock(mutex); + const boost::unique_lock lock(mutex); assert(last != nullptr); assert(last->ptr == nullptr); - if (pending == std::numeric_limits::max()) { - throw std::overflow_error("thread_group exceeded max queue depth"); - } + last->ptr = std::move(latest); last = latest_node; - ++pending; } has_work.notify_one(); } diff --git a/src/common/thread_group.h b/src/common/thread_group.h index d8461d49a..62e82d832 100644 --- a/src/common/thread_group.h +++ b/src/common/thread_group.h @@ -25,8 +25,12 @@ // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, // STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF // THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + #include -#include +#include +#include +#include #include #include #include @@ -35,11 +39,21 @@ namespace tools { -//! Manages zero or more threads for work dispatching +//! Manages zero or more threads for work dispatching. class thread_group { public: - //! Create `min(count, get_max_concurrency()) - 1` threads + + //! \return `get_max_concurrency() ? get_max_concurrency() - 1 : 0` + static std::size_t optimal(); + + //! \return `count ? min(count - 1, optimal()) : 0` + static std::size_t optimal_with_max(std::size_t count); + + //! Create an optimal number of threads. + explicit thread_group() : thread_group(optimal()) {} + + //! Create exactly `count` threads. explicit thread_group(std::size_t count); thread_group(thread_group const&) = delete; @@ -51,30 +65,26 @@ public: thread_group& operator=(thread_group const&) = delete; thread_group& operator=(thread_group&&) = delete; - /*! Blocks until all functions provided to `dispatch` complete. Does not - destroy threads. If a dispatched function calls `this->dispatch(...)`, - `this->sync()` will continue to block until that new function completes. */ - void sync() noexcept { + //! \return Number of threads owned by `this` group. + std::size_t count() const noexcept { if (internal) { - internal->sync(); + return internal->count(); } + return 0; } - /*! Example usage: - std::unique_ptr sync(std::addressof(group)); - which guarantees synchronization before the unique_ptr destructor returns. */ - struct lazy_sync { - void operator()(thread_group* group) const noexcept { - if (group != nullptr) { - group->sync(); - } + //! \return True iff a function was available and executed (on `this_thread`). + bool try_run_one() noexcept { + if (internal) { + return internal->try_run_one(); } - }; + return false; + } - /*! `f` is invoked immediately if the thread_group is empty, otherwise - execution of `f` is queued for next available thread. If `f` is queued, any - exception leaving that function will result in process termination. Use - std::packaged_task if exceptions need to be handled. */ + /*! `f` is invoked immediately if `count() == 0`, otherwise execution of `f` + is queued for next available thread. If `f` is queued, any exception leaving + that function will result in process termination. Use std::packaged_task if + exceptions need to be handled. */ template void dispatch(F&& f) { if (internal) { @@ -91,8 +101,11 @@ private: data(std::size_t count); ~data() noexcept; - void sync() noexcept; + std::size_t count() const noexcept { + return threads.size(); + } + bool try_run_one() noexcept; void dispatch(std::function f); private: @@ -116,13 +129,11 @@ private: void run() noexcept; private: - std::vector threads; + std::vector threads; node head; node* last; - std::size_t pending; - std::condition_variable has_work; - std::condition_variable finished_work; - std::mutex mutex; + boost::condition_variable has_work; + boost::mutex mutex; bool stop; }; @@ -130,4 +141,5 @@ private: // optionally construct elements, without separate heap allocation boost::optional internal; }; + } diff --git a/src/ringct/rctSigs.cpp b/src/ringct/rctSigs.cpp index df33c26b2..b773be1e5 100644 --- a/src/ringct/rctSigs.cpp +++ b/src/ringct/rctSigs.cpp @@ -30,6 +30,7 @@ #include "misc_log_ex.h" #include "common/perf_timer.h" +#include "common/task_region.h" #include "common/thread_group.h" #include "common/util.h" #include "rctSigs.h" @@ -39,22 +40,6 @@ using namespace crypto; using namespace std; namespace rct { - namespace { - struct verRangeWrapper_ { - void operator()(const key & C, const rangeSig & as, bool &result) const { - result = verRange(C, as); - } - }; - constexpr const verRangeWrapper_ verRangeWrapper{}; - - struct verRctMGSimpleWrapper_ { - void operator()(const key &message, const mgSig &mg, const ctkeyV & pubs, const key & C, bool &result) const { - result = verRctMGSimple(message, mg, pubs, C); - } - }; - constexpr const verRctMGSimpleWrapper_ verRctMGSimpleWrapper{}; - } - //Schnorr Non-linkable //Gen Gives a signature (L1, s1, s2) proving that the sender knows "x" such that xG = one of P1 or P2 //Ver Verifies that signer knows an "x" such that xG = one of P1 or P2 @@ -766,15 +751,17 @@ namespace rct { try { std::deque results(rv.outPk.size(), false); - tools::thread_group threadpool(rv.outPk.size()); // this must destruct before results + tools::thread_group threadpool(tools::thread_group::optimal_with_max(rv.outPk.size())); + + tools::task_region(threadpool, [&] (tools::task_region_handle& region) { + DP("range proofs verified?"); + for (size_t i = 0; i < rv.outPk.size(); i++) { + region.run([&, i] { + results[i] = verRange(rv.outPk[i].mask, rv.p.rangeSigs[i]); + }); + } + }); - DP("range proofs verified?"); - for (size_t i = 0; i < rv.outPk.size(); i++) { - threadpool.dispatch( - std::bind(verRangeWrapper, std::cref(rv.outPk[i].mask), std::cref(rv.p.rangeSigs[i]), std::ref(results[i])) - ); - } - threadpool.sync(); for (size_t i = 0; i < rv.outPk.size(); ++i) { if (!results[i]) { LOG_ERROR("Range proof verified failed for input " << i); @@ -804,7 +791,6 @@ namespace rct { //assumes only post-rct style inputs (at least for max anonymity) bool verRctSimple(const rctSig & rv) { PERF_TIMER(verRctSimple); - size_t i = 0; CHECK_AND_ASSERT_MES(rv.type == RCTTypeSimple, false, "verRctSimple called on non simple rctSig"); CHECK_AND_ASSERT_MES(rv.outPk.size() == rv.p.rangeSigs.size(), false, "Mismatched sizes of outPk and rv.p.rangeSigs"); @@ -813,28 +799,29 @@ namespace rct { CHECK_AND_ASSERT_MES(rv.pseudoOuts.size() == rv.mixRing.size(), false, "Mismatched sizes of rv.pseudoOuts and mixRing"); const size_t threads = std::max(rv.outPk.size(), rv.mixRing.size()); - tools::thread_group threadpool(threads); - { - std::deque results(rv.outPk.size(), false); - { - const std::unique_ptr - sync(std::addressof(threadpool)); - for (i = 0; i < rv.outPk.size(); i++) { - threadpool.dispatch( - std::bind(verRangeWrapper, std::cref(rv.outPk[i].mask), std::cref(rv.p.rangeSigs[i]), std::ref(results[i])) - ); - } - } // threadpool.sync(); - for (size_t i = 0; i < rv.outPk.size(); ++i) { - if (!results[i]) { - LOG_ERROR("Range proof verified failed for input " << i); - return false; - } + + std::deque results(threads); + tools::thread_group threadpool(tools::thread_group::optimal_with_max(threads)); + + results.clear(); + results.resize(rv.outPk.size()); + tools::task_region(threadpool, [&] (tools::task_region_handle& region) { + for (size_t i = 0; i < rv.outPk.size(); i++) { + region.run([&, i] { + results[i] = verRange(rv.outPk[i].mask, rv.p.rangeSigs[i]); + }); + } + }); + + for (size_t i = 0; i < results.size(); ++i) { + if (!results[i]) { + LOG_ERROR("Range proof verified failed for input " << i); + return false; } } key sumOutpks = identity(); - for (i = 0; i < rv.outPk.size(); i++) { + for (size_t i = 0; i < rv.outPk.size(); i++) { addKeys(sumOutpks, sumOutpks, rv.outPk[i].mask); } DP(sumOutpks); @@ -843,27 +830,25 @@ namespace rct { key message = get_pre_mlsag_hash(rv); - { - std::deque results(rv.mixRing.size(), false); - { - const std::unique_ptr - sync(std::addressof(threadpool)); - for (i = 0 ; i < rv.mixRing.size() ; i++) { - threadpool.dispatch( - std::bind(verRctMGSimpleWrapper, std::cref(message), std::cref(rv.p.MGs[i]), std::cref(rv.mixRing[i]), std::cref(rv.pseudoOuts[i]), std::ref(results[i])) - ); - } - } // threadpool.sync(); - for (size_t i = 0; i < results.size(); ++i) { - if (!results[i]) { - LOG_ERROR("verRctMGSimple failed for input " << i); - return false; - } + results.clear(); + results.resize(rv.mixRing.size()); + tools::task_region(threadpool, [&] (tools::task_region_handle& region) { + for (size_t i = 0 ; i < rv.mixRing.size() ; i++) { + region.run([&, i] { + results[i] = verRctMGSimple(message, rv.p.MGs[i], rv.mixRing[i], rv.pseudoOuts[i]); + }); + } + }); + + for (size_t i = 0; i < results.size(); ++i) { + if (!results[i]) { + LOG_ERROR("verRctMGSimple failed for input " << i); + return false; } } key sumPseudoOuts = identity(); - for (i = 0 ; i < rv.mixRing.size() ; i++) { + for (size_t i = 0 ; i < rv.mixRing.size() ; i++) { addKeys(sumPseudoOuts, sumPseudoOuts, rv.pseudoOuts[i]); } DP(sumPseudoOuts); diff --git a/tests/unit_tests/CMakeLists.txt b/tests/unit_tests/CMakeLists.txt index f3658b9ff..5f050554f 100644 --- a/tests/unit_tests/CMakeLists.txt +++ b/tests/unit_tests/CMakeLists.txt @@ -50,6 +50,7 @@ set(unit_tests_sources test_format_utils.cpp test_peerlist.cpp test_protocol_pack.cpp + thread_group.cpp hardfork.cpp unbound.cpp varint.cpp diff --git a/tests/unit_tests/thread_group.cpp b/tests/unit_tests/thread_group.cpp new file mode 100644 index 000000000..2e7a78353 --- /dev/null +++ b/tests/unit_tests/thread_group.cpp @@ -0,0 +1,177 @@ +// Copyright (c) 2014-2016, The Monero Project +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, are +// permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of +// conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list +// of conditions and the following disclaimer in the documentation and/or other +// materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be +// used to endorse or promote products derived from this software without specific +// prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL +// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF +// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "gtest/gtest.h" + +#include +#include "common/task_region.h" +#include "common/thread_group.h" + +TEST(ThreadGroup, NoThreads) +{ + tools::task_region(tools::thread_group(0), [] (tools::task_region_handle& region) { + std::atomic completed{false}; + region.run([&] { completed = true; }); + EXPECT_TRUE(completed); + }); + { + tools::thread_group group(0); + std::atomic completed{false}; + group.dispatch([&] { completed = true; }); + EXPECT_TRUE(completed); + } +} + +TEST(ThreadGroup, OneThread) +{ + tools::thread_group group(1); + + for (unsigned i = 0; i < 3; ++i) { + std::atomic completed{false}; + tools::task_region(group, [&] (tools::task_region_handle& region) { + region.run([&] { completed = true; }); + }); + EXPECT_TRUE(completed); + } +} + + +TEST(ThreadGroup, UseActiveThreadOnSync) +{ + tools::thread_group group(1); + + for (unsigned i = 0; i < 3; ++i) { + std::atomic completed{false}; + tools::task_region(group, [&] (tools::task_region_handle& region) { + region.run([&] { while (!completed); }); + region.run([&] { completed = true; }); + }); + EXPECT_TRUE(completed); + } +} + +TEST(ThreadGroup, InOrder) +{ + tools::thread_group group(1); + + for (unsigned i = 0; i < 3; ++i) { + std::atomic count{0}; + std::atomic completed{false}; + tools::task_region(group, [&] (tools::task_region_handle& region) { + region.run([&] { while (!completed); }); + region.run([&] { if (count == 0) completed = true; }); + region.run([&] { ++count; }); + }); + EXPECT_TRUE(completed); + EXPECT_EQ(1u, count); + } +} + +TEST(ThreadGroup, TwoThreads) +{ + tools::thread_group group(2); + + for (unsigned i = 0; i < 3; ++i) { + std::atomic completed{false}; + tools::task_region(group, [&] (tools::task_region_handle& region) { + region.run([&] { while (!completed); }); + region.run([&] { while (!completed); }); + region.run([&] { completed = true; }); + }); + EXPECT_TRUE(completed); + } +} + +TEST(ThreadGroup, Nested) { + struct fib { + unsigned operator()(tools::thread_group& group, unsigned value) const { + if (value == 0 || value == 1) { + return value; + } + unsigned left = 0; + unsigned right = 0; + tools::task_region(group, [&, value] (tools::task_region_handle& region) { + region.run([&, value] { left = fib{}(group, value - 1); }); + region.run([&, value] { right = fib{}(group, value - 2); } ); + }); + return left + right; + } + + unsigned operator()(tools::thread_group&& group, unsigned value) const { + return (*this)(group, value); + } + }; + // be careful of depth on asynchronous version + EXPECT_EQ(6765, fib{}(tools::thread_group(0), 20)); + EXPECT_EQ(377, fib{}(tools::thread_group(1), 14)); +} + +TEST(ThreadGroup, Many) +{ + tools::thread_group group(1); + + for (unsigned i = 0; i < 3; ++i) { + std::atomic count{0}; + tools::task_region(group, [&] (tools::task_region_handle& region) { + for (unsigned tasks = 0; tasks < 1000; ++tasks) { + region.run([&] { ++count; }); + } + }); + EXPECT_EQ(1000u, count); + } +} + +TEST(ThreadGroup, ThrowInTaskRegion) +{ + class test_exception final : std::exception { + public: + explicit test_exception() : std::exception() {} + + virtual const char* what() const noexcept override { + return "test_exception"; + } + }; + + tools::thread_group group(1); + + for (unsigned i = 0; i < 3; ++i) { + std::atomic count{0}; + EXPECT_THROW( + [&] { + tools::task_region(group, [&] (tools::task_region_handle& region) { + for (unsigned tasks = 0; tasks < 1000; ++tasks) { + region.run([&] { ++count; }); + } + throw test_exception(); + }); + }(), + test_exception + ); + EXPECT_GE(1000u, count); + } +}