59 requires(std::floating_point<T>)
153 T _learning_rate{0.01};
157 bool _variable_lr{
false};
158 bool _is_converged{
false};
160 OlsGDTrainer() =
default;
170 using std::runtime_error::runtime_error;
A container for managing training, evaluation, and test data splits.
Exceptions concerning txeo::OlsGDTrainer.
Ordinary Least Squares trainer using Gradient Descent optimization.
void enable_variable_lr()
Enables adaptive learning rate adjustment (Barzilai-Borwein Method). When enabled,...
T tolerance() const
Gets convergence tolerance.
OlsGDTrainer(txeo::DataTable< T > &&data)
Construct a new OlsGD Trainer object from a data table.
OlsGDTrainer(OlsGDTrainer &&)=delete
T learning_rate() const
Gets current learning rate.
T min_loss() const
Gets the minimum loss during training.
void set_learning_rate(T learning_rate)
Sets learning rate for gradient descent.
OlsGDTrainer & operator=(const OlsGDTrainer &)=delete
OlsGDTrainer(const OlsGDTrainer &)=delete
bool is_converged() const
Checks convergence status.
OlsGDTrainer(const txeo::DataTable< T > &data)
void disable_variable_lr()
Disables adaptive learning rate adjustment (Barzilai-Borwein Method)
const txeo::Matrix< T > & weight_bias() const
Gets weight/bias matrix related to the minimum loss during fit.
OlsGDTrainer & operator=(OlsGDTrainer &&)=delete
txeo::Tensor< T > predict(const txeo::Tensor< T > &input) const override
Makes predictions using learned weights.
void set_tolerance(const T &tolerance)
Sets convergence tolerance.
Abstract base class for machine learning trainers.
LossFunc
Types of loss functions.