tensorflow-core-framework-function_testlib.h 2019-06-10 334 tensorflow-core-framework ```cpp #ifndef TENSORFLOW_CORE_FRAMEWORK_FUNCTION_TESTLIB_H_ #define TENSORFLOW_CORE_FRAMEWORK_FUNCTION_TESTLIB_H_ #include #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { namespace test { namespace function { // A helper class to make AttrSlice from initializer lists class Attrs { public: Attrs(const std::initializer_list< // NOLINT(runtime/explicit) std::pair>& attrs) { for (const auto& aval : attrs) { map_.insert({aval.first, aval.second.proto}); } } operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit) private: AttrValueMap map_; }; // Helper to construct a NodeDef. NodeDef NDef( StringPiece name, StringPiece op, gtl::ArraySlice inputs, gtl::ArraySlice> attrs = {}, const string& device = ""); // Helper to construct a GraphDef proto. GraphDef GDef(gtl::ArraySlice nodes, gtl::ArraySlice funcs = {}); // For testing convenience, we provide a few simple functions that can // be easily executed and tested. // x:T -> x * 2. FunctionDef XTimesTwo(); // x:T -> cpu(x * 2) + cpu(x * 3). FunctionDef TwoDeviceTimesFive(); // x:T -> cpu(x * 2), gpu(x * 3). FunctionDef TwoDeviceMult(); // cpu(x):T, gpu(y):T -> cpu(x * 2), gpu(y * 3). FunctionDef TwoDeviceInputOutput(); // Function taking a list of Tensors as input. FunctionDef FuncWithListInput(); // Function returning a list of Tensors as output. FunctionDef FuncWithListOutput(); // x:T -> x + x. FunctionDef XAddX(); // x: T, y:T -> x + y. FunctionDef XAddY(); // x:T -> x * 2, where x is int32. FunctionDef XTimesTwoInt32(); // x:T -> (x * 2) * 2. FunctionDef XTimesFour(); // x:T -> ((x * 2) * 2) * 2. FunctionDef XTimes16(); // w:T, x:T, b:T -> MatMul(w, x) + b FunctionDef WXPlusB(); // x:T -> x:T, T is a type which we automatically converts to a bool. FunctionDef NonZero(); // x: T -> bool. FunctionDef IsZero(); // x: T -> int64 FunctionDef RandomUniform(); // x:T, y:T -> y:T, x:T FunctionDef Swap(); // x:T, y:T -> y:T, x:T, the body has no nodes. FunctionDef EmptyBodySwap(); // x:float, y:resource -> y:resource, 2*x:float. FunctionDef ResourceOutput(); // x:resource -> x:resource FunctionDef ResourceIdentity(); // x:resource -> y:float. FunctionDef ReadResourceVariable(); // Contains malformed control flow which can't be run by the executor. FunctionDef InvalidControlFlow(); // x:T -> x <= N. FunctionDef LessThanOrEqualToN(int64 N); // x:T, y:T -> x+1, x*y FunctionDef XPlusOneXTimesY(); // x:T, y:T -> x <= N FunctionDef XYXLessThanOrEqualToN(int64 N); // x: T -> bool FunctionDef RandomUniformLess(); // x:T -> y: TensorSliceDatasetOp::Dataset FunctionDef MakeTensorSliceDataset(); // x:T -> y: T, idx: out_idx FunctionDef Unique(); void FunctionTestSchedClosure(std::function fn); } // end namespace function } // end namespace test } // end namespace tensorflow #endif // TENSORFLOW_CORE_FRAMEWORK_FUNCTION_TESTLIB_H_ ``` 本文链接: http://codeeyes.net/archives/tensorflow-core-framework-function_testlib_h.html