Skip to content

Commit 7f43559

Browse files
authored
fix:pickle args/kwargs for job (#251)
1 parent 72766a3 commit 7f43559

File tree

6 files changed

+71
-17
lines changed

6 files changed

+71
-17
lines changed

docs/changelog.md

+4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
See breaking changes in 4.0.0 beta versions.
66

7+
### 🐛 Bug Fixes
8+
9+
- Fix issue with non-primitive parameters for @job #249
10+
711
## v4.0.0b3 🌈
812

913
Refactor the code to make it more organized and easier to maintain. This includes:

poetry.lock

+3-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

scheduler/redis_models/base.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _deserialize(value: str, _type: Type) -> Any:
6363
return int(value)
6464
elif _type is float or _type == Optional[float]:
6565
return float(value)
66-
elif _type in {List[Any], List[str], Dict[str, str]}:
66+
elif _type in {List[str], Dict[str, str]}:
6767
return json.loads(value)
6868
elif _type == Optional[Any]:
6969
return json.loads(value)
@@ -78,6 +78,9 @@ def _deserialize(value: str, _type: Type) -> Any:
7878
class BaseModel:
7979
name: str
8080
_element_key_template: ClassVar[str] = ":element:{}"
81+
# fields that are not serializable using method above and should be dealt with in the subclass
82+
# e.g. args/kwargs for a job
83+
_non_serializable_fields: ClassVar[Set[str]] = set()
8184

8285
@classmethod
8386
def key_for(cls, name: str) -> str:
@@ -92,14 +95,14 @@ def serialize(self, with_nones: bool = False) -> Dict[str, str]:
9295
self, dict_factory=lambda fields: {key: value for (key, value) in fields if not key.startswith("_")}
9396
)
9497
if not with_nones:
95-
data = {k: v for k, v in data.items() if v is not None}
98+
data = {k: v for k, v in data.items() if v is not None and k not in self._non_serializable_fields}
9699
for k in data:
97100
data[k] = _serialize(data[k])
98101
return data
99102

100103
@classmethod
101104
def deserialize(cls, data: Dict[str, Any]) -> Self:
102-
types = {f.name: f.type for f in dataclasses.fields(cls)}
105+
types = {f.name: f.type for f in dataclasses.fields(cls) if f.name not in cls._non_serializable_fields}
103106
for k in data:
104107
if k not in types:
105108
logger.warning(f"Unknown field {k} in {cls.__name__}")

scheduler/redis_models/job.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import base64
12
import dataclasses
23
import inspect
34
import numbers
5+
import pickle
46
from datetime import datetime
57
from enum import Enum
68
from typing import ClassVar, Dict, Optional, List, Callable, Any, Union, Tuple
@@ -35,13 +37,15 @@ class JobModel(HashModel):
3537
_list_key: ClassVar[str] = ":jobs:ALL:"
3638
_children_key_template: ClassVar[str] = ":{}:jobs:"
3739
_element_key_template: ClassVar[str] = ":jobs:{}"
40+
_non_serializable_fields = {"args", "kwargs"}
41+
42+
args: List[Any]
43+
kwargs: Dict[str, str]
3844

3945
queue_name: str
4046
description: str
4147
func_name: str
4248

43-
args: List[Any]
44-
kwargs: Dict[str, str]
4549
timeout: int = SCHEDULER_CONFIG.DEFAULT_JOB_TIMEOUT
4650
success_ttl: int = SCHEDULER_CONFIG.DEFAULT_SUCCESS_TTL
4751
job_info_ttl: int = SCHEDULER_CONFIG.DEFAULT_JOB_TTL
@@ -63,10 +67,6 @@ class JobModel(HashModel):
6367
task_type: Optional[str] = None
6468
scheduled_task_id: Optional[int] = None
6569

66-
def serialize(self, with_nones: bool = False) -> Dict[str, str]:
67-
res = super(JobModel, self).serialize()
68-
return res
69-
7070
def __hash__(self):
7171
return hash(self.name)
7272

@@ -166,6 +166,21 @@ def stopped_callback(self) -> Optional[Callable[..., Any]]:
166166
def get_call_string(self):
167167
return _get_call_string(self.func_name, self.args, self.kwargs)
168168

