From 12fd46747e6c0c605d6611f64a946ef7ff9d79bb Mon Sep 17 00:00:00 2001 From: Barrett Ruth Date: Sat, 30 Aug 2025 17:12:46 -0500 Subject: [PATCH] some more changes --- compile_flags.txt | 30 +++++++++++ include/bmath.hh | 128 +++++++++++++++++++++++++++++++++++----------- makefile | 2 + tests/test_add.cc | 17 ++++-- 4 files changed, 143 insertions(+), 34 deletions(-) create mode 100644 compile_flags.txt diff --git a/compile_flags.txt b/compile_flags.txt new file mode 100644 index 0000000..541be64 --- /dev/null +++ b/compile_flags.txt @@ -0,0 +1,30 @@ +-pedantic-errors +-O2 +-Wall +-Wextra +-Wpedantic +-Wshadow +-Wformat=2 +-Wfloat-equal +-Wlogical-op +-Wshift-overflow=2 +-Wnon-virtual-dtor +-Wold-style-cast +-Wcast-qual +-Wuseless-cast +-Wno-sign-promotion +-Wcast-align +-Wunused +-Woverloaded-virtual +-Wconversion +-Wmisleading-indentation +-Wduplicated-cond +-Wduplicated-branches +-Wlogical-op +-Wnull-dereference +-Wformat=2 +-Wformat-overflow +-Wformat-truncation +-Wdouble-promotion +-Wundef +-DLOCAL diff --git a/include/bmath.hh b/include/bmath.hh index 6b787c1..57d60e8 100644 --- a/include/bmath.hh +++ b/include/bmath.hh @@ -1,55 +1,121 @@ #ifndef BMATH_HEADER_ONLY_MATH_LIB #define BMATH_HEADER_ONLY_MATH_LIB -#include -#include -#include +#include +#include +#include +#include +#include namespace bmath { -template - requires(std::integral || std::is_same_v) +inline constexpr uint64_t DEFAULT_MOD = 1'000'000'007; + +template (DEFAULT_MOD)> + requires(M > 0 && DEFAULT_MOD <= std::numeric_limits::max()) class mint { public: [[nodiscard]] constexpr explicit mint() : value{} {} - [[nodiscard]] constexpr explicit mint(IntegralType _value) : value{_value} {} - [[nodiscard]] constexpr explicit operator IntegralType() const noexcept { - return value; - } + [[nodiscard]] constexpr explicit mint(T t) : value{t % M} {} + [[nodiscard]] constexpr explicit operator T() const noexcept { return value; } - template + template OtherT, T OtherM> + requires(M == OtherM) [[nodiscard]] constexpr mint operator+( - mint const& otherMint) const noexcept { - return mint{value + otherMint.value}; + mint const other) const noexcept { + if constexpr (M != 0 && OtherM != 0) { + static_assert(M == OtherM, + "Cannot add integral types with differing moduli"); + } else if (M != 0 && OtherM != 0) { + assert(M == OtherM && "Cannot add integral types with differing moduli"); + } + return mint{value + other.get()}; } - [[nodiscard]] constexpr mint& operator-(mint const& other) const noexcept {} - [[nodiscard]] constexpr mint& operator*(mint const& other) const noexcept {} - [[nodiscard]] constexpr mint& operator/(mint const& other) const noexcept {} - [[nodiscard]] constexpr mint& operator%(mint const& other) const noexcept {} + // [[nodiscard]] constexpr mint& operator-(mint const other) const noexcept {} + // [[nodiscard]] constexpr mint& operator*(mint const other) const noexcept {} + // [[nodiscard]] constexpr mint& operator/(mint const other) const noexcept {} + // [[nodiscard]] constexpr mint& operator%(mint const other) const noexcept {} + [[nodiscard]] constexpr T get() const noexcept { return value; } - friend constexpr bool operator==(std::convertible_to auto a, - mint const& b) noexcept { - return static_cast(a) == b.value; - } - - friend constexpr bool operator==( - mint const& a, std::convertible_to auto b) noexcept { - return a.value == static_cast(b); - } - - template + template OtherT, OtherT OtherM> [[nodiscard]] constexpr bool operator==( - mint const& otherMint) const noexcept { - return value == otherMint.value; + mint const other) const noexcept { + return get() == static_cast(other.get()); } - [[nodiscard]] constexpr IntegralType get() const { return value; } + template OtherT> + [[nodiscard]] constexpr bool operator==(OtherT const other) const noexcept { + return get() == static_cast(other); + } + + template OtherT, T OtherM> + friend std::ostream& operator<<(std::ostream& out, + mint const other) { + return out << other.get(); + } private: - IntegralType value{}; + T value{}; }; +template + requires(M >= 0) +[[nodiscard]] static constexpr bmath::mint pow(mint base, + U exponent) { + if (exponent < 0) { + throw std::domain_error( + std::format("cannot compute pow({}, {}) with negative exponent {}", + base, exponent, exponent)); + } + + if (base == 0) { + if (exponent == 0) { + throw std::domain_error("pow(0, 0) is indeterminate"); + } + + return mint{0}; + } + + if (base == 1 || exponent == 0) { + return mint{1}; + } + + T t{}; + + while (exponent > 0) { + exponent >>= 1; + } + + return mint{t}; +} + +template + requires(M > 0) +[[nodiscard]] std::string to_string(mint const number) { + return std::to_string(number.get()); +} + } // namespace bmath +template +struct std::formatter, CharT> { + std::formatter, CharT> inner; + + constexpr auto parse(std::basic_format_parse_context& pc) { + return inner.parse(pc); + } + + template + auto format(bmath::mint const x, Ctx& ctx) const { + std::basic_string tmp; + if constexpr (std::same_as) { + std::format_to(std::back_inserter(tmp), L"{}", x.get()); + } else { + std::format_to(std::back_inserter(tmp), "{}", x.get()); + } + return inner.format(tmp, ctx); + } +}; + #endif diff --git a/makefile b/makefile index 6dcde75..3c4acfa 100644 --- a/makefile +++ b/makefile @@ -57,6 +57,7 @@ DBG_LDFLAGS := \ -fsanitize=float-cast-overflow CXXFLAGS := $(STD) $(WARNFLAGS) $(INCLUDES) $(BASEDEFS) +CXXFLAGS += -MMD -MP ifeq ($(MODE),debug) CXXFLAGS += $(DBGFLAGS) LDFLAGS += $(DBG_LDFLAGS) @@ -87,6 +88,7 @@ test: .TEST $(BUILD_DIR)/$(TEST_DIR)/%.o: $(TEST_DIR)/%.cc @mkdir -p $(dir $@) $(CXX) $(CXXFLAGS) -c $< -o $@ +-include $(TEST_OBJ:.o=.d) $(TEST_EXE): $(TEST_OBJ) @mkdir -p $(dir $@) diff --git a/tests/test_add.cc b/tests/test_add.cc index deccf5c..f35df1d 100644 --- a/tests/test_add.cc +++ b/tests/test_add.cc @@ -1,9 +1,11 @@ #include +#include #include #include "../include/bmath.hh" using namespace bmath; +using namespace std; int main() { constexpr uint64_t four{4}, five{5}; @@ -11,11 +13,20 @@ int main() { constexpr mint mintfive{five}; constexpr auto mintnine = mintfour + mintfive; - static_assert(four + five == mintnine.get()); static_assert(mintnine == four + five); - static_assert(4 + 5 == mint{9}); - static_assert(mint{8} == 4 + 3); + // static_assert(4 + 5 == mint{9}); + static_assert(mint{8} == 4 + 4); + + static_assert(is_trivially_copyable_v>); + + pow(mint{2}, 0); + + // cout << (std::format("x: {}\n", mintfour)); + + // auto res = mint{4} + mint{4}; + + cout << (mint{5} + mint{3}); return 0; }