diff --git a/lua/cp/pickers/init.lua b/lua/cp/pickers/init.lua index 77c0685..2380b74 100644 --- a/lua/cp/pickers/init.lua +++ b/lua/cp/pickers/init.lua @@ -1,6 +1,7 @@ local M = {} local cache = require('cp.cache') +local config = require('cp.config').get_config() local logger = require('cp.log') local utils = require('cp.utils') @@ -18,26 +19,28 @@ local utils = require('cp.utils') ---@field name string Problem name (e.g. "Two Permutations", "Painting Walls") ---@field display_name string Formatted display name for picker ----Get list of available competitive programming platforms ---@return cp.PlatformItem[] local function get_platforms() local constants = require('cp.constants') - return vim.tbl_map(function(platform) - return { - id = platform, - display_name = constants.PLATFORM_DISPLAY_NAMES[platform] or platform, - } - end, constants.PLATFORMS) + local result = {} + + for _, platform in ipairs(constants.PLATFORMS) do + if config.contests[platform] then + table.insert(result, { + id = platform, + display_name = constants.PLATFORM_DISPLAY_NAMES[platform] or platform, + }) + end + end + + return result end ---Get list of contests for a specific platform ---@param platform string Platform identifier (e.g. "codeforces", "atcoder") ---@return cp.ContestItem[] local function get_contests_for_platform(platform) - local constants = require('cp.constants') - local platform_display_name = constants.PLATFORM_DISPLAY_NAMES[platform] or platform - - logger.log(('loading %s contests...'):format(platform_display_name), vim.log.levels.INFO, true) + logger.log('loading contests...', vim.log.levels.INFO, true) cache.load() local cached_contests = cache.get_contest_list(platform) @@ -114,8 +117,6 @@ end ---@param contest_id string Contest identifier ---@return cp.ProblemItem[] local function get_problems_for_contest(platform, contest_id) - local constants = require('cp.constants') - local platform_display_name = constants.PLATFORM_DISPLAY_NAMES[platform] or platform local problems = {} cache.load() @@ -131,11 +132,7 @@ local function get_problems_for_contest(platform, contest_id) return problems end - logger.log( - ('loading %s %s problems...'):format(platform_display_name, contest_id), - vim.log.levels.INFO, - true - ) + logger.log('loading contest problems...', vim.log.levels.INFO, true) if not utils.setup_python_env() then return problems diff --git a/lua/cp/setup.lua b/lua/cp/setup.lua index 48ebed8..21dba2c 100644 --- a/lua/cp/setup.lua +++ b/lua/cp/setup.lua @@ -136,7 +136,7 @@ function M.setup_problem(contest_id, problem_id, language) local source_file = state.get_source_file(language) if not source_file then - error('Failed to generate source file path') + return end vim.cmd.e(source_file) local source_buf = vim.api.nvim_get_current_buf() diff --git a/lua/cp/ui/panel.lua b/lua/cp/ui/panel.lua index 9fe2a9f..c851a04 100644 --- a/lua/cp/ui/panel.lua +++ b/lua/cp/ui/panel.lua @@ -9,10 +9,6 @@ local state = require('cp.state') local current_diff_layout = nil local current_mode = nil -local function get_current_problem() - return state.get_problem_id() -end - function M.toggle_run_panel(is_debug) if state.is_run_panel_active() then if current_diff_layout then @@ -39,7 +35,7 @@ function M.toggle_run_panel(is_debug) return end - local problem_id = get_current_problem() + local problem_id = state.get_problem_id() if not problem_id then return end @@ -49,9 +45,9 @@ function M.toggle_run_panel(is_debug) logger.log( ('run panel: platform=%s, contest=%s, problem=%s'):format( - platform or 'nil', - contest_id or 'nil', - problem_id or 'nil' + tostring(platform), + tostring(contest_id), + tostring(problem_id) ) ) @@ -124,12 +120,7 @@ function M.toggle_run_panel(is_debug) return end - test_state.current_index = test_state.current_index + delta - if test_state.current_index < 1 then - test_state.current_index = #test_state.test_cases - elseif test_state.current_index > #test_state.test_cases then - test_state.current_index = 1 - end + test_state.current_index = (test_state.current_index + delta) % #test_state.test_cases refresh_run_panel() end diff --git a/scrapers/__init__.py b/scrapers/__init__.py index 6140dce..4749123 100644 --- a/scrapers/__init__.py +++ b/scrapers/__init__.py @@ -1,56 +1,5 @@ -# Lazy imports to avoid module loading conflicts when running scrapers with -m -def __getattr__(name): - if name == "AtCoderScraper": - from .atcoder import AtCoderScraper +from .atcoder import AtCoderScraper +from .codeforces import CodeforcesScraper +from .cses import CSESScraper - return AtCoderScraper - elif name == "BaseScraper": - from .base import BaseScraper - - return BaseScraper - elif name == "ScraperConfig": - from .base import ScraperConfig - - return ScraperConfig - elif name == "CodeforcesScraper": - from .codeforces import CodeforcesScraper - - return CodeforcesScraper - elif name == "CSESScraper": - from .cses import CSESScraper - - return CSESScraper - elif name in [ - "ContestListResult", - "ContestSummary", - "MetadataResult", - "ProblemSummary", - "TestCase", - "TestsResult", - ]: - from .models import ( - ContestListResult, # noqa: F401 - ContestSummary, # noqa: F401 - MetadataResult, # noqa: F401 - ProblemSummary, # noqa: F401 - TestCase, # noqa: F401 - TestsResult, # noqa: F401 - ) - - return locals()[name] - raise AttributeError(f"module 'scrapers' has no attribute '{name}'") - - -__all__ = [ - "AtCoderScraper", - "BaseScraper", - "CodeforcesScraper", - "CSESScraper", - "ScraperConfig", - "ContestListResult", - "ContestSummary", - "MetadataResult", - "ProblemSummary", - "TestCase", - "TestsResult", -] +__all__ = ["CodeforcesScraper", "CSESScraper", "AtCoderScraper"] diff --git a/scrapers/cses.py b/scrapers/cses.py old mode 100755 new mode 100644 diff --git a/spec/picker_spec.lua b/spec/picker_spec.lua index e9bb5e2..eeee9d7 100644 --- a/spec/picker_spec.lua +++ b/spec/picker_spec.lua @@ -11,33 +11,6 @@ describe('cp.picker', function() spec_helper.teardown() end) - describe('get_platforms', function() - it('returns platform list with display names', function() - local platforms = picker.get_platforms() - - assert.is_table(platforms) - assert.is_true(#platforms > 0) - - for _, platform in ipairs(platforms) do - assert.is_string(platform.id) - assert.is_string(platform.display_name) - assert.is_not_nil(platform.display_name:match('^%u')) - end - end) - - it('includes expected platforms with correct display names', function() - local platforms = picker.get_platforms() - local platform_map = {} - for _, p in ipairs(platforms) do - platform_map[p.id] = p.display_name - end - - assert.equals('CodeForces', platform_map['codeforces']) - assert.equals('AtCoder', platform_map['atcoder']) - assert.equals('CSES', platform_map['cses']) - end) - end) - describe('get_contests_for_platform', function() it('returns empty list when scraper fails', function() vim.system = function(_, _) diff --git a/tests/scrapers/test_interface_compliance.py b/tests/scrapers/test_interface_compliance.py deleted file mode 100644 index ab07ff2..0000000 --- a/tests/scrapers/test_interface_compliance.py +++ /dev/null @@ -1,167 +0,0 @@ -from unittest.mock import Mock - -import pytest - -import scrapers -from scrapers.base import BaseScraper -from scrapers.models import ContestListResult, MetadataResult, TestsResult - -SCRAPERS = [ - scrapers.AtCoderScraper, - scrapers.CodeforcesScraper, - scrapers.CSESScraper, -] - - -class TestScraperInterfaceCompliance: - @pytest.mark.parametrize("scraper_class", SCRAPERS) - def test_implements_base_interface(self, scraper_class): - scraper = scraper_class() - - assert isinstance(scraper, BaseScraper) - assert hasattr(scraper, "platform_name") - assert hasattr(scraper, "scrape_contest_metadata") - assert hasattr(scraper, "scrape_problem_tests") - assert hasattr(scraper, "scrape_contest_list") - - @pytest.mark.parametrize("scraper_class", SCRAPERS) - def test_platform_name_is_string(self, scraper_class): - scraper = scraper_class() - platform_name = scraper.platform_name - - assert isinstance(platform_name, str) - assert len(platform_name) > 0 - assert platform_name.islower() # Convention: lowercase platform names - - @pytest.mark.parametrize("scraper_class", SCRAPERS) - def test_metadata_method_signature(self, scraper_class, mocker): - scraper = scraper_class() - - # Mock the underlying HTTP calls to avoid network requests - if scraper.platform_name == "codeforces": - mock_scraper = Mock() - mock_response = Mock() - mock_response.text = "A. Test" - mock_scraper.get.return_value = mock_response - mocker.patch( - "scrapers.codeforces.cloudscraper.create_scraper", - return_value=mock_scraper, - ) - - result = scraper.scrape_contest_metadata("test_contest") - - assert isinstance(result, MetadataResult) - assert hasattr(result, "success") - assert hasattr(result, "error") - assert hasattr(result, "problems") - assert hasattr(result, "contest_id") - assert isinstance(result.success, bool) - assert isinstance(result.error, str) - - @pytest.mark.parametrize("scraper_class", SCRAPERS) - def test_problem_tests_method_signature(self, scraper_class, mocker): - scraper = scraper_class() - - if scraper.platform_name == "codeforces": - mock_scraper = Mock() - mock_response = Mock() - mock_response.text = """ -
3
6