Txeo v0.1
A Developer-Friendly TensorFlow C++ Wrapper
Loading...
Searching...
No Matches
Predictor.h
Go to the documentation of this file.
1#ifndef PREDICTOR_H
2#define PREDICTOR_H
3
4#pragma once
5
6#include "txeo/Logger.h"
8#include "txeo/Tensor.h"
9#include "txeo/TensorShape.h"
10#include "types.h"
11
12#include <filesystem>
13#include <optional>
14#include <string>
15
16namespace txeo {
17
23template <typename T = float>
24class Predictor {
25 public:
26 using TensorInfo = std::vector<std::pair<std::string, txeo::TensorShape>>;
27 using TensorIdent = std::vector<std::pair<std::string, txeo::Tensor<T>>>;
28
29 explicit Predictor() = delete;
30 Predictor(const Predictor &) = delete;
31 Predictor(Predictor &&) = delete;
32 Predictor &operator=(const Predictor &) = delete;
35
66 explicit Predictor(std::filesystem::path model_path,
68
83
96
112
127
141
158
177
190 void enable_xla(bool enable);
191
192 private:
193 struct Impl;
194 std::unique_ptr<Impl> _impl{nullptr};
195 txeo::Logger *_logger;
196
197 void load_model();
198};
199
200class PredictorError : public std::runtime_error {
201 public:
202 using std::runtime_error::runtime_error;
203};
204
205} // namespace txeo
206#endif
A container for managing training, evaluation, and test data splits.
Definition DataTable.h:24
static LoggerConsole & instance()
Access the singleton instance.
Abstract base class for logging subsystems.
Definition Logger.h:45
Class that deals with the main tasks of prediction (inference)
Definition Predictor.h:24
txeo::Tensor< T > predict(const txeo::Tensor< T > &input) const
Perform single input/single output inference.
Predictor & operator=(const Predictor &)=delete
Predictor(Predictor &&)=delete
const TensorInfo & get_input_metadata() const noexcept
Returns the input tensor metadata for the loaded model.
Predictor & operator=(Predictor &&)=delete
std::optional< txeo::TensorShape > get_output_metadata_shape(const std::string &name) const
Get shape for specific output tensor by name.
void enable_xla(bool enable)
Enable/disable XLA (Accelerated Linear Algebra) compilation.
const TensorInfo & get_output_metadata() const noexcept
Returns the output tensor metadata for the loaded model.
Predictor(const Predictor &)=delete
Predictor()=delete
std::vector< txeo::Tensor< T > > predict_batch(const TensorIdent &inputs) const
Perform batch inference with multiple named inputs.
std::vector< std::pair< std::string, txeo::Tensor< T > > > TensorIdent
Definition Predictor.h:27
std::vector< DeviceInfo > get_devices() const
Returns the available compute devices.
std::vector< std::pair< std::string, txeo::TensorShape > > TensorInfo
Definition Predictor.h:26
std::optional< txeo::TensorShape > get_input_metadata_shape(const std::string &name) const
Returns shape for specific input tensor by name.
Predictor(std::filesystem::path model_path, txeo::Logger &logger=txeo::LoggerConsole::instance())
Constructs a Predictor from a TensorFlow SavedModel directory.
Implements the mathematical concept of tensor, which is a magnitude of multiple order....
Definition Tensor.h:50
The shape of a tensor is an ordered collection of dimensions of mathematical vector spaces.
Definition TensorShape.h:31
Bundle of device information.
Definition types.h:14