Txeo v0.1
A Developer-Friendly TensorFlow C++ Wrapper
Loading...
Searching...
No Matches
Predictor.h
Go to the documentation of this file.
1#ifndef PREDICTOR_H
2#define PREDICTOR_H
3
4#pragma once
5
6#include "txeo/Tensor.h"
7#include "txeo/TensorShape.h"
8
9#include <filesystem>
10#include <optional>
11#include <string>
12
13namespace txeo {
14
19struct DeviceInfo {
24 std::string name{};
25
30 std::string device_type{};
31
36 size_t memory_limit{};
37};
38
44template <typename T = float>
45class Predictor {
46 public:
47 using TensorInfo = std::vector<std::pair<std::string, txeo::TensorShape>>;
48 using TensorIdent = std::vector<std::pair<std::string, txeo::Tensor<T>>>;
49
50 explicit Predictor() = delete;
51 Predictor(const Predictor &) = delete;
52 Predictor(Predictor &&) = delete;
53 Predictor &operator=(const Predictor &) = delete;
56
87 explicit Predictor(std::filesystem::path model_path);
88
102 [[nodiscard]] const TensorInfo &get_input_metadata() const noexcept;
103
115 [[nodiscard]] const TensorInfo &get_output_metadata() const noexcept;
116
130 [[nodiscard]] std::optional<txeo::TensorShape>
131 get_input_metadata_shape(const std::string &name) const;
132
145 [[nodiscard]] std::optional<txeo::TensorShape>
146 get_output_metadata_shape(const std::string &name) const;
147
160 [[nodiscard]] std::vector<DeviceInfo> get_devices() const;
161
177 [[nodiscard]] txeo::Tensor<T> predict(const txeo::Tensor<T> &input) const;
178
196 [[nodiscard]] std::vector<txeo::Tensor<T>> predict_batch(const TensorIdent &inputs) const;
197
210 void enable_xla(bool enable);
211
212 private:
213 struct Impl;
214 std::unique_ptr<Impl> _impl{nullptr};
215
216 void load_model();
217};
218
219class PredictorError : public std::runtime_error {
220 public:
221 using std::runtime_error::runtime_error;
222};
223
224} // namespace txeo
225#endif
Class that deals with the main tasks of prediction (inference)
Definition Predictor.h:45
txeo::Tensor< T > predict(const txeo::Tensor< T > &input) const
Perform single input/single output inference.
Predictor & operator=(const Predictor &)=delete
Predictor(Predictor &&)=delete
const TensorInfo & get_input_metadata() const noexcept
Returns the input tensor metadata for the loaded model.
Predictor & operator=(Predictor &&)=delete
std::optional< txeo::TensorShape > get_output_metadata_shape(const std::string &name) const
Get shape for specific output tensor by name.
void enable_xla(bool enable)
Enable/disable XLA (Accelerated Linear Algebra) compilation.
const TensorInfo & get_output_metadata() const noexcept
Returns the output tensor metadata for the loaded model.
Predictor(const Predictor &)=delete
Predictor()=delete
std::vector< txeo::Tensor< T > > predict_batch(const TensorIdent &inputs) const
Perform batch inference with multiple named inputs.
std::vector< std::pair< std::string, txeo::Tensor< T > > > TensorIdent
Definition Predictor.h:48
std::vector< DeviceInfo > get_devices() const
Returns the available compute devices.
std::vector< std::pair< std::string, txeo::TensorShape > > TensorInfo
Definition Predictor.h:47
std::optional< txeo::TensorShape > get_input_metadata_shape(const std::string &name) const
Returns shape for specific input tensor by name.
Predictor(std::filesystem::path model_path)
Constructs a Predictor from a TensorFlow SavedModel directory.
Implements the mathematical concept of tensor, which is a magnitude of multiple order....
Definition Tensor.h:48
The shape of a tensor is an ordered collection of dimensions of mathematical vector spaces.
Definition TensorShape.h:30
Definition Matrix.h:11
Bundle of device information.
Definition Predictor.h:19
std::string name
Device name.
Definition Predictor.h:24
size_t memory_limit
Memory limit in bytes.
Definition Predictor.h:36
std::string device_type
Device type (CPU or GPU)
Definition Predictor.h:30