Skip to content

Support partialed functions. #80

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:

- name: Run unit tests and doctests.
shell: bash -l {0}
run: tox -e test -- tests -m "unit or (not integration and not end_to_end)" --cov=./ --cov-report=xml
run: tox -e test -- tests -m "unit or (not integration and not end_to_end)" --cov=src --cov=tests --cov-report=xml

- name: Upload coverage report for unit tests and doctests.
if: runner.os == 'Linux' && matrix.python-version == '3.10'
Expand All @@ -51,7 +51,7 @@ jobs:

- name: Run end-to-end tests.
shell: bash -l {0}
run: tox -e test -- tests -m end_to_end --cov=./ --cov-report=xml
run: tox -e test -- tests -m end_to_end --cov=src --cov=tests --cov-report=xml

- name: Upload coverage reports of end-to-end tests.
if: runner.os == 'Linux' && matrix.python-version == '3.10'
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ __pycache__
build
dist
src/pytask_parallel/_version.py
tests/test_jupyter/file.txt
tests/test_jupyter/*.txt
3 changes: 2 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ chronological order. Releases follow [semantic versioning](https://semver.org/)
releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and
[Anaconda.org](https://anaconda.org/conda-forge/pytask-parallel).

## 0.4.1 - 2023-12-xx
## 0.4.1 - 2024-01-12

- {pull}`72` moves the project to `pyproject.toml`.
- {pull}`75` updates the release strategy.
- {pull}`79` add tests for Jupyter and fix parallelization with `PythonNode`s.
- {pull}`80` adds support for partialed functions.
- {pull}`82` fixes testing with pytask v0.4.5.

## 0.4.0 - 2023-10-07
Expand Down
64 changes: 35 additions & 29 deletions src/pytask_parallel/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
import warnings
from concurrent.futures import Future
from functools import partial
from pathlib import Path
from types import ModuleType
from types import TracebackType
Expand Down Expand Up @@ -296,23 +297,7 @@ def _execute_task( # noqa: PLR0913
exc_info, show_locals, console_options
)
else:
if "return" in task.produces:
structure_out = tree_structure(out)
structure_return = tree_structure(task.produces["return"])
# strict must be false when none is leaf.
if not structure_return.is_prefix(structure_out, strict=False):
msg = (
"The structure of the return annotation is not a subtree of "
"the structure of the function return.\n\nFunction return: "
f"{structure_out}\n\nReturn annotation: {structure_return}"
)
raise ValueError(msg)

nodes = tree_leaves(task.produces["return"])
values = structure_return.flatten_up_to(out)
for node, value in zip(nodes, values):
node.save(value)

_handle_task_function_return(task, out)
processed_exc_info = None

task_display_name = getattr(task, "display_name", task.name)
Expand Down Expand Up @@ -347,6 +332,27 @@ def _process_exception(
return (*exc_info[:2], text)


def _handle_task_function_return(task: PTask, out: Any) -> None:
if "return" not in task.produces:
return

structure_out = tree_structure(out)
structure_return = tree_structure(task.produces["return"])
# strict must be false when none is leaf.
if not structure_return.is_prefix(structure_out, strict=False):
msg = (
"The structure of the return annotation is not a subtree of "
"the structure of the function return.\n\nFunction return: "
f"{structure_out}\n\nReturn annotation: {structure_return}"
)
raise ValueError(msg)

nodes = tree_leaves(task.produces["return"])
values = structure_return.flatten_up_to(out)
for node, value in zip(nodes, values):
node.save(value)


class DefaultBackendNameSpace:
"""The name space for hooks related to threads."""

Expand All @@ -362,13 +368,13 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None:
if session.config["n_workers"] > 1:
kwargs = _create_kwargs_for_task(task)
return session.config["_parallel_executor"].submit(
_mock_processes_for_threads, func=task.execute, **kwargs
_mock_processes_for_threads, task=task, **kwargs
)
return None


def _mock_processes_for_threads(
func: Callable[..., Any], **kwargs: Any
task: PTask, **kwargs: Any
) -> tuple[
None, list[Any], tuple[type[BaseException], BaseException, TracebackType] | None
]:
Expand All @@ -381,10 +387,11 @@ def _mock_processes_for_threads(
"""
__tracebackhide__ = True
try:
func(**kwargs)
out = task.function(**kwargs)
except Exception: # noqa: BLE001
exc_info = sys.exc_info()
else:
_handle_task_function_return(task, out)
exc_info = None
return None, [], exc_info

Expand Down Expand Up @@ -430,18 +437,17 @@ def sleep(self) -> None:
def _get_module(func: Callable[..., Any], path: Path | None) -> ModuleType:
"""Get the module of a python function.

For Python <3.10, functools.partial does not set a `__module__` attribute which is
why ``inspect.getmodule`` returns ``None`` and ``cloudpickle.pickle_by_value``
fails. In later versions, ``functools`` is returned and everything seems to work
fine.
``functools.partial`` obfuscates the module of the function and
``inspect.getmodule`` returns :mod`functools`. Therefore, we recover the original
function.

Therefore, we use the path from the task module to aid the search which works for
Python <3.10.

We do not unwrap the partialed function with ``func.func``, since pytask in general
does not really support ``functools.partial``. Instead, use ``@task(kwargs=...)``.
We use the path from the task module to aid the search although it is not clear
whether it helps.

"""
if isinstance(func, partial):
func = func.func

if path:
return inspect.getmodule(func, path.as_posix())
return inspect.getmodule(func)
40 changes: 34 additions & 6 deletions tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,12 @@ def task_example() -> Annotated[str, Path("file.txt")]:
"""
tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source))
result = runner.invoke(
cli, [tmp_path.as_posix(), "--parallel-backend", parallel_backend]
cli, [tmp_path.as_posix(), "-n", "2", "--parallel-backend", parallel_backend]
)
assert result.exit_code == ExitCode.OK
assert tmp_path.joinpath("file.txt").exists()
assert (
tmp_path.joinpath("file.txt").read_text() == "Hello, Darkness, my old friend."
)


@pytest.mark.end_to_end()
Expand All @@ -252,10 +254,12 @@ def test_task_without_path_that_return(runner, tmp_path, parallel_backend):
"""
tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source))
result = runner.invoke(
cli, [tmp_path.as_posix(), "--parallel-backend", parallel_backend]
cli, [tmp_path.as_posix(), "-n", "2", "--parallel-backend", parallel_backend]
)
assert result.exit_code == ExitCode.OK
assert tmp_path.joinpath("file.txt").exists()
assert (
tmp_path.joinpath("file.txt").read_text() == "Hello, Darkness, my old friend."
)


@pytest.mark.end_to_end()
Expand All @@ -264,7 +268,8 @@ def test_task_without_path_that_return(runner, tmp_path, parallel_backend):
def test_parallel_execution_is_deactivated(runner, tmp_path, flag, parallel_backend):
tmp_path.joinpath("task_example.py").write_text("def task_example(): pass")
result = runner.invoke(
cli, [tmp_path.as_posix(), "-n 2", "--parallel-backend", parallel_backend, flag]
cli,
[tmp_path.as_posix(), "-n", "2", "--parallel-backend", parallel_backend, flag],
)
assert result.exit_code == ExitCode.OK
assert "Started 2 workers" not in result.output
Expand All @@ -278,7 +283,30 @@ def test_parallel_execution_is_deactivated(runner, tmp_path, flag, parallel_back
def test_raise_error_on_breakpoint(runner, tmp_path, code, parallel_backend):
tmp_path.joinpath("task_example.py").write_text(f"def task_example(): {code}")
result = runner.invoke(
cli, [tmp_path.as_posix(), "-n 2", "--parallel-backend", parallel_backend]
cli, [tmp_path.as_posix(), "-n", "2", "--parallel-backend", parallel_backend]
)
assert result.exit_code == ExitCode.FAILED
assert "You cannot use 'breakpoint()'" in result.output


@pytest.mark.end_to_end()
@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS)
def test_task_partialed(runner, tmp_path, parallel_backend):
source = """
from pathlib import Path
from pytask import task
from functools import partial

def create_text(text):
return text

task_example = task(
produces=Path("file.txt")
)(partial(create_text, text="Hello, Darkness, my old friend."))
"""
tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source))
result = runner.invoke(
cli, [tmp_path.as_posix(), "-n", "2", "--parallel-backend", parallel_backend]
)
assert result.exit_code == ExitCode.OK
assert tmp_path.joinpath("file.txt").exists()
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = tox>=4
envlist = test

[testenv]
package = wheel
package = editable

[testenv:test]
extras = test
Expand Down