diff --git a/include/bmath.hh b/include/bmath.hh index 57d60e8..8422c85 100644 --- a/include/bmath.hh +++ b/include/bmath.hh @@ -19,37 +19,79 @@ class mint { [[nodiscard]] constexpr explicit mint(T t) : value{t % M} {} [[nodiscard]] constexpr explicit operator T() const noexcept { return value; } - template OtherT, T OtherM> - requires(M == OtherM) - [[nodiscard]] constexpr mint operator+( - mint const other) const noexcept { - if constexpr (M != 0 && OtherM != 0) { - static_assert(M == OtherM, - "Cannot add integral types with differing moduli"); - } else if (M != 0 && OtherM != 0) { - assert(M == OtherM && "Cannot add integral types with differing moduli"); - } - return mint{value + other.get()}; + [[nodiscard]] constexpr mint operator+( + mint const other) const noexcept { + T result{add(static_cast(other.get()))}; + + return mint{result}; + } + [[nodiscard]] constexpr mint& operator+=(mint const other) const noexcept { + value += other.value; + + return *this; + } + + [[nodiscard]] constexpr mint operator-( + mint const other) const noexcept { + T result{sub(static_cast(other.get()))}; + + return mint{result}; + } + + [[nodiscard]] constexpr mint operator*(mint const other) const noexcept { + T result{mul(static_cast(other.get()))}; + + return mint{result}; + } + + [[nodiscard]] constexpr mint& operator*=(mint const other) noexcept { + value *= other.value; + + return *this; + } + + [[nodiscard]] constexpr mint operator/(mint const other) const noexcept { + if constexpr (other.get() == 0) { + static_assert(false, "Cannot divide by 0"); + } else if (other.get() == 0) { + throw std::domain_error("Cannot divide by 0"); + } + + T result{div(static_cast(other.get()))}; + + return mint{result}; + } + [[nodiscard]] constexpr mint& operator/=(mint const other) noexcept { + if constexpr (other.get() == 0) { + static_assert(false, "Cannot divide by 0"); + } else if (other.get() == 0) { + throw std::domain_error("Cannot divide by 0"); + } + + value /= other.value; + + return *this; + } + + [[nodiscard]] constexpr mint operator%(mint const other) const noexcept { + T result = mod(other); + + return mint{result}; + } + + [[nodiscard]] constexpr mint& operator%=(mint const other) const noexcept { + value %= other.value; + + return *this; } - // [[nodiscard]] constexpr mint& operator-(mint const other) const noexcept {} - // [[nodiscard]] constexpr mint& operator*(mint const other) const noexcept {} - // [[nodiscard]] constexpr mint& operator/(mint const other) const noexcept {} - // [[nodiscard]] constexpr mint& operator%(mint const other) const noexcept {} [[nodiscard]] constexpr T get() const noexcept { return value; } - template OtherT, OtherT OtherM> - [[nodiscard]] constexpr bool operator==( - mint const other) const noexcept { + [[nodiscard]] constexpr bool operator==(mint const other) const noexcept { return get() == static_cast(other.get()); } - template OtherT> - [[nodiscard]] constexpr bool operator==(OtherT const other) const noexcept { - return get() == static_cast(other); - } - - template OtherT, T OtherM> + template OtherT, OtherT OtherM> friend std::ostream& operator<<(std::ostream& out, mint const other) { return out << other.get(); @@ -57,10 +99,30 @@ class mint { private: T value{}; + + [[nodiscard]] constexpr T add(T other) const noexcept { + return (get() + other) % M; + } + + [[nodiscard]] constexpr T sub(T other) const noexcept { + return (get() - other + M) % M; + } + + [[nodiscard]] constexpr T mul(T other) const noexcept { + return (get() * other) % M; + } + + [[nodiscard]] constexpr T div(T other) const noexcept { + return get() / other; + } + + [[nodiscard]] constexpr T mod(T other) const noexcept { + return get() % other; + } }; template - requires(M >= 0) + requires(M > 0) [[nodiscard]] static constexpr bmath::mint pow(mint base, U exponent) { if (exponent < 0) { @@ -69,7 +131,7 @@ template base, exponent, exponent)); } - if (base == 0) { + if (base.get() == 0) { if (exponent == 0) { throw std::domain_error("pow(0, 0) is indeterminate"); } @@ -77,17 +139,21 @@ template return mint{0}; } - if (base == 1 || exponent == 0) { + if (base.get() == 1 || exponent == 0) { return mint{1}; } - T t{}; + mint t{1}; while (exponent > 0) { + if (exponent & 1) { + t *= base; + } + base *= base; exponent >>= 1; } - return mint{t}; + return t; } template diff --git a/tests/test_add.cc b/tests/test_add.cc index f35df1d..3377ea5 100644 --- a/tests/test_add.cc +++ b/tests/test_add.cc @@ -1,6 +1,6 @@ #include -#include #include +#include #include "../include/bmath.hh" @@ -13,20 +13,24 @@ int main() { constexpr mint mintfive{five}; constexpr auto mintnine = mintfour + mintfive; - static_assert(mintnine == four + five); + // static_assert(mintnine == four + five); - // static_assert(4 + 5 == mint{9}); - static_assert(mint{8} == 4 + 4); + // static_assert(4 + 5 == mint{9}); + static_assert(mint{8} == + mint{4} + mint{4}); - static_assert(is_trivially_copyable_v>); + static_assert(is_trivially_copyable_v>); - pow(mint{2}, 0); + // pow(mint{2}, 0); - // cout << (std::format("x: {}\n", mintfour)); + // cout << (std::format("x: {}\n", mintfour)); - // auto res = mint{4} + mint{4}; + // auto res = mint{4} + mint{4}; - cout << (mint{5} + mint{3}); + // cout << (mint{5} + mint{7}); + // cout << pow(mint{4}, 5); + + cout << (pow(mint{5}, 5)); return 0; }