Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Deprecate Differentiable.AllDifferentiableVariables. #419

Merged
merged 3 commits into from
Aug 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions Sources/TensorFlow/Core/DataTypes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ extension Int64: TensorFlowIndex {}
public protocol TensorFlowFloatingPoint:
TensorFlowScalar & BinaryFloatingPoint & Differentiable & ElementaryFunctions
where Self.RawSignificand: FixedWidthInteger,
Self == Self.TangentVector,
Self == Self.AllDifferentiableVariables {}
Self == Self.TangentVector {}

extension Float: TensorFlowFloatingPoint {}
extension Double: TensorFlowFloatingPoint {}
Expand Down
1 change: 0 additions & 1 deletion Sources/TensorFlow/Core/Tensor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -578,5 +578,4 @@ extension Tensor: PointwiseMultiplicative where Scalar: Numeric {

extension Tensor: Differentiable where Scalar: TensorFlowFloatingPoint {
public typealias TangentVector = Tensor
public typealias AllDifferentiableVariables = Tensor
}
13 changes: 3 additions & 10 deletions Sources/TensorFlow/Layer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

public protocol Module: Differentiable, KeyPathIterable
where TangentVector: VectorProtocol & ElementaryFunctions &
PointwiseMultiplicative & KeyPathIterable,
AllDifferentiableVariables == TangentVector {
PointwiseMultiplicative & KeyPathIterable {
/// The input type of the layer.
associatedtype Input
/// The output type of the layer.
Expand Down Expand Up @@ -55,7 +54,6 @@ public extension Layer {
/// An empty struct representing empty `TangentVector`s for parameterless layers.
public struct EmptyTangentVector: Differentiable, VectorProtocol, ElementaryFunctions,
PointwiseMultiplicative, KeyPathIterable {
public typealias AllDifferentiableVariables = EmptyTangentVector
public typealias VectorSpaceScalar = Float

public func adding(_ x: Float) -> EmptyTangentVector { self }
Expand All @@ -69,17 +67,12 @@ public struct EmptyTangentVector: Differentiable, VectorProtocol, ElementaryFunc
/// A parameterless neural network layer.
///
/// The `TangentVector` of parameterless layers is always `EmptyTangentVector`.
public protocol ParameterlessLayer: Layer where AllDifferentiableVariables == EmptyTangentVector {
public protocol ParameterlessLayer: Layer {
@differentiable
func callAsFunction(_ input: Input) -> Output
}

public extension ParameterlessLayer {
var allDifferentiableVariables: EmptyTangentVector {
get { EmptyTangentVector() }
set {}
}

mutating func move(along direction: EmptyTangentVector) {}
}

Expand All @@ -98,7 +91,7 @@ public extension Layer {
@usableFromInline
internal func _vjpInferring(from input: Input)
-> (value: Output, pullback: (Output.TangentVector)
-> (AllDifferentiableVariables, Input.TangentVector)) {
-> (TangentVector, Input.TangentVector)) {
withLearningPhase(LearningPhase.inference) {
let (output, pullback) = appliedForBackpropagation(to: input)
return (output, { v in pullback(v) })
Expand Down
2 changes: 1 addition & 1 deletion Sources/TensorFlow/Layers/Upsampling.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public struct UpSampling3D<Scalar: TensorFlowFloatingPoint>: ParameterlessLayer

private func _vjpRepeatingElements(
_ input: Tensor<Scalar>, alongAxis axis: Int, count: Int
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> (AllDifferentiableVariables, Tensor<Scalar>)) {
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> (TangentVector, Tensor<Scalar>)) {
let value = repeatingElements(input, alongAxis: axis, count: count)
return (value, { v in
let splits = Raw.split(
Expand Down
5 changes: 4 additions & 1 deletion Sources/TensorFlow/Loss.swift
Original file line number Diff line number Diff line change
Expand Up @@ -287,5 +287,8 @@ public func sigmoidCrossEntropy<Scalar: TensorFlowFloatingPoint>(
) -> Tensor<Scalar> {
// This numerically stable implementation is based on the TensorFlow Python API.
let maxLogitsWithZero = max(logits, Tensor(0))
return reduction(maxLogitsWithZero - logits * labels + log(1 + exp(-abs(logits))))
// Note: `result` is split into two lines to avoid the "compiler is unable to type-check this
// expression in reasonable time" error.
let result = log(1 + exp(-abs(logits)))
return reduction(maxLogitsWithZero - logits * labels + result)
}
68 changes: 9 additions & 59 deletions Sources/TensorFlow/Optimizers/MomentumBased.swift
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,6 @@ public class RMSProp<Model: Differentiable>: Optimizer
}

public func update(_ model: inout Model, along direction: Model.TangentVector) {
update(&model.allDifferentiableVariables, along: direction)
}

// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
public func update(
_ model: inout Model.AllDifferentiableVariables,
along direction: Model.TangentVector
) {
step += 1
let learningRate = self.learningRate * 1 / (1 + decay * Float(step))
alpha = alpha * rho + direction .* direction * (1 - rho)
Expand Down Expand Up @@ -107,14 +99,6 @@ public class AdaGrad<Model: Differentiable>: Optimizer
}

public func update(_ model: inout Model, along direction: Model.TangentVector) {
update(&model.allDifferentiableVariables, along: direction)
}

// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
public func update(
_ model: inout Model.AllDifferentiableVariables,
along direction: Model.TangentVector
) {
alpha = rho + direction .* direction
let denominator = Model.TangentVector.sqrt(alpha) + epsilon
model.move(along: -learningRate * direction ./ denominator)
Expand Down Expand Up @@ -166,14 +150,6 @@ public class AdaDelta<Model: Differentiable>: Optimizer
}

public func update(_ model: inout Model, along direction: Model.TangentVector) {
update(&model.allDifferentiableVariables, along: direction)
}

// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
public func update(
_ model: inout Model.AllDifferentiableVariables,
along direction: Model.TangentVector
) {
step += 1
let learningRate = self.learningRate / (1 + decay * Float(step))
averageSquared = rho * averageSquared + (1 - rho) * direction .* direction
Expand Down Expand Up @@ -230,15 +206,7 @@ public class Adam<Model: Differentiable>: Optimizer
}

public func update(_ model: inout Model, along direction: Model.TangentVector) {
update(&model.allDifferentiableVariables, along: direction)
}

// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
public func update(
_ model: inout Model.AllDifferentiableVariables,
along direction: Model.TangentVector
) {
self.step += 1
step += 1
let step = Float(self.step)
let learningRate = self.learningRate * 1 / (1 + decay * step)
// Note: `stepSize` and `secondMoments` are split into two lines to avoid the "compiler is
Expand All @@ -262,8 +230,7 @@ public class Adam<Model: Differentiable>: Optimizer
public class AdaMax<Model: Differentiable & KeyPathIterable>: Optimizer
where Model.TangentVector: VectorProtocol & PointwiseMultiplicative &
ElementaryFunctions & KeyPathIterable,
Model.TangentVector.VectorSpaceScalar == Float,
Model.AllDifferentiableVariables == Model.TangentVector {
Model.TangentVector.VectorSpaceScalar == Float {
public typealias Model = Model
/// The learning rate.
public var learningRate: Float
Expand Down Expand Up @@ -304,15 +271,7 @@ public class AdaMax<Model: Differentiable & KeyPathIterable>: Optimizer
}

public func update(_ model: inout Model, along direction: Model.TangentVector) {
update(&model.allDifferentiableVariables, along: direction)
}

// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
public func update(
_ model: inout Model.AllDifferentiableVariables,
along direction: Model.TangentVector
) {
self.step += 1
step += 1
let step = Float(self.step)
let learningRate = self.learningRate * 1 / (1 + decay * step)
// Note: `stepSize` is split into two lines to avoid the "compiler is unable to type-check
Expand All @@ -323,11 +282,11 @@ public class AdaMax<Model: Differentiable & KeyPathIterable>: Optimizer

// Update `infinityNorm` using a key path approach because `max(_:_:)` cannot be
// currently applied in a simpler manner.
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
for kp in infinityNorm.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
infinityNorm[keyPath: kp] = max(
beta2 * infinityNorm[keyPath: kp], abs(direction[keyPath: kp]))
}
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
for kp in infinityNorm.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
infinityNorm[keyPath: kp] = max(
Double(beta2) * infinityNorm[keyPath: kp], abs(direction[keyPath: kp]))
}
Expand All @@ -347,8 +306,7 @@ public class AdaMax<Model: Differentiable & KeyPathIterable>: Optimizer
public class AMSGrad<Model: Differentiable & KeyPathIterable>: Optimizer
where Model.TangentVector: VectorProtocol & PointwiseMultiplicative &
ElementaryFunctions & KeyPathIterable,
Model.TangentVector.VectorSpaceScalar == Float,
Model.AllDifferentiableVariables == Model.TangentVector {
Model.TangentVector.VectorSpaceScalar == Float {
public typealias Model = Model
/// The learning rate.
public var learningRate: Float
Expand Down Expand Up @@ -390,15 +348,7 @@ public class AMSGrad<Model: Differentiable & KeyPathIterable>: Optimizer
}

public func update(_ model: inout Model, along direction: Model.TangentVector) {
update(&model.allDifferentiableVariables, along: direction)
}

// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
public func update(
_ model: inout Model.AllDifferentiableVariables,
along direction: Model.TangentVector
) {
self.step += 1
step += 1
let step = Float(self.step)
let beta1Power = pow(beta1, step)
let beta2Power = pow(beta2, step)
Expand All @@ -413,11 +363,11 @@ public class AMSGrad<Model: Differentiable & KeyPathIterable>: Optimizer

// Update `secondMomentsMax` using a key path approach because `max(_:_:)` cannot be
// currently applied in a simpler manner.
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
for kp in secondMomentsMax.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
secondMomentsMax[keyPath: kp] = max(
secondMomentsMax[keyPath: kp], secondMoments[keyPath: kp])
}
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
for kp in secondMomentsMax.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
secondMomentsMax[keyPath: kp] = max(
secondMomentsMax[keyPath: kp], secondMoments[keyPath: kp])
}
Expand Down
8 changes: 0 additions & 8 deletions Sources/TensorFlow/Optimizers/SGD.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,6 @@ public class SGD<Model: Differentiable>: Optimizer
}

public func update(_ model: inout Model, along direction: Model.TangentVector) {
update(&model.allDifferentiableVariables, along: direction)
}

// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
public func update(
_ model: inout Model.AllDifferentiableVariables,
along direction: Model.TangentVector
) {
step += 1
let learningRate = self.learningRate * 1 / (1 + decay * Float(step))
velocity = momentum * velocity - direction * learningRate
Expand Down
1 change: 0 additions & 1 deletion Sources/third_party/Experimental/Complex.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ struct Complex<T: FloatingPoint> {

extension Complex: Differentiable where T: Differentiable {
typealias TangentVector = Complex
typealias AllDifferentiableVariables = Complex
}

extension Complex {
Expand Down
2 changes: 1 addition & 1 deletion Tests/TensorFlowTests/OptimizersTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ final class OptimizerTests: XCTestCase {
let ŷ = classifier(x)
return meanSquaredError(predicted: ŷ, expected: y)
}
optimizer.update(&classifier.allDifferentiableVariables, along: 𝛁model)
optimizer.update(&classifier, along: 𝛁model)
}

// trained classifier should return valid values
Expand Down
7 changes: 0 additions & 7 deletions Tests/TensorFlowTests/SequentialTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,12 @@ final class SequentialTests: XCTestCase {
return meanSquaredError(predicted: ŷ, expected: y)
}
sgd.update(&model, along: 𝛁model)
sgd.update(&model.allDifferentiableVariables, along: 𝛁model)
rmsprop.update(&model, along: 𝛁model)
rmsprop.update(&model.allDifferentiableVariables, along: 𝛁model)
adam.update(&model, along: 𝛁model)
adam.update(&model.allDifferentiableVariables, along: 𝛁model)
adamax.update(&model, along: 𝛁model)
adamax.update(&model.allDifferentiableVariables, along: 𝛁model)
amsgrad.update(&model, along: 𝛁model)
amsgrad.update(&model.allDifferentiableVariables, along: 𝛁model)
adagrad.update(&model, along: 𝛁model)
adagrad.update(&model.allDifferentiableVariables, along: 𝛁model)
adadelta.update(&model, along: 𝛁model)
adadelta.update(&model.allDifferentiableVariables, along: 𝛁model)
}
}
XCTAssertEqual(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]),
Expand Down