Txeo v0.1
A Developer-Friendly TensorFlow C++ Wrapper
|
Class that deals with the main tasks of prediction (inference) More...
#include <Predictor.h>
Public Types | |
using | TensorInfo = std::vector< std::pair< std::string, txeo::TensorShape > > |
using | TensorIdent = std::vector< std::pair< std::string, txeo::Tensor< T > > > |
Public Member Functions | |
Predictor ()=delete | |
Predictor (const Predictor &)=delete | |
Predictor (Predictor &&)=delete | |
Predictor & | operator= (const Predictor &)=delete |
Predictor & | operator= (Predictor &&)=delete |
~Predictor () | |
Predictor (std::filesystem::path model_path) | |
Constructs a Predictor from a TensorFlow SavedModel directory. | |
const TensorInfo & | get_input_metadata () const noexcept |
Returns the input tensor metadata for the loaded model. | |
const TensorInfo & | get_output_metadata () const noexcept |
Returns the output tensor metadata for the loaded model. | |
std::optional< txeo::TensorShape > | get_input_metadata_shape (const std::string &name) const |
Returns shape for specific input tensor by name. | |
std::optional< txeo::TensorShape > | get_output_metadata_shape (const std::string &name) const |
Get shape for specific output tensor by name. | |
std::vector< DeviceInfo > | get_devices () const |
Returns the available compute devices. | |
txeo::Tensor< T > | predict (const txeo::Tensor< T > &input) const |
Perform single input/single output inference. | |
std::vector< txeo::Tensor< T > > | predict_batch (const TensorIdent &inputs) const |
Perform batch inference with multiple named inputs. | |
void | enable_xla (bool enable) |
Enable/disable XLA (Accelerated Linear Algebra) compilation. | |
Class that deals with the main tasks of prediction (inference)
T | Specifies the data type of the model involved |
Definition at line 45 of file Predictor.h.
using txeo::Predictor< T >::TensorIdent = std::vector<std::pair<std::string, txeo::Tensor<T> >> |
Definition at line 48 of file Predictor.h.
using txeo::Predictor< T >::TensorInfo = std::vector<std::pair<std::string, txeo::TensorShape> > |
Definition at line 47 of file Predictor.h.
|
explicitdelete |
|
delete |
|
delete |
txeo::Predictor< T >::~Predictor | ( | ) |
|
explicit |
Constructs a Predictor from a TensorFlow SavedModel directory.
The directory must contain a valid SavedModel (typically with a .pb file). For best performance, use models with frozen weights.
model_path | Path to the directory of the .pb saved model |
PredictorError |
void txeo::Predictor< T >::enable_xla | ( | bool | enable | ) |
Enable/disable XLA (Accelerated Linear Algebra) compilation.
enable | Whether to enable XLA optimizations |
std::vector< DeviceInfo > txeo::Predictor< T >::get_devices | ( | ) | const |
Returns the available compute devices.
|
noexcept |
Returns the input tensor metadata for the loaded model.
std::optional< txeo::TensorShape > txeo::Predictor< T >::get_input_metadata_shape | ( | const std::string & | name | ) | const |
Returns shape for specific input tensor by name.
name | Tensor name from model signature |
|
noexcept |
Returns the output tensor metadata for the loaded model.
std::optional< txeo::TensorShape > txeo::Predictor< T >::get_output_metadata_shape | ( | const std::string & | name | ) | const |
Get shape for specific output tensor by name.
name | Tensor name from model signature |
|
delete |
|
delete |
txeo::Tensor< T > txeo::Predictor< T >::predict | ( | const txeo::Tensor< T > & | input | ) | const |
Perform single input/single output inference.
input | Input tensor matching model's expected shape |
PredictorError |
std::vector< txeo::Tensor< T > > txeo::Predictor< T >::predict_batch | ( | const TensorIdent & | inputs | ) | const |
Perform batch inference with multiple named inputs.
inputs | Vector of (name, tensor) pairs |
PredictorError |