diff --git a/Sources/TensorFlow/Core/DataTypes.swift b/Sources/TensorFlow/Core/DataTypes.swift index 472a32ce2..05740e3de 100644 --- a/Sources/TensorFlow/Core/DataTypes.swift +++ b/Sources/TensorFlow/Core/DataTypes.swift @@ -84,8 +84,7 @@ extension Int64: TensorFlowIndex {} public protocol TensorFlowFloatingPoint: TensorFlowScalar & BinaryFloatingPoint & Differentiable & ElementaryFunctions where Self.RawSignificand: FixedWidthInteger, - Self == Self.TangentVector, - Self == Self.AllDifferentiableVariables {} + Self == Self.TangentVector {} extension Float: TensorFlowFloatingPoint {} extension Double: TensorFlowFloatingPoint {} diff --git a/Sources/TensorFlow/Core/Tensor.swift b/Sources/TensorFlow/Core/Tensor.swift index d3d25e66f..d8eb05553 100644 --- a/Sources/TensorFlow/Core/Tensor.swift +++ b/Sources/TensorFlow/Core/Tensor.swift @@ -578,5 +578,4 @@ extension Tensor: PointwiseMultiplicative where Scalar: Numeric { extension Tensor: Differentiable where Scalar: TensorFlowFloatingPoint { public typealias TangentVector = Tensor - public typealias AllDifferentiableVariables = Tensor } diff --git a/Sources/TensorFlow/Layer.swift b/Sources/TensorFlow/Layer.swift index 4eb0a8718..757b3e763 100644 --- a/Sources/TensorFlow/Layer.swift +++ b/Sources/TensorFlow/Layer.swift @@ -14,8 +14,7 @@ public protocol Module: Differentiable, KeyPathIterable where TangentVector: VectorProtocol & ElementaryFunctions & - PointwiseMultiplicative & KeyPathIterable, - AllDifferentiableVariables == TangentVector { + PointwiseMultiplicative & KeyPathIterable { /// The input type of the layer. associatedtype Input /// The output type of the layer. @@ -55,7 +54,6 @@ public extension Layer { /// An empty struct representing empty `TangentVector`s for parameterless layers. public struct EmptyTangentVector: Differentiable, VectorProtocol, ElementaryFunctions, PointwiseMultiplicative, KeyPathIterable { - public typealias AllDifferentiableVariables = EmptyTangentVector public typealias VectorSpaceScalar = Float public func adding(_ x: Float) -> EmptyTangentVector { self } @@ -69,17 +67,12 @@ public struct EmptyTangentVector: Differentiable, VectorProtocol, ElementaryFunc /// A parameterless neural network layer. /// /// The `TangentVector` of parameterless layers is always `EmptyTangentVector`. -public protocol ParameterlessLayer: Layer where AllDifferentiableVariables == EmptyTangentVector { +public protocol ParameterlessLayer: Layer { @differentiable func callAsFunction(_ input: Input) -> Output } public extension ParameterlessLayer { - var allDifferentiableVariables: EmptyTangentVector { - get { EmptyTangentVector() } - set {} - } - mutating func move(along direction: EmptyTangentVector) {} } @@ -98,7 +91,7 @@ public extension Layer { @usableFromInline internal func _vjpInferring(from input: Input) -> (value: Output, pullback: (Output.TangentVector) - -> (AllDifferentiableVariables, Input.TangentVector)) { + -> (TangentVector, Input.TangentVector)) { withLearningPhase(LearningPhase.inference) { let (output, pullback) = appliedForBackpropagation(to: input) return (output, { v in pullback(v) }) diff --git a/Sources/TensorFlow/Layers/Upsampling.swift b/Sources/TensorFlow/Layers/Upsampling.swift index 554de6d2d..3c6f50056 100644 --- a/Sources/TensorFlow/Layers/Upsampling.swift +++ b/Sources/TensorFlow/Layers/Upsampling.swift @@ -93,7 +93,7 @@ public struct UpSampling3D: ParameterlessLayer private func _vjpRepeatingElements( _ input: Tensor, alongAxis axis: Int, count: Int - ) -> (Tensor, (Tensor) -> (AllDifferentiableVariables, Tensor)) { + ) -> (Tensor, (Tensor) -> (TangentVector, Tensor)) { let value = repeatingElements(input, alongAxis: axis, count: count) return (value, { v in let splits = Raw.split( diff --git a/Sources/TensorFlow/Loss.swift b/Sources/TensorFlow/Loss.swift index 1e5092475..d4454a15c 100644 --- a/Sources/TensorFlow/Loss.swift +++ b/Sources/TensorFlow/Loss.swift @@ -287,5 +287,8 @@ public func sigmoidCrossEntropy( ) -> Tensor { // This numerically stable implementation is based on the TensorFlow Python API. let maxLogitsWithZero = max(logits, Tensor(0)) - return reduction(maxLogitsWithZero - logits * labels + log(1 + exp(-abs(logits)))) + // Note: `result` is split into two lines to avoid the "compiler is unable to type-check this + // expression in reasonable time" error. + let result = log(1 + exp(-abs(logits))) + return reduction(maxLogitsWithZero - logits * labels + result) } diff --git a/Sources/TensorFlow/Optimizers/MomentumBased.swift b/Sources/TensorFlow/Optimizers/MomentumBased.swift index 50325064f..b83964c8f 100644 --- a/Sources/TensorFlow/Optimizers/MomentumBased.swift +++ b/Sources/TensorFlow/Optimizers/MomentumBased.swift @@ -55,14 +55,6 @@ public class RMSProp: Optimizer } public func update(_ model: inout Model, along direction: Model.TangentVector) { - update(&model.allDifferentiableVariables, along: direction) - } - - // TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed. - public func update( - _ model: inout Model.AllDifferentiableVariables, - along direction: Model.TangentVector - ) { step += 1 let learningRate = self.learningRate * 1 / (1 + decay * Float(step)) alpha = alpha * rho + direction .* direction * (1 - rho) @@ -107,14 +99,6 @@ public class AdaGrad: Optimizer } public func update(_ model: inout Model, along direction: Model.TangentVector) { - update(&model.allDifferentiableVariables, along: direction) - } - - // TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed. - public func update( - _ model: inout Model.AllDifferentiableVariables, - along direction: Model.TangentVector - ) { alpha = rho + direction .* direction let denominator = Model.TangentVector.sqrt(alpha) + epsilon model.move(along: -learningRate * direction ./ denominator) @@ -166,14 +150,6 @@ public class AdaDelta: Optimizer } public func update(_ model: inout Model, along direction: Model.TangentVector) { - update(&model.allDifferentiableVariables, along: direction) - } - - // TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed. - public func update( - _ model: inout Model.AllDifferentiableVariables, - along direction: Model.TangentVector - ) { step += 1 let learningRate = self.learningRate / (1 + decay * Float(step)) averageSquared = rho * averageSquared + (1 - rho) * direction .* direction @@ -230,15 +206,7 @@ public class Adam: Optimizer } public func update(_ model: inout Model, along direction: Model.TangentVector) { - update(&model.allDifferentiableVariables, along: direction) - } - - // TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed. - public func update( - _ model: inout Model.AllDifferentiableVariables, - along direction: Model.TangentVector - ) { - self.step += 1 + step += 1 let step = Float(self.step) let learningRate = self.learningRate * 1 / (1 + decay * step) // Note: `stepSize` and `secondMoments` are split into two lines to avoid the "compiler is @@ -262,8 +230,7 @@ public class Adam: Optimizer public class AdaMax: Optimizer where Model.TangentVector: VectorProtocol & PointwiseMultiplicative & ElementaryFunctions & KeyPathIterable, - Model.TangentVector.VectorSpaceScalar == Float, - Model.AllDifferentiableVariables == Model.TangentVector { + Model.TangentVector.VectorSpaceScalar == Float { public typealias Model = Model /// The learning rate. public var learningRate: Float @@ -304,15 +271,7 @@ public class AdaMax: Optimizer } public func update(_ model: inout Model, along direction: Model.TangentVector) { - update(&model.allDifferentiableVariables, along: direction) - } - - // TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed. - public func update( - _ model: inout Model.AllDifferentiableVariables, - along direction: Model.TangentVector - ) { - self.step += 1 + step += 1 let step = Float(self.step) let learningRate = self.learningRate * 1 / (1 + decay * step) // Note: `stepSize` is split into two lines to avoid the "compiler is unable to type-check @@ -323,11 +282,11 @@ public class AdaMax: Optimizer // Update `infinityNorm` using a key path approach because `max(_:_:)` cannot be // currently applied in a simpler manner. - for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { + for kp in infinityNorm.recursivelyAllWritableKeyPaths(to: Tensor.self) { infinityNorm[keyPath: kp] = max( beta2 * infinityNorm[keyPath: kp], abs(direction[keyPath: kp])) } - for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { + for kp in infinityNorm.recursivelyAllWritableKeyPaths(to: Tensor.self) { infinityNorm[keyPath: kp] = max( Double(beta2) * infinityNorm[keyPath: kp], abs(direction[keyPath: kp])) } @@ -347,8 +306,7 @@ public class AdaMax: Optimizer public class AMSGrad: Optimizer where Model.TangentVector: VectorProtocol & PointwiseMultiplicative & ElementaryFunctions & KeyPathIterable, - Model.TangentVector.VectorSpaceScalar == Float, - Model.AllDifferentiableVariables == Model.TangentVector { + Model.TangentVector.VectorSpaceScalar == Float { public typealias Model = Model /// The learning rate. public var learningRate: Float @@ -390,15 +348,7 @@ public class AMSGrad: Optimizer } public func update(_ model: inout Model, along direction: Model.TangentVector) { - update(&model.allDifferentiableVariables, along: direction) - } - - // TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed. - public func update( - _ model: inout Model.AllDifferentiableVariables, - along direction: Model.TangentVector - ) { - self.step += 1 + step += 1 let step = Float(self.step) let beta1Power = pow(beta1, step) let beta2Power = pow(beta2, step) @@ -413,11 +363,11 @@ public class AMSGrad: Optimizer // Update `secondMomentsMax` using a key path approach because `max(_:_:)` cannot be // currently applied in a simpler manner. - for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { + for kp in secondMomentsMax.recursivelyAllWritableKeyPaths(to: Tensor.self) { secondMomentsMax[keyPath: kp] = max( secondMomentsMax[keyPath: kp], secondMoments[keyPath: kp]) } - for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { + for kp in secondMomentsMax.recursivelyAllWritableKeyPaths(to: Tensor.self) { secondMomentsMax[keyPath: kp] = max( secondMomentsMax[keyPath: kp], secondMoments[keyPath: kp]) } diff --git a/Sources/TensorFlow/Optimizers/SGD.swift b/Sources/TensorFlow/Optimizers/SGD.swift index 4c71ad3e5..743899de7 100644 --- a/Sources/TensorFlow/Optimizers/SGD.swift +++ b/Sources/TensorFlow/Optimizers/SGD.swift @@ -52,14 +52,6 @@ public class SGD: Optimizer } public func update(_ model: inout Model, along direction: Model.TangentVector) { - update(&model.allDifferentiableVariables, along: direction) - } - - // TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed. - public func update( - _ model: inout Model.AllDifferentiableVariables, - along direction: Model.TangentVector - ) { step += 1 let learningRate = self.learningRate * 1 / (1 + decay * Float(step)) velocity = momentum * velocity - direction * learningRate diff --git a/Sources/third_party/Experimental/Complex.swift b/Sources/third_party/Experimental/Complex.swift index c09f3dcf3..271403c73 100644 --- a/Sources/third_party/Experimental/Complex.swift +++ b/Sources/third_party/Experimental/Complex.swift @@ -56,7 +56,6 @@ struct Complex { extension Complex: Differentiable where T: Differentiable { typealias TangentVector = Complex - typealias AllDifferentiableVariables = Complex } extension Complex { diff --git a/Tests/TensorFlowTests/OptimizersTests.swift b/Tests/TensorFlowTests/OptimizersTests.swift index 9c910a0d4..e0b30d5e8 100644 --- a/Tests/TensorFlowTests/OptimizersTests.swift +++ b/Tests/TensorFlowTests/OptimizersTests.swift @@ -56,7 +56,7 @@ final class OptimizerTests: XCTestCase { let ŷ = classifier(x) return meanSquaredError(predicted: ŷ, expected: y) } - optimizer.update(&classifier.allDifferentiableVariables, along: 𝛁model) + optimizer.update(&classifier, along: 𝛁model) } // trained classifier should return valid values diff --git a/Tests/TensorFlowTests/SequentialTests.swift b/Tests/TensorFlowTests/SequentialTests.swift index 5af8d819a..041769440 100644 --- a/Tests/TensorFlowTests/SequentialTests.swift +++ b/Tests/TensorFlowTests/SequentialTests.swift @@ -52,19 +52,12 @@ final class SequentialTests: XCTestCase { return meanSquaredError(predicted: ŷ, expected: y) } sgd.update(&model, along: 𝛁model) - sgd.update(&model.allDifferentiableVariables, along: 𝛁model) rmsprop.update(&model, along: 𝛁model) - rmsprop.update(&model.allDifferentiableVariables, along: 𝛁model) adam.update(&model, along: 𝛁model) - adam.update(&model.allDifferentiableVariables, along: 𝛁model) adamax.update(&model, along: 𝛁model) - adamax.update(&model.allDifferentiableVariables, along: 𝛁model) amsgrad.update(&model, along: 𝛁model) - amsgrad.update(&model.allDifferentiableVariables, along: 𝛁model) adagrad.update(&model, along: 𝛁model) - adagrad.update(&model.allDifferentiableVariables, along: 𝛁model) adadelta.update(&model, along: 𝛁model) - adadelta.update(&model.allDifferentiableVariables, along: 𝛁model) } } XCTAssertEqual(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]),