libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

resizable_parallel_runner.cc (6014B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include <jxl/jxl_threads_export.h>
      7 #include <jxl/memory_manager.h>
      8 #include <jxl/parallel_runner.h>
      9 #include <jxl/resizable_parallel_runner.h>
     10 
     11 #include <algorithm>
     12 #include <atomic>
     13 #include <condition_variable>
     14 #include <cstddef>
     15 #include <cstdint>
     16 #include <mutex>
     17 #include <thread>
     18 #include <vector>
     19 
     20 namespace jpegxl {
     21 namespace {
     22 
     23 // A thread pool that allows changing the number of threads it runs. It also
     24 // runs tasks on the calling thread, which can work better on schedulers for
     25 // heterogeneous architectures.
     26 struct ResizeableParallelRunner {
     27   void SetNumThreads(size_t num) {
     28     if (num > 0) {
     29       num -= 1;
     30     }
     31     {
     32       std::unique_lock<std::mutex> l(state_mutex_);
     33       num_desired_workers_ = num;
     34       workers_can_proceed_.notify_all();
     35     }
     36     if (workers_.size() < num) {
     37       for (size_t i = workers_.size(); i < num; i++) {
     38         workers_.emplace_back([this, i]() { WorkerBody(i); });
     39       }
     40     }
     41     if (workers_.size() > num) {
     42       for (size_t i = num; i < workers_.size(); i++) {
     43         workers_[i].join();
     44       }
     45       workers_.resize(num);
     46     }
     47   }
     48 
     49   ~ResizeableParallelRunner() { SetNumThreads(0); }
     50 
     51   JxlParallelRetCode Run(void* jxl_opaque, JxlParallelRunInit init,
     52                          JxlParallelRunFunction func, uint32_t start,
     53                          uint32_t end) {
     54     if (start + 1 == end) {
     55       JxlParallelRetCode ret = init(jxl_opaque, 1);
     56       if (ret != 0) return ret;
     57 
     58       func(jxl_opaque, start, 0);
     59       return ret;
     60     }
     61 
     62     size_t num_workers = std::min<size_t>(workers_.size() + 1, end - start);
     63     JxlParallelRetCode ret = init(jxl_opaque, num_workers);
     64     if (ret != 0) {
     65       return ret;
     66     }
     67 
     68     {
     69       std::unique_lock<std::mutex> l(state_mutex_);
     70       // Avoid waking up more workers than needed.
     71       max_running_workers_ = end - start - 1;
     72       next_task_ = start;
     73       end_task_ = end;
     74       func_ = func;
     75       jxl_opaque_ = jxl_opaque;
     76       work_available_ = true;
     77       num_running_workers_++;
     78       workers_can_proceed_.notify_all();
     79     }
     80 
     81     DequeueTasks(0);
     82 
     83     while (true) {
     84       std::unique_lock<std::mutex> l(state_mutex_);
     85       if (num_running_workers_ == 0) break;
     86       work_done_.wait(l);
     87     }
     88 
     89     return ret;
     90   }
     91 
     92  private:
     93   void WorkerBody(size_t worker_id) {
     94     while (true) {
     95       {
     96         std::unique_lock<std::mutex> l(state_mutex_);
     97         // Worker pool was reduced, resize down.
     98         if (worker_id >= num_desired_workers_) {
     99           return;
    100         }
    101         // Nothing to do this time.
    102         if (!work_available_ || worker_id >= max_running_workers_) {
    103           workers_can_proceed_.wait(l);
    104           continue;
    105         }
    106         num_running_workers_++;
    107       }
    108       DequeueTasks(worker_id + 1);
    109     }
    110   }
    111 
    112   void DequeueTasks(size_t thread_id) {
    113     while (true) {
    114       uint32_t task = next_task_++;
    115       if (task >= end_task_) {
    116         std::unique_lock<std::mutex> l(state_mutex_);
    117         num_running_workers_--;
    118         work_available_ = false;
    119         if (num_running_workers_ == 0) {
    120           work_done_.notify_all();
    121         }
    122         break;
    123       }
    124       func_(jxl_opaque_, task, thread_id);
    125     }
    126   }
    127 
    128   // Checks when the worker has something to do, which can be one of:
    129   // - quitting (when worker_id >= num_desired_workers_)
    130   // - having work available for them (work_available_ is true and worker_id >=
    131   // max_running_workers_)
    132   std::condition_variable workers_can_proceed_;
    133 
    134   // Workers are done, and the main thread can proceed (num_running_workers_ ==
    135   // 0)
    136   std::condition_variable work_done_;
    137 
    138   std::vector<std::thread> workers_;
    139 
    140   // Protects all the remaining variables, except for func_, jxl_opaque_ and
    141   // end_task_ (for which only the write by the main thread is protected, and
    142   // subsequent uses by workers happen-after it) and next_task_ (which is
    143   // atomic).
    144   std::mutex state_mutex_;
    145 
    146   // Range of tasks still need to be done.
    147   std::atomic<uint32_t> next_task_;
    148   uint32_t end_task_;
    149 
    150   // Function to run and its argument.
    151   JxlParallelRunFunction func_;
    152   void* jxl_opaque_;  // not owned
    153 
    154   // Variables that control the workers:
    155   // - work_available_ is set to true after a call to Run() and to false at the
    156   // end of it.
    157   // - num_desired_workers_ represents the number of workers that should be
    158   // present.
    159   // - max_running_workers_ represents the number of workers that should be
    160   // executing tasks.
    161   // - num_running_workers_ represents the number of workers that are executing
    162   // tasks.
    163   size_t num_desired_workers_ = 0;
    164   size_t max_running_workers_ = 0;
    165   size_t num_running_workers_ = 0;
    166   bool work_available_ = false;
    167 };
    168 }  // namespace
    169 }  // namespace jpegxl
    170 
    171 extern "C" {
    172 JXL_THREADS_EXPORT JxlParallelRetCode JxlResizableParallelRunner(
    173     void* runner_opaque, void* jpegxl_opaque, JxlParallelRunInit init,
    174     JxlParallelRunFunction func, uint32_t start_range, uint32_t end_range) {
    175   return static_cast<jpegxl::ResizeableParallelRunner*>(runner_opaque)
    176       ->Run(jpegxl_opaque, init, func, start_range, end_range);
    177 }
    178 
    179 JXL_THREADS_EXPORT void* JxlResizableParallelRunnerCreate(
    180     const JxlMemoryManager* memory_manager) {
    181   return new jpegxl::ResizeableParallelRunner();
    182 }
    183 
    184 JXL_THREADS_EXPORT void JxlResizableParallelRunnerSetThreads(
    185     void* runner_opaque, size_t num_threads) {
    186   static_cast<jpegxl::ResizeableParallelRunner*>(runner_opaque)
    187       ->SetNumThreads(num_threads);
    188 }
    189 
    190 JXL_THREADS_EXPORT void JxlResizableParallelRunnerDestroy(void* runner_opaque) {
    191   delete static_cast<jpegxl::ResizeableParallelRunner*>(runner_opaque);
    192 }
    193 
    194 JXL_THREADS_EXPORT uint32_t
    195 JxlResizableParallelRunnerSuggestThreads(uint64_t xsize, uint64_t ysize) {
    196   // ~one thread per group.
    197   return std::min<uint64_t>(std::thread::hardware_concurrency(),
    198                             xsize * ysize / (256 * 256));
    199 }
    200 }