Skip to content

Commit 995d70c

Browse files
authored
Set explicit Origin in CORS preflight response if allow_credentials is True and allow_origins is wildcard (#1113)
* Set explicit Origin in CORS preflight response if allow_credentials is True and allow_origins is wildcard When making a preflight request, the browser makes no indication as to whether the actual subsequent request will pass up credentials. However, unless the preflight response explicitly allows the request's `Origin` in the `Access-Control-Response-Header`, the browser will fail the CORS check and prevent the actual follow-up CORS request. This means that responding with the `*` wildcard is not sufficient to allow preflighted credentialed requests. The current workaround is to provide an equivalently permissive `allow_origin_regex` pattern. The `simple_response()` code already performs similar logic which currently only applies to non-preflighted requests since the browser would never make a preflighted request that hits this code due to this issue: ``` if self.allow_all_origins and has_cookie: headers["Access-Control-Allow-Origin"] = origin ``` This just bring the two halves inline with each other. * Add Vary header to preflight response if allow_credentials * Use allow_explicit_origin() for preflight request_headers This simplifies the code slightly by using this recently added method. It has some trade-offs, though. We now construct a `MutableHeaders` instead of a simple `dict` when copying the pre-computed preflight headers, and we move the `Vary` header construction out of the pre-computation and into the call handler. I think it makes the code more maintainable and the added per-call computation is minimal. * Convert MutableHeaders to dict for PlainTextResponse * Revert back to dict() for preflight headers This also names and caches some of the boolean tests in __init__() which we use in later if-blocks. This follows the existing pattern in order to better self-document the code. * Clean up comments * Remove unused self.allow_credentials attribute
1 parent f5ecb53 commit 995d70c

File tree

2 files changed

+126
-13
lines changed

2 files changed

+126
-13
lines changed

starlette/middleware/cors.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -30,27 +30,32 @@ def __init__(
3030
if allow_origin_regex is not None:
3131
compiled_allow_origin_regex = re.compile(allow_origin_regex)
3232

33+
allow_all_origins = "*" in allow_origins
34+
allow_all_headers = "*" in allow_headers
35+
preflight_explicit_allow_origin = not allow_all_origins or allow_credentials
36+
3337
simple_headers = {}
34-
if "*" in allow_origins:
38+
if allow_all_origins:
3539
simple_headers["Access-Control-Allow-Origin"] = "*"
3640
if allow_credentials:
3741
simple_headers["Access-Control-Allow-Credentials"] = "true"
3842
if expose_headers:
3943
simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers)
4044

4145
preflight_headers = {}
42-
if "*" in allow_origins:
43-
preflight_headers["Access-Control-Allow-Origin"] = "*"
44-
else:
46+
if preflight_explicit_allow_origin:
47+
# The origin value will be set in preflight_response() if it is allowed.
4548
preflight_headers["Vary"] = "Origin"
49+
else:
50+
preflight_headers["Access-Control-Allow-Origin"] = "*"
4651
preflight_headers.update(
4752
{
4853
"Access-Control-Allow-Methods": ", ".join(allow_methods),
4954
"Access-Control-Max-Age": str(max_age),
5055
}
5156
)
5257
allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers))
53-
if allow_headers and "*" not in allow_headers:
58+
if allow_headers and not allow_all_headers:
5459
preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers)
5560
if allow_credentials:
5661
preflight_headers["Access-Control-Allow-Credentials"] = "true"
@@ -59,8 +64,9 @@ def __init__(
5964
self.allow_origins = allow_origins
6065
self.allow_methods = allow_methods
6166
self.allow_headers = [h.lower() for h in allow_headers]
62-
self.allow_all_origins = "*" in allow_origins
63-
self.allow_all_headers = "*" in allow_headers
67+
self.allow_all_origins = allow_all_origins
68+
self.allow_all_headers = allow_all_headers
69+
self.preflight_explicit_allow_origin = preflight_explicit_allow_origin
6470
self.allow_origin_regex = compiled_allow_origin_regex
6571
self.simple_headers = simple_headers
6672
self.preflight_headers = preflight_headers
@@ -105,11 +111,9 @@ def preflight_response(self, request_headers: Headers) -> Response:
105111
failures = []
106112

107113
if self.is_allowed_origin(origin=requested_origin):
108-
if not self.allow_all_origins:
109-
# If self.allow_all_origins is True, then the
110-
# "Access-Control-Allow-Origin" header is already set to "*".
111-
# If we only allow specific origins, then we have to mirror back
112-
# the Origin header in the response.
114+
if self.preflight_explicit_allow_origin:
115+
# The "else" case is already accounted for in self.preflight_headers
116+
# and the value would be "*".
113117
headers["Access-Control-Allow-Origin"] = requested_origin
114118
else:
115119
failures.append("origin")

tests/middleware/test_cors.py

+110-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,62 @@ def homepage(request):
2222

2323
client = TestClient(app)
2424

25+
# Test pre-flight response
26+
headers = {
27+
"Origin": "https://example.org",
28+
"Access-Control-Request-Method": "GET",
29+
"Access-Control-Request-Headers": "X-Example",
30+
}
31+
response = client.options("/", headers=headers)
32+
assert response.status_code == 200
33+
assert response.text == "OK"
34+
assert response.headers["access-control-allow-origin"] == "https://example.org"
35+
assert response.headers["access-control-allow-headers"] == "X-Example"
36+
assert response.headers["access-control-allow-credentials"] == "true"
37+
assert response.headers["vary"] == "Origin"
38+
39+
# Test standard response
40+
headers = {"Origin": "https://example.org"}
41+
response = client.get("/", headers=headers)
42+
assert response.status_code == 200
43+
assert response.text == "Homepage"
44+
assert response.headers["access-control-allow-origin"] == "*"
45+
assert response.headers["access-control-expose-headers"] == "X-Status"
46+
assert response.headers["access-control-allow-credentials"] == "true"
47+
48+
# Test standard credentialed response
49+
headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"}
50+
response = client.get("/", headers=headers)
51+
assert response.status_code == 200
52+
assert response.text == "Homepage"
53+
assert response.headers["access-control-allow-origin"] == "https://example.org"
54+
assert response.headers["access-control-expose-headers"] == "X-Status"
55+
assert response.headers["access-control-allow-credentials"] == "true"
56+
57+
# Test non-CORS response
58+
response = client.get("/")
59+
assert response.status_code == 200
60+
assert response.text == "Homepage"
61+
assert "access-control-allow-origin" not in response.headers
62+
63+
64+
def test_cors_allow_all_except_credentials():
65+
app = Starlette()
66+
67+
app.add_middleware(
68+
CORSMiddleware,
69+
allow_origins=["*"],
70+
allow_headers=["*"],
71+
allow_methods=["*"],
72+
expose_headers=["X-Status"],
73+
)
74+
75+
@app.route("/")
76+
def homepage(request):
77+
return PlainTextResponse("Homepage", status_code=200)
78+
79+
client = TestClient(app)
80+
2581
# Test pre-flight response
2682
headers = {
2783
"Origin": "https://example.org",
@@ -33,6 +89,8 @@ def homepage(request):
3389
assert response.text == "OK"
3490
assert response.headers["access-control-allow-origin"] == "*"
3591
assert response.headers["access-control-allow-headers"] == "X-Example"
92+
assert "access-control-allow-credentials" not in response.headers
93+
assert "vary" not in response.headers
3694

3795
# Test standard response
3896
headers = {"Origin": "https://example.org"}
@@ -41,6 +99,7 @@ def homepage(request):
4199
assert response.text == "Homepage"
42100
assert response.headers["access-control-allow-origin"] == "*"
43101
assert response.headers["access-control-expose-headers"] == "X-Status"
102+
assert "access-control-allow-credentials" not in response.headers
44103

45104
# Test non-CORS response
46105
response = client.get("/")
@@ -77,13 +136,15 @@ def homepage(request):
77136
assert response.headers["access-control-allow-headers"] == (
78137
"Accept, Accept-Language, Content-Language, Content-Type, X-Example"
79138
)
139+
assert "access-control-allow-credentials" not in response.headers
80140

81141
# Test standard response
82142
headers = {"Origin": "https://example.org"}
83143
response = client.get("/", headers=headers)
84144
assert response.status_code == 200
85145
assert response.text == "Homepage"
86146
assert response.headers["access-control-allow-origin"] == "https://example.org"
147+
assert "access-control-allow-credentials" not in response.headers
87148

88149
# Test non-CORS response
89150
response = client.get("/")
@@ -116,6 +177,38 @@ def homepage(request):
116177
response = client.options("/", headers=headers)
117178
assert response.status_code == 400
118179
assert response.text == "Disallowed CORS origin, method, headers"
180+
assert "access-control-allow-origin" not in response.headers
181+
182+
183+
def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed():
184+
app = Starlette()
185+
186+
app.add_middleware(
187+
CORSMiddleware,
188+
allow_origins=["*"],
189+
allow_methods=["POST"],
190+
allow_credentials=True,
191+
)
192+
193+
@app.route("/")
194+
def homepage(request):
195+
return # pragma: no cover
196+
197+
client = TestClient(app)
198+
199+
# Test pre-flight response
200+
headers = {
201+
"Origin": "https://example.org",
202+
"Access-Control-Request-Method": "POST",
203+
}
204+
response = client.options(
205+
"/",
206+
headers=headers,
207+
)
208+
assert response.status_code == 200
209+
assert response.headers["access-control-allow-origin"] == "https://example.org"
210+
assert response.headers["access-control-allow-credentials"] == "true"
211+
assert response.headers["vary"] == "Origin"
119212

120213

121214
def test_cors_preflight_allow_all_methods():
@@ -175,6 +268,7 @@ def test_cors_allow_origin_regex():
175268
CORSMiddleware,
176269
allow_headers=["X-Example", "Content-Type"],
177270
allow_origin_regex="https://.*",
271+
allow_credentials=True,
178272
)
179273

180274
@app.route("/")
@@ -189,8 +283,17 @@ def homepage(request):
189283
assert response.status_code == 200
190284
assert response.text == "Homepage"
191285
assert response.headers["access-control-allow-origin"] == "https://example.org"
286+
assert response.headers["access-control-allow-credentials"] == "true"
192287

193-
# Test diallowed standard response
288+
# Test standard credentialed response
289+
headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"}
290+
response = client.get("/", headers=headers)
291+
assert response.status_code == 200
292+
assert response.text == "Homepage"
293+
assert response.headers["access-control-allow-origin"] == "https://example.org"
294+
assert response.headers["access-control-allow-credentials"] == "true"
295+
296+
# Test disallowed standard response
194297
# Note that enforcement is a browser concern. The disallowed-ness is reflected
195298
# in the lack of an "access-control-allow-origin" header in the response.
196299
headers = {"Origin": "http://example.org"}
@@ -212,6 +315,7 @@ def homepage(request):
212315
assert response.headers["access-control-allow-headers"] == (
213316
"Accept, Accept-Language, Content-Language, Content-Type, X-Example"
214317
)
318+
assert response.headers["access-control-allow-credentials"] == "true"
215319

216320
# Test disallowed pre-flight response
217321
headers = {
@@ -249,6 +353,7 @@ def homepage(request):
249353
response.headers["access-control-allow-origin"]
250354
== "https://subdomain.example.org"
251355
)
356+
assert "access-control-allow-credentials" not in response.headers
252357

253358
# Test diallowed standard response
254359
headers = {"Origin": "https://subdomain.example.org.hacker.com"}
@@ -275,6 +380,7 @@ def homepage(request):
275380
assert response.status_code == 200
276381
assert response.text == "Homepage"
277382
assert response.headers["access-control-allow-origin"] == "https://example.org"
383+
assert "access-control-allow-credentials" not in response.headers
278384

279385

280386
def test_cors_vary_header_defaults_to_origin():
@@ -365,11 +471,14 @@ def homepage(request):
365471
client = TestClient(app)
366472
response = client.get("/", headers={"Origin": "https://someplace.org"})
367473
assert response.headers["access-control-allow-origin"] == "*"
474+
assert "access-control-allow-credentials" not in response.headers
368475

369476
response = client.get(
370477
"/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"}
371478
)
372479
assert response.headers["access-control-allow-origin"] == "https://someplace.org"
480+
assert "access-control-allow-credentials" not in response.headers
373481

374482
response = client.get("/", headers={"Origin": "https://someplace.org"})
375483
assert response.headers["access-control-allow-origin"] == "*"
484+
assert "access-control-allow-credentials" not in response.headers

0 commit comments

Comments
 (0)