Txeo v0.1
A Developer-Friendly TensorFlow C++ Wrapper
Loading...
Searching...
No Matches
txeo::OlsGDTrainer< T > Class Template Reference

Ordinary Least Squares trainer using Gradient Descent optimization. More...

#include <OlsGDTrainer.h>

Inheritance diagram for txeo::OlsGDTrainer< T >:
Inheritance graph
Collaboration diagram for txeo::OlsGDTrainer< T >:
Collaboration graph

Public Member Functions

 OlsGDTrainer (const OlsGDTrainer &)=delete
 
 OlsGDTrainer (OlsGDTrainer &&)=delete
 
OlsGDTraineroperator= (const OlsGDTrainer &)=delete
 
OlsGDTraineroperator= (OlsGDTrainer &&)=delete
 
 ~OlsGDTrainer ()=default
 
 OlsGDTrainer (txeo::DataTable< T > &&data)
 Construct a new OlsGD Trainer object from a data table.
 
 OlsGDTrainer (const txeo::DataTable< T > &data)
 
txeo::Tensor< Tpredict (const txeo::Tensor< T > &input) const override
 Makes predictions using learned weights.
 
T learning_rate () const
 Gets current learning rate.
 
void set_learning_rate (T learning_rate)
 Sets learning rate for gradient descent.
 
void enable_variable_lr ()
 Enables adaptive learning rate adjustment (Barzilai-Borwein Method). When enabled, learning rate automatically reduces when loss plateaus. For the majority of the cases, convergence drastically increases.
 
void disable_variable_lr ()
 Disables adaptive learning rate adjustment (Barzilai-Borwein Method)
 
const txeo::Matrix< T > & weight_bias () const
 Gets weight/bias matrix related to the minimum loss during fit.
 
T tolerance () const
 Gets convergence tolerance.
 
void set_tolerance (const T &tolerance)
 Sets convergence tolerance.
 
bool is_converged () const
 Checks convergence status.
 
T min_loss () const
 Gets the minimum loss during training.
 
- Public Member Functions inherited from txeo::Trainer< T >
 Trainer (const Trainer &)=delete
 
 Trainer (Trainer &&)=delete
 
Traineroperator= (const Trainer &)=delete
 
Traineroperator= (Trainer &&)=delete
 
virtual ~Trainer ()=default
 
 Trainer (txeo::DataTable< T > &&data, txeo::Logger &logger=txeo::LoggerConsole::instance())
 Construct a new Trainer object from a data table.
 
 Trainer (const txeo::DataTable< T > &data)
 
virtual void fit (size_t epochs, txeo::LossFunc metric)
 Trains the model for specified number of epochs.
 
virtual void fit (size_t epochs, txeo::LossFunc metric, size_t patience)
 Trains with early stopping based on validation performance.
 
virtual void fit (size_t epochs, txeo::LossFunc metric, size_t patience, txeo::NormalizationType type)
 Trains with early stopping based on validation performance and feature normalization.
 
T compute_test_loss (txeo::LossFunc metric) const
 Evaluates test data based on a specified metric.
 
bool is_trained () const
 Checks if model has been trained.
 
const txeo::DataTable< T > & data_table () const
 Returns the data table of this trainer.
 
void enable_feature_norm (txeo::NormalizationType type)
 Enable feature normalization.
 
void disable_feature_norm ()
 Disable feature normalization.
 

Additional Inherited Members

- Protected Member Functions inherited from txeo::Trainer< T >
 Trainer ()=default
 
- Protected Attributes inherited from txeo::Trainer< T >
bool _is_trained {false}
 
bool _is_early_stop {false}
 
size_t _patience {0}
 
std::unique_ptr< txeo::DataTable< T > > _data_table
 
txeo::Logger_logger {nullptr}
 
txeo::DataTableNorm< T_data_table_norm
 
bool _is_norm_enabled {false}
 

Detailed Description

template<typename T>
requires (std::floating_point<T>)
class txeo::OlsGDTrainer< T >

Ordinary Least Squares trainer using Gradient Descent optimization.

Template Parameters
TFloating-point type for calculations (float, double, etc.)

Implements linear regression training through gradient descent with:

  • Configurable learning rate
  • Convergence tolerance
  • Variable learning rate support (Barzilai-Borwein Method)
  • Weight/bias matrix access

Inherits from txeo::Trainer<T> and implements required virtual methods.

