libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

thread_parallel_runner_internal.h (6308B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 //
      6 
      7 // C++ implementation using std::thread of a ::JxlParallelRunner.
      8 
      9 // The main class in this module, ThreadParallelRunner, implements a static
     10 // method ThreadParallelRunner::Runner than can be passed as a
     11 // JxlParallelRunner when using the JPEG XL library. This uses std::thread
     12 // internally and related synchronization functions. The number of threads
     13 // created is fixed at construction time and the threads are re-used for every
     14 // ThreadParallelRunner::Runner call. Only one concurrent Runner() call per
     15 // instance is allowed at a time.
     16 //
     17 // This is a scalable, lower-overhead thread pool runner, especially suitable
     18 // for data-parallel computations in the fork-join model, where clients need to
     19 // know when all tasks have completed.
     20 //
     21 // This thread pool can efficiently load-balance millions of tasks using an
     22 // atomic counter, thus avoiding per-task virtual or system calls. With 48
     23 // hyperthreads and 1M tasks that add to an atomic counter, overall runtime is
     24 // 10-20x higher when using std::async, and ~200x for a queue-based thread
     25 // pool.
     26 //
     27 // Usage:
     28 //   ThreadParallelRunner runner;
     29 //   JxlDecode(
     30 //       ... , &ThreadParallelRunner::Runner, static_cast<void*>(&runner));
     31 
     32 #ifndef LIB_THREADS_THREAD_PARALLEL_RUNNER_INTERNAL_H_
     33 #define LIB_THREADS_THREAD_PARALLEL_RUNNER_INTERNAL_H_
     34 
     35 #include <jxl/memory_manager.h>
     36 #include <jxl/parallel_runner.h>
     37 #include <stddef.h>
     38 #include <stdint.h>
     39 #include <stdlib.h>
     40 
     41 #include <atomic>
     42 #include <condition_variable>  //NOLINT
     43 #include <mutex>               //NOLINT
     44 #include <thread>              //NOLINT
     45 #include <vector>
     46 
     47 namespace jpegxl {
     48 
     49 // Main helper class implementing the ::JxlParallelRunner interface.
     50 class ThreadParallelRunner {
     51  public:
     52   // ::JxlParallelRunner interface.
     53   static JxlParallelRetCode Runner(void* runner_opaque, void* jpegxl_opaque,
     54                                    JxlParallelRunInit init,
     55                                    JxlParallelRunFunction func,
     56                                    uint32_t start_range, uint32_t end_range);
     57 
     58   // Starts the given number of worker threads and blocks until they are ready.
     59   // "num_worker_threads" defaults to one per hyperthread. If zero, all tasks
     60   // run on the main thread.
     61   explicit ThreadParallelRunner(
     62       int num_worker_threads = std::thread::hardware_concurrency());
     63 
     64   // Waits for all threads to exit.
     65   ~ThreadParallelRunner();
     66 
     67   // Returns maximum number of main/worker threads that may call Func. Useful
     68   // for allocating per-thread storage.
     69   size_t NumThreads() const { return num_threads_; }
     70 
     71   // Runs func(thread, thread) on all thread(s) that may participate in Run.
     72   // If NumThreads() == 0, runs on the main thread with thread == 0, otherwise
     73   // concurrently called by each worker thread in [0, NumThreads()).
     74   template <class Func>
     75   void RunOnEachThread(const Func& func) {
     76     if (num_worker_threads_ == 0) {
     77       const int thread = 0;
     78       func(thread, thread);
     79       return;
     80     }
     81 
     82     data_func_ = reinterpret_cast<JxlParallelRunFunction>(&CallClosure<Func>);
     83     jpegxl_opaque_ = const_cast<void*>(static_cast<const void*>(&func));
     84     StartWorkers(kWorkerOnce);
     85     WorkersReadyBarrier();
     86   }
     87 
     88   JxlMemoryManager memory_manager;
     89 
     90  private:
     91   // After construction and between calls to Run, workers are "ready", i.e.
     92   // waiting on worker_start_cv_. They are "started" by sending a "command"
     93   // and notifying all worker_start_cv_ waiters. (That is why all workers
     94   // must be ready/waiting - otherwise, the notification will not reach all of
     95   // them and the main thread waits in vain for them to report readiness.)
     96   using WorkerCommand = uint64_t;
     97 
     98   // Special values; all others encode the begin/end parameters. Note that all
     99   // these are no-op ranges (begin >= end) and therefore never used to encode
    100   // ranges.
    101   static constexpr WorkerCommand kWorkerWait = ~1ULL;
    102   static constexpr WorkerCommand kWorkerOnce = ~2ULL;
    103   static constexpr WorkerCommand kWorkerExit = ~3ULL;
    104 
    105   // Calls f(task, thread). Used for type erasure of Func arguments. The
    106   // signature must match JxlParallelRunFunction, hence a void* argument.
    107   template <class Closure>
    108   static void CallClosure(void* f, const uint32_t task, const size_t thread) {
    109     (*reinterpret_cast<const Closure*>(f))(task, thread);
    110   }
    111 
    112   void WorkersReadyBarrier() {
    113     std::unique_lock<std::mutex> lock(mutex_);
    114     // Typically only a single iteration.
    115     while (workers_ready_ != threads_.size()) {
    116       workers_ready_cv_.wait(lock);
    117     }
    118     workers_ready_ = 0;
    119 
    120     // Safely handle spurious worker wakeups.
    121     worker_start_command_ = kWorkerWait;
    122   }
    123 
    124   // Precondition: all workers are ready.
    125   void StartWorkers(const WorkerCommand worker_command) {
    126     mutex_.lock();
    127     worker_start_command_ = worker_command;
    128     // Workers will need this lock, so release it before they wake up.
    129     mutex_.unlock();
    130     worker_start_cv_.notify_all();
    131   }
    132 
    133   // Attempts to reserve and perform some work from the global range of tasks,
    134   // which is encoded within "command". Returns after all tasks are reserved.
    135   static void RunRange(ThreadParallelRunner* self, WorkerCommand command,
    136                        int thread);
    137 
    138   static void ThreadFunc(ThreadParallelRunner* self, int thread);
    139 
    140   // Unmodified after ctor, but cannot be const because we call thread::join().
    141   std::vector<std::thread> threads_;
    142 
    143   const uint32_t num_worker_threads_;  // == threads_.size()
    144   const uint32_t num_threads_;
    145 
    146   std::atomic<int> depth_{0};  // detects if Run is re-entered (not supported).
    147 
    148   std::mutex mutex_;  // guards both cv and their variables.
    149   std::condition_variable workers_ready_cv_;
    150   uint32_t workers_ready_ = 0;
    151   std::condition_variable worker_start_cv_;
    152   WorkerCommand worker_start_command_;
    153 
    154   // Written by main thread, read by workers (after mutex lock/unlock).
    155   JxlParallelRunFunction data_func_;
    156   void* jpegxl_opaque_;
    157 
    158   // Updated by workers; padding avoids false sharing.
    159   uint8_t padding1[64];
    160   std::atomic<uint32_t> num_reserved_{0};
    161   uint8_t padding2[64];
    162 };
    163 
    164 }  // namespace jpegxl
    165 
    166 #endif  // LIB_THREADS_THREAD_PARALLEL_RUNNER_INTERNAL_H_