|
Txeo v0.1
A Developer-Friendly TensorFlow C++ Wrapper
|
Abstract base class for machine learning trainers. More...
#include <Trainer.h>


Public Member Functions | |
| Trainer (const Trainer &)=delete | |
| Trainer (Trainer &&)=delete | |
| Trainer & | operator= (const Trainer &)=delete |
| Trainer & | operator= (Trainer &&)=delete |
| virtual | ~Trainer ()=default |
| Trainer (txeo::DataTable< T > &&data, txeo::Logger &logger=txeo::LoggerConsole::instance()) | |
| Construct a new Trainer object from a data table. | |
| Trainer (const txeo::DataTable< T > &data) | |
| virtual void | fit (size_t epochs, txeo::LossFunc metric) |
| Trains the model for specified number of epochs. | |
| virtual void | fit (size_t epochs, txeo::LossFunc metric, size_t patience) |
| Trains with early stopping based on validation performance. | |
| virtual void | fit (size_t epochs, txeo::LossFunc metric, size_t patience, txeo::NormalizationType type) |
| Trains with early stopping based on validation performance and feature normalization. | |
| T | compute_test_loss (txeo::LossFunc metric) const |
| Evaluates test data based on a specified metric. | |
| virtual txeo::Tensor< T > | predict (const txeo::Tensor< T > &input) const =0 |
| Makes predictions using the trained model (pure virtual) | |
| bool | is_trained () const |
| Checks if model has been trained. | |
| const txeo::DataTable< T > & | data_table () const |
| Returns the data table of this trainer. | |
| void | enable_feature_norm (txeo::NormalizationType type) |
| Enable feature normalization. | |
| void | disable_feature_norm () |
| Disable feature normalization. | |
Protected Member Functions | |
| Trainer ()=default | |
| virtual void | train (size_t epochs, txeo::LossFunc loss_func)=0 |
Protected Attributes | |
| bool | _is_trained {false} |
| bool | _is_early_stop {false} |
| size_t | _patience {0} |
| std::unique_ptr< txeo::DataTable< T > > | _data_table |
| txeo::Logger * | _logger {nullptr} |
| txeo::DataTableNorm< T > | _data_table_norm |
| bool | _is_norm_enabled {false} |
Abstract base class for machine learning trainers.
| T | Numeric type for tensor elements (e.g., float, double) |
This class provides the core interface for training machine learning models with:
Derived classes must implement the prediction logic and training algorithm.
|
inline |
Construct a new Trainer object from a data table.
| data | Training/Evaluation/Test data. |
Definition at line 45 of file Trainer.h.
|
inline |
|
protecteddefault |
| T txeo::Trainer< T >::compute_test_loss | ( | txeo::LossFunc | metric | ) | const |
Evaluates test data based on a specified metric.
@ throws txeo::TrainerError
| metric | Loss function type |
|
inline |
|
inline |
| void txeo::Trainer< T >::enable_feature_norm | ( | txeo::NormalizationType | type | ) |
Enable feature normalization.
| type | Type of normalization |
|
virtual |
Trains the model for specified number of epochs.
| epochs | Number of training iterations |
| metric | Loss function to optimize |
|
virtual |
Trains with early stopping based on validation performance.
| epochs | Maximum number of training iterations |
| metric | Loss function to optimize |
| patience | Number of epochs to wait without improvement before stopping |
|
virtual |
Trains with early stopping based on validation performance and feature normalization.
| epochs | Maximum number of training iterations |
| metric | Loss function to optimize |
| patience | Number of epochs to wait without improvement before stopping |
| type | Type of normalization |
|
inline |
|
pure virtual |
Makes predictions using the trained model (pure virtual)
| input | Input tensor for prediction (shape: [samples, features]) |
| txeo::TrainerError |
Implemented in txeo::OlsGDTrainer< T >.
|
protectedpure virtual |
|
protected |
|
protected |
|
protected |
|
protected |