Txeo v0.1
A Developer-Friendly TensorFlow C++ Wrapper
|
Abstract base class for machine learning trainers. More...
#include <Trainer.h>
Public Member Functions | |
Trainer (const Trainer &)=delete | |
Trainer (Trainer &&)=delete | |
Trainer & | operator= (const Trainer &)=delete |
Trainer & | operator= (Trainer &&)=delete |
virtual | ~Trainer ()=default |
Trainer (txeo::DataTable< T > &&data, txeo::Logger &logger=txeo::LoggerConsole::instance()) | |
Construct a new Trainer object from a data table. | |
Trainer (const txeo::DataTable< T > &data) | |
virtual void | fit (size_t epochs, txeo::LossFunc metric) |
Trains the model for specified number of epochs. | |
virtual void | fit (size_t epochs, txeo::LossFunc metric, size_t patience) |
Trains with early stopping based on validation performance. | |
virtual void | fit (size_t epochs, txeo::LossFunc metric, size_t patience, txeo::NormalizationType type) |
Trains with early stopping based on validation performance and feature normalization. | |
T | compute_test_loss (txeo::LossFunc metric) const |
Evaluates test data based on a specified metric. | |
virtual txeo::Tensor< T > | predict (const txeo::Tensor< T > &input) const =0 |
Makes predictions using the trained model (pure virtual) | |
bool | is_trained () const |
Checks if model has been trained. | |
const txeo::DataTable< T > & | data_table () const |
Returns the data table of this trainer. | |
void | enable_feature_norm (txeo::NormalizationType type) |
Enable feature normalization. | |
void | disable_feature_norm () |
Disable feature normalization. | |
Protected Member Functions | |
Trainer ()=default | |
virtual void | train (size_t epochs, txeo::LossFunc loss_func)=0 |
Protected Attributes | |
bool | _is_trained {false} |
bool | _is_early_stop {false} |
size_t | _patience {0} |
std::unique_ptr< txeo::DataTable< T > > | _data_table |
txeo::Logger * | _logger {nullptr} |
txeo::DataTableNorm< T > | _data_table_norm |
bool | _is_norm_enabled {false} |
Abstract base class for machine learning trainers.
T | Numeric type for tensor elements (e.g., float, double) |
This class provides the core interface for training machine learning models with:
Derived classes must implement the prediction logic and training algorithm.
|
inline |
Construct a new Trainer object from a data table.
data | Training/Evaluation/Test data. |
Definition at line 45 of file Trainer.h.
|
inline |
|
protecteddefault |
T txeo::Trainer< T >::compute_test_loss | ( | txeo::LossFunc | metric | ) | const |
Evaluates test data based on a specified metric.
@ throws txeo::TrainerError
metric | Loss function type |
|
inline |
|
inline |
void txeo::Trainer< T >::enable_feature_norm | ( | txeo::NormalizationType | type | ) |
Enable feature normalization.
type | Type of normalization |
|
virtual |
Trains the model for specified number of epochs.
epochs | Number of training iterations |
metric | Loss function to optimize |
|
virtual |
Trains with early stopping based on validation performance.
epochs | Maximum number of training iterations |
metric | Loss function to optimize |
patience | Number of epochs to wait without improvement before stopping |
|
virtual |
Trains with early stopping based on validation performance and feature normalization.
epochs | Maximum number of training iterations |
metric | Loss function to optimize |
patience | Number of epochs to wait without improvement before stopping |
type | Type of normalization |
|
inline |
|
pure virtual |
Makes predictions using the trained model (pure virtual)
input | Input tensor for prediction (shape: [samples, features]) |
txeo::TrainerError |
Implemented in txeo::OlsGDTrainer< T >.
|
protectedpure virtual |
|
protected |
|
protected |
|
protected |
|
protected |