Skip to content

Commit 8b500d2

Browse files
committed
Internal wrapper for direct eager dispatching.
- Add support for dispatching an eager op which returns a tensor group. - AnyTensor type for array output and input lists. - Add proper eager support for saveV2 and restoreV2 ops. These will be used for model checkpointing.
1 parent 2ed1ad2 commit 8b500d2

File tree

5 files changed

+222
-0
lines changed

5 files changed

+222
-0
lines changed
+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
//===-- ArrayOps.swift ----------------------------------------*- swift -*-===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// This file contains some Array ops that cannot be properly handled by #tfop.
14+
//
15+
// TODO: These should be deleted once we can properly generate raw ops for these.
16+
//
17+
//===----------------------------------------------------------------------===//
18+
19+
import CTensorFlow
20+
21+
public extension Raw {
22+
/// Saves tensors in V2 checkpoint format.
23+
///
24+
/// By default, saves the named tensors in full. If the caller wishes to save
25+
/// specific slices of full tensors, "shape_and_slices" should be non-empty strings
26+
/// and correspondingly well-formed.
27+
///
28+
/// - Parameters:
29+
/// - prefix: Must have a single element. The prefix of the V2 checkpoint to which we
30+
/// write the tensors.
31+
/// - tensor_names: shape {N}. The names of the tensors to be saved.
32+
/// - shape_and_slices: shape {N}. The slice specs of the tensors to be saved.
33+
/// Empty strings indicate that they are non-partitioned tensors.
34+
/// - tensors: `N` tensors to save.
35+
@inlinable @inline(__always)
36+
static func saveV2(
37+
prefix: StringTensor,
38+
tensorNames: StringTensor,
39+
shapeAndSlices: StringTensor,
40+
tensors: [AnyTensor]
41+
) {
42+
let s: CTFStatus = TF_NewStatus()
43+
defer { TF_DeleteStatus(s) }
44+
let op: CTFEOp = TFE_NewOp(_ExecutionContext.global.eagerContext, "SaveV2", s)
45+
defer { TFE_DeleteOp(op) }
46+
let _ = _TFCOpAddInputFromTensorGroup(op, prefix, s)
47+
let _ = _TFCOpAddInputFromTensorGroup(op, tensorNames, s)
48+
let _ = _TFCOpAddInputFromTensorGroup(op, shapeAndSlices, s)
49+
let _ = _TFCOpAddInputFromAnyTensors(op, tensors, s)
50+
let _ = _TFCOpSetAttrTypeArray(op, "dtypes", tensors.map { $0._tensorFlowDataType })
51+
return _TFCExecuteOp(op, s)
52+
}
53+
54+
/// Restores tensors from a V2 checkpoint.
55+
///
56+
/// For backward compatibility with the V1 format, this Op currently allows
57+
/// restoring from a V1 checkpoint as well:
58+
/// - This Op first attempts to find the V2 index file pointed to by "prefix", and
59+
/// if found proceed to read it as a V2 checkpoint;
60+
/// - Otherwise the V1 read path is invoked.
61+
/// Relying on this behavior is not recommended, as the ability to fall back to read
62+
/// V1 might be deprecated and eventually removed.
63+
///
64+
/// By default, restores the named tensors in full. If the caller wishes to restore
65+
/// specific slices of stored tensors, "shape_and_slices" should be non-empty
66+
/// strings and correspondingly well-formed.
67+
///
68+
/// Callers must ensure all the named tensors are indeed stored in the checkpoint.
69+
///
70+
/// - Parameters:
71+
/// - prefix: Must have a single element. The prefix of a V2 checkpoint.
72+
/// - tensor_names: shape {N}. The names of the tensors to be restored.
73+
/// - shape_and_slices: shape {N}. The slice specs of the tensors to be restored.
74+
/// Empty strings indicate that they are non-partitioned tensors.
75+
///
76+
/// - Attr dtypes: shape {N}. The list of expected dtype for the tensors. Must match
77+
/// those stored in the checkpoint.
78+
///
79+
/// - Output tensors: shape {N}. The restored tensors, whose shapes are read from the
80+
/// checkpoint directly.
81+
@inlinable @inline(__always)
82+
static func restoreV2(
83+
prefix: StringTensor,
84+
tensorNames: StringTensor,
85+
shapeAndSlices: StringTensor,
86+
dtypes: [TensorDataType]
87+
) -> [AnyTensor] {
88+
let s: CTFStatus = TF_NewStatus()
89+
defer { TF_DeleteStatus(s) }
90+
let op: CTFEOp = TFE_NewOp(_ExecutionContext.global.eagerContext, "RestoreV2", s)
91+
defer { TFE_DeleteOp(op) }
92+
let _ = _TFCOpAddInputFromTensorGroup(op, prefix, s)
93+
let _ = _TFCOpAddInputFromTensorGroup(op, tensorNames, s)
94+
let _ = _TFCOpAddInputFromTensorGroup(op, shapeAndSlices, s)
95+
let _ = _TFCOpSetAttrTypeArray(op, "dtypes", dtypes)
96+
97+
var count: Int32 = Int32(dtypes.count)
98+
let buffer: UnsafeMutablePointer<CTensorHandle> =
99+
UnsafeMutablePointer.allocate(capacity: Int(count))
100+
defer { buffer.deallocate() }
101+
_TFCEagerExecute(op, UnsafeMutablePointer<CTensorHandle?>(buffer), &count, s)
102+
checkOk(s)
103+
104+
var out: [AnyTensor] = []
105+
var cursor = buffer
106+
for type in dtypes {
107+
out.append(makeTensor(dataType: type, owning: cursor.pointee))
108+
cursor = cursor.advanced(by: 1)
109+
}
110+
return out
111+
}
112+
}

