Txeo v0.1
A Developer-Friendly TensorFlow C++ Wrapper
Loading...
Searching...
No Matches
txeo::Trainer< T > Class Template Referenceabstract

Abstract base class for machine learning trainers. More...

#include <Trainer.h>

Inheritance diagram for txeo::Trainer< T >:
Inheritance graph
Collaboration diagram for txeo::Trainer< T >:
Collaboration graph

Public Member Functions

 Trainer (const Trainer &)=delete
 
 Trainer (Trainer &&)=delete
 
Traineroperator= (const Trainer &)=delete
 
Traineroperator= (Trainer &&)=delete
 
virtual ~Trainer ()=default
 
 Trainer (txeo::DataTable< T > &&data, txeo::Logger &logger=txeo::LoggerConsole::instance())
 Construct a new Trainer object from a data table.
 
 Trainer (const txeo::DataTable< T > &data)
 
virtual void fit (size_t epochs, txeo::LossFunc metric)
 Trains the model for specified number of epochs.
 
virtual void fit (size_t epochs, txeo::LossFunc metric, size_t patience)
 Trains with early stopping based on validation performance.
 
virtual void fit (size_t epochs, txeo::LossFunc metric, size_t patience, txeo::NormalizationType type)
 Trains with early stopping based on validation performance and feature normalization.
 
T compute_test_loss (txeo::LossFunc metric) const
 Evaluates test data based on a specified metric.
 
virtual txeo::Tensor< Tpredict (const txeo::Tensor< T > &input) const =0
 Makes predictions using the trained model (pure virtual)
 
bool is_trained () const
 Checks if model has been trained.
 
const txeo::DataTable< T > & data_table () const
 Returns the data table of this trainer.
 
void enable_feature_norm (txeo::NormalizationType type)
 Enable feature normalization.
 
void disable_feature_norm ()
 Disable feature normalization.
 

Protected Member Functions

 Trainer ()=default
 
virtual void train (size_t epochs, txeo::LossFunc loss_func)=0
 

Protected Attributes

bool _is_trained {false}
 
bool _is_early_stop {false}
 
size_t _patience {0}
 
std::unique_ptr< txeo::DataTable< T > > _data_table
 
txeo::Logger_logger {nullptr}
 
txeo::DataTableNorm< T_data_table_norm
 
bool _is_norm_enabled {false}
 

Detailed Description

template<typename T>
class txeo::Trainer< T >

Abstract base class for machine learning trainers.

Template Parameters
TNumeric type for tensor elements (e.g., float, double)

This class provides the core interface for training machine learning models with:

  • Training/evaluation data management
  • Common training parameters
  • Basic training lifecycle control

Derived classes must implement the prediction logic and training algorithm.

Definition at line 32 of file Trainer.h.

Constructor & Destructor Documentation

◆ Trainer() [1/5]

template<typename T >
txeo::Trainer< T >::Trainer ( const Trainer< T > &  )
delete

◆ Trainer() [2/5]

template<typename T >
txeo::Trainer< T >::Trainer ( Trainer< T > &&  )
delete

◆ ~Trainer()

template<typename T >
virtual txeo::Trainer< T >::~Trainer ( )
virtualdefault

◆ Trainer() [3/5]

template<typename T >
txeo::Trainer< T >::Trainer ( txeo::DataTable< T > &&  data,
txeo::Logger logger = txeo::LoggerConsole::instance() 
)
inline

Construct a new Trainer object from a data table.

Parameters
dataTraining/Evaluation/Test data.

Definition at line 45 of file Trainer.h.

46 : _data_table{std::make_unique<txeo::DataTable<T>>(std::move(data))}, _logger{&logger} {};
txeo::Logger * _logger
Definition Trainer.h:136
std::unique_ptr< txeo::DataTable< T > > _data_table
Definition Trainer.h:135

◆ Trainer() [4/5]

template<typename T >
txeo::Trainer< T >::Trainer ( const txeo::DataTable< T > &  data)
inline

Definition at line 48 of file Trainer.h.

48: Trainer{data.clone()} {};
DataTable< T > clone() const
Trainer()=default

◆ Trainer() [5/5]

template<typename T >
txeo::Trainer< T >::Trainer ( )
protecteddefault

Member Function Documentation

◆ compute_test_loss()

template<typename T >
T txeo::Trainer< T >::compute_test_loss ( txeo::LossFunc  metric) const

Evaluates test data based on a specified metric.

@ throws txeo::TrainerError

