Skip to content

Commit 4258a60

Browse files
Add a preference system for turning on/off slow fallbacks
This gives a good way to balance development vs usage. For development, you want to just error if you hit any slower path. But for users, code should just work. Thus the slower fallbacks were given a preference system for allowing error throwing, without forcing all users to have to always see errors on new types just for more optimizations.
1 parent e1870b7 commit 4258a60

File tree

1 file changed

+52
-36
lines changed

1 file changed

+52
-36
lines changed

src/ArrayInterface.jl

+52-36
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module ArrayInterface
22

3+
const SLOWFALLBACKS = @load_preference("slow_fallbacks", true)
4+
35
using LinearAlgebra
46
using SparseArrays
57
using SuiteSparse
@@ -282,7 +284,7 @@ function ismutable end
282284
ismutable(::Type{T}) -> Bool
283285
284286
Query whether instances of type `T` are mutable or not, see
285-
https://github.com./JuliaDiffEq/RecursiveArrayTools.jl/issues/19.
287+
https://github.com./SciML/RecursiveArrayTools.jl/issues/19.
286288
"""
287289
ismutable(x) = ismutable(typeof(x))
288290
function ismutable(::Type{T}) where {T <: AbstractArray}
@@ -460,12 +462,15 @@ Returns the number.
460462
"""
461463
bunchkaufman_instance(a::Number) = a
462464

463-
"""
464-
bunchkaufman_instance(a::Any) -> cholesky(a, check=false)
465+
@static if SLOWFALLBACKS
466+
"""
467+
bunchkaufman_instance(a::Any) -> cholesky(a, check=false)
465468
466-
Returns the number.
467-
"""
468-
bunchkaufman_instance(a::Any) = bunchkaufman(a, check = false)
469+
Slow fallback which gets the instance via factorization. Should get
470+
specialized for new matrix types.
471+
"""
472+
bunchkaufman_instance(a::Any) = bunchkaufman(a, check = false)
473+
end
469474

470475
"""
471476
cholesky_instance(A, pivot = LinearAlgebra.RowMaximum()) -> cholesky_factorization_instance
@@ -487,13 +492,15 @@ Returns the number.
487492
"""
488493
cholesky_instance(a::Number, pivot = LinearAlgebra.RowMaximum()) = a
489494

490-
"""
491-
cholesky_instance(a::Any, pivot = LinearAlgebra.RowMaximum()) -> cholesky(a, check=false)
495+
@static if SLOWFALLBACKS
496+
"""
497+
cholesky_instance(a::Any, pivot = LinearAlgebra.RowMaximum()) -> cholesky(a, check=false)
492498
493-
Slow fallback which gets the instance via factorization. Should get
494-
specialized for new matrix types.
495-
"""
496-
cholesky_instance(a::Any, pivot = LinearAlgebra.RowMaximum()) = cholesky(a, pivot, check = false)
499+
Slow fallback which gets the instance via factorization. Should get
500+
specialized for new matrix types.
501+
"""
502+
cholesky_instance(a::Any, pivot = LinearAlgebra.RowMaximum()) = cholesky(a, pivot, check = false)
503+
end
497504

498505
"""
499506
ldlt_instance(A) -> ldlt_factorization_instance
@@ -515,13 +522,15 @@ Returns the number.
515522
"""
516523
ldlt_instance(a::Number) = a
517524

518-
"""
519-
ldlt_instance(a::Any) -> ldlt(a, check=false)
525+
@static if SLOWFALLBACKS
526+
"""
527+
ldlt_instance(a::Any) -> ldlt(a, check=false)
520528
521-
Slow fallback which gets the instance via factorization. Should get
522-
specialized for new matrix types.
523-
"""
524-
ldlt_instance(a::Any) = ldlt(a)
529+
Slow fallback which gets the instance via factorization. Should get
530+
specialized for new matrix types.
531+
"""
532+
ldlt_instance(a::Any) = ldlt(a)
533+
end
525534

526535
"""
527536
lu_instance(A) -> lu_factorization_instance
@@ -558,13 +567,15 @@ Returns the number.
558567
"""
559568
lu_instance(a::Number) = a
560569

561-
"""
562-
lu_instance(a::Any) -> lu(a, check=false)
570+
@static if SLOWFALLBACKS
571+
"""
572+
lu_instance(a::Any) -> lu(a, check=false)
563573
564-
Slow fallback which gets the instance via factorization. Should get
565-
specialized for new matrix types.
566-
"""
567-
lu_instance(a::Any) = lu(a, check = false)
574+
Slow fallback which gets the instance via factorization. Should get
575+
specialized for new matrix types.
576+
"""
577+
lu_instance(a::Any) = lu(a, check = false)
578+
end
568579

569580
"""
570581
qr_instance(A) -> qr_factorization_instance
@@ -588,13 +599,15 @@ Returns the number.
588599
"""
589600
qr_instance(a::Number) = a
590601

591-
"""
592-
qr_instance(a::Any) -> qr(a)
602+
@static if SLOWFALLBACKS
603+
"""
604+
qr_instance(a::Any) -> qr(a)
593605
594-
Slow fallback which gets the instance via factorization. Should get
595-
specialized for new matrix types.
596-
"""
597-
qr_instance(a::Any) = qr(a)# check = false)
606+
Slow fallback which gets the instance via factorization. Should get
607+
specialized for new matrix types.
608+
"""
609+
qr_instance(a::Any) = qr(a)# check = false)
610+
end
598611

599612
"""
600613
svd_instance(A) -> qr_factorization_instance
@@ -613,13 +626,15 @@ Returns the number.
613626
"""
614627
svd_instance(a::Number) = a
615628

616-
"""
617-
svd_instance(a::Any) -> svd(a)
629+
@static if SLOWFALLBACKS
630+
"""
631+
svd_instance(a::Any) -> svd(a)
618632
619-
Slow fallback which gets the instance via factorization. Should get
620-
specialized for new matrix types.
621-
"""
622-
svd_instance(a::Any) = svd(a) #check = false)
633+
Slow fallback which gets the instance via factorization. Should get
634+
specialized for new matrix types.
635+
"""
636+
svd_instance(a::Any) = svd(a) #check = false)
637+
end
623638

624639
"""
625640
safevec(v)
@@ -1034,3 +1049,4 @@ import Requires
10341049
end
10351050

10361051
end # module
1052+

0 commit comments

Comments
 (0)