stdlib/public/TensorFlow/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ set(SOURCES
4848
TensorProtocol.swift
4949
TensorShape.swift
5050
Utilities.swift
51+
ArrayOps.swift
5152
Threading.swift
53+
ExecuteOp.swift.gyb
5254
# NumPy bridging for `ShapedArray` and `Tensor`.
5355
NumpyConversion.swift)
5456

stdlib/public/TensorFlow/CompilerRuntime.swift

+23
Original file line numberDiff line numberDiff line change
@@ -1703,6 +1703,29 @@ func _TFCOpAddInputFromTensorGroup<T : TensorArrayProtocol>(
17031703
return count
17041704
}
17051705

1706+
/// Special protocol for calling tensorflow operations that take heterogeneous
1707+
/// arrays as input.
1708+
public protocol AnyTensor {
1709+
var _rawTensorHandle: CTensorHandle { get }
1710+
var _tensorFlowDataType: TensorDataType { get }
1711+
}
1712+
1713+
extension Tensor : AnyTensor {
1714+
public var _rawTensorHandle: CTensorHandle { return handle._cTensorHandle }
1715+
public var _tensorFlowDataType: TensorDataType { return Scalar.tensorFlowDataType }
1716+
}
1717+
1718+
@usableFromInline
1719+
func _TFCOpAddInputFromAnyTensors(
1720+
_ op: CTFEOp, _ tensors: [AnyTensor], _ status: CTFStatus
1721+
) {
1722+
for tensor in tensors {
1723+
let handle = tensor._rawTensorHandle
1724+
TFE_OpAddInput(op, handle, status)
1725+
checkOk(status)
1726+
}
1727+
}
1728+
17061729
/// Initializes a TensorGroup value, taking ownership of all the tensor
17071730
/// handles in `tensorHandles`.
17081731
@usableFromInline

stdlib/public/TensorFlow/DataTypes.swift

+37
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,43 @@ public struct TensorDataType {
3636
}
3737
}
3838