Parameters
metricLoss function type
Returns
T Loss value

◆ data_table()

template<typename T >
const txeo::DataTable< T > & txeo::Trainer< T >::data_table ( ) const
inline

Returns the data table of this trainer.

Returns
const txeo::DataTable<T>&

Definition at line 112 of file Trainer.h.

112{ return *_data_table; }

◆ disable_feature_norm()

template<typename T >
void txeo::Trainer< T >::disable_feature_norm ( )
inline

Disable feature normalization.

Parameters
typeType of normalization

Definition at line 126 of file Trainer.h.

126{ _is_norm_enabled = false; };
bool _is_norm_enabled
Definition Trainer.h:139

◆ enable_feature_norm()

template<typename T >
void txeo::Trainer< T >::enable_feature_norm ( txeo::NormalizationType  type)

Enable feature normalization.

Parameters
typeType of normalization

◆ fit() [1/3]

template<typename T >
virtual void txeo::Trainer< T >::fit ( size_t  epochs,
txeo::LossFunc  metric 
)
virtual

Trains the model for specified number of epochs.

Parameters
epochsNumber of training iterations
metricLoss function to optimize

◆ fit() [2/3]

template<typename T >
virtual void txeo::Trainer< T >::fit ( size_t  epochs,
txeo::LossFunc  metric,
size_t  patience 
)
virtual

Trains with early stopping based on validation performance.

Parameters
epochsMaximum number of training iterations
metricLoss function to optimize
patienceNumber of epochs to wait without improvement before stopping

◆ fit() [3/3]

template<typename T >
virtual void txeo::Trainer< T >::fit ( size_t  epochs,
txeo::LossFunc  metric,
size_t  patience,
txeo::NormalizationType  type 
)
virtual

Trains with early stopping based on validation performance and feature normalization.

Parameters
epochsMaximum number of training iterations
metricLoss function to optimize
patienceNumber of epochs to wait without improvement before stopping
typeType of normalization

◆ is_trained()

template<typename T >
bool txeo::Trainer< T >::is_trained ( ) const
inline

Checks if model has been trained.

Returns
true if training has been completed, false otherwise

Definition at line 105 of file Trainer.h.

105{ return _is_trained; }
bool _is_trained
Definition Trainer.h:131

◆ operator=() [1/2]

template<typename T >
Trainer & txeo::Trainer< T >::operator= ( const Trainer< T > &  )
delete

◆ operator=() [2/2]

template<typename T >
Trainer & txeo::Trainer< T >::operator= ( Trainer< T > &&  )
delete

◆ predict()

template<typename T >
virtual txeo::Tensor< T > txeo::Trainer< T >::predict ( const txeo::Tensor< T > &  input) const
pure virtual

Makes predictions using the trained model (pure virtual)

Parameters
inputInput tensor for prediction (shape: [samples, features])
Returns
Prediction tensor (shape: [samples, outputs])
Exceptions
txeo::TrainerError
Note
Must be implemented in derived classes

Implemented in txeo::OlsGDTrainer< T >.

◆ train()

template<typename T >
virtual void txeo::Trainer< T >::train ( size_t  epochs,
txeo::LossFunc  loss_func 
)
protectedpure virtual

Member Data Documentation

◆ _data_table

template<typename T >
std::unique_ptr<txeo::DataTable<T> > txeo::Trainer< T >::_data_table
protected

Definition at line 135 of file Trainer.h.

◆ _data_table_norm

template<typename T >
txeo::DataTableNorm<T> txeo::Trainer< T >::_data_table_norm
protected

Definition at line 138 of file Trainer.h.

◆ _is_early_stop

template<typename T >
bool txeo::Trainer< T >::_is_early_stop {false}
protected

Definition at line 132 of file Trainer.h.

132{false};

◆ _is_norm_enabled

template<typename T >
bool txeo::Trainer< T >::_is_norm_enabled {false}
protected

Definition at line 139 of file Trainer.h.

139{false};

◆ _is_trained

template<typename T >
bool txeo::Trainer< T >::_is_trained {false}
protected

Definition at line 131 of file Trainer.h.

131{false};

◆ _logger

template<typename T >
txeo::Logger* txeo::Trainer< T >::_logger {nullptr}
protected

Definition at line 136 of file Trainer.h.

136{nullptr};

◆ _patience

template<typename T >
size_t txeo::Trainer< T >::_patience {0}
protected

Definition at line 133 of file Trainer.h.

133{0};

The documentation for this class was generated from the following file: