diff --git a/docs/make.jl b/docs/make.jl index 573fe9fc1..95bbe10f6 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -12,6 +12,7 @@ makedocs(; "sparsearrays.md", "tuples.md", "wrapping.md", + "dimnames.md", ] ) diff --git a/docs/src/dimnames.md b/docs/src/dimnames.md new file mode 100644 index 000000000..40d11ca06 --- /dev/null +++ b/docs/src/dimnames.md @@ -0,0 +1,9 @@ +# Named Dimensions Interface + +The following functions provide a common interface for interacting with named dimensions. + +```@docs +ArrayInterface.has_dimnames +ArrayInterface.dimnames +ArrayInterface.to_dims +``` diff --git a/docs/src/indexing.md b/docs/src/indexing.md index ed1282e74..81e2cee0c 100644 --- a/docs/src/indexing.md +++ b/docs/src/indexing.md @@ -46,4 +46,4 @@ and index translations. ArrayInterface.ArrayIndex ArrayInterface.GetIndex ArrayInterface.SetIndex! -``` \ No newline at end of file +``` diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index ed616a87f..e545f8a01 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -17,6 +17,14 @@ else end end end + +@assume_effects :total function _find_first_egal(v::T, vals::NTuple{N, T}) where {N, T} + for i in 1:N + getfield(vals, i, false) === v && return i + end + return 0 +end + @assume_effects :total __parameterless_type(T)=Base.typename(T).wrapper parameterless_type(x) = parameterless_type(typeof(x)) parameterless_type(x::Type) = __parameterless_type(x) @@ -1030,6 +1038,129 @@ ensures_sorted(@nospecialize( T::Type{<:AbstractRange})) = true ensures_sorted(T::Type) = is_forwarding_wrapper(T) ? ensures_sorted(parent_type(T)) : false ensures_sorted(@nospecialize(x)) = ensures_sorted(typeof(x)) +DIMNAMES_EXTENDED_HELP = """ +## Extended help + +Structures that explicitly provide named dimensions must define both `has_dimnames` and +`dimnames`. Wrappers that don't change the layout of their parent data and define +`is_forwarding_wrapper` will propagate these methods freely. All other wrappers must +define `has_dimnames` and `dimnames`. For example: + +```julia +function ArrayInterface.has_dimnames(T::Type{<:Wrapper}) + has_dimnames(ArrayInterface.parent_type(T)) +end + +function ArrayInterface.dimnames(x::Wrapper) + if has_dimnames(x) + # appropriately modify wrapped dimension names to reflect changes lazy changes + # in the parent data layout + modify_wrapped_dimnames(dimnames(parent(x)))::NTuple{ndims(x), Symbol} + else # need to return "blank" dimension name :_ when names aren't defined + ntuple(_ -> :_, ndims(x)) + end +end +``` + +In some cases `Wrapper` may modify some aspect of its parent data's layout that has no +impact on the dimension names (e.g., mapping offset indices to a parent array). In such +cases there may be no need to modify dimension names and simply defining +`ArrayInterface.dimnames(x::Wrapper) = dimnames(parent(x))` may be sufficient. + +Since the utlity of dimension names is highly specific to the domain they are used in, +there are very few explicit guidelines how they should be modified by wrappers. The most +important guideline is that `dimnames(x)` returns an instance of type +`NTuple{ndims(x), Symbol}`. +""" + +""" + has_dimnames(T::Type) -> Bool + +Returns `true` if instances of `T` have named dimensions. Structures overloading this +method are also responsible for defining [`ArrayInterface.dimnames`](@ref). + +See also: [`ArrayInterface.to_dims`](@ref) + +$(DIMNAMES_EXTENDED_HELP) +""" +has_dimnames(T::Type) = is_forwarding_wrapper(T) ? has_dimnames(parent_type(T)) : false + +""" + dimnames(x) -> NTuple{ndims(x), Symbol} + dimnames(x, dim::Integer) -> Symbol + dimnames(x, dim::Tuple{Vararg{Integer, N}}) -> NTuple{N, Symbol} + +Return the names of the dimensions for `x`. `:_` is used to indicate a dimension does not +have a name. Structures overloading this method are also responsible for defining +[`ArrayInterface.has_dimnames`](@ref). + +See also: [`ArrayInterface.to_dims`](@ref) + +$(DIMNAMES_EXTENDED_HELP) +""" +@inline function dimnames(x::X) where {X} + if is_forwarding_wrapper(X) + return dimnames(buffer(x)) + elseif isa(Base.IteratorSize(X), Base.HasShape) + return ntuple(_ -> :_, Val(ndims(X))) + else + return (:_,) + end +end +@inline function dimnames(x::X, dim::Tuple{Vararg{Integer, N}}) where {X, N} + has_dimnames(X) || return ntuple(_ -> :_, Val{N}()) + dnames = dimnames(x) + nd = nfields(dnames) + ntuple(Val{N}()) do i + dim_i = Int(getfield(dim, i)) + in(dim_i, 1:nd) ? getfield(dnames, dim_i, false) : :_ + end +end +@inline function dimnames(x::X, dim::Integer) where {X} + if dim in 1:(isa(Base.IteratorSize(X), Base.HasShape) ? ndims(X) : 1) + return getfield(dimnames(x), Int(dim), false) # already know is inbounds + else # trailing dim is unnamed + return :_ + end +end + +@noinline function _throw_dimname(s::Symbol) + throw(DimensionMismatch("dimension name $(s) not found")) +end + +""" + to_dims(x, d::Integer) -> Int + to_dims(x, d::Symbol) -> Int + to_dims(x, d::NTuple{N}) -> NTuple{N, Int} + +Return the dimension(s) of `x` corresponding to `d`. Symbols are converted to dimensions +by searching through dimension names (see [`dimnames`](@ref)). Integers may be converted +to `Int` but are otherwise returned as is. + +""" +to_dims(x, dim::Colon) = dim +to_dims(x, dim::Integer) = Int(dim) +function to_dims(x::X, s::Symbol) where {X} + dim = _find_first_egal(s, dimnames(x)) + dim === 0 && _throw_dimname(s) + return dim +end +to_dims(x, dims::Tuple{Vararg{Int}}) = dims +function to_dims(x::X, dims::Tuple{Vararg{Union{Symbol, Integer}, N}}) where {X, N} + dnames = dimnames(x) + ntuple(Val{N}()) do i + dim = getfield(dims, i, false) + if dim isa Symbol + dim_i = _find_first_egal(dim, dnames) + dim_i === 0 && _throw_dimname(dim) + dim_i + else + dim_i = to_dims(x, dim) + end + dim_i + end +end + ## Extensions import Requires diff --git a/test/core.jl b/test/core.jl index bd0cd6cf3..e35338100 100644 --- a/test/core.jl +++ b/test/core.jl @@ -7,6 +7,24 @@ using Random using SparseArrays using Test +struct NamedDimsWrapper{D,T,N,P<:AbstractArray{T,N}} <: AbstractArray{T,N} + parent::P + + NamedDimsWrapper{D}(p::P) where {D,P} = new{D,eltype(P),ndims(p),P}(p) +end + +ArrayInterface.has_dimnames(T::Type{<:NamedDimsWrapper}) = true +ArrayInterface.is_forwarding_wrapper(::Type{<:NamedDimsWrapper}) = true +ArrayInterface.parent_type(::Type{T}) where {P,T<:NamedDimsWrapper{<:Any,<:Any,<:Any,P}} = P +ArrayInterface.dimnames(::NamedDimsWrapper{D}) where {D} = D +Base.parent(x::NamedDimsWrapper) = getfield(x, :parent) +Base.size(x::NamedDimsWrapper) = size(parent(x)) +Base.IndexStyle(T::Type{<:NamedDimsWrapper}) = IndexStyle(parent_type(T)) +Base.@propagate_inbounds Base.getindex(x::NamedDimsWrapper, inds...) = parent(x)[inds...] +Base.@propagate_inbounds function Base.setindex!(x::NamedDimsWrapper, v, inds...) + setindex!(parent(x), v, inds...) +end + # ensure we are correctly parsing these ArrayInterface.@assume_effects :total foo(x::Bool) = x ArrayInterface.@assume_effects bar(x::Bool) = x @@ -282,4 +300,28 @@ end end @test ArrayInterface.ldlt_instance(SymTridiagonal(A' * A)) isa typeof(ldlt(SymTridiagonal(A' * A))) end -end \ No newline at end of file +end + +@testset "dimnames interface" begin + a = zeros(3, 4, 5); + nda = NamedDimsWrapper{(:x, :y, :z)}(a) + + @test !@inferred(ArrayInterface.has_dimnames(typeof(a))) + @test @inferred(ArrayInterface.has_dimnames(typeof(nda))) + + @test @inferred(ArrayInterface.dimnames(a)) === (:_, :_, :_) + @test @inferred(ArrayInterface.dimnames(nda)) === (:x, :y, :z) + @test @inferred(ArrayInterface.dimnames(nda, 1)) === :x + @test @inferred(ArrayInterface.dimnames(nda, (1, 2))) === (:x, :y) + @test @inferred(ArrayInterface.dimnames((1,))) === (:_,) + + @test @inferred(ArrayInterface.to_dims(nda, (:))) === Colon() + @test @inferred(ArrayInterface.to_dims(nda, 1)) === 1 + @test @inferred(ArrayInterface.to_dims(nda, :x)) === 1 + @test @inferred(ArrayInterface.to_dims(nda, (1, 2))) === (1, 2) + @test @inferred(ArrayInterface.to_dims(nda, (:x, :y))) === (1, 2) + @test @inferred(ArrayInterface.to_dims(nda, (:y, :x))) === (2, 1) + @test @inferred(ArrayInterface.to_dims(nda, (:y, 1))) === (2, 1) + + @test_throws DimensionMismatch ArrayInterface.to_dims(a, :x) +end