Skip to content

Commit 2ce5fab

Browse files
authored
Patch pprint to make pytest diffs nicer for big objects (#92)
Replaces alexmojaki#1 Closes #73
1 parent b1e6384 commit 2ce5fab

File tree

2 files changed

+80
-8
lines changed

2 files changed

+80
-8
lines changed

dirty_equals/_base.py

+22
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import io
12
from abc import ABCMeta
3+
from pprint import PrettyPrinter
24
from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Optional, Protocol, Tuple, TypeVar
35

46
from ._utils import Omit
@@ -131,6 +133,26 @@ def __repr__(self) -> str:
131133
# else return something which explains what's going on.
132134
return self._repr_ne()
133135

136+
def _pprint_format(self, pprinter: PrettyPrinter, stream: io.StringIO, *args: Any, **kwargs: Any) -> None:
137+
# pytest diffs use pprint to format objects, so we patch pprint to call this method
138+
# for DirtyEquals objects. So this method needs to follow the same pattern as __repr__.
139+
# We check that the protected _format method actually exists
140+
# to be safe and to make linters happy.
141+
if self._was_equal and hasattr(pprinter, '_format'):
142+
pprinter._format(self._other, stream, *args, **kwargs)
143+
else:
144+
stream.write(repr(self)) # i.e. self._repr_ne() (for now)
145+
146+
147+
# Patch pprint to call _pprint_format for DirtyEquals objects
148+
# Check that the protected attribute _dispatch exists to be safe and to make linters happy.
149+
# The reason we modify _dispatch rather than _format
150+
# is that pytest sometimes uses a subclass of PrettyPrinter which overrides _format.
151+
if hasattr(PrettyPrinter, '_dispatch'): # pragma: no branch
152+
PrettyPrinter._dispatch[DirtyEquals.__repr__] = lambda pprinter, obj, *args, **kwargs: obj._pprint_format(
153+
pprinter, *args, **kwargs
154+
)
155+
134156

135157
InstanceOrType: 'TypeAlias' = 'Union[DirtyEquals[Any], DirtyEqualsMeta]'
136158

tests/test_base.py

+58-8
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import platform
2+
import pprint
23

34
import packaging.version
45
import pytest
56

6-
from dirty_equals import Contains, IsApprox, IsInt, IsNegative, IsOneOf, IsPositive, IsStr
7+
from dirty_equals import Contains, IsApprox, IsInt, IsList, IsNegative, IsOneOf, IsPositive, IsStr
78
from dirty_equals.version import VERSION
89

910

@@ -39,8 +40,7 @@ def test_value_eq():
3940
v.value
4041

4142
assert 'foo' == v
42-
assert str(v) == "'foo'"
43-
assert repr(v) == "'foo'"
43+
assert repr(v) == str(v) == "'foo'" == pprint.pformat(v)
4444
assert v.value == 'foo'
4545

4646

@@ -50,8 +50,7 @@ def test_value_ne():
5050
with pytest.raises(AssertionError):
5151
assert 1 == v
5252

53-
assert str(v) == 'IsStr()'
54-
assert repr(v) == 'IsStr()'
53+
assert repr(v) == str(v) == 'IsStr()' == pprint.pformat(v)
5554
with pytest.raises(AttributeError, match='value is not available until __eq__ has been called'):
5655
v.value
5756

@@ -110,7 +109,7 @@ def test_repr():
110109
],
111110
)
112111
def test_repr_class(v, v_repr):
113-
assert repr(v) == v_repr
112+
assert repr(v) == str(v) == v_repr == pprint.pformat(v)
114113

115114

116115
def test_is_approx_without_init():
@@ -119,11 +118,62 @@ def test_is_approx_without_init():
119118

120119
def test_ne_repr():
121120
v = IsInt
122-
assert repr(v) == 'IsInt'
121+
assert repr(v) == str(v) == 'IsInt' == pprint.pformat(v)
123122

124123
assert 'x' != v
125124

126-
assert repr(v) == 'IsInt'
125+
assert repr(v) == str(v) == 'IsInt' == pprint.pformat(v)
126+
127+
128+
def test_pprint():
129+
v = [IsList(length=...), 1, [IsList(length=...), 2], 3, IsInt()]
130+
lorem = ['lorem', 'ipsum', 'dolor', 'sit', 'amet'] * 2
131+
with pytest.raises(AssertionError):
132+
assert [lorem, 1, [lorem, 2], 3, '4'] == v
133+
134+
assert repr(v) == (f'[{lorem}, 1, [{lorem}, 2], 3, IsInt()]')
135+
assert pprint.pformat(v) == (
136+
"[['lorem',\n"
137+
" 'ipsum',\n"
138+
" 'dolor',\n"
139+
" 'sit',\n"
140+
" 'amet',\n"
141+
" 'lorem',\n"
142+
" 'ipsum',\n"
143+
" 'dolor',\n"
144+
" 'sit',\n"
145+
" 'amet'],\n"
146+
' 1,\n'
147+
" [['lorem',\n"
148+
" 'ipsum',\n"
149+
" 'dolor',\n"
150+
" 'sit',\n"
151+
" 'amet',\n"
152+
" 'lorem',\n"
153+
" 'ipsum',\n"
154+
" 'dolor',\n"
155+
" 'sit',\n"
156+
" 'amet'],\n"
157+
' 2],\n'
158+
' 3,\n'
159+
' IsInt()]'
160+
)
161+
162+
163+
def test_pprint_not_equal():
164+
v = IsList(*range(30)) # need a big value to trigger pprint
165+
with pytest.raises(AssertionError):
166+
assert [] == v
167+
168+
assert (
169+
pprint.pformat(v)
170+
== (
171+
'IsList(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, '
172+
'15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29)'
173+
)
174+
== repr(v)
175+
== str(v)
176+
)
127177

128178

129179
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)