|
2 | 2 |
|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
5 |
| -import enum |
6 | 5 | import os
|
7 | 6 | from typing import Any
|
8 | 7 |
|
9 | 8 | from pytask import hookimpl
|
10 | 9 |
|
| 10 | +from pytask_parallel import execute |
| 11 | +from pytask_parallel import processes |
| 12 | +from pytask_parallel import threads |
11 | 13 | from pytask_parallel.backends import ParallelBackend
|
12 | 14 |
|
13 | 15 |
|
14 | 16 | @hookimpl
|
15 | 17 | def pytask_parse_config(config: dict[str, Any]) -> None:
|
16 | 18 | """Parse the configuration."""
|
| 19 | + __tracebackhide__ = True |
| 20 | + |
17 | 21 | if config["n_workers"] == "auto":
|
18 | 22 | config["n_workers"] = max(os.cpu_count() - 1, 1)
|
19 | 23 |
|
20 |
| - if ( |
21 |
| - isinstance(config["parallel_backend"], str) |
22 |
| - and config["parallel_backend"] in ParallelBackend._value2member_map_ # noqa: SLF001 |
23 |
| - ): |
| 24 | + try: |
24 | 25 | config["parallel_backend"] = ParallelBackend(config["parallel_backend"])
|
25 |
| - elif ( |
26 |
| - isinstance(config["parallel_backend"], enum.Enum) |
27 |
| - and config["parallel_backend"] in ParallelBackend |
28 |
| - ): |
29 |
| - pass |
30 |
| - else: |
31 |
| - msg = f"Invalid value for 'parallel_backend'. Got {config['parallel_backend']}." |
32 |
| - raise ValueError(msg) |
| 26 | + except ValueError: |
| 27 | + msg = ( |
| 28 | + f"Invalid value for 'parallel_backend'. Got {config['parallel_backend']}. " |
| 29 | + f"Choose one of {', '.join([e.value for e in ParallelBackend])}." |
| 30 | + ) |
| 31 | + raise ValueError(msg) from None |
33 | 32 |
|
34 | 33 | config["delay"] = 0.1
|
35 | 34 |
|
36 | 35 |
|
37 | 36 | @hookimpl
|
38 | 37 | def pytask_post_parse(config: dict[str, Any]) -> None:
|
39 |
| - """Disable parallelization if debugging is enabled.""" |
| 38 | + """Register the parallel backend if debugging is not enabled.""" |
40 | 39 | if config["pdb"] or config["trace"] or config["dry_run"]:
|
41 | 40 | config["n_workers"] = 1
|
| 41 | + |
| 42 | + if config["n_workers"] > 1: |
| 43 | + config["pm"].register(execute) |
| 44 | + if config["parallel_backend"] == ParallelBackend.THREADS: |
| 45 | + config["pm"].register(threads) |
| 46 | + else: |
| 47 | + config["pm"].register(processes) |
0 commit comments