LIVE / parallel.cpp
Xu Ma
update
1c3c0d9
raw
history blame
No virus
8.66 kB
#include "parallel.h"
#include <list>
#include <thread>
#include <condition_variable>
#include <vector>
#include <cassert>
// From https://github.com/mmp/pbrt-v3/blob/master/src/core/parallel.cpp
static std::vector<std::thread> threads;
static bool shutdownThreads = false;
struct ParallelForLoop;
static ParallelForLoop *workList = nullptr;
static std::mutex workListMutex;
struct ParallelForLoop {
ParallelForLoop(std::function<void(int64_t)> func1D, int64_t maxIndex, int chunkSize)
: func1D(std::move(func1D)), maxIndex(maxIndex), chunkSize(chunkSize) {
}
ParallelForLoop(const std::function<void(Vector2i)> &f, const Vector2i count)
: func2D(f), maxIndex(count[0] * count[1]), chunkSize(1) {
nX = count[0];
}
std::function<void(int64_t)> func1D;
std::function<void(Vector2i)> func2D;
const int64_t maxIndex;
const int chunkSize;
int64_t nextIndex = 0;
int activeWorkers = 0;
ParallelForLoop *next = nullptr;
int nX = -1;
bool Finished() const {
return nextIndex >= maxIndex && activeWorkers == 0;
}
};
void Barrier::Wait() {
std::unique_lock<std::mutex> lock(mutex);
assert(count > 0);
if (--count == 0) {
// This is the last thread to reach the barrier; wake up all of the
// other ones before exiting.
cv.notify_all();
} else {
// Otherwise there are still threads that haven't reached it. Give
// up the lock and wait to be notified.
cv.wait(lock, [this] { return count == 0; });
}
}
static std::condition_variable workListCondition;
static void worker_thread_func(const int tIndex, std::shared_ptr<Barrier> barrier) {
ThreadIndex = tIndex;
// The main thread sets up a barrier so that it can be sure that all
// workers have called ProfilerWorkerThreadInit() before it continues
// (and actually starts the profiling system).
barrier->Wait();
// Release our reference to the Barrier so that it's freed once all of
// the threads have cleared it.
barrier.reset();
std::unique_lock<std::mutex> lock(workListMutex);
while (!shutdownThreads) {
if (!workList) {
// Sleep until there are more tasks to run
workListCondition.wait(lock);
} else {
// Get work from _workList_ and run loop iterations
ParallelForLoop &loop = *workList;
// Run a chunk of loop iterations for _loop_
// Find the set of loop iterations to run next
int64_t indexStart = loop.nextIndex;
int64_t indexEnd = std::min(indexStart + loop.chunkSize, loop.maxIndex);
// Update _loop_ to reflect iterations this thread will run
loop.nextIndex = indexEnd;
if (loop.nextIndex == loop.maxIndex)
workList = loop.next;
loop.activeWorkers++;
// Run loop indices in _[indexStart, indexEnd)_
lock.unlock();
for (int64_t index = indexStart; index < indexEnd; ++index) {
if (loop.func1D) {
loop.func1D(index);
}
// Handle other types of loops
else {
assert(loop.func2D != nullptr);
loop.func2D(Vector2i{int(index % loop.nX),
int(index / loop.nX)});
}
}
lock.lock();
// Update _loop_ to reflect completion of iterations
loop.activeWorkers--;
if (loop.Finished()) {
workListCondition.notify_all();
}
}
}
}
void parallel_for_host(const std::function<void(int64_t)> &func,
int64_t count,
int chunkSize) {
// Run iterations immediately if not using threads or if _count_ is small
if (threads.empty() || count < chunkSize) {
for (int64_t i = 0; i < count; ++i) {
func(i);
}
return;
}
// Create and enqueue _ParallelForLoop_ for this loop
ParallelForLoop loop(func, count, chunkSize);
workListMutex.lock();
loop.next = workList;
workList = &loop;
workListMutex.unlock();
// Notify worker threads of work to be done
std::unique_lock<std::mutex> lock(workListMutex);
workListCondition.notify_all();
// Help out with parallel loop iterations in the current thread
while (!loop.Finished()) {
// Run a chunk of loop iterations for _loop_
// Find the set of loop iterations to run next
int64_t indexStart = loop.nextIndex;
int64_t indexEnd = std::min(indexStart + loop.chunkSize, loop.maxIndex);
// Update _loop_ to reflect iterations this thread will run
loop.nextIndex = indexEnd;
if (loop.nextIndex == loop.maxIndex) {
workList = loop.next;
}
loop.activeWorkers++;
// Run loop indices in _[indexStart, indexEnd)_
lock.unlock();
for (int64_t index = indexStart; index < indexEnd; ++index) {
if (loop.func1D) {
loop.func1D(index);
}
// Handle other types of loops
else {
assert(loop.func2D != nullptr);
loop.func2D(Vector2i{int(index % loop.nX),
int(index / loop.nX)});
}
}
lock.lock();
// Update _loop_ to reflect completion of iterations
loop.activeWorkers--;
}
}
thread_local int ThreadIndex;
void parallel_for_host(
std::function<void(Vector2i)> func, const Vector2i count) {
// Launch worker threads if needed
if (threads.empty() || count.x * count.y <= 1) {
for (int y = 0; y < count.y; ++y) {
for (int x = 0; x < count.x; ++x) {
func(Vector2i{x, y});
}
}
return;
}
ParallelForLoop loop(std::move(func), count);
{
std::lock_guard<std::mutex> lock(workListMutex);
loop.next = workList;
workList = &loop;
}
std::unique_lock<std::mutex> lock(workListMutex);
workListCondition.notify_all();
// Help out with parallel loop iterations in the current thread
while (!loop.Finished()) {
// Run a chunk of loop iterations for _loop_
// Find the set of loop iterations to run next
int64_t indexStart = loop.nextIndex;
int64_t indexEnd = std::min(indexStart + loop.chunkSize, loop.maxIndex);
// Update _loop_ to reflect iterations this thread will run
loop.nextIndex = indexEnd;
if (loop.nextIndex == loop.maxIndex) {
workList = loop.next;
}
loop.activeWorkers++;
// Run loop indices in _[indexStart, indexEnd)_
lock.unlock();
for (int64_t index = indexStart; index < indexEnd; ++index) {
if (loop.func1D) {
loop.func1D(index);
}
// Handle other types of loops
else {
assert(loop.func2D != nullptr);
loop.func2D(Vector2i{int(index % loop.nX),
int(index / loop.nX)});
}
}
lock.lock();
// Update _loop_ to reflect completion of iterations
loop.activeWorkers--;
}
}
int num_system_cores() {
// return 1;
int ret = std::thread::hardware_concurrency();
if (ret == 0) {
return 16;
}
return ret;
}
void parallel_init() {
assert(threads.size() == 0);
int nThreads = num_system_cores();
ThreadIndex = 0;
// Create a barrier so that we can be sure all worker threads get past
// their call to ProfilerWorkerThreadInit() before we return from this
// function. In turn, we can be sure that the profiling system isn't
// started until after all worker threads have done that.
std::shared_ptr<Barrier> barrier = std::make_shared<Barrier>(nThreads);
// Launch one fewer worker thread than the total number we want doing
// work, since the main thread helps out, too.
for (int i = 0; i < nThreads - 1; ++i) {
threads.push_back(std::thread(worker_thread_func, i + 1, barrier));
}
barrier->Wait();
}
void parallel_cleanup() {
if (threads.empty()) {
return;
}
{
std::lock_guard<std::mutex> lock(workListMutex);
shutdownThreads = true;
workListCondition.notify_all();
}
for (std::thread &thread : threads) {
thread.join();
}
threads.erase(threads.begin(), threads.end());
shutdownThreads = false;
}