more methods

This commit is contained in:
Barrett Ruth 2025-08-30 17:52:20 -05:00
parent 12fd46747e
commit e002fd91d7
2 changed files with 108 additions and 38 deletions

View file

@ -19,37 +19,79 @@ class mint {
[[nodiscard]] constexpr explicit mint(T t) : value{t % M} {}
[[nodiscard]] constexpr explicit operator T() const noexcept { return value; }
template <std::convertible_to<T> OtherT, T OtherM>
requires(M == OtherM)
[[nodiscard]] constexpr mint operator+(
mint<OtherT, OtherM> const other) const noexcept {
if constexpr (M != 0 && OtherM != 0) {
static_assert(M == OtherM,
"Cannot add integral types with differing moduli");
} else if (M != 0 && OtherM != 0) {
assert(M == OtherM && "Cannot add integral types with differing moduli");
}
return mint<T, M>{value + other.get()};
[[nodiscard]] constexpr mint<T, M> operator+(
mint const other) const noexcept {
T result{add(static_cast<T>(other.get()))};
return mint<T, M>{result};
}
[[nodiscard]] constexpr mint& operator+=(mint const other) const noexcept {
value += other.value;
return *this;
}
[[nodiscard]] constexpr mint<T, M> operator-(
mint const other) const noexcept {
T result{sub(static_cast<T>(other.get()))};
return mint<T, M>{result};
}
[[nodiscard]] constexpr mint operator*(mint const other) const noexcept {
T result{mul(static_cast<T>(other.get()))};
return mint<T, M>{result};
}
[[nodiscard]] constexpr mint& operator*=(mint const other) noexcept {
value *= other.value;
return *this;
}
[[nodiscard]] constexpr mint operator/(mint const other) const noexcept {
if constexpr (other.get() == 0) {
static_assert(false, "Cannot divide by 0");
} else if (other.get() == 0) {
throw std::domain_error("Cannot divide by 0");
}
T result{div(static_cast<T>(other.get()))};
return mint<T, M>{result};
}
[[nodiscard]] constexpr mint& operator/=(mint const other) noexcept {
if constexpr (other.get() == 0) {
static_assert(false, "Cannot divide by 0");
} else if (other.get() == 0) {
throw std::domain_error("Cannot divide by 0");
}
value /= other.value;
return *this;
}
[[nodiscard]] constexpr mint operator%(mint const other) const noexcept {
T result = mod(other);
return mint<T, M>{result};
}
[[nodiscard]] constexpr mint& operator%=(mint const other) const noexcept {
value %= other.value;
return *this;
}
// [[nodiscard]] constexpr mint& operator-(mint const other) const noexcept {}
// [[nodiscard]] constexpr mint& operator*(mint const other) const noexcept {}
// [[nodiscard]] constexpr mint& operator/(mint const other) const noexcept {}
// [[nodiscard]] constexpr mint& operator%(mint const other) const noexcept {}
[[nodiscard]] constexpr T get() const noexcept { return value; }
template <std::convertible_to<T> OtherT, OtherT OtherM>
[[nodiscard]] constexpr bool operator==(
mint<OtherT, OtherM> const other) const noexcept {
[[nodiscard]] constexpr bool operator==(mint const other) const noexcept {
return get() == static_cast<T>(other.get());
}
template <std::convertible_to<T> OtherT>
[[nodiscard]] constexpr bool operator==(OtherT const other) const noexcept {
return get() == static_cast<T>(other);
}
template <std::convertible_to<T> OtherT, T OtherM>
template <std::convertible_to<T> OtherT, OtherT OtherM>
friend std::ostream& operator<<(std::ostream& out,
mint<OtherT, OtherM> const other) {
return out << other.get();
@ -57,10 +99,30 @@ class mint {
private:
T value{};
[[nodiscard]] constexpr T add(T other) const noexcept {
return (get() + other) % M;
}
[[nodiscard]] constexpr T sub(T other) const noexcept {
return (get() - other + M) % M;
}
[[nodiscard]] constexpr T mul(T other) const noexcept {
return (get() * other) % M;
}
[[nodiscard]] constexpr T div(T other) const noexcept {
return get() / other;
}
[[nodiscard]] constexpr T mod(T other) const noexcept {
return get() % other;
}
};
template <std::integral T, T M, std::integral U>
requires(M >= 0)
requires(M > 0)
[[nodiscard]] static constexpr bmath::mint<T, M> pow(mint<T, M> base,
U exponent) {
if (exponent < 0) {
@ -69,7 +131,7 @@ template <std::integral T, T M, std::integral U>
base, exponent, exponent));
}
if (base == 0) {
if (base.get() == 0) {
if (exponent == 0) {
throw std::domain_error("pow(0, 0) is indeterminate");
}
@ -77,17 +139,21 @@ template <std::integral T, T M, std::integral U>
return mint<T, M>{0};
}
if (base == 1 || exponent == 0) {
if (base.get() == 1 || exponent == 0) {
return mint<T, M>{1};
}
T t{};
mint<T, M> t{1};
while (exponent > 0) {
if (exponent & 1) {
t *= base;
}
base *= base;
exponent >>= 1;
}
return mint<T, M>{t};
return t;
}
template <std::integral T, T M = DEFAULT_MOD>