39+
@usableFromInline
40+
internal func makeTensor(
41+
dataType: TensorDataType,
42+
owning pointer: CTensorHandle
43+
) -> AnyTensor {
44+
switch dataType._cDataType {
45+
case TF_BOOL:
46+
return Tensor<Bool>(handle: TensorHandle(_owning: pointer))
47+
case TF_INT8:
48+
return Tensor<Int8>(handle: TensorHandle(_owning: pointer))
49+
case TF_UINT8:
50+
return Tensor<UInt8>(handle: TensorHandle(_owning: pointer))
51+
case TF_INT16:
52+
return Tensor<Int16>(handle: TensorHandle(_owning: pointer))
53+
case TF_UINT16:
54+
return Tensor<UInt16>(handle: TensorHandle(_owning: pointer))
55+
case TF_INT32:
56+
return Tensor<Int32>(handle: TensorHandle(_owning: pointer))
57+
case TF_UINT32:
58+
return Tensor<UInt32>(handle: TensorHandle(_owning: pointer))
59+
case TF_INT64:
60+
return Tensor<Int64>(handle: TensorHandle(_owning: pointer))
61+
case TF_UINT64:
62+
return Tensor<UInt64>(handle: TensorHandle(_owning: pointer))
63+
case TF_BFLOAT16:
64+
return Tensor<BFloat16>(handle: TensorHandle(_owning: pointer))
65+
case TF_FLOAT:
66+
return Tensor<Float>(handle: TensorHandle(_owning: pointer))
67+
case TF_DOUBLE:
68+
return Tensor<Double>(handle: TensorHandle(_owning: pointer))
69+
case TF_STRING:
70+
fatalError("StringTensor does not conform to AnyTensor")
71+
default:
72+
fatalError("Unhandled type: \(dataType)")
73+
}
74+
}
75+
3976
/// A data type compatible with TensorFlow.
4077
public protocol _TensorFlowDataTypeCompatible {
4178
/// The underlying TensorFlow data type.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
//===-- ExecuteOp.swift.gyb -----------------------------------*- swift -*-===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// This file contains _TFCExecuteOp which allows dispatching an op and
14+
// returning an arbitrary set of tensor-groups.
15+
//
16+
// TODO: A nice wrapper for TFEOp could possibly make this simpler to use. This
17+
// may need to be extended in order to work with multiple tfops.
18+
//
19+
//===----------------------------------------------------------------------===//
20+
21+
@usableFromInline
22+
func _TFCExecuteOp(_ op: CTFEOp, _ s: CTFStatus) {
23+
var count: Int32 = 0
24+
var unused: CTensorHandle?
25+
_TFCEagerExecute(op, &unused, &count, s)
26+
checkOk(s)
27+
}
28+
29+
%for n in range(1, 11):
30+
// Calls _TFCEagerExecute under the hood and unpacks into TensorGroup conforming
31+
// types.
32+
@usableFromInline
33+
func _TFCExecuteOp<${", ".join(["T" + str(i) + " : TensorGroup" for i in range(n)])}>
34+
(_ op: CTFEOp, _ s: CTFStatus)
35+
-> (${", ".join(["T" + str(i) for i in range(n)])}) {
36+
37+
var count: Int32 = ${" + ".join(["T" + str(i) + "._tensorHandleCount" for i in range(n)])}
38+
let buffer: UnsafeMutablePointer<CTensorHandle> =
39+
UnsafeMutablePointer.allocate(capacity: Int(count))
40+
defer { buffer.deallocate() }
41+
_TFCEagerExecute(op, UnsafeMutablePointer<CTensorHandle?>(buffer), &count, s)
42+
checkOk(s)
43+
%for i in range(n):
44+
let off${i}: Int32 = ${"0" if i == 0 else "off" + str(i - 1) + " + T" + str(i - 1) + "._tensorHandleCount"}
45+
%end
46+
return (${", ".join(["T" + str(i) + ".init(_owning: buffer.advanced(by: Int(off" + str(i) + ")))" for i in range(n)])})
47+
}
48+
%end

0 commit comments

Comments
 (0)