diff --git a/doc/pending.txt b/doc/pending.txt index 739ca88..af6a663 100644 --- a/doc/pending.txt +++ b/doc/pending.txt @@ -1143,6 +1143,21 @@ Shared utilities for backend authors are provided by `sync/util.lua`: Backend-specific configuration goes under `sync.` in |pending-config|. +Auto-auth: ~ + *pending-sync-auto-auth* +Running a sync action (`:Pending push/pull/sync`) without valid +credentials automatically triggers authentication before proceeding: + +- OAuth backends (gcal, gtasks): if real credentials are configured but no + token exists, the browser-based auth flow starts automatically. On + success, the original action continues. Bundled placeholder credentials + cannot auto-auth and require the setup wizard via `:Pending auth`. +- S3: `aws sts get-caller-identity` runs before every sync action. If SSO + is expired, `aws sso login` is triggered automatically. Missing + credentials abort with an error pointing to |pending-s3|. + +On auth failure, the sync action is aborted with an error message. + ============================================================================== GOOGLE CALENDAR *pending-gcal* diff --git a/lua/pending/sync/oauth.lua b/lua/pending/sync/oauth.lua index a64b984..22f4803 100644 --- a/lua/pending/sync/oauth.lua +++ b/lua/pending/sync/oauth.lua @@ -45,8 +45,28 @@ function M.with_token(client, name, callback) util.with_guard(name, function() local token = client:get_access_token() if not token then - require('pending.log').warn(name .. ': Not authenticated — run :Pending auth.') - return + local creds = client:resolve_credentials() + if creds.client_id == BUNDLED_CLIENT_ID then + log.warn(name .. ': No credentials configured — run :Pending auth.') + return + end + log.info(name .. ': Not authenticated — starting auth flow...') + local co = coroutine.running() + client:auth(function(ok) + vim.schedule(function() + coroutine.resume(co, ok) + end) + end) + local auth_ok = coroutine.yield() + if not auth_ok then + log.error(name .. ': Authentication failed.') + return + end + token = client:get_access_token() + if not token then + log.error(name .. ': Still not authenticated after auth flow.') + return + end end callback(token) end) @@ -349,7 +369,7 @@ function OAuthClient:setup() end) end ----@param on_complete? fun(): nil +---@param on_complete? fun(ok: boolean): nil ---@return nil function OAuthClient:auth(on_complete) if _active_close then @@ -360,6 +380,9 @@ function OAuthClient:auth(on_complete) local creds = self:resolve_credentials() if creds.client_id == BUNDLED_CLIENT_ID then log.error(self.name .. ': No credentials configured — run :Pending auth.') + if on_complete then + on_complete(false) + end return end local port = self.port @@ -411,6 +434,9 @@ function OAuthClient:auth(on_complete) if not bind_ok or bind_err == nil then close_server() log.error(self.name .. ': Port ' .. port .. ' already in use — try again in a moment.') + if on_complete then + on_complete(false) + end return end @@ -453,6 +479,9 @@ function OAuthClient:auth(on_complete) if not server_closed then close_server() log.warn(self.name .. ': OAuth callback timed out (120s).') + if on_complete then + on_complete(false) + end end end, 120000) end @@ -461,7 +490,7 @@ end ---@param code string ---@param code_verifier string ---@param port integer ----@param on_complete? fun(): nil +---@param on_complete? fun(ok: boolean): nil ---@return nil function OAuthClient:_exchange_code(creds, code, code_verifier, port, on_complete) local body = 'client_id=' @@ -491,6 +520,9 @@ function OAuthClient:_exchange_code(creds, code, code_verifier, port, on_complet if result.code ~= 0 then self:clear_tokens() log.error(self.name .. ': Token exchange failed.') + if on_complete then + on_complete(false) + end return end @@ -498,6 +530,9 @@ function OAuthClient:_exchange_code(creds, code, code_verifier, port, on_complet if not ok or not decoded.access_token then self:clear_tokens() log.error(self.name .. ': Invalid token response.') + if on_complete then + on_complete(false) + end return end @@ -505,7 +540,7 @@ function OAuthClient:_exchange_code(creds, code, code_verifier, port, on_complet self:save_tokens(decoded) log.info(self.name .. ': Authorized successfully.') if on_complete then - on_complete() + on_complete(true) end end diff --git a/lua/pending/sync/s3.lua b/lua/pending/sync/s3.lua index 0d669b3..91e52c1 100644 --- a/lua/pending/sync/s3.lua +++ b/lua/pending/sync/s3.lua @@ -66,6 +66,35 @@ local function ensure_sync_id(task) return sync_id end +---@return boolean +local function ensure_credentials() + local cmd = base_cmd() + vim.list_extend(cmd, { 'sts', 'get-caller-identity', '--output', 'json' }) + local result = util.system(cmd, { text = true }) + if result.code == 0 then + return true + end + local stderr = result.stderr or '' + if stderr:find('SSO') or stderr:find('sso') then + log.info('S3: SSO session expired — running login...') + local login_cmd = base_cmd() + vim.list_extend(login_cmd, { 'sso', 'login' }) + local login_result = util.system(login_cmd, { text = true }) + if login_result.code == 0 then + log.info('S3: SSO login successful') + return true + end + log.error('S3: SSO login failed — ' .. (login_result.stderr or '')) + return false + end + if stderr:find('Unable to locate credentials') or stderr:find('NoCredentialProviders') then + log.error('S3: no AWS credentials configured. See :h pending-s3') + else + log.error('S3: credential check failed — ' .. stderr) + end + return false +end + local function create_bucket() local name = util.input({ prompt = 'S3 bucket name (pending.nvim): ' }) if not name then @@ -177,6 +206,9 @@ end function M.push() util.async(function() util.with_guard('S3', function() + if not ensure_credentials() then + return + end local s3cfg = get_config() if not s3cfg or not s3cfg.bucket then log.error('S3: bucket is required. Set sync.s3.bucket in config.') @@ -231,6 +263,9 @@ end function M.pull() util.async(function() util.with_guard('S3', function() + if not ensure_credentials() then + return + end local s3cfg = get_config() if not s3cfg or not s3cfg.bucket then log.error('S3: bucket is required. Set sync.s3.bucket in config.') @@ -330,6 +365,9 @@ end function M.sync() util.async(function() util.with_guard('S3', function() + if not ensure_credentials() then + return + end local s3cfg = get_config() if not s3cfg or not s3cfg.bucket then log.error('S3: bucket is required. Set sync.s3.bucket in config.') @@ -466,5 +504,6 @@ function M.health() end M._ensure_sync_id = ensure_sync_id +M._ensure_credentials = ensure_credentials return M diff --git a/spec/oauth_spec.lua b/spec/oauth_spec.lua index a4a6f1d..d004b90 100644 --- a/spec/oauth_spec.lua +++ b/spec/oauth_spec.lua @@ -232,4 +232,98 @@ describe('oauth', function() assert.equals('test', c.config_key) end) end) + + describe('with_token', function() + it('auto-triggers auth when not authenticated', function() + local c = oauth.new({ name = 'test_auth', scope = 'x', port = 0, config_key = 'gtasks' }) + local call_count = 0 + c.get_access_token = function() + call_count = call_count + 1 + if call_count == 1 then + return nil + end + return 'new-token' + end + c.resolve_credentials = function() + return { client_id = 'real-id', client_secret = 'real-secret' } + end + local auth_called = false + c.auth = function(_, on_complete) + auth_called = true + vim.schedule(function() + on_complete(true) + end) + end + local received_token + oauth.with_token(c, 'test_auth', function(token) + received_token = token + end) + vim.wait(1000, function() + return received_token ~= nil + end) + assert.is_true(auth_called) + assert.equals('new-token', received_token) + end) + + it('bails on bundled credentials without calling auth', function() + local c = oauth.new({ name = 'test_bail', scope = 'x', port = 0, config_key = 'gtasks' }) + c.get_access_token = function() + return nil + end + c.resolve_credentials = function() + return { client_id = oauth.BUNDLED_CLIENT_ID, client_secret = 'x' } + end + local auth_called = false + c.auth = function() + auth_called = true + end + local callback_called = false + oauth.with_token(c, 'test_bail', function() + callback_called = true + end) + vim.wait(500, function() + return false + end) + assert.is_false(auth_called) + assert.is_false(callback_called) + end) + + it('stops when auth fails', function() + local c = oauth.new({ name = 'test_fail', scope = 'x', port = 0, config_key = 'gtasks' }) + c.get_access_token = function() + return nil + end + c.resolve_credentials = function() + return { client_id = 'real-id', client_secret = 'real-secret' } + end + c.auth = function(_, on_complete) + vim.schedule(function() + on_complete(false) + end) + end + local callback_called = false + oauth.with_token(c, 'test_fail', function() + callback_called = true + end) + vim.wait(500, function() + return false + end) + assert.is_false(callback_called) + end) + + it('proceeds directly when already authenticated', function() + local c = oauth.new({ name = 'test_ok', scope = 'x', port = 0, config_key = 'gtasks' }) + c.get_access_token = function() + return 'existing-token' + end + local received_token + oauth.with_token(c, 'test_ok', function(token) + received_token = token + end) + vim.wait(1000, function() + return received_token ~= nil + end) + assert.equals('existing-token', received_token) + end) + end) end) diff --git a/spec/s3_spec.lua b/spec/s3_spec.lua index a9b1dbd..47bdc49 100644 --- a/spec/s3_spec.lua +++ b/spec/s3_spec.lua @@ -374,6 +374,64 @@ describe('s3', function() end) end) + describe('ensure_credentials', function() + it('returns true on valid credentials', function() + util.system = function(args) + if vim.tbl_contains(args, 'get-caller-identity') then + return { code = 0, stdout = '{"Account":"123"}', stderr = '' } + end + return { code = 0, stdout = '', stderr = '' } + end + assert.is_true(s3._ensure_credentials()) + end) + + it('returns false on missing credentials', function() + util.system = function() + return { code = 1, stdout = '', stderr = 'Unable to locate credentials' } + end + local msg + local orig_notify = vim.notify + vim.notify = function(m, level) + if level == vim.log.levels.ERROR then + msg = m + end + end + assert.is_false(s3._ensure_credentials()) + vim.notify = orig_notify + assert.truthy(msg and msg:find('no AWS credentials')) + end) + + it('retries SSO login on expired session', function() + local calls = {} + util.system = function(args) + if vim.tbl_contains(args, 'get-caller-identity') then + return { code = 1, stdout = '', stderr = 'Error: SSO session expired' } + end + if vim.tbl_contains(args, 'sso') then + table.insert(calls, 'sso-login') + return { code = 0, stdout = '', stderr = '' } + end + return { code = 0, stdout = '', stderr = '' } + end + assert.is_true(s3._ensure_credentials()) + assert.equals(1, #calls) + assert.equals('sso-login', calls[1]) + end) + + it('returns false when SSO login fails', function() + util.system = function(args) + if vim.tbl_contains(args, 'get-caller-identity') then + return { code = 1, stdout = '', stderr = 'SSO token expired' } + end + if vim.tbl_contains(args, 'sso') then + return { code = 1, stdout = '', stderr = 'login failed' } + end + return { code = 0, stdout = '', stderr = '' } + end + assert.is_false(s3._ensure_credentials()) + end) + end) + describe('push', function() it('uploads store to S3', function() local s = pending.store() @@ -383,6 +441,9 @@ describe('s3', function() local captured_args util.system = function(args) + if vim.tbl_contains(args, 'get-caller-identity') then + return { code = 0, stdout = '{"Account":"123"}', stderr = '' } + end if vim.tbl_contains(args, 's3') then captured_args = args return { code = 0, stdout = '', stderr = '' } @@ -405,6 +466,13 @@ describe('s3', function() pending = require('pending') s3 = require('pending.sync.s3') + util.system = function(args) + if vim.tbl_contains(args, 'get-caller-identity') then + return { code = 0, stdout = '{"Account":"123"}', stderr = '' } + end + return { code = 0, stdout = '', stderr = '' } + end + local msg local orig_notify = vim.notify vim.notify = function(m, level)