Txeo v0.1
A Developer-Friendly TensorFlow C++ Wrapper
Loading...
Searching...
No Matches
Trainer.h
Go to the documentation of this file.
1#ifndef TRAINER_H
2#define TRAINER_H
3#include "txeo/Logger.h"
5#pragma once
6
7#include "txeo/DataTable.h"
9#include "txeo/Tensor.h"
10#include "txeo/types.h"
11
12#include <cstddef>
13#include <stdexcept>
14
15namespace txeo {
16enum class LossFunc;
17
31template <typename T>
32class Trainer {
33 public:
34 Trainer(const Trainer &) = delete;
35 Trainer(Trainer &&) = delete;
36 Trainer &operator=(const Trainer &) = delete;
37 Trainer &operator=(Trainer &&) = delete;
38 virtual ~Trainer() = default;
39
47
48 Trainer(const txeo::DataTable<T> &data) : Trainer{data.clone()} {};
49
56 virtual void fit(size_t epochs, txeo::LossFunc metric);
57
65 virtual void fit(size_t epochs, txeo::LossFunc metric, size_t patience);
66
75 virtual void fit(size_t epochs, txeo::LossFunc metric, size_t patience,
77
87
98 virtual txeo::Tensor<T> predict(const txeo::Tensor<T> &input) const = 0;
99
105 [[nodiscard]] bool is_trained() const { return _is_trained; }
106
112 const txeo::DataTable<T> &data_table() const { return *_data_table; }
113
120
127
128 protected:
129 Trainer() = default;
130
131 bool _is_trained{false};
132 bool _is_early_stop{false};
133 size_t _patience{0};
134
135 std::unique_ptr<txeo::DataTable<T>> _data_table;
137
139 bool _is_norm_enabled{false};
140
141 virtual void train(size_t epochs, txeo::LossFunc loss_func) = 0;
142};
143
148class TrainerError : public std::runtime_error {
149 public:
150 using std::runtime_error::runtime_error;
151};
152
153} // namespace txeo
154
155#endif
A container for managing training, evaluation, and test data splits.
Definition DataTable.h:24
static LoggerConsole & instance()
Access the singleton instance.
Abstract base class for logging subsystems.
Definition Logger.h:45
Exceptions concerning txeo::OlsGDTrainer.
Definition Trainer.h:148
Abstract base class for machine learning trainers.
Definition Trainer.h:32
txeo::Logger * _logger
Definition Trainer.h:136
Trainer & operator=(Trainer &&)=delete
txeo::DataTableNorm< T > _data_table_norm
Definition Trainer.h:138
Trainer()=default
virtual void fit(size_t epochs, txeo::LossFunc metric, size_t patience)
Trains with early stopping based on validation performance.
bool _is_trained
Definition Trainer.h:131
Trainer(const Trainer &)=delete
Trainer(Trainer &&)=delete
size_t _patience
Definition Trainer.h:133
bool _is_early_stop
Definition Trainer.h:132
virtual void train(size_t epochs, txeo::LossFunc loss_func)=0
Trainer(const txeo::DataTable< T > &data)
Definition Trainer.h:48
void disable_feature_norm()
Disable feature normalization.
Definition Trainer.h:126
virtual ~Trainer()=default
Trainer & operator=(const Trainer &)=delete
Trainer(txeo::DataTable< T > &&data, txeo::Logger &logger=txeo::LoggerConsole::instance())
Construct a new Trainer object from a data table.
Definition Trainer.h:45
std::unique_ptr< txeo::DataTable< T > > _data_table
Definition Trainer.h:135
bool is_trained() const
Checks if model has been trained.
Definition Trainer.h:105
virtual txeo::Tensor< T > predict(const txeo::Tensor< T > &input) const =0
Makes predictions using the trained model (pure virtual)
const txeo::DataTable< T > & data_table() const
Returns the data table of this trainer.
Definition Trainer.h:112
void enable_feature_norm(txeo::NormalizationType type)
Enable feature normalization.
T compute_test_loss(txeo::LossFunc metric) const
Evaluates test data based on a specified metric.
virtual void fit(size_t epochs, txeo::LossFunc metric, size_t patience, txeo::NormalizationType type)
Trains with early stopping based on validation performance and feature normalization.
virtual void fit(size_t epochs, txeo::LossFunc metric)
Trains the model for specified number of epochs.
bool _is_norm_enabled
Definition Trainer.h:139
NormalizationType
Normalization types to be used in normalization functions.
Definition types.h:38
LossFunc
Types of loss functions.
Definition types.h:44