150 using std::runtime_error::runtime_error;
A container for managing training, evaluation, and test data splits.
static LoggerConsole & instance()
Access the singleton instance.
Abstract base class for logging subsystems.
Exceptions concerning txeo::OlsGDTrainer.
Abstract base class for machine learning trainers.
Trainer & operator=(Trainer &&)=delete
txeo::DataTableNorm< T > _data_table_norm
virtual void fit(size_t epochs, txeo::LossFunc metric, size_t patience)
Trains with early stopping based on validation performance.
Trainer(const Trainer &)=delete
Trainer(Trainer &&)=delete
virtual void train(size_t epochs, txeo::LossFunc loss_func)=0
Trainer(const txeo::DataTable< T > &data)
void disable_feature_norm()
Disable feature normalization.
virtual ~Trainer()=default
Trainer & operator=(const Trainer &)=delete
Trainer(txeo::DataTable< T > &&data, txeo::Logger &logger=txeo::LoggerConsole::instance())
Construct a new Trainer object from a data table.
std::unique_ptr< txeo::DataTable< T > > _data_table
bool is_trained() const
Checks if model has been trained.
virtual txeo::Tensor< T > predict(const txeo::Tensor< T > &input) const =0
Makes predictions using the trained model (pure virtual)
const txeo::DataTable< T > & data_table() const
Returns the data table of this trainer.
void enable_feature_norm(txeo::NormalizationType type)
Enable feature normalization.
T compute_test_loss(txeo::LossFunc metric) const
Evaluates test data based on a specified metric.
virtual void fit(size_t epochs, txeo::LossFunc metric, size_t patience, txeo::NormalizationType type)
Trains with early stopping based on validation performance and feature normalization.
virtual void fit(size_t epochs, txeo::LossFunc metric)
Trains the model for specified number of epochs.
NormalizationType
Normalization types to be used in normalization functions.
LossFunc
Types of loss functions.