Txeo v0.1
A Developer-Friendly TensorFlow C++ Wrapper
Loading...
Searching...
No Matches
Loss.h
Go to the documentation of this file.
1#ifndef LOSS_H
2#define LOSS_H
3#pragma once
4
5#include "txeo/Tensor.h"
6#include "txeo/types.h"
7
8#include <functional>
9#include <stdexcept>
10namespace txeo {
11template <typename T>
12class Tensor;
13
29template <typename T>
30class Loss {
31 public:
32 Loss(const Loss &) = default;
33 Loss(Loss &&) = default;
34 Loss &operator=(const Loss &) = default;
35 Loss &operator=(Loss &&) = default;
36 ~Loss() = default;
37
51
66
81
94
104
114
125
135
136 const txeo::Tensor<T> &label() const { return _label; }
137
140 T mse(const txeo::Tensor<T> &pred) const {
141 return mean_squared_error(pred);
142 };
143
144 T mae(const txeo::Tensor<T> &pred) const {
146 };
147
148 T msle(const txeo::Tensor<T> &pred) const {
150 };
151
152 T lche(const txeo::Tensor<T> &pred) const {
153 return log_cosh_error(pred);
154 };
156
157 private:
158 Loss() = default;
159 txeo::Tensor<T> _label{};
160
161 void verify_parameter(const txeo::Tensor<T> &pred) const;
162 std::function<T(const txeo::Tensor<T> &)> _loss_func;
163};
164
169class LossError : public std::runtime_error {
170 public:
171 using std::runtime_error::runtime_error;
172};
173
174} // namespace txeo
175
176#endif
A container for managing training, evaluation, and test data splits.
Definition DataTable.h:24
Exceptions concerning txeo::Loss.
Definition Loss.h:169
Computes error metrics between predicted and validation tensors.
Definition Loss.h:30
Loss(const txeo::Tensor< T > &label, txeo::LossFunc func=txeo::LossFunc::MSE)
Construct a new Loss object.
Definition Loss.h:64
void set_loss(txeo::LossFunc func)
Set the active loss function.
T msle(const txeo::Tensor< T > &pred) const
Definition Loss.h:148
T mean_squared_logarithmic_error(const txeo::Tensor< T > &pred) const
Compute Mean Squared Logarithmic Error (MSLE)
T mse(const txeo::Tensor< T > &pred) const
Definition Loss.h:140
Loss(Loss &&)=default
Loss(txeo::Tensor< T > &&label, txeo::LossFunc func=txeo::LossFunc::MSE)
Construct a new Loss object from a label rvalue.
T get_loss(const txeo::Tensor< T > &pred) const
Compute loss using currently selected function.
Loss & operator=(Loss &&)=default
T mean_absolute_error(const txeo::Tensor< T > &pred) const
Compute Mean Absolute Error (MAE)
T lche(const txeo::Tensor< T > &pred) const
Definition Loss.h:152
T mae(const txeo::Tensor< T > &pred) const
Definition Loss.h:144
~Loss()=default
const txeo::Tensor< T > & label() const
Definition Loss.h:136
T mean_squared_error(const txeo::Tensor< T > &pred) const
Compute Mean Squared Error (MSE)
Loss & operator=(const Loss &)=default
T log_cosh_error(const txeo::Tensor< T > &pred) const
Compute Log-Cosh Error (LCHE)
Loss(const Loss &)=default
LossFunc
Types of loss functions.
Definition types.h:44