44template <
typename T =
float>
47 using TensorInfo = std::vector<std::pair<std::string, txeo::TensorShape>>;
48 using TensorIdent = std::vector<std::pair<std::string, txeo::Tensor<T>>>;
87 explicit Predictor(std::filesystem::path model_path);
214 std::unique_ptr<Impl> _impl{
nullptr};
221 using std::runtime_error::runtime_error;
Class that deals with the main tasks of prediction (inference)
txeo::Tensor< T > predict(const txeo::Tensor< T > &input) const
Perform single input/single output inference.
Predictor & operator=(const Predictor &)=delete
Predictor(Predictor &&)=delete
const TensorInfo & get_input_metadata() const noexcept
Returns the input tensor metadata for the loaded model.
Predictor & operator=(Predictor &&)=delete
std::optional< txeo::TensorShape > get_output_metadata_shape(const std::string &name) const
Get shape for specific output tensor by name.
void enable_xla(bool enable)
Enable/disable XLA (Accelerated Linear Algebra) compilation.
const TensorInfo & get_output_metadata() const noexcept
Returns the output tensor metadata for the loaded model.
Predictor(const Predictor &)=delete
std::vector< txeo::Tensor< T > > predict_batch(const TensorIdent &inputs) const
Perform batch inference with multiple named inputs.
std::vector< std::pair< std::string, txeo::Tensor< T > > > TensorIdent
std::vector< DeviceInfo > get_devices() const
Returns the available compute devices.
std::vector< std::pair< std::string, txeo::TensorShape > > TensorInfo
std::optional< txeo::TensorShape > get_input_metadata_shape(const std::string &name) const
Returns shape for specific input tensor by name.
Predictor(std::filesystem::path model_path)
Constructs a Predictor from a TensorFlow SavedModel directory.
Implements the mathematical concept of tensor, which is a magnitude of multiple order....
The shape of a tensor is an ordered collection of dimensions of mathematical vector spaces.
Bundle of device information.
std::string name
Device name.
size_t memory_limit
Memory limit in bytes.
std::string device_type
Device type (CPU or GPU)