more methods
This commit is contained in:
parent
12fd46747e
commit
e002fd91d7
2 changed files with 108 additions and 38 deletions
124
include/bmath.hh
124
include/bmath.hh
|
|
@ -19,37 +19,79 @@ class mint {
|
||||||
[[nodiscard]] constexpr explicit mint(T t) : value{t % M} {}
|
[[nodiscard]] constexpr explicit mint(T t) : value{t % M} {}
|
||||||
[[nodiscard]] constexpr explicit operator T() const noexcept { return value; }
|
[[nodiscard]] constexpr explicit operator T() const noexcept { return value; }
|
||||||
|
|
||||||
template <std::convertible_to<T> OtherT, T OtherM>
|
[[nodiscard]] constexpr mint<T, M> operator+(
|
||||||
requires(M == OtherM)
|
mint const other) const noexcept {
|
||||||
[[nodiscard]] constexpr mint operator+(
|
T result{add(static_cast<T>(other.get()))};
|
||||||
mint<OtherT, OtherM> const other) const noexcept {
|
|
||||||
if constexpr (M != 0 && OtherM != 0) {
|
return mint<T, M>{result};
|
||||||
static_assert(M == OtherM,
|
}
|
||||||
"Cannot add integral types with differing moduli");
|
[[nodiscard]] constexpr mint& operator+=(mint const other) const noexcept {
|
||||||
} else if (M != 0 && OtherM != 0) {
|
value += other.value;
|
||||||
assert(M == OtherM && "Cannot add integral types with differing moduli");
|
|
||||||
}
|
return *this;
|
||||||
return mint<T, M>{value + other.get()};
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] constexpr mint<T, M> operator-(
|
||||||
|
mint const other) const noexcept {
|
||||||
|
T result{sub(static_cast<T>(other.get()))};
|
||||||
|
|
||||||
|
return mint<T, M>{result};
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] constexpr mint operator*(mint const other) const noexcept {
|
||||||
|
T result{mul(static_cast<T>(other.get()))};
|
||||||
|
|
||||||
|
return mint<T, M>{result};
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] constexpr mint& operator*=(mint const other) noexcept {
|
||||||
|
value *= other.value;
|
||||||
|
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] constexpr mint operator/(mint const other) const noexcept {
|
||||||
|
if constexpr (other.get() == 0) {
|
||||||
|
static_assert(false, "Cannot divide by 0");
|
||||||
|
} else if (other.get() == 0) {
|
||||||
|
throw std::domain_error("Cannot divide by 0");
|
||||||
|
}
|
||||||
|
|
||||||
|
T result{div(static_cast<T>(other.get()))};
|
||||||
|
|
||||||
|
return mint<T, M>{result};
|
||||||
|
}
|
||||||
|
[[nodiscard]] constexpr mint& operator/=(mint const other) noexcept {
|
||||||
|
if constexpr (other.get() == 0) {
|
||||||
|
static_assert(false, "Cannot divide by 0");
|
||||||
|
} else if (other.get() == 0) {
|
||||||
|
throw std::domain_error("Cannot divide by 0");
|
||||||
|
}
|
||||||
|
|
||||||
|
value /= other.value;
|
||||||
|
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] constexpr mint operator%(mint const other) const noexcept {
|
||||||
|
T result = mod(other);
|
||||||
|
|
||||||
|
return mint<T, M>{result};
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] constexpr mint& operator%=(mint const other) const noexcept {
|
||||||
|
value %= other.value;
|
||||||
|
|
||||||
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
// [[nodiscard]] constexpr mint& operator-(mint const other) const noexcept {}
|
|
||||||
// [[nodiscard]] constexpr mint& operator*(mint const other) const noexcept {}
|
|
||||||
// [[nodiscard]] constexpr mint& operator/(mint const other) const noexcept {}
|
|
||||||
// [[nodiscard]] constexpr mint& operator%(mint const other) const noexcept {}
|
|
||||||
[[nodiscard]] constexpr T get() const noexcept { return value; }
|
[[nodiscard]] constexpr T get() const noexcept { return value; }
|
||||||
|
|
||||||
template <std::convertible_to<T> OtherT, OtherT OtherM>
|
[[nodiscard]] constexpr bool operator==(mint const other) const noexcept {
|
||||||
[[nodiscard]] constexpr bool operator==(
|
|
||||||
mint<OtherT, OtherM> const other) const noexcept {
|
|
||||||
return get() == static_cast<T>(other.get());
|
return get() == static_cast<T>(other.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <std::convertible_to<T> OtherT>
|
template <std::convertible_to<T> OtherT, OtherT OtherM>
|
||||||
[[nodiscard]] constexpr bool operator==(OtherT const other) const noexcept {
|
|
||||||
return get() == static_cast<T>(other);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <std::convertible_to<T> OtherT, T OtherM>
|
|
||||||
friend std::ostream& operator<<(std::ostream& out,
|
friend std::ostream& operator<<(std::ostream& out,
|
||||||
mint<OtherT, OtherM> const other) {
|
mint<OtherT, OtherM> const other) {
|
||||||
return out << other.get();
|
return out << other.get();
|
||||||
|
|
@ -57,10 +99,30 @@ class mint {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
T value{};
|
T value{};
|
||||||
|
|
||||||
|
[[nodiscard]] constexpr T add(T other) const noexcept {
|
||||||
|
return (get() + other) % M;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] constexpr T sub(T other) const noexcept {
|
||||||
|
return (get() - other + M) % M;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] constexpr T mul(T other) const noexcept {
|
||||||
|
return (get() * other) % M;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] constexpr T div(T other) const noexcept {
|
||||||
|
return get() / other;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] constexpr T mod(T other) const noexcept {
|
||||||
|
return get() % other;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <std::integral T, T M, std::integral U>
|
template <std::integral T, T M, std::integral U>
|
||||||
requires(M >= 0)
|
requires(M > 0)
|
||||||
[[nodiscard]] static constexpr bmath::mint<T, M> pow(mint<T, M> base,
|
[[nodiscard]] static constexpr bmath::mint<T, M> pow(mint<T, M> base,
|
||||||
U exponent) {
|
U exponent) {
|
||||||
if (exponent < 0) {
|
if (exponent < 0) {
|
||||||
|
|
@ -69,7 +131,7 @@ template <std::integral T, T M, std::integral U>
|
||||||
base, exponent, exponent));
|
base, exponent, exponent));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (base == 0) {
|
if (base.get() == 0) {
|
||||||
if (exponent == 0) {
|
if (exponent == 0) {
|
||||||
throw std::domain_error("pow(0, 0) is indeterminate");
|
throw std::domain_error("pow(0, 0) is indeterminate");
|
||||||
}
|
}
|
||||||
|
|
@ -77,17 +139,21 @@ template <std::integral T, T M, std::integral U>
|
||||||
return mint<T, M>{0};
|
return mint<T, M>{0};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (base == 1 || exponent == 0) {
|
if (base.get() == 1 || exponent == 0) {
|
||||||
return mint<T, M>{1};
|
return mint<T, M>{1};
|
||||||
}
|
}
|
||||||
|
|
||||||
T t{};
|
mint<T, M> t{1};
|
||||||
|
|
||||||
while (exponent > 0) {
|
while (exponent > 0) {
|
||||||
|
if (exponent & 1) {
|
||||||
|
t *= base;
|
||||||
|
}
|
||||||
|
base *= base;
|
||||||
exponent >>= 1;
|
exponent >>= 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
return mint<T, M>{t};
|
return t;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <std::integral T, T M = DEFAULT_MOD>
|
template <std::integral T, T M = DEFAULT_MOD>
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <iostream>
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
#include "../include/bmath.hh"
|
#include "../include/bmath.hh"
|
||||||
|
|
||||||
|
|
@ -13,20 +13,24 @@ int main() {
|
||||||
constexpr mint<uint64_t> mintfive{five};
|
constexpr mint<uint64_t> mintfive{five};
|
||||||
|
|
||||||
constexpr auto mintnine = mintfour + mintfive;
|
constexpr auto mintnine = mintfour + mintfive;
|
||||||
static_assert(mintnine == four + five);
|
// static_assert(mintnine == four + five);
|
||||||
|
|
||||||
// static_assert(4 + 5 == mint<uint64_t>{9});
|
// static_assert(4 + 5 == mint<uint64_t>{9});
|
||||||
static_assert(mint<uint64_t, 100000>{8} == 4 + 4);
|
static_assert(mint<uint64_t, 100000>{8} ==
|
||||||
|
mint<uint64_t, 100000>{4} + mint<uint64_t, 100000>{4});
|
||||||
|
|
||||||
static_assert(is_trivially_copyable_v<mint<uint64_t>>);
|
static_assert(is_trivially_copyable_v<mint<uint64_t>>);
|
||||||
|
|
||||||
pow(mint<int>{2}, 0);
|
// pow(mint<int>{2}, 0);
|
||||||
|
|
||||||
// cout << (std::format("x: {}\n", mintfour));
|
// cout << (std::format("x: {}\n", mintfour));
|
||||||
|
|
||||||
// auto res = mint<int>{4} + mint<int, 5>{4};
|
// auto res = mint<int>{4} + mint<int, 5>{4};
|
||||||
|
|
||||||
cout << (mint<int, 5>{5} + mint<int, 5>{3});
|
// cout << (mint<int, 4>{5} + mint<int, 4>{7});
|
||||||
|
// cout << pow(mint<int, 5>{4}, 5);
|
||||||
|
|
||||||
|
cout << (pow(mint<int>{5}, 5));
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue