Skip to content

Commit d7fb5e5

Browse files
authored
Refactor ProcessPoolExecutor. (#56)
1 parent 49d48ac commit d7fb5e5

File tree

3 files changed

+37
-13
lines changed

3 files changed

+37
-13
lines changed

CHANGES.md

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and
77

88
## 0.3.1 - 2023-xx-xx
99

10+
- {pull}`56` refactors the `ProcessPoolExecutor`.
11+
1012
## 0.3.0 - 2023-01-23
1113

1214
- {pull}`50` deprecates INI configurations and aligns the package with pytask v0.3.

src/pytask_parallel/backends.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,37 @@
22
from __future__ import annotations
33

44
import enum
5+
from concurrent.futures import Future
56
from concurrent.futures import ProcessPoolExecutor
67
from concurrent.futures import ThreadPoolExecutor
8+
from typing import Any
9+
from typing import Callable
10+
11+
import cloudpickle
12+
13+
14+
def deserialize_and_run_with_cloudpickle(
15+
fn: Callable[..., Any], kwargs: dict[str, Any]
16+
) -> Any:
17+
"""Deserialize and execute a function and keyword arguments."""
18+
deserialized_fn = cloudpickle.loads(fn)
19+
deserialized_kwargs = cloudpickle.loads(kwargs)
20+
return deserialized_fn(**deserialized_kwargs)
21+
22+
23+
class CloudpickleProcessPoolExecutor(ProcessPoolExecutor):
24+
"""Patches the standard executor to serialize functions with cloudpickle."""
25+
26+
# The type signature is wrong for version above Py3.7. Fix when 3.7 is deprecated.
27+
def submit( # type: ignore[override]
28+
self, fn: Callable[..., Any], *args: Any, **kwargs: Any # noqa: ARG002
29+
) -> Future[Any]:
30+
"""Submit a new task."""
31+
return super().submit(
32+
deserialize_and_run_with_cloudpickle,
33+
fn=cloudpickle.dumps(fn),
34+
kwargs=cloudpickle.dumps(kwargs),
35+
)
736

837

938
try:
@@ -20,7 +49,7 @@ class ParallelBackendChoices(enum.Enum):
2049
PARALLEL_BACKENDS_DEFAULT = ParallelBackendChoices.PROCESSES
2150

2251
PARALLEL_BACKENDS = {
23-
ParallelBackendChoices.PROCESSES: ProcessPoolExecutor,
52+
ParallelBackendChoices.PROCESSES: CloudpickleProcessPoolExecutor,
2453
ParallelBackendChoices.THREADS: ThreadPoolExecutor,
2554
}
2655

@@ -36,7 +65,7 @@ class ParallelBackendChoices(enum.Enum): # type: ignore[no-redef]
3665
PARALLEL_BACKENDS_DEFAULT = ParallelBackendChoices.PROCESSES
3766

3867
PARALLEL_BACKENDS = {
39-
ParallelBackendChoices.PROCESSES: ProcessPoolExecutor,
68+
ParallelBackendChoices.PROCESSES: CloudpickleProcessPoolExecutor,
4069
ParallelBackendChoices.THREADS: ThreadPoolExecutor,
4170
ParallelBackendChoices.LOKY: ( # type: ignore[attr-defined]
4271
get_reusable_executor

src/pytask_parallel/execute.py

+4-11
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from typing import List
1313

1414
import attr
15-
import cloudpickle
1615
from pybaum.tree_util import tree_map
1716
from pytask import console
1817
from pytask import ExecutionReport
@@ -179,13 +178,10 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None:
179178
if session.config["n_workers"] > 1:
180179
kwargs = _create_kwargs_for_task(task)
181180

182-
bytes_function = cloudpickle.dumps(task)
183-
bytes_kwargs = cloudpickle.dumps(kwargs)
184-
185181
return session.config["_parallel_executor"].submit(
186182
_unserialize_and_execute_task,
187-
bytes_function=bytes_function,
188-
bytes_kwargs=bytes_kwargs,
183+
task=task,
184+
kwargs=kwargs,
189185
show_locals=session.config["show_locals"],
190186
console_options=console.options,
191187
session_filterwarnings=session.config["filterwarnings"],
@@ -196,8 +192,8 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None:
196192

197193

198194
def _unserialize_and_execute_task( # noqa: PLR0913
199-
bytes_function: bytes,
200-
bytes_kwargs: bytes,
195+
task: Task,
196+
kwargs: dict[str, Any],
201197
show_locals: bool,
202198
console_options: ConsoleOptions,
203199
session_filterwarnings: tuple[str, ...],
@@ -212,9 +208,6 @@ def _unserialize_and_execute_task( # noqa: PLR0913
212208
"""
213209
__tracebackhide__ = True
214210

215-
task = cloudpickle.loads(bytes_function)
216-
kwargs = cloudpickle.loads(bytes_kwargs)
217-
218211
with warnings.catch_warnings(record=True) as log:
219212
# mypy can't infer that record=True means log is not None; help it.
220213
assert log is not None

0 commit comments

Comments
 (0)