Txeo v0.1
A Developer-Friendly TensorFlow C++ Wrapper
Loading...
Searching...
No Matches
OlsGDTrainer.h
Go to the documentation of this file.
1#ifndef OLSGDTRAINER_H
2#define OLSGDTRAINER_H
3#pragma once
4
5#include "txeo/Matrix.h"
6#include "txeo/Tensor.h"
7#include "txeo/TensorShape.h"
8#include "txeo/Trainer.h"
9
10#include <concepts>
11#include <cstddef>
12#include <stdexcept>
13
14namespace txeo {
15enum class LossFunc;
16
58template <typename T>
59 requires(std::floating_point<T>)
61 public:
62 OlsGDTrainer(const OlsGDTrainer &) = delete;
66 ~OlsGDTrainer() = default;
67
73 OlsGDTrainer(txeo::DataTable<T> &&data) : txeo::Trainer<T>(std::move(data)) {};
74
75 OlsGDTrainer(const txeo::DataTable<T> &data) : txeo::Trainer<T>(data) {};
76
86
93
101 void set_learning_rate(T learning_rate);
102
108 void enable_variable_lr() { _variable_lr = true; }
109
113 void disable_variable_lr() { _variable_lr = false; }
114
123
129 T tolerance() const { return _tolerance; }
130
136 void set_tolerance(const T &tolerance);
137
143 [[nodiscard]] bool is_converged() const { return _is_converged; }
144
150 T min_loss() const;
151
152 private:
153 T _learning_rate{0.01};
154 T _tolerance{0.001};
155 T _min_loss{0};
156 txeo::Matrix<T> _weight_bias{};
157 bool _variable_lr{false};
158 bool _is_converged{false};
159
160 OlsGDTrainer() = default;
161 void train(size_t epochs, txeo::LossFunc metric) override;
162};
163
168class OlsGDTrainerError : public std::runtime_error {
169 public:
170 using std::runtime_error::runtime_error;
171};
172
173} // namespace txeo
174
175#endif
A container for managing training, evaluation, and test data splits.
Definition DataTable.h:24
Exceptions concerning txeo::OlsGDTrainer.
Ordinary Least Squares trainer using Gradient Descent optimization.
void enable_variable_lr()
Enables adaptive learning rate adjustment (Barzilai-Borwein Method). When enabled,...
T tolerance() const
Gets convergence tolerance.
OlsGDTrainer(txeo::DataTable< T > &&data)
Construct a new OlsGD Trainer object from a data table.
OlsGDTrainer(OlsGDTrainer &&)=delete
T learning_rate() const
Gets current learning rate.
T min_loss() const
Gets the minimum loss during training.
void set_learning_rate(T learning_rate)
Sets learning rate for gradient descent.
OlsGDTrainer & operator=(const OlsGDTrainer &)=delete
OlsGDTrainer(const OlsGDTrainer &)=delete
~OlsGDTrainer()=default
bool is_converged() const
Checks convergence status.
OlsGDTrainer(const txeo::DataTable< T > &data)
void disable_variable_lr()
Disables adaptive learning rate adjustment (Barzilai-Borwein Method)
const txeo::Matrix< T > & weight_bias() const
Gets weight/bias matrix related to the minimum loss during fit.
OlsGDTrainer & operator=(OlsGDTrainer &&)=delete
txeo::Tensor< T > predict(const txeo::Tensor< T > &input) const override
Makes predictions using learned weights.
void set_tolerance(const T &tolerance)
Sets convergence tolerance.
Abstract base class for machine learning trainers.
Definition Trainer.h:32
LossFunc
Types of loss functions.
Definition types.h:44