Implements algorithms based on paper: Algarte, R.D., "Tensor-Based Foundations of Ordinary Least Squares and Neural Network Regression Models" (https://arxiv.org/abs/2411.12873)

Example Usage:

// Create training data (y = 2x + 1)
txeo::Matrix<double> X({{1.0}, {2.0}, {3.0}}); // 3x1
txeo::Matrix<double> y({{3.0}, {5.0}, {7.0}}); // 3x1
OlsGDTrainer<double> trainer(X, y);
trainer.set_tolerance(1e-5);
// Train with early stopping
trainer.fit(1000, LossFunc::MSE, 10);
if(trainer.is_converged()) {
auto weights = trainer.weight_bias();
std::cout << "Model: y = " << weights(0,0) << "x + " << weights(1,0) << std::endl;
// Make prediction
txeo::Matrix<double> test_input(1,1,{4.0});
auto prediction = trainer.predict(test_input);
std::cout << "Prediction for x=4: " << prediction(0,0) << std::endl;
}
A container for managing training, evaluation, and test data splits.
Definition DataTable.h:24

Definition at line 60 of file OlsGDTrainer.h.

Constructor & Destructor Documentation

◆ OlsGDTrainer() [1/4]

template<typename T >
txeo::OlsGDTrainer< T >::OlsGDTrainer ( const OlsGDTrainer< T > &  )
delete

◆ OlsGDTrainer() [2/4]

template<typename T >
txeo::OlsGDTrainer< T >::OlsGDTrainer ( OlsGDTrainer< T > &&  )
delete

◆ ~OlsGDTrainer()

template<typename T >
txeo::OlsGDTrainer< T >::~OlsGDTrainer ( )
default

◆ OlsGDTrainer() [3/4]

template<typename T >
txeo::OlsGDTrainer< T >::OlsGDTrainer ( txeo::DataTable< T > &&  data)
inline

Construct a new OlsGD Trainer object from a data table.

Parameters
dataTraining/Evaluation/Test data

Definition at line 73 of file OlsGDTrainer.h.

73: txeo::Trainer<T>(std::move(data)) {};

◆ OlsGDTrainer() [4/4]

template<typename T >
txeo::OlsGDTrainer< T >::OlsGDTrainer ( const txeo::DataTable< T > &  data)
inline

Definition at line 75 of file OlsGDTrainer.h.

75: txeo::Trainer<T>(data) {};

Member Function Documentation

◆ disable_variable_lr()

template<typename T >
void txeo::OlsGDTrainer< T >::disable_variable_lr ( )
inline

Disables adaptive learning rate adjustment (Barzilai-Borwein Method)

Definition at line 113 of file OlsGDTrainer.h.

113{ _variable_lr = false; }

◆ enable_variable_lr()

template<typename T >
void txeo::OlsGDTrainer< T >::enable_variable_lr ( )
inline

Enables adaptive learning rate adjustment (Barzilai-Borwein Method). When enabled, learning rate automatically reduces when loss plateaus. For the majority of the cases, convergence drastically increases.

Definition at line 108 of file OlsGDTrainer.h.

108{ _variable_lr = true; }

◆ is_converged()

template<typename T >
bool txeo::OlsGDTrainer< T >::is_converged ( ) const
inline

Checks convergence status.

Returns
true if training converged before max epochs

Definition at line 143 of file OlsGDTrainer.h.

143{ return _is_converged; }

◆ learning_rate()

template<typename T >
T txeo::OlsGDTrainer< T >::learning_rate ( ) const

Gets current learning rate.

Returns
Current learning rate value

◆ min_loss()

template<typename T >
T txeo::OlsGDTrainer< T >::min_loss ( ) const

Gets the minimum loss during training.

Returns
Value of the minimum loss

◆ operator=() [1/2]

template<typename T >
OlsGDTrainer & txeo::OlsGDTrainer< T >::operator= ( const OlsGDTrainer< T > &  )
delete

◆ operator=() [2/2]

template<typename T >
OlsGDTrainer & txeo::OlsGDTrainer< T >::operator= ( OlsGDTrainer< T > &&  )
delete

◆ predict()

template<typename T >
txeo::Tensor< T > txeo::OlsGDTrainer< T >::predict ( const txeo::Tensor< T > &  input) const
overridevirtual

Makes predictions using learned weights.

Parameters
inputFeature matrix (shape: [samples, features])
Returns
Prediction matrix (shape: [samples, outputs])
Exceptions
OlsGDTrainerError

Implements txeo::Trainer< T >.

◆ set_learning_rate()

template<typename T >
void txeo::OlsGDTrainer< T >::set_learning_rate ( T  learning_rate)

Sets learning rate for gradient descent.

Parameters
learning_rateMust be > 0
Exceptions
OlsGDTrainerErrorfor invalid values

◆ set_tolerance()

template<typename T >
void txeo::OlsGDTrainer< T >::set_tolerance ( const T tolerance)

Sets convergence tolerance.

Parameters
toleranceMinimum loss difference to consider converged (>0)

◆ tolerance()

template<typename T >
T txeo::OlsGDTrainer< T >::tolerance ( ) const
inline

Gets convergence tolerance.

Returns
Current tolerance value

Definition at line 129 of file OlsGDTrainer.h.

129{ return _tolerance; }

◆ weight_bias()

template<typename T >
const txeo::Matrix< T > & txeo::OlsGDTrainer< T >::weight_bias ( ) const

Gets weight/bias matrix related to the minimum loss during fit.

Returns
Matrix containing model parameters (shape: [features+1, outputs])
Exceptions
OlsGDTrainerError

The documentation for this class was generated from the following file: