Skip to content

Commit dc3d8ff

Browse files
fazeelghafoorn2ygkpre-commit-ci[bot]
authored
change token to TextField in AbstractAccessToken model (#1447)
* change token field to TextField in AbstractAccessToken model - add TokenChecksumField field - update middleware, validators, and views to use token checksums for token retrieval and validation - modified test migrations to include token_checksum field in "sampleaccesstoken" model - add test for token checksum field --------- Co-authored-by: Alan Crosswell <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com.> Co-authored-by: Alan Crosswell <[email protected]>
1 parent 51d9798 commit dc3d8ff

File tree

10 files changed

+83
-10
lines changed

10 files changed

+83
-10
lines changed

AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ Dylan Tack
5151
Eduardo Oliveira
5252
Egor Poderiagin
5353
Emanuele Palazzetti
54+
Fazeel Ghafoor
5455
Federico Dolce
5556
Florian Demmer
5657
Frederico Vieira

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1616

1717
## [unreleased]
1818
### Added
19+
* Add migration to include `token_checksum` field in AbstractAccessToken model.
1920
* #1404 Add a new setting `REFRESH_TOKEN_REUSE_PROTECTION`
2021
### Changed
22+
* Update token to TextField from CharField with 255 character limit and SHA-256 checksum in AbstractAccessToken model. Removing the 255 character limit enables supporting JWT tokens with additional claims
23+
24+
* Update middleware, validators, and views to use token checksums instead of token for token retrieval and validation.
2125
### Deprecated
2226
### Removed
2327
* #1425 Remove deprecated `RedirectURIValidator`, `WildcardSet` per #1345; `validate_logout_request` per #1274

oauth2_provider/middleware.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import hashlib
12
import logging
23

34
from django.contrib.auth import authenticate
@@ -55,7 +56,8 @@ def __call__(self, request):
5556
tokenstring = authheader.split()[1]
5657
AccessToken = get_access_token_model()
5758
try:
58-
token = AccessToken.objects.get(token=tokenstring)
59+
token_checksum = hashlib.sha256(tokenstring.encode("utf-8")).hexdigest()
60+
token = AccessToken.objects.get(token_checksum=token_checksum)
5961
request.access_token = token
6062
except AccessToken.DoesNotExist as e:
6163
log.exception(e)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Generated by Django 5.0.7 on 2024-07-29 23:13
2+
3+
import oauth2_provider.models
4+
from django.db import migrations, models
5+
from oauth2_provider.settings import oauth2_settings
6+
7+
class Migration(migrations.Migration):
8+
dependencies = [
9+
("oauth2_provider", "0011_refreshtoken_token_family"),
10+
migrations.swappable_dependency(oauth2_settings.ACCESS_TOKEN_MODEL),
11+
]
12+
13+
operations = [
14+
migrations.AddField(
15+
model_name="accesstoken",
16+
name="token_checksum",
17+
field=oauth2_provider.models.TokenChecksumField(
18+
blank=True, db_index=True, max_length=64, unique=True
19+
),
20+
),
21+
migrations.AlterField(
22+
model_name="accesstoken",
23+
name="token",
24+
field=models.TextField(),
25+
),
26+
]

oauth2_provider/models.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import hashlib
12
import logging
23
import time
34
import uuid
@@ -44,6 +45,14 @@ def pre_save(self, model_instance, add):
4445
return super().pre_save(model_instance, add)
4546

4647

48+
class TokenChecksumField(models.CharField):
49+
def pre_save(self, model_instance, add):
50+
token = getattr(model_instance, "token")
51+
checksum = hashlib.sha256(token.encode("utf-8")).hexdigest()
52+
setattr(model_instance, self.attname, checksum)
53+
return super().pre_save(model_instance, add)
54+
55+
4756
class AbstractApplication(models.Model):
4857
"""
4958
An Application instance represents a Client on the Authorization server.
@@ -379,8 +388,10 @@ class AbstractAccessToken(models.Model):
379388
null=True,
380389
related_name="refreshed_access_token",
381390
)
382-
token = models.CharField(
383-
max_length=255,
391+
token = models.TextField()
392+
token_checksum = TokenChecksumField(
393+
max_length=64,
394+
blank=True,
384395
unique=True,
385396
db_index=True,
386397
)

oauth2_provider/oauth2_validators.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import base64
22
import binascii
3+
import hashlib
34
import http.client
45
import inspect
56
import json
@@ -461,7 +462,12 @@ def validate_bearer_token(self, token, scopes, request):
461462
return False
462463

463464
def _load_access_token(self, token):
464-
return AccessToken.objects.select_related("application", "user").filter(token=token).first()
465+
token_checksum = hashlib.sha256(token.encode("utf-8")).hexdigest()
466+
return (
467+
AccessToken.objects.select_related("application", "user")
468+
.filter(token_checksum=token_checksum)
469+
.first()
470+
)
465471

466472
def validate_code(self, client_id, code, client, request, *args, **kwargs):
467473
try:

oauth2_provider/views/base.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import hashlib
12
import json
23
import logging
34
from urllib.parse import parse_qsl, urlencode, urlparse
@@ -289,7 +290,8 @@ def post(self, request, *args, **kwargs):
289290
if status == 200:
290291
access_token = json.loads(body).get("access_token")
291292
if access_token is not None:
292-
token = get_access_token_model().objects.get(token=access_token)
293+
token_checksum = hashlib.sha256(access_token.encode("utf-8")).hexdigest()
294+
token = get_access_token_model().objects.get(token_checksum=token_checksum)
293295
app_authorized.send(sender=self, request=request, token=token)
294296
response = HttpResponse(content=body, status=status)
295297

oauth2_provider/views/introspect.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import calendar
2+
import hashlib
23

34
from django.core.exceptions import ObjectDoesNotExist
45
from django.http import JsonResponse
@@ -24,8 +25,11 @@ class IntrospectTokenView(ClientProtectedScopedResourceView):
2425
@staticmethod
2526
def get_token_response(token_value=None):
2627
try:
28+
token_checksum = hashlib.sha256(token_value.encode("utf-8")).hexdigest()
2729
token = (
28-
get_access_token_model().objects.select_related("user", "application").get(token=token_value)
30+
get_access_token_model()
31+
.objects.select_related("user", "application")
32+
.get(token_checksum=token_checksum)
2933
)
3034
except ObjectDoesNotExist:
3135
return JsonResponse({"active": False}, status=200)

tests/migrations/0002_swapped_models.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,14 @@ class Migration(migrations.Migration):
118118
field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='s_refreshed_access_token', to=settings.OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL),
119119
),
120120
migrations.AddField(
121-
model_name='sampleaccesstoken',
122-
name='token',
123-
field=models.CharField(max_length=255, unique=True),
124-
preserve_default=False,
121+
model_name="sampleaccesstoken",
122+
name="token",
123+
field=models.TextField(),
124+
),
125+
migrations.AddField(
126+
model_name="sampleaccesstoken",
127+
name="token_checksum",
128+
field=models.CharField(max_length=64, unique=True, db_index=True),
125129
),
126130
migrations.AddField(
127131
model_name='sampleaccesstoken',

tests/test_models.py

+13
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import hashlib
2+
import secrets
13
from datetime import timedelta
24

35
import pytest
@@ -310,6 +312,17 @@ def test_expires_can_be_none(self):
310312
self.assertIsNone(access_token.expires)
311313
self.assertTrue(access_token.is_expired())
312314

315+
def test_token_checksum_field(self):
316+
token = secrets.token_urlsafe(32)
317+
access_token = AccessToken.objects.create(
318+
user=self.user,
319+
token=token,
320+
expires=timezone.now() + timedelta(hours=1),
321+
)
322+
expected_checksum = hashlib.sha256(token.encode()).hexdigest()
323+
324+
self.assertEqual(access_token.token_checksum, expected_checksum)
325+
313326

314327
class TestRefreshTokenModel(BaseTestModels):
315328
def test_str(self):

0 commit comments

Comments
 (0)