Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions stan/math/fwd/fun/log_softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace math {
* @tparam T `std::vector` whose scalar type is `fvar`
* @param x container of vectors to transform
* @return container of log softmax results
* @throw std::invalid_argument if any input vector is empty
*/
template <typename T, require_std_vector_st<is_fvar, T>* = nullptr>
inline auto log_softmax(T&& x) {
Expand All @@ -33,14 +34,14 @@ inline auto log_softmax(T&& x) {
* @tparam Vec Eigen vector with `fvar` scalar
* @param x vector to transform
* @return log softmax of the vector
* @throw std::domain_error if the input size is 0
* @throw std::invalid_argument if the input size is 0
*/
template <typename Vec, require_eigen_vector_vt<is_fvar, Vec>* = nullptr>
inline auto log_softmax(Vec&& x) {
using vec = std::decay_t<Vec>;
constexpr int Rows = vec::RowsAtCompileTime;
constexpr int Cols = vec::ColsAtCompileTime;
using T = typename value_type_t<Vec>::Scalar;
using T = typename value_type_t<vec>::Scalar;
check_nonzero_size("log_softmax", "x", x);
decltype(auto) x_ref = to_ref(std::forward<Vec>(x));
const auto s = softmax(value_of(x_ref));
Expand Down
7 changes: 4 additions & 3 deletions stan/math/fwd/fun/softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/fwd/core.hpp>
#include <stan/math/fwd/fun/value_of.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/fun/softmax.hpp>
#include <stan/math/prim/functor/apply_vector_unary.hpp>
Expand All @@ -17,6 +18,7 @@ namespace math {
* @tparam T `std::vector` whose scalar type is `fvar`
* @param x container of vectors to transform
* @return container of softmax results
* @throw std::invalid_argument if any input vector is empty
*/
template <typename T, require_std_vector_st<is_fvar, T>* = nullptr>
inline auto softmax(T&& x) {
Expand All @@ -31,16 +33,15 @@ inline auto softmax(T&& x) {
* @tparam Vec Eigen vector with `fvar` scalar
* @param x vector to transform
* @return softmax of the vector
* @throw std::invalid_argument if the input size is 0
*/
template <typename Vec, require_eigen_vector_vt<is_fvar, Vec>* = nullptr>
inline auto softmax(Vec&& x) {
using vec = std::decay_t<Vec>;
constexpr int Rows = vec::RowsAtCompileTime;
constexpr int Cols = vec::ColsAtCompileTime;
using T = typename value_type_t<vec>::Scalar;
if (x.size() == 0) {
return Eigen::Matrix<fvar<T>, Rows, Cols>();
}
check_nonzero_size("softmax", "x", x);
decltype(auto) x_ref = to_ref(std::forward<Vec>(x));
const auto s = softmax(value_of(x_ref));
const auto d_in = x_ref.d();
Expand Down
5 changes: 2 additions & 3 deletions stan/math/opencl/prim/softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <stan/math/opencl/ref_type.hpp>
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err/check_matching_sizes.hpp>
#include <stan/math/prim/err/check_nonzero_size.hpp>
#include <stan/math/prim/fun/to_ref.hpp>

namespace stan {
Expand All @@ -22,9 +23,7 @@ template <typename T,
require_all_kernel_expressions_and_none_scalar_t<T>* = nullptr>
inline matrix_cl<double> softmax(const T& a) {
check_vector("softmax (OpenCL)", "a", a);
if (a.size() == 0) {
return a;
}
check_nonzero_size("softmax", "a", a);
Comment thread
jachymb marked this conversation as resolved.
Outdated
matrix_cl<double> theta;
if constexpr (stan::internal::is_trivial_kg_expression<T>::value) {
matrix_cl<double> a_max = max_2d(a);
Expand Down
5 changes: 2 additions & 3 deletions stan/math/opencl/rev/softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <stan/math/opencl/prim/dot_product.hpp>
#include <stan/math/opencl/prim/softmax.hpp>
#include <stan/math/opencl/kernel_generator.hpp>
#include <stan/math/prim/err/check_nonzero_size.hpp>
#include <stan/math/rev/core.hpp>
#include <stan/math/rev/fun/value_of.hpp>

Expand All @@ -22,9 +23,7 @@ namespace math {
template <typename T,
require_all_kernel_expressions_and_none_scalar_t<T>* = nullptr>
inline var_value<matrix_cl<double>> softmax(const var_value<T>& A) {
if (A.size() == 0) {
return A;
}
check_nonzero_size("softmax", "A", A);
return make_callback_var(
softmax(A.val()), [A](vari_value<matrix_cl<double>>& res) mutable {
A.adj() += elt_multiply(
Expand Down
10 changes: 6 additions & 4 deletions stan/math/prim/fun/log_softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ namespace math {
*
* @tparam Container type of input: an Eigen vector, `std::vector` of doubles,
* or nested container whose scalar type is arithmetic
* @param[in] x vector or container of vectors to transform
* @param x vector or container of vectors to transform
* @return log softmax of the input, preserving the container structure
* @throw std::domain_error if any input vector is empty
* @throw std::invalid_argument if any input vector is empty
*/
template <typename Container, require_st_arithmetic<Container>* = nullptr,
require_container_t<Container>* = nullptr,
Expand All @@ -51,8 +51,10 @@ inline auto log_softmax(Container&& x) {
return make_holder(
[](auto&& a) {
return apply_vector_unary<ref_type_t<Container>>::apply(
std::forward<decltype(a)>(a),
[](auto&& v) { return v.array() - log_sum_exp(v); });
std::forward<decltype(a)>(a), [](auto&& v) {
check_nonzero_size("log_softmax", "v", v);
return v.array() - log_sum_exp(v);
});
},
to_ref(std::forward<Container>(x)));
}
Expand Down
54 changes: 24 additions & 30 deletions stan/math/prim/fun/softmax.hpp
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
#ifndef STAN_MATH_PRIM_FUN_SOFTMAX_HPP
#define STAN_MATH_PRIM_FUN_SOFTMAX_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/functor/apply_vector_unary.hpp>
#include <cmath>

namespace stan {
namespace math {

/**
* Return the softmax of the specified vector.
* Return the softmax of the specified vector, or of each vector in a container.
*
* <p>
* \f$
* \mbox{softmax}(y)
* = \frac{\exp(y)}
Expand All @@ -39,36 +38,31 @@ namespace math {
* \end{array}
* \f$
*
* @tparam Vec type of the input vector
* @param[in] v Vector to transform.
* @return Unit simplex result of the softmax transform of the vector.
* @tparam Container type of input: an Eigen vector, `std::vector` of doubles,
* or nested container whose scalar type is arithmetic
* @param x vector or container of vectors to transform
* @return softmax of the input, preserving the container structure
* @throw std::invalid_argument if any input vector is empty
*/
template <typename Vec,
require_eigen_vector_vt<std::is_arithmetic, Vec>* = nullptr>
inline plain_type_t<Vec> softmax(Vec&& v) {
if (v.size() == 0) {
return v;
}
decltype(auto) v_ref = to_ref(std::forward<Vec>(v));
const auto theta = (v_ref.array() - v_ref.maxCoeff()).exp();
return (theta / theta.sum()).matrix();
}

/**
* Return the softmax of each vector in an array.
*
* @tparam T `std::vector` whose scalar type is arithmetic
* @param[in] x Array of vectors to transform.
* @return Array of unit simplex results.
*/
template <typename T, require_std_vector_st<std::is_arithmetic, T>* = nullptr>
inline auto softmax(T&& x) {
return apply_vector_unary<T>::apply(std::forward<T>(x), [](auto&& v) {
return softmax(std::forward<decltype(v)>(v));
});
template <typename Container, require_st_arithmetic<Container>* = nullptr,
require_container_t<Container>* = nullptr,
require_not_t<bool_constant<
is_eigen<std::decay_t<Container>>::value
&& !is_eigen_vector<std::decay_t<Container>>::value>>* = nullptr>
inline auto softmax(Container&& x) {
check_nonzero_size("softmax", "x", x);
return make_holder(
[](auto&& a) {
return apply_vector_unary<ref_type_t<Container>>::apply(
std::forward<decltype(a)>(a), [](auto&& v) {
check_nonzero_size("softmax", "v", v);
const auto theta = (v.array() - v.maxCoeff()).exp();
return (theta / theta.sum()).matrix();
});
},
to_ref(std::forward<Container>(x)));
}

} // namespace math
} // namespace stan

#endif
4 changes: 2 additions & 2 deletions stan/math/rev/fun/log_softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace math {
* @tparam T a `var_value` or Eigen vector/row_vector with `var` scalar
* @param x input
* @return log softmax of the input
* @throw std::domain_error if the input size is 0
* @throw std::invalid_argument if the input size is 0
*/
template <typename T, require_rev_matrix_t<T>* = nullptr>
inline auto log_softmax(T&& x) {
Expand All @@ -42,7 +42,7 @@ inline auto log_softmax(T&& x) {
* @tparam T `std::vector` whose scalar type is `var`
* @param x array of vectors to transform
* @return array of log softmax results
* @throw std::domain_error if any element size is 0
* @throw std::invalid_argument if any input vector is empty
*/
template <typename T, require_std_vector_st<is_var, T>* = nullptr>
inline auto log_softmax(T&& x) {
Expand Down
7 changes: 4 additions & 3 deletions stan/math/rev/fun/softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <stan/math/rev/core/reverse_pass_callback.hpp>
#include <stan/math/rev/core/arena_matrix.hpp>
#include <stan/math/rev/fun/to_arena.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/fun/softmax.hpp>
#include <stan/math/prim/functor/apply_vector_unary.hpp>
Expand All @@ -19,15 +20,14 @@ namespace math {
* @tparam T a `var_value` or Eigen vector/row_vector with `var` scalar
* @param x input
* @return softmax of the input
* @throw std::invalid_argument if the input size is 0
*/
template <typename T, require_rev_matrix_t<T>* = nullptr>
inline auto softmax(T&& x) {
check_nonzero_size("softmax", "x", x);
auto x_arena = to_arena(std::forward<T>(x));
using return_t
= return_var_matrix_t<plain_type_t<decltype(x_arena.val())>, T>;
if (x_arena.size() == 0) {
return x_arena;
}
arena_t<return_t> res = softmax(x_arena.val());
reverse_pass_callback([x_arena, res]() mutable {
x_arena.adj().array()
Expand All @@ -42,6 +42,7 @@ inline auto softmax(T&& x) {
* @tparam T `std::vector` whose scalar type is `var`
* @param x array of vectors to transform
* @return array of softmax results
* @throw std::invalid_argument if any input vector is empty
*/
template <typename T, require_std_vector_st<is_var, T>* = nullptr>
inline auto softmax(T&& x) {
Expand Down
4 changes: 2 additions & 2 deletions test/unit/math/mix/fun/softmax_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ TEST(MathMixMatFun, softmax) {
tols.hessian_fvar_hessian_ = 1e-2;

// Column vectors
Eigen::VectorXd a(0);
Eigen::VectorXd a(0); // error case
stan::test::expect_ad(tols, f, a);
expect_ad_matvar(f, a);
Eigen::VectorXd b(1);
Expand Down Expand Up @@ -44,7 +44,7 @@ TEST(MathMixMatFun, softmax) {
expect_ad_matvar(f, d4);

// Row vectors
Eigen::RowVectorXd ra(0);
Eigen::RowVectorXd ra(0); // error case
stan::test::expect_ad(tols, f, ra);
expect_ad_matvar(f, ra);

Expand Down
5 changes: 5 additions & 0 deletions test/unit/math/opencl/rev/log_softmax_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ TEST(OpenCLLogSoftmax, prim_rev_size_1) {
stan::math::test::compare_cpu_opencl_prim_rev(log_softmax_functor, a);
}

TEST(OpenCLLogSoftmax, prim_rev_size_0_throws) {
Eigen::VectorXd a(0);
EXPECT_THROW(stan::math::log_softmax(a), std::invalid_argument);
}

TEST(OpenCLLogSoftmax, prim_rev_values_large) {
int N = 71;

Expand Down
8 changes: 3 additions & 5 deletions test/unit/math/opencl/rev/softmax_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@ TEST(OpenCLSoftmax, prim_rev_values_small) {
stan::math::test::compare_cpu_opencl_prim_rev(softmax_functor, a);
}

TEST(OpenCLSoftmax, prim_rev_size_0) {
int N = 0;

Eigen::VectorXd a(N);
stan::math::test::compare_cpu_opencl_prim_rev(softmax_functor, a);
TEST(OpenCLSoftmax, prim_rev_size_0_throws) {
Eigen::VectorXd a(0);
EXPECT_THROW(stan::math::softmax(a), std::invalid_argument);
}

TEST(OpenCLSoftmax, prim_rev_values_large) {
Expand Down
7 changes: 7 additions & 0 deletions test/unit/math/prim/fun/softmax_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ TEST(MathMatrixPrimMat, softmax_neg_inf) {
EXPECT_FLOAT_EQ(1.0, theta.sum());
}

TEST(MathMatrixPrimMat, softmax_exception) {
using stan::math::softmax;
Eigen::Matrix<double, Eigen::Dynamic, 1> v0; // size == 0

EXPECT_THROW(softmax(v0), std::invalid_argument);
}

TEST(MathMatrixPrimMat, softmax_row_vector) {
using Eigen::Dynamic;
using Eigen::Matrix;
Expand Down
Loading