169+
def serialize(self, with_nones: bool = False) -> Dict[str, str]:
170+
"""Serialize the job model to a dictionary."""
171+
res = super(JobModel, self).serialize(with_nones=with_nones)
172+
res["args"] = base64.encodebytes(pickle.dumps(self.args)).decode("utf-8")
173+
res["kwargs"] = base64.encodebytes(pickle.dumps(self.kwargs)).decode("utf-8")
174+
return res
175+
176+
@classmethod
177+
def deserialize(cls, data: Dict[str, Any]) -> Self:
178+
"""Deserialize the job model from a dictionary."""
179+
res = super(JobModel, cls).deserialize(data)
180+
res.args = pickle.loads(base64.decodebytes(data.get("args").encode("utf-8")))
181+
res.kwargs = pickle.loads(base64.decodebytes(data.get("kwargs").encode("utf-8")))
182+
return res
183+
169184
@classmethod
170185
def create(
171186
cls,

scheduler/tests/test_job_decorator.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from scheduler.helpers.queues import get_queue
77
from . import test_settings # noqa
88
from ..decorators import JOB_METHODS_LIST, job
9+
from ..redis_models import JobStatus
910
from ..redis_models.job import JobModel
11+
from ..worker import create_worker
1012

1113

1214
@job()
@@ -32,12 +34,27 @@ def test_job_result_ttl():
3234
return 1 + 1
3335

3436

37+
class MyClass:
38+
def run(self):
39+
print("Hello")
40+
41+
def __eq__(self, other):
42+
if not isinstance(other, MyClass):
43+
return False
44+
return True
45+
46+
47+
@job()
48+
def long_running_func(x):
49+
x.run()
50+
51+
3552
class JobDecoratorTest(TestCase):
3653
def setUp(self) -> None:
3754
get_queue("default").connection.flushall()
3855

3956
def test_all_job_methods_registered(self):
40-
self.assertEqual(4, len(JOB_METHODS_LIST))
57+
self.assertEqual(5, len(JOB_METHODS_LIST))
4158

4259
def test_job_decorator_no_params(self):
4360
test_job.delay()
@@ -92,3 +109,18 @@ def test_job_decorator_bad_queue(self):
92109
def test_job_bad_queue():
93110
time.sleep(1)
94111
return 1 + 1
112+
113+
def test_job_decorator_delay_with_param(self):
114+
queue_name = "default"
115+
long_running_func.delay(MyClass())
116+
117+
worker = create_worker(queue_name, burst=True)
118+
worker.work()
119+
120+
jobs_list = worker.queues[0].get_all_jobs()
121+
self.assertEqual(1, len(jobs_list))
122+
job = jobs_list[0]
123+
self.assertEqual(job.func, long_running_func)
124+
self.assertEqual(job.kwargs, {})
125+
self.assertEqual(job.status, JobStatus.FINISHED)
126+
self.assertEqual(job.args, (MyClass(),))

scheduler/tests/test_task_types/test_task_model.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
import zoneinfo
22
from datetime import datetime, timedelta
33

4+
import time_machine
45
from django.contrib.messages import get_messages
56
from django.core.exceptions import ValidationError
67
from django.test import override_settings
78
from django.urls import reverse
89
from django.utils import timezone
9-
from freezegun import freeze_time
1010

1111
from scheduler import settings
12-
from scheduler.models import TaskType, Task, TaskArg, TaskKwarg, run_task
1312
from scheduler.helpers.queues import get_queue
1413
from scheduler.helpers.queues import perform_job
14+
from scheduler.models import TaskType, Task, TaskArg, TaskKwarg, run_task
1515
from scheduler.redis_models import JobStatus, JobModel
1616
from scheduler.tests import jobs, test_settings # noqa
1717
from scheduler.tests.testtools import (
@@ -480,14 +480,14 @@ class TestSchedulableTask(TestBaseTask):
480480
# Currently ScheduledJob and RepeatableJob
481481
task_type = TaskType.ONCE
482482

483-
@freeze_time("2016-12-25")
483+
@time_machine.travel(datetime(2016, 12, 25))
484484
@override_settings(USE_TZ=False)
485485
def test_schedule_time_no_tz(self):
486486
task = task_factory(self.task_type)
487487
task.scheduled_time = datetime(2016, 12, 25, 8, 0, 0, tzinfo=None)
488488
self.assertEqual("2016-12-25T08:00:00", task._schedule_time().isoformat())
489489

490-
@freeze_time("2016-12-25")
490+
@time_machine.travel(datetime(2016, 12, 25))
491491
@override_settings(USE_TZ=True)
492492
def test_schedule_time_with_tz(self):
493493
task = task_factory(self.task_type)

0 commit comments

Comments
 (0)