Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit bcdc695

Browse files
committed
Deprecate Differentiable.AllDifferentiableVariables.
Remove usages of `AllDifferentiableVariables` and `var allDifferentiableVariables`.
1 parent faf540a commit bcdc695

File tree

6 files changed

+14
-81
lines changed

6 files changed

+14
-81
lines changed

Sources/TensorFlow/Core/DataTypes.swift

+1-2
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,7 @@ extension Int64: TensorFlowIndex {}
8484
public protocol TensorFlowFloatingPoint:
8585
TensorFlowScalar & BinaryFloatingPoint & Differentiable & ElementaryFunctions
8686
where Self.RawSignificand: FixedWidthInteger,
87-
Self == Self.TangentVector,
88-
Self == Self.AllDifferentiableVariables {}
87+
Self == Self.TangentVector {}
8988

9089
extension Float: TensorFlowFloatingPoint {}
9190
extension Double: TensorFlowFloatingPoint {}

Sources/TensorFlow/Core/Tensor.swift

-1
Original file line numberDiff line numberDiff line change
@@ -578,5 +578,4 @@ extension Tensor: PointwiseMultiplicative where Scalar: Numeric {
578578

579579
extension Tensor: Differentiable where Scalar: TensorFlowFloatingPoint {
580580
public typealias TangentVector = Tensor
581-
public typealias AllDifferentiableVariables = Tensor
582581
}

Sources/TensorFlow/Layer.swift

+3-10
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414

1515
public protocol Module: Differentiable, KeyPathIterable
1616
where TangentVector: VectorProtocol & ElementaryFunctions &
17-
PointwiseMultiplicative & KeyPathIterable,
18-
AllDifferentiableVariables == TangentVector {
17+
PointwiseMultiplicative & KeyPathIterable {
1918
/// The input type of the layer.
2019
associatedtype Input
2120
/// The output type of the layer.
@@ -55,7 +54,6 @@ public extension Layer {
5554
/// An empty struct representing empty `TangentVector`s for parameterless layers.
5655
public struct EmptyTangentVector: Differentiable, VectorProtocol, ElementaryFunctions,
5756
PointwiseMultiplicative, KeyPathIterable {
58-
public typealias AllDifferentiableVariables = EmptyTangentVector
5957
public typealias VectorSpaceScalar = Float
6058

6159
public func adding(_ x: Float) -> EmptyTangentVector { self }
@@ -69,17 +67,12 @@ public struct EmptyTangentVector: Differentiable, VectorProtocol, ElementaryFunc
6967
/// A parameterless neural network layer.
7068
///
7169
/// The `TangentVector` of parameterless layers is always `EmptyTangentVector`.
72-
public protocol ParameterlessLayer: Layer where AllDifferentiableVariables == EmptyTangentVector {
70+
public protocol ParameterlessLayer: Layer {
7371
@differentiable
7472
func callAsFunction(_ input: Input) -> Output
7573
}
7674

7775
public extension ParameterlessLayer {
78-
var allDifferentiableVariables: EmptyTangentVector {
79-
get { EmptyTangentVector() }
80-
set {}
81-
}
82-
8376
mutating func move(along direction: EmptyTangentVector) {}
8477
}
8578

@@ -98,7 +91,7 @@ public extension Layer {
9891
@usableFromInline
9992
internal func _vjpInferring(from input: Input)
10093
-> (value: Output, pullback: (Output.TangentVector)
101-
-> (AllDifferentiableVariables, Input.TangentVector)) {
94+
-> (TangentVector, Input.TangentVector)) {
10295
withLearningPhase(LearningPhase.inference) {
10396
let (output, pullback) = appliedForBackpropagation(to: input)
10497
return (output, { v in pullback(v) })

Sources/TensorFlow/Layers/Upsampling.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ public struct UpSampling3D<Scalar: TensorFlowFloatingPoint>: ParameterlessLayer
9393

9494
private func _vjpRepeatingElements(
9595
_ input: Tensor<Scalar>, alongAxis axis: Int, count: Int
96-
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> (AllDifferentiableVariables, Tensor<Scalar>)) {
96+
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> (TangentVector, Tensor<Scalar>)) {
9797
let value = repeatingElements(input, alongAxis: axis, count: count)
9898
return (value, { v in
9999
let splits = Raw.split(

Sources/TensorFlow/Optimizers/MomentumBased.swift

+9-59
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,6 @@ public class RMSProp<Model: Differentiable>: Optimizer
5555
}
5656

5757
public func update(_ model: inout Model, along direction: Model.TangentVector) {
58-
update(&model.allDifferentiableVariables, along: direction)
59-
}
60-
61-
// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
62-
public func update(
63-
_ model: inout Model.AllDifferentiableVariables,
64-
along direction: Model.TangentVector
65-
) {
6658
step += 1
6759
let learningRate = self.learningRate * 1 / (1 + decay * Float(step))
6860
alpha = alpha * rho + direction .* direction * (1 - rho)
@@ -107,14 +99,6 @@ public class AdaGrad<Model: Differentiable>: Optimizer
10799
}
108100

109101
public func update(_ model: inout Model, along direction: Model.TangentVector) {
110-
update(&model.allDifferentiableVariables, along: direction)
111-
}
112-
113-
// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
114-
public func update(
115-
_ model: inout Model.AllDifferentiableVariables,
116-
along direction: Model.TangentVector
117-
) {
118102
alpha = rho + direction .* direction
119103
let denominator = Model.TangentVector.sqrt(alpha) + epsilon
120104
model.move(along: -learningRate * direction ./ denominator)
@@ -166,14 +150,6 @@ public class AdaDelta<Model: Differentiable>: Optimizer
166150
}
167151

168152
public func update(_ model: inout Model, along direction: Model.TangentVector) {
169-
update(&model.allDifferentiableVariables, along: direction)
170-
}
171-
172-
// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
173-
public func update(
174-
_ model: inout Model.AllDifferentiableVariables,
175-
along direction: Model.TangentVector
176-
) {
177153
step += 1
178154
let learningRate = self.learningRate / (1 + decay * Float(step))
179155
averageSquared = rho * averageSquared + (1 - rho) * direction .* direction
@@ -230,15 +206,7 @@ public class Adam<Model: Differentiable>: Optimizer
230206
}
231207

232208
public func update(_ model: inout Model, along direction: Model.TangentVector) {
233-
update(&model.allDifferentiableVariables, along: direction)
234-
}
235-
236-
// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
237-
public func update(
238-
_ model: inout Model.AllDifferentiableVariables,
239-
along direction: Model.TangentVector
240-
) {
241-
self.step += 1
209+
step += 1
242210
let step = Float(self.step)
243211
let learningRate = self.learningRate * 1 / (1 + decay * step)
244212
// Note: `stepSize` and `secondMoments` are split into two lines to avoid the "compiler is
@@ -262,8 +230,7 @@ public class Adam<Model: Differentiable>: Optimizer
262230
public class AdaMax<Model: Differentiable & KeyPathIterable>: Optimizer
263231
where Model.TangentVector: VectorProtocol & PointwiseMultiplicative &
264232
ElementaryFunctions & KeyPathIterable,
265-
Model.TangentVector.VectorSpaceScalar == Float,
266-
Model.AllDifferentiableVariables == Model.TangentVector {
233+
Model.TangentVector.VectorSpaceScalar == Float {
267234
public typealias Model = Model
268235
/// The learning rate.
269236
public var learningRate: Float
@@ -304,15 +271,7 @@ public class AdaMax<Model: Differentiable & KeyPathIterable>: Optimizer
304271
}
305272

306273
public func update(_ model: inout Model, along direction: Model.TangentVector) {
307-
update(&model.allDifferentiableVariables, along: direction)
308-
}
309-
310-
// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
311-
public func update(
312-
_ model: inout Model.AllDifferentiableVariables,
313-
along direction: Model.TangentVector
314-
) {
315-
self.step += 1
274+
step += 1
316275
let step = Float(self.step)
317276
let learningRate = self.learningRate * 1 / (1 + decay * step)
318277
// Note: `stepSize` is split into two lines to avoid the "compiler is unable to type-check
@@ -323,11 +282,11 @@ public class AdaMax<Model: Differentiable & KeyPathIterable>: Optimizer
323282

324283
// Update `infinityNorm` using a key path approach because `max(_:_:)` cannot be
325284
// currently applied in a simpler manner.
326-
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
285+
for kp in infinityNorm.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
327286
infinityNorm[keyPath: kp] = max(
328287
beta2 * infinityNorm[keyPath: kp], abs(direction[keyPath: kp]))
329288
}
330-
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
289+
for kp in infinityNorm.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
331290
infinityNorm[keyPath: kp] = max(
332291
Double(beta2) * infinityNorm[keyPath: kp], abs(direction[keyPath: kp]))
333292
}
@@ -347,8 +306,7 @@ public class AdaMax<Model: Differentiable & KeyPathIterable>: Optimizer
347306
public class AMSGrad<Model: Differentiable & KeyPathIterable>: Optimizer
348307
where Model.TangentVector: VectorProtocol & PointwiseMultiplicative &
349308
ElementaryFunctions & KeyPathIterable,
350-
Model.TangentVector.VectorSpaceScalar == Float,
351-
Model.AllDifferentiableVariables == Model.TangentVector {
309+
Model.TangentVector.VectorSpaceScalar == Float {
352310
public typealias Model = Model
353311
/// The learning rate.
354312
public var learningRate: Float
@@ -390,15 +348,7 @@ public class AMSGrad<Model: Differentiable & KeyPathIterable>: Optimizer
390348
}
391349

392350
public func update(_ model: inout Model, along direction: Model.TangentVector) {
393-
update(&model.allDifferentiableVariables, along: direction)
394-
}
395-
396-
// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
397-
public func update(
398-
_ model: inout Model.AllDifferentiableVariables,
399-
along direction: Model.TangentVector
400-
) {
401-
self.step += 1
351+
step += 1
402352
let step = Float(self.step)
403353
let beta1Power = pow(beta1, step)
404354
let beta2Power = pow(beta2, step)
@@ -413,11 +363,11 @@ public class AMSGrad<Model: Differentiable & KeyPathIterable>: Optimizer
413363

414364
// Update `secondMomentsMax` using a key path approach because `max(_:_:)` cannot be
415365
// currently applied in a simpler manner.
416-
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
366+
for kp in secondMomentsMax.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
417367
secondMomentsMax[keyPath: kp] = max(
418368
secondMomentsMax[keyPath: kp], secondMoments[keyPath: kp])
419369
}
420-
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
370+
for kp in secondMomentsMax.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
421371
secondMomentsMax[keyPath: kp] = max(
422372
secondMomentsMax[keyPath: kp], secondMoments[keyPath: kp])
423373
}

Sources/TensorFlow/Optimizers/SGD.swift

-8
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,6 @@ public class SGD<Model: Differentiable>: Optimizer
5252
}
5353

5454
public func update(_ model: inout Model, along direction: Model.TangentVector) {
55-
update(&model.allDifferentiableVariables, along: direction)
56-
}
57-
58-
// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
59-
public func update(
60-
_ model: inout Model.AllDifferentiableVariables,
61-
along direction: Model.TangentVector
62-
) {
6355
step += 1
6456
let learningRate = self.learningRate * 1 / (1 + decay * Float(step))
6557
velocity = momentum * velocity - direction * learningRate

0 commit comments

Comments
 (0)