fix(sync): auto-trigger auth flow on unauthenticated sync actions (#120)
Problem: running a sync action (e.g. `:Pending gtasks push`) without being authenticated would silently abort with a warning, requiring the user to manually run `:Pending auth` first. Solution: `oauth.with_token()` now auto-triggers the browser auth flow when no token exists (for non-bundled credentials) and resumes the original action on success. `auth()` and `_exchange_code()` now call `on_complete(ok)` on all exit paths. S3 backends run `aws sts get-caller-identity` before every sync action, auto-triggering SSO login on expired sessions.
This commit is contained in:
parent
422f8f9b05
commit
149f2dac2e
5 changed files with 256 additions and 5 deletions
|
|
@ -1143,6 +1143,21 @@ Shared utilities for backend authors are provided by `sync/util.lua`:
|
|||
|
||||
Backend-specific configuration goes under `sync.<name>` in |pending-config|.
|
||||
|
||||
Auto-auth: ~
|
||||
*pending-sync-auto-auth*
|
||||
Running a sync action (`:Pending <name> push/pull/sync`) without valid
|
||||
credentials automatically triggers authentication before proceeding:
|
||||
|
||||
- OAuth backends (gcal, gtasks): if real credentials are configured but no
|
||||
token exists, the browser-based auth flow starts automatically. On
|
||||
success, the original action continues. Bundled placeholder credentials
|
||||
cannot auto-auth and require the setup wizard via `:Pending auth`.
|
||||
- S3: `aws sts get-caller-identity` runs before every sync action. If SSO
|
||||
is expired, `aws sso login` is triggered automatically. Missing
|
||||
credentials abort with an error pointing to |pending-s3|.
|
||||
|
||||
On auth failure, the sync action is aborted with an error message.
|
||||
|
||||
==============================================================================
|
||||
GOOGLE CALENDAR *pending-gcal*
|
||||
|
||||
|
|
|
|||
|
|
@ -45,8 +45,28 @@ function M.with_token(client, name, callback)
|
|||
util.with_guard(name, function()
|
||||
local token = client:get_access_token()
|
||||
if not token then
|
||||
require('pending.log').warn(name .. ': Not authenticated — run :Pending auth.')
|
||||
return
|
||||
local creds = client:resolve_credentials()
|
||||
if creds.client_id == BUNDLED_CLIENT_ID then
|
||||
log.warn(name .. ': No credentials configured — run :Pending auth.')
|
||||
return
|
||||
end
|
||||
log.info(name .. ': Not authenticated — starting auth flow...')
|
||||
local co = coroutine.running()
|
||||
client:auth(function(ok)
|
||||
vim.schedule(function()
|
||||
coroutine.resume(co, ok)
|
||||
end)
|
||||
end)
|
||||
local auth_ok = coroutine.yield()
|
||||
if not auth_ok then
|
||||
log.error(name .. ': Authentication failed.')
|
||||
return
|
||||
end
|
||||
token = client:get_access_token()
|
||||
if not token then
|
||||
log.error(name .. ': Still not authenticated after auth flow.')
|
||||
return
|
||||
end
|
||||
end
|
||||
callback(token)
|
||||
end)
|
||||
|
|
@ -349,7 +369,7 @@ function OAuthClient:setup()
|
|||
end)
|
||||
end
|
||||
|
||||
---@param on_complete? fun(): nil
|
||||
---@param on_complete? fun(ok: boolean): nil
|
||||
---@return nil
|
||||
function OAuthClient:auth(on_complete)
|
||||
if _active_close then
|
||||
|
|
@ -360,6 +380,9 @@ function OAuthClient:auth(on_complete)
|
|||
local creds = self:resolve_credentials()
|
||||
if creds.client_id == BUNDLED_CLIENT_ID then
|
||||
log.error(self.name .. ': No credentials configured — run :Pending auth.')
|
||||
if on_complete then
|
||||
on_complete(false)
|
||||
end
|
||||
return
|
||||
end
|
||||
local port = self.port
|
||||
|
|
@ -411,6 +434,9 @@ function OAuthClient:auth(on_complete)
|
|||
if not bind_ok or bind_err == nil then
|
||||
close_server()
|
||||
log.error(self.name .. ': Port ' .. port .. ' already in use — try again in a moment.')
|
||||
if on_complete then
|
||||
on_complete(false)
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
|
|
@ -453,6 +479,9 @@ function OAuthClient:auth(on_complete)
|
|||
if not server_closed then
|
||||
close_server()
|
||||
log.warn(self.name .. ': OAuth callback timed out (120s).')
|
||||
if on_complete then
|
||||
on_complete(false)
|
||||
end
|
||||
end
|
||||
end, 120000)
|
||||
end
|
||||
|
|
@ -461,7 +490,7 @@ end
|
|||
---@param code string
|
||||
---@param code_verifier string
|
||||
---@param port integer
|
||||
---@param on_complete? fun(): nil
|
||||
---@param on_complete? fun(ok: boolean): nil
|
||||
---@return nil
|
||||
function OAuthClient:_exchange_code(creds, code, code_verifier, port, on_complete)
|
||||
local body = 'client_id='
|
||||
|
|
@ -491,6 +520,9 @@ function OAuthClient:_exchange_code(creds, code, code_verifier, port, on_complet
|
|||
if result.code ~= 0 then
|
||||
self:clear_tokens()
|
||||
log.error(self.name .. ': Token exchange failed.')
|
||||
if on_complete then
|
||||
on_complete(false)
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
|
|
@ -498,6 +530,9 @@ function OAuthClient:_exchange_code(creds, code, code_verifier, port, on_complet
|
|||
if not ok or not decoded.access_token then
|
||||
self:clear_tokens()
|
||||
log.error(self.name .. ': Invalid token response.')
|
||||
if on_complete then
|
||||
on_complete(false)
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
|
|
@ -505,7 +540,7 @@ function OAuthClient:_exchange_code(creds, code, code_verifier, port, on_complet
|
|||
self:save_tokens(decoded)
|
||||
log.info(self.name .. ': Authorized successfully.')
|
||||
if on_complete then
|
||||
on_complete()
|
||||
on_complete(true)
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
|||
|
|
@ -66,6 +66,35 @@ local function ensure_sync_id(task)
|
|||
return sync_id
|
||||
end
|
||||
|
||||
---@return boolean
|
||||
local function ensure_credentials()
|
||||
local cmd = base_cmd()
|
||||
vim.list_extend(cmd, { 'sts', 'get-caller-identity', '--output', 'json' })
|
||||
local result = util.system(cmd, { text = true })
|
||||
if result.code == 0 then
|
||||
return true
|
||||
end
|
||||
local stderr = result.stderr or ''
|
||||
if stderr:find('SSO') or stderr:find('sso') then
|
||||
log.info('S3: SSO session expired — running login...')
|
||||
local login_cmd = base_cmd()
|
||||
vim.list_extend(login_cmd, { 'sso', 'login' })
|
||||
local login_result = util.system(login_cmd, { text = true })
|
||||
if login_result.code == 0 then
|
||||
log.info('S3: SSO login successful')
|
||||
return true
|
||||
end
|
||||
log.error('S3: SSO login failed — ' .. (login_result.stderr or ''))
|
||||
return false
|
||||
end
|
||||
if stderr:find('Unable to locate credentials') or stderr:find('NoCredentialProviders') then
|
||||
log.error('S3: no AWS credentials configured. See :h pending-s3')
|
||||
else
|
||||
log.error('S3: credential check failed — ' .. stderr)
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
local function create_bucket()
|
||||
local name = util.input({ prompt = 'S3 bucket name (pending.nvim): ' })
|
||||
if not name then
|
||||
|
|
@ -177,6 +206,9 @@ end
|
|||
function M.push()
|
||||
util.async(function()
|
||||
util.with_guard('S3', function()
|
||||
if not ensure_credentials() then
|
||||
return
|
||||
end
|
||||
local s3cfg = get_config()
|
||||
if not s3cfg or not s3cfg.bucket then
|
||||
log.error('S3: bucket is required. Set sync.s3.bucket in config.')
|
||||
|
|
@ -231,6 +263,9 @@ end
|
|||
function M.pull()
|
||||
util.async(function()
|
||||
util.with_guard('S3', function()
|
||||
if not ensure_credentials() then
|
||||
return
|
||||
end
|
||||
local s3cfg = get_config()
|
||||
if not s3cfg or not s3cfg.bucket then
|
||||
log.error('S3: bucket is required. Set sync.s3.bucket in config.')
|
||||
|
|
@ -330,6 +365,9 @@ end
|
|||
function M.sync()
|
||||
util.async(function()
|
||||
util.with_guard('S3', function()
|
||||
if not ensure_credentials() then
|
||||
return
|
||||
end
|
||||
local s3cfg = get_config()
|
||||
if not s3cfg or not s3cfg.bucket then
|
||||
log.error('S3: bucket is required. Set sync.s3.bucket in config.')
|
||||
|
|
@ -466,5 +504,6 @@ function M.health()
|
|||
end
|
||||
|
||||
M._ensure_sync_id = ensure_sync_id
|
||||
M._ensure_credentials = ensure_credentials
|
||||
|
||||
return M
|
||||
|
|
|
|||
|
|
@ -232,4 +232,98 @@ describe('oauth', function()
|
|||
assert.equals('test', c.config_key)
|
||||
end)
|
||||
end)
|
||||
|
||||
describe('with_token', function()
|
||||
it('auto-triggers auth when not authenticated', function()
|
||||
local c = oauth.new({ name = 'test_auth', scope = 'x', port = 0, config_key = 'gtasks' })
|
||||
local call_count = 0
|
||||
c.get_access_token = function()
|
||||
call_count = call_count + 1
|
||||
if call_count == 1 then
|
||||
return nil
|
||||
end
|
||||
return 'new-token'
|
||||
end
|
||||
c.resolve_credentials = function()
|
||||
return { client_id = 'real-id', client_secret = 'real-secret' }
|
||||
end
|
||||
local auth_called = false
|
||||
c.auth = function(_, on_complete)
|
||||
auth_called = true
|
||||
vim.schedule(function()
|
||||
on_complete(true)
|
||||
end)
|
||||
end
|
||||
local received_token
|
||||
oauth.with_token(c, 'test_auth', function(token)
|
||||
received_token = token
|
||||
end)
|
||||
vim.wait(1000, function()
|
||||
return received_token ~= nil
|
||||
end)
|
||||
assert.is_true(auth_called)
|
||||
assert.equals('new-token', received_token)
|
||||
end)
|
||||
|
||||
it('bails on bundled credentials without calling auth', function()
|
||||
local c = oauth.new({ name = 'test_bail', scope = 'x', port = 0, config_key = 'gtasks' })
|
||||
c.get_access_token = function()
|
||||
return nil
|
||||
end
|
||||
c.resolve_credentials = function()
|
||||
return { client_id = oauth.BUNDLED_CLIENT_ID, client_secret = 'x' }
|
||||
end
|
||||
local auth_called = false
|
||||
c.auth = function()
|
||||
auth_called = true
|
||||
end
|
||||
local callback_called = false
|
||||
oauth.with_token(c, 'test_bail', function()
|
||||
callback_called = true
|
||||
end)
|
||||
vim.wait(500, function()
|
||||
return false
|
||||
end)
|
||||
assert.is_false(auth_called)
|
||||
assert.is_false(callback_called)
|
||||
end)
|
||||
|
||||
it('stops when auth fails', function()
|
||||
local c = oauth.new({ name = 'test_fail', scope = 'x', port = 0, config_key = 'gtasks' })
|
||||
c.get_access_token = function()
|
||||
return nil
|
||||
end
|
||||
c.resolve_credentials = function()
|
||||
return { client_id = 'real-id', client_secret = 'real-secret' }
|
||||
end
|
||||
c.auth = function(_, on_complete)
|
||||
vim.schedule(function()
|
||||
on_complete(false)
|
||||
end)
|
||||
end
|
||||
local callback_called = false
|
||||
oauth.with_token(c, 'test_fail', function()
|
||||
callback_called = true
|
||||
end)
|
||||
vim.wait(500, function()
|
||||
return false
|
||||
end)
|
||||
assert.is_false(callback_called)
|
||||
end)
|
||||
|
||||
it('proceeds directly when already authenticated', function()
|
||||
local c = oauth.new({ name = 'test_ok', scope = 'x', port = 0, config_key = 'gtasks' })
|
||||
c.get_access_token = function()
|
||||
return 'existing-token'
|
||||
end
|
||||
local received_token
|
||||
oauth.with_token(c, 'test_ok', function(token)
|
||||
received_token = token
|
||||
end)
|
||||
vim.wait(1000, function()
|
||||
return received_token ~= nil
|
||||
end)
|
||||
assert.equals('existing-token', received_token)
|
||||
end)
|
||||
end)
|
||||
end)
|
||||
|
|
|
|||
|
|
@ -374,6 +374,64 @@ describe('s3', function()
|
|||
end)
|
||||
end)
|
||||
|
||||
describe('ensure_credentials', function()
|
||||
it('returns true on valid credentials', function()
|
||||
util.system = function(args)
|
||||
if vim.tbl_contains(args, 'get-caller-identity') then
|
||||
return { code = 0, stdout = '{"Account":"123"}', stderr = '' }
|
||||
end
|
||||
return { code = 0, stdout = '', stderr = '' }
|
||||
end
|
||||
assert.is_true(s3._ensure_credentials())
|
||||
end)
|
||||
|
||||
it('returns false on missing credentials', function()
|
||||
util.system = function()
|
||||
return { code = 1, stdout = '', stderr = 'Unable to locate credentials' }
|
||||
end
|
||||
local msg
|
||||
local orig_notify = vim.notify
|
||||
vim.notify = function(m, level)
|
||||
if level == vim.log.levels.ERROR then
|
||||
msg = m
|
||||
end
|
||||
end
|
||||
assert.is_false(s3._ensure_credentials())
|
||||
vim.notify = orig_notify
|
||||
assert.truthy(msg and msg:find('no AWS credentials'))
|
||||
end)
|
||||
|
||||
it('retries SSO login on expired session', function()
|
||||
local calls = {}
|
||||
util.system = function(args)
|
||||
if vim.tbl_contains(args, 'get-caller-identity') then
|
||||
return { code = 1, stdout = '', stderr = 'Error: SSO session expired' }
|
||||
end
|
||||
if vim.tbl_contains(args, 'sso') then
|
||||
table.insert(calls, 'sso-login')
|
||||
return { code = 0, stdout = '', stderr = '' }
|
||||
end
|
||||
return { code = 0, stdout = '', stderr = '' }
|
||||
end
|
||||
assert.is_true(s3._ensure_credentials())
|
||||
assert.equals(1, #calls)
|
||||
assert.equals('sso-login', calls[1])
|
||||
end)
|
||||
|
||||
it('returns false when SSO login fails', function()
|
||||
util.system = function(args)
|
||||
if vim.tbl_contains(args, 'get-caller-identity') then
|
||||
return { code = 1, stdout = '', stderr = 'SSO token expired' }
|
||||
end
|
||||
if vim.tbl_contains(args, 'sso') then
|
||||
return { code = 1, stdout = '', stderr = 'login failed' }
|
||||
end
|
||||
return { code = 0, stdout = '', stderr = '' }
|
||||
end
|
||||
assert.is_false(s3._ensure_credentials())
|
||||
end)
|
||||
end)
|
||||
|
||||
describe('push', function()
|
||||
it('uploads store to S3', function()
|
||||
local s = pending.store()
|
||||
|
|
@ -383,6 +441,9 @@ describe('s3', function()
|
|||
|
||||
local captured_args
|
||||
util.system = function(args)
|
||||
if vim.tbl_contains(args, 'get-caller-identity') then
|
||||
return { code = 0, stdout = '{"Account":"123"}', stderr = '' }
|
||||
end
|
||||
if vim.tbl_contains(args, 's3') then
|
||||
captured_args = args
|
||||
return { code = 0, stdout = '', stderr = '' }
|
||||
|
|
@ -405,6 +466,13 @@ describe('s3', function()
|
|||
pending = require('pending')
|
||||
s3 = require('pending.sync.s3')
|
||||
|
||||
util.system = function(args)
|
||||
if vim.tbl_contains(args, 'get-caller-identity') then
|
||||
return { code = 0, stdout = '{"Account":"123"}', stderr = '' }
|
||||
end
|
||||
return { code = 0, stdout = '', stderr = '' }
|
||||
end
|
||||
|
||||
local msg
|
||||
local orig_notify = vim.notify
|
||||
vim.notify = function(m, level)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue