Peano
Loading...
Searching...
No Matches
task.h
Go to the documentation of this file.
1#pragma once
2
3#include <omp.h>
4
5#include <atomic>
6#include <condition_variable>
7
8#include <lang/assert.h>
9#include <lang/channel.h>
10#include <lang/thread_pool.h>
11#include <lang/type.h>
12
13struct TaskCtx {
15};
16
17template<typename T>
18concept CtxUnawareStatelessTask = requires(T t) {
19 { t() } -> std::same_as<void>;
20};
21
22template<typename T>
23concept CtxAwareStatelessTask = requires(T t, TaskCtx &ctx) {
24 { t(ctx) } -> std::same_as<void>;
25};
26
27template<typename T>
29
30template<typename T>
32
33template<typename T>
34concept StatelessTaskGenerator = requires(T t) {
35 { t() } -> OptionalStatelessTask;
36};
37
38template<typename T, typename State>
39concept CtxUnawareStatefulTask = requires(T t, State &s) {
40 { t(s) } -> std::same_as<void>;
41};
42
43template<typename T, typename State>
44concept CtxAwareStatefulTask = requires(T t, State &s, TaskCtx &ctx) {
45 { t(s, ctx) } -> std::same_as<void>;
46};
47
48template<typename T, typename State>
50
51template<typename T, typename State>
53
54template<typename T, typename State>
55concept StatefulTaskGenerator = requires(T t) {
57};
58
60 return Lambda([](auto &generator) {
61 auto taskOpt = generator();
62
63 using TaskType = std::remove_reference_t<decltype(taskOpt.value())>;
64
66 return taskOpt ? std::optional(Lambda([](auto &task, auto &acc) { task(); }, taskOpt.value())) : std::nullopt;
67 } else if constexpr (CtxAwareStatelessTask<TaskType>) {
68 return taskOpt ? std::optional(Lambda([](auto &task, auto &acc, auto &ctx) { task(ctx); }, taskOpt.value())) : std::nullopt;
69 } else static_assert(false);
70
71 }, generator);
72}
73
75 template<typename Val, StatefulTask<Val> Task>
76 static void invokeTask(Task &task, Val &val, i32 threadId) {
77 if constexpr (CtxAwareStatefulTask<Task, Val>) {
78 auto ctx = TaskCtx{.workerId = (u32) threadId};
79 task(val, ctx);
80 } else if constexpr (CtxUnawareStatefulTask<Task, Val>) {
81 task(val);
82 } else static_assert(false);
83 }
84
85public:
86 template<StatelessTaskGenerator T>
87 static void runSerialGenerator(T generator, u32 numThreads = 0) {
88 numThreads = OmpTaskBackend::THREAD_COUNT(numThreads);
89 auto mapReducerGenerator = ToStatefulGenerator(generator);
90 OmpTaskBackend::runSerialGenerator(mapReducerGenerator, [](){return Any{};}, numThreads);
91 }
92
93 template<typename StateInitF, StatefulTaskGenerator<std::invoke_result_t<StateInitF>> T>
94 static auto runSerialGenerator(T generator, StateInitF stateInitF, u32 numThreads = 0) {
95 using ValueType = std::invoke_result_t<StateInitF>;
96
97 numThreads = OmpTaskBackend::THREAD_COUNT(numThreads);
98
99 auto accumulators = std::vector<ValueType>(numThreads);
100
101 #pragma omp parallel num_threads(numThreads) default(none) shared(numThreads, generator, accumulators, stateInitF)
102 #pragma omp single nowait
103 {
104 for (int i = 0; i < numThreads; i++) {
105 #pragma omp task default(none) shared(stateInitF, accumulators) firstprivate(i)
106 accumulators[i] = stateInitF();
107 }
108
109 #pragma omp taskwait
110
111 while (true) {
112 auto taskOpt = generator();
113 if (!taskOpt) [[unlikely]] break;
114
115 auto task = taskOpt.value();
116 #pragma omp task default(none) shared(stateInitF, accumulators) firstprivate(task)
117 {
118 auto threadNum = omp_get_thread_num();
119 OmpTaskBackend::invokeTask(task, accumulators[threadNum], threadNum);
120 }
121 }
122 }
123
124 return accumulators;
125 }
126
127 template<StatelessTaskGenerator T>
128 static void runParallelGenerator(T generator, u32 numThreads = 0) {
129 numThreads = OmpTaskBackend::THREAD_COUNT(numThreads);
130 auto mapReducerGenerator = ToStatefulGenerator(generator);
131 OmpTaskBackend::runParallelGenerator(mapReducerGenerator, [](){ return Any{}; }, numThreads);
132 }
133
134 template<typename StateInitF, StatefulTaskGenerator<std::invoke_result_t<StateInitF>> T>
135 static auto runParallelGenerator(T generator, StateInitF stateInitF, u32 numThreads = 0) {
136 using ValueType = std::invoke_result_t<StateInitF>;
137
138 numThreads = OmpTaskBackend::THREAD_COUNT(numThreads);
139
140 auto accumulators = std::vector<ValueType>(numThreads);
141
142 #pragma omp parallel num_threads(numThreads) default(none) shared(generator, stateInitF, accumulators)
143 {
144 auto threadId = omp_get_thread_num();
145 accumulators[threadId] = stateInitF();
146
147 while (true) {
148 auto taskOpt = generator();
149 if (!taskOpt) [[unlikely]] break;
150
151 auto task = taskOpt.value();
152 OmpTaskBackend::invokeTask(task, accumulators[threadId], threadId);
153 }
154 }
155
156 return accumulators;
157 }
158
159 static u32 MAX_THREADS() {
160 auto numThreads = (u32) omp_get_max_threads();
161 return numThreads;
162 }
163
164 static u32 THREAD_COUNT(u32 numThreads = 0) {
165 if (numThreads == 0) numThreads = OmpTaskBackend::MAX_THREADS();
166 numThreads = std::clamp(numThreads, (u32) 1, OmpTaskBackend::MAX_THREADS());
167 return numThreads;
168 }
169};
170
172 template<typename Task, typename TaskGenerator>
173 static std::optional<Task> createNewTasks(TaskGenerator *generator,
174 SPMCChannel<Task> *channel,
175 std::atomic_int32_t *producerThreadId) {
176 if (channel->isClosed()) {
177 producerThreadId->store(-1, std::memory_order_relaxed);
178 return std::nullopt;
179 }
180
181 while (true) {
182 auto taskOpt = (*generator)();
183
184 if (!taskOpt) [[unlikely]] {
185 channel->close();
186 producerThreadId->store(-1, std::memory_order_relaxed);
187 return std::nullopt;
188 }
189
190 auto task = taskOpt.value();
191
192 auto success = channel->put(task);
193 if (!success) [[unlikely]] return task;
194 }
195 }
196
197 template<typename Val, StatefulTask<Val> Task>
198 static void invokeTask(Task &task, Val &val, i32 threadId) {
199 if constexpr (CtxAwareStatefulTask<Task, Val>) {
200 auto ctx = TaskCtx{.workerId = (u32) threadId};
201 task(val, ctx);
202 } else if constexpr (CtxUnawareStatefulTask<Task, Val>) {
203 task(val);
204 } else static_assert(false);
205 }
206
207 template<typename ValueType, StatefulTaskGenerator<ValueType> TaskGenerator, StatefulTask<ValueType> Task, typename StateInitF>
208 static void producerConsumerLoop(i32 threadId,
209 TaskGenerator *generator,
210 SPMCChannel<Task> *channel,
211 StateInitF stateInitF,
212 ValueType *retVal,
213 std::atomic_int32_t *producerThreadId) {
214 auto rx = channel->getReader(threadId);
215 auto acc = stateInitF();
216
217 if (threadId == producerThreadId->load(std::memory_order_relaxed)) [[unlikely]] goto I_AM_PRODUCER;
218 else goto I_AM_CONSUMER;
219
220 I_AM_PRODUCER:
221 while (true) {
222 auto myNewTaskOpt = createNewTasks(generator, channel, producerThreadId);
223 if (!myNewTaskOpt) [[unlikely]] goto I_AM_CONSUMER;
224
225 producerThreadId->store(-1, std::memory_order_relaxed);
226 auto myNewTask = myNewTaskOpt.value();
227 ThreadPoolTaskBackend::invokeTask(myNewTask, acc, threadId);
228
229 auto currentProducerThreadId = producerThreadId->load(std::memory_order_relaxed);
230 while (currentProducerThreadId == -1) {
231 producerThreadId->compare_exchange_weak(currentProducerThreadId, threadId);
232 }
233
234 if (currentProducerThreadId != threadId) [[unlikely]] goto I_AM_CONSUMER;
235 }
236
237 I_AM_CONSUMER:
238 while (true) {
239 auto taskOpt = rx.getFast();
240
241 if (taskOpt) {
242 auto task = taskOpt.value();
243 ThreadPoolTaskBackend::invokeTask(task, acc, threadId);
244 continue;
245 }
246
247 if (channel->isClosed()) [[unlikely]] {
248 while (true) {
249 taskOpt = rx.getOther();
250 if (taskOpt) {
251 auto task = taskOpt.value();
252 ThreadPoolTaskBackend::invokeTask(task, acc, threadId);
253 continue;
254 }
255 break;
256 }
257
258 *retVal = std::move(acc);
259 return;
260 }
261
262 auto m1 = -1;
263 auto someElseGenerator = producerThreadId->compare_exchange_weak(m1, threadId);
264 if (!someElseGenerator) continue;
265
266 goto I_AM_PRODUCER;
267 }
268 }
269
271 static auto value = ThreadPool();
272 return value;
273 }
274
275 template<typename StateInitF, StatefulTaskGenerator<std::invoke_result_t<StateInitF>> T>
276 static auto runSerialGeneratorSerial(T generator, StateInitF stateInitF) {
277 using ValueType = std::invoke_result_t<StateInitF>;
278
279 auto accumulators = std::vector<ValueType>(1);
280 accumulators[0] = stateInitF();
281
282 while (true) {
283 auto taskOpt = generator();
284 if (!taskOpt) break;
285 auto task = taskOpt.value();
286 ThreadPoolTaskBackend::invokeTask(task, accumulators[0], 0);
287 }
288
289 return accumulators;
290 }
291
292 template<typename StateInitF, StatefulTaskGenerator<std::invoke_result_t<StateInitF>> T>
293 static auto runSerialGeneratorParallel(T generator, StateInitF stateInitF, u32 numThreads) {
294 using ValueType = std::invoke_result_t<StateInitF>;
295 using Task = std::invoke_result_t<T>::value_type;
296
297 auto channel = SPMCChannel<Task>(numThreads);
298 auto accumulators = std::vector<ValueType>(numThreads);
299
300 std::atomic_int32_t producerThreadId = 0;
301
302 for (i32 localHelperId = 1; localHelperId < numThreads; localHelperId++) {
303 auto runnable = Lambda([](auto localHelperId, auto *generator, auto *channel, auto stateInitF, auto *accumulator, auto *producerThreadId) {
304 producerConsumerLoop<ValueType, T, Task, StateInitF>(localHelperId, generator, channel, stateInitF, accumulator, producerThreadId);
305 }, localHelperId, &generator, &channel, stateInitF, &accumulators[localHelperId], &producerThreadId);
306
308 }
309
310 producerConsumerLoop<ValueType, T, Task, StateInitF>(0, &generator, &channel, stateInitF, &accumulators[0], &producerThreadId);
311
313
314 return accumulators;
315 }
316
317 template<typename StateInitF, StatefulTaskGenerator<std::invoke_result_t<StateInitF>> T>
318 static auto runParallelGeneratorParallel(T &generator, StateInitF stateInitF, u32 numThreads) {
319 using ValueType = std::invoke_result_t<StateInitF>;
320
321 auto accumulators = std::vector<ValueType>(numThreads);
322
323 for (u32 localHelperId = 1; localHelperId < numThreads; localHelperId++) {
324 ThreadPoolTaskBackend::WORKER_POOL().dispatch([localHelperId, &accumulators, &generator, stateInitF]() {
325 auto acc = stateInitF();
326 while (true) {
327 auto taskOpt = generator();
328 if (!taskOpt) break;
329
330 auto task = taskOpt.value();
331 ThreadPoolTaskBackend::invokeTask(task, acc, localHelperId);
332 }
333 accumulators[localHelperId] = acc;
334 }, 1);
335 }
336
337 auto &acc = accumulators[0];
338 acc = stateInitF();
339 while (true) {
340 auto taskOpt = generator();
341 if (!taskOpt) break;
342
343 auto task = taskOpt.value();
345 }
346
348
349 return accumulators;
350 }
351
352public:
353 template<StatelessTaskGenerator T>
354 static void runSerialGenerator(T generator, u32 numThreads = 0) {
355 numThreads = ThreadPoolTaskBackend::THREAD_COUNT(numThreads);
356 auto mapReducerGenerator = ToStatefulGenerator(generator);
357 ThreadPoolTaskBackend::runSerialGenerator(mapReducerGenerator, [](){return Any{};}, numThreads);
358 }
359
360 template<typename StateInitF, StatefulTaskGenerator<std::invoke_result_t<StateInitF>> T>
361 static auto runSerialGenerator(T generator, StateInitF stateInitF, u32 numThreads = 0) {
362 numThreads = ThreadPoolTaskBackend::THREAD_COUNT(numThreads);
363 if (numThreads == 1) [[unlikely]] return ThreadPoolTaskBackend::runSerialGeneratorSerial(generator, stateInitF);
364 else return ThreadPoolTaskBackend::runSerialGeneratorParallel(generator, stateInitF, numThreads);
365 }
366
367 template<StatelessTaskGenerator T>
368 static void runParallelGenerator(T generator, u32 numThreads = 0) {
369 numThreads = ThreadPoolTaskBackend::THREAD_COUNT(numThreads);
370 auto mapReducerGenerator = ToStatefulGenerator(generator);
371 ThreadPoolTaskBackend::runParallelGenerator(mapReducerGenerator, [](){return Any{};}, numThreads);
372 }
373
374 template<typename StateInitF, StatefulTaskGenerator<std::invoke_result_t<StateInitF>> T>
375 static auto runParallelGenerator(T generator, StateInitF stateInitF, u32 numThreads = 0) {
376 numThreads = ThreadPoolTaskBackend::THREAD_COUNT(numThreads);
377 if (numThreads == 1) [[unlikely]] return ThreadPoolTaskBackend::runSerialGeneratorSerial(generator, stateInitF);
378 else return ThreadPoolTaskBackend::runParallelGeneratorParallel(generator, stateInitF, numThreads);
379 }
380
381 static u32 THREAD_COUNT(u32 numThreads = 0) {
382 if (numThreads == 0) numThreads = ThreadPoolTaskBackend::MAX_THREADS();
383 numThreads = std::clamp(numThreads, (u32) 1, ThreadPoolTaskBackend::MAX_THREADS());
384 return numThreads;
385 }
386
387 static u32 MAX_THREADS() {
388 auto size = 1 + ThreadPoolTaskBackend::WORKER_POOL().size();
389 return size;
390 }
391};
applications::exahype2::acoustic::VariableShortcuts s
Definition Acoustic.cpp:9
Definition lambda.h:6
static void runParallelGenerator(T generator, u32 numThreads=0)
Definition task.h:128
static auto runSerialGenerator(T generator, StateInitF stateInitF, u32 numThreads=0)
Definition task.h:94
static void runSerialGenerator(T generator, u32 numThreads=0)
Definition task.h:87
static auto runParallelGenerator(T generator, StateInitF stateInitF, u32 numThreads=0)
Definition task.h:135
static u32 THREAD_COUNT(u32 numThreads=0)
Definition task.h:164
static void invokeTask(Task &task, Val &val, i32 threadId)
Definition task.h:76
static u32 MAX_THREADS()
Definition task.h:159
bool put(Item item)
Definition channel.h:109
Reader getReader(u32 consumerIdHint=0)
Definition channel.h:105
bool isClosed() const
Definition channel.h:133
void close()
Definition channel.h:129
static void runParallelGenerator(T generator, u32 numThreads=0)
Definition task.h:368
static auto runParallelGeneratorParallel(T &generator, StateInitF stateInitF, u32 numThreads)
Definition task.h:318
static auto runParallelGenerator(T generator, StateInitF stateInitF, u32 numThreads=0)
Definition task.h:375
static ThreadPool & WORKER_POOL()
Definition task.h:270
static auto runSerialGenerator(T generator, StateInitF stateInitF, u32 numThreads=0)
Definition task.h:361
static u32 MAX_THREADS()
Definition task.h:387
static u32 THREAD_COUNT(u32 numThreads=0)
Definition task.h:381
static auto runSerialGeneratorParallel(T generator, StateInitF stateInitF, u32 numThreads)
Definition task.h:293
static void invokeTask(Task &task, Val &val, i32 threadId)
Definition task.h:198
static void runSerialGenerator(T generator, u32 numThreads=0)
Definition task.h:354
static void producerConsumerLoop(i32 threadId, TaskGenerator *generator, SPMCChannel< Task > *channel, StateInitF stateInitF, ValueType *retVal, std::atomic_int32_t *producerThreadId)
Definition task.h:208
static auto runSerialGeneratorSerial(T generator, StateInitF stateInitF)
Definition task.h:276
static std::optional< Task > createNewTasks(TaskGenerator *generator, SPMCChannel< Task > *channel, std::atomic_int32_t *producerThreadId)
Definition task.h:173
void dispatch(const std::function< void()> &work, u32 numThreads=0)
void wait()
u32 size() const
Definition type.h:19
Definition task.h:13
u32 workerId
Definition task.h:14
StatefulTaskGenerator< Any > auto ToStatefulGenerator(StatelessTaskGenerator auto generator)
Definition task.h:59
std::uint32_t u32
Definition type.h:11
std::int32_t i32
Definition type.h:10