fix(sync): auto-trigger auth flow on unauthenticated sync actions (#120)
Problem: running a sync action (e.g. `:Pending gtasks push`) without being authenticated would silently abort with a warning, requiring the user to manually run `:Pending auth` first. Solution: `oauth.with_token()` now auto-triggers the browser auth flow when no token exists (for non-bundled credentials) and resumes the original action on success. `auth()` and `_exchange_code()` now call `on_complete(ok)` on all exit paths. S3 backends run `aws sts get-caller-identity` before every sync action, auto-triggering SSO login on expired sessions.
This commit is contained in:
parent
422f8f9b05
commit
149f2dac2e
5 changed files with 256 additions and 5 deletions
|
|
@ -1143,6 +1143,21 @@ Shared utilities for backend authors are provided by `sync/util.lua`:
|
||||||
|
|
||||||
Backend-specific configuration goes under `sync.<name>` in |pending-config|.
|
Backend-specific configuration goes under `sync.<name>` in |pending-config|.
|
||||||
|
|
||||||
|
Auto-auth: ~
|
||||||
|
*pending-sync-auto-auth*
|
||||||
|
Running a sync action (`:Pending <name> push/pull/sync`) without valid
|
||||||
|
credentials automatically triggers authentication before proceeding:
|
||||||
|
|
||||||
|
- OAuth backends (gcal, gtasks): if real credentials are configured but no
|
||||||
|
token exists, the browser-based auth flow starts automatically. On
|
||||||
|
success, the original action continues. Bundled placeholder credentials
|
||||||
|
cannot auto-auth and require the setup wizard via `:Pending auth`.
|
||||||
|
- S3: `aws sts get-caller-identity` runs before every sync action. If SSO
|
||||||
|
is expired, `aws sso login` is triggered automatically. Missing
|
||||||
|
credentials abort with an error pointing to |pending-s3|.
|
||||||
|
|
||||||
|
On auth failure, the sync action is aborted with an error message.
|
||||||
|
|
||||||
==============================================================================
|
==============================================================================
|
||||||
GOOGLE CALENDAR *pending-gcal*
|
GOOGLE CALENDAR *pending-gcal*
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -45,8 +45,28 @@ function M.with_token(client, name, callback)
|
||||||
util.with_guard(name, function()
|
util.with_guard(name, function()
|
||||||
local token = client:get_access_token()
|
local token = client:get_access_token()
|
||||||
if not token then
|
if not token then
|
||||||
require('pending.log').warn(name .. ': Not authenticated — run :Pending auth.')
|
local creds = client:resolve_credentials()
|
||||||
return
|
if creds.client_id == BUNDLED_CLIENT_ID then
|
||||||
|
log.warn(name .. ': No credentials configured — run :Pending auth.')
|
||||||
|
return
|
||||||
|
end
|
||||||
|
log.info(name .. ': Not authenticated — starting auth flow...')
|
||||||
|
local co = coroutine.running()
|
||||||
|
client:auth(function(ok)
|
||||||
|
vim.schedule(function()
|
||||||
|
coroutine.resume(co, ok)
|
||||||
|
end)
|
||||||
|
end)
|
||||||
|
local auth_ok = coroutine.yield()
|
||||||
|
if not auth_ok then
|
||||||
|
log.error(name .. ': Authentication failed.')
|
||||||
|
return
|
||||||
|
end
|
||||||
|
token = client:get_access_token()
|
||||||
|
if not token then
|
||||||
|
log.error(name .. ': Still not authenticated after auth flow.')
|
||||||
|
return
|
||||||
|
end
|
||||||
end
|
end
|
||||||
callback(token)
|
callback(token)
|
||||||
end)
|
end)
|
||||||
|
|
@ -349,7 +369,7 @@ function OAuthClient:setup()
|
||||||
end)
|
end)
|
||||||
end
|
end
|
||||||
|
|
||||||
---@param on_complete? fun(): nil
|
---@param on_complete? fun(ok: boolean): nil
|
||||||
---@return nil
|
---@return nil
|
||||||
function OAuthClient:auth(on_complete)
|
function OAuthClient:auth(on_complete)
|
||||||
if _active_close then
|
if _active_close then
|
||||||
|
|
@ -360,6 +380,9 @@ function OAuthClient:auth(on_complete)
|
||||||
local creds = self:resolve_credentials()
|
local creds = self:resolve_credentials()
|
||||||
if creds.client_id == BUNDLED_CLIENT_ID then
|
if creds.client_id == BUNDLED_CLIENT_ID then
|
||||||
log.error(self.name .. ': No credentials configured — run :Pending auth.')
|
log.error(self.name .. ': No credentials configured — run :Pending auth.')
|
||||||
|
if on_complete then
|
||||||
|
on_complete(false)
|
||||||
|
end
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
local port = self.port
|
local port = self.port
|
||||||
|
|
@ -411,6 +434,9 @@ function OAuthClient:auth(on_complete)
|
||||||
if not bind_ok or bind_err == nil then
|
if not bind_ok or bind_err == nil then
|
||||||
close_server()
|
close_server()
|
||||||
log.error(self.name .. ': Port ' .. port .. ' already in use — try again in a moment.')
|
log.error(self.name .. ': Port ' .. port .. ' already in use — try again in a moment.')
|
||||||
|
if on_complete then
|
||||||
|
on_complete(false)
|
||||||
|
end
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
@ -453,6 +479,9 @@ function OAuthClient:auth(on_complete)
|
||||||
if not server_closed then
|
if not server_closed then
|
||||||
close_server()
|
close_server()
|
||||||
log.warn(self.name .. ': OAuth callback timed out (120s).')
|
log.warn(self.name .. ': OAuth callback timed out (120s).')
|
||||||
|
if on_complete then
|
||||||
|
on_complete(false)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end, 120000)
|
end, 120000)
|
||||||
end
|
end
|
||||||
|
|
@ -461,7 +490,7 @@ end
|
||||||
---@param code string
|
---@param code string
|
||||||
---@param code_verifier string
|
---@param code_verifier string
|
||||||
---@param port integer
|
---@param port integer
|
||||||
---@param on_complete? fun(): nil
|
---@param on_complete? fun(ok: boolean): nil
|
||||||
---@return nil
|
---@return nil
|
||||||
function OAuthClient:_exchange_code(creds, code, code_verifier, port, on_complete)
|
function OAuthClient:_exchange_code(creds, code, code_verifier, port, on_complete)
|
||||||
local body = 'client_id='
|
local body = 'client_id='
|
||||||
|
|
@ -491,6 +520,9 @@ function OAuthClient:_exchange_code(creds, code, code_verifier, port, on_complet
|
||||||
if result.code ~= 0 then
|
if result.code ~= 0 then
|
||||||
self:clear_tokens()
|
self:clear_tokens()
|
||||||
log.error(self.name .. ': Token exchange failed.')
|
log.error(self.name .. ': Token exchange failed.')
|
||||||
|
if on_complete then
|
||||||
|
on_complete(false)
|
||||||
|
end
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
@ -498,6 +530,9 @@ function OAuthClient:_exchange_code(creds, code, code_verifier, port, on_complet
|
||||||
if not ok or not decoded.access_token then
|
if not ok or not decoded.access_token then
|
||||||
self:clear_tokens()
|
self:clear_tokens()
|
||||||
log.error(self.name .. ': Invalid token response.')
|
log.error(self.name .. ': Invalid token response.')
|
||||||
|
if on_complete then
|
||||||
|
on_complete(false)
|
||||||
|
end
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
@ -505,7 +540,7 @@ function OAuthClient:_exchange_code(creds, code, code_verifier, port, on_complet
|
||||||
self:save_tokens(decoded)
|
self:save_tokens(decoded)
|
||||||
log.info(self.name .. ': Authorized successfully.')
|
log.info(self.name .. ': Authorized successfully.')
|
||||||
if on_complete then
|
if on_complete then
|
||||||
on_complete()
|
on_complete(true)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -66,6 +66,35 @@ local function ensure_sync_id(task)
|
||||||
return sync_id
|
return sync_id
|
||||||
end
|
end
|
||||||
|
|
||||||
|
---@return boolean
|
||||||
|
local function ensure_credentials()
|
||||||
|
local cmd = base_cmd()
|
||||||
|
vim.list_extend(cmd, { 'sts', 'get-caller-identity', '--output', 'json' })
|
||||||
|
local result = util.system(cmd, { text = true })
|
||||||
|
if result.code == 0 then
|
||||||
|
return true
|
||||||
|
end
|
||||||
|
local stderr = result.stderr or ''
|
||||||
|
if stderr:find('SSO') or stderr:find('sso') then
|
||||||
|
log.info('S3: SSO session expired — running login...')
|
||||||
|
local login_cmd = base_cmd()
|
||||||
|
vim.list_extend(login_cmd, { 'sso', 'login' })
|
||||||
|
local login_result = util.system(login_cmd, { text = true })
|
||||||
|
if login_result.code == 0 then
|
||||||
|
log.info('S3: SSO login successful')
|
||||||
|
return true
|
||||||
|
end
|
||||||
|
log.error('S3: SSO login failed — ' .. (login_result.stderr or ''))
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
if stderr:find('Unable to locate credentials') or stderr:find('NoCredentialProviders') then
|
||||||
|
log.error('S3: no AWS credentials configured. See :h pending-s3')
|
||||||
|
else
|
||||||
|
log.error('S3: credential check failed — ' .. stderr)
|
||||||
|
end
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
|
||||||
local function create_bucket()
|
local function create_bucket()
|
||||||
local name = util.input({ prompt = 'S3 bucket name (pending.nvim): ' })
|
local name = util.input({ prompt = 'S3 bucket name (pending.nvim): ' })
|
||||||
if not name then
|
if not name then
|
||||||
|
|
@ -177,6 +206,9 @@ end
|
||||||
function M.push()
|
function M.push()
|
||||||
util.async(function()
|
util.async(function()
|
||||||
util.with_guard('S3', function()
|
util.with_guard('S3', function()
|
||||||
|
if not ensure_credentials() then
|
||||||
|
return
|
||||||
|
end
|
||||||
local s3cfg = get_config()
|
local s3cfg = get_config()
|
||||||
if not s3cfg or not s3cfg.bucket then
|
if not s3cfg or not s3cfg.bucket then
|
||||||
log.error('S3: bucket is required. Set sync.s3.bucket in config.')
|
log.error('S3: bucket is required. Set sync.s3.bucket in config.')
|
||||||
|
|
@ -231,6 +263,9 @@ end
|
||||||
function M.pull()
|
function M.pull()
|
||||||
util.async(function()
|
util.async(function()
|
||||||
util.with_guard('S3', function()
|
util.with_guard('S3', function()
|
||||||
|
if not ensure_credentials() then
|
||||||
|
return
|
||||||
|
end
|
||||||
local s3cfg = get_config()
|
local s3cfg = get_config()
|
||||||
if not s3cfg or not s3cfg.bucket then
|
if not s3cfg or not s3cfg.bucket then
|
||||||
log.error('S3: bucket is required. Set sync.s3.bucket in config.')
|
log.error('S3: bucket is required. Set sync.s3.bucket in config.')
|
||||||
|
|
@ -330,6 +365,9 @@ end
|
||||||
function M.sync()
|
function M.sync()
|
||||||
util.async(function()
|
util.async(function()
|
||||||
util.with_guard('S3', function()
|
util.with_guard('S3', function()
|
||||||
|
if not ensure_credentials() then
|
||||||
|
return
|
||||||
|
end
|
||||||
local s3cfg = get_config()
|
local s3cfg = get_config()
|
||||||
if not s3cfg or not s3cfg.bucket then
|
if not s3cfg or not s3cfg.bucket then
|
||||||
log.error('S3: bucket is required. Set sync.s3.bucket in config.')
|
log.error('S3: bucket is required. Set sync.s3.bucket in config.')
|
||||||
|
|
@ -466,5 +504,6 @@ function M.health()
|
||||||
end
|
end
|
||||||
|
|
||||||
M._ensure_sync_id = ensure_sync_id
|
M._ensure_sync_id = ensure_sync_id
|
||||||
|
M._ensure_credentials = ensure_credentials
|
||||||
|
|
||||||
return M
|
return M
|
||||||
|
|
|
||||||
|
|
@ -232,4 +232,98 @@ describe('oauth', function()
|
||||||
assert.equals('test', c.config_key)
|
assert.equals('test', c.config_key)
|
||||||
end)
|
end)
|
||||||
end)
|
end)
|
||||||
|
|
||||||
|
describe('with_token', function()
|
||||||
|
it('auto-triggers auth when not authenticated', function()
|
||||||
|
local c = oauth.new({ name = 'test_auth', scope = 'x', port = 0, config_key = 'gtasks' })
|
||||||
|
local call_count = 0
|
||||||
|
c.get_access_token = function()
|
||||||
|
call_count = call_count + 1
|
||||||
|
if call_count == 1 then
|
||||||
|
return nil
|
||||||
|
end
|
||||||
|
return 'new-token'
|
||||||
|
end
|
||||||
|
c.resolve_credentials = function()
|
||||||
|
return { client_id = 'real-id', client_secret = 'real-secret' }
|
||||||
|
end
|
||||||
|
local auth_called = false
|
||||||
|
c.auth = function(_, on_complete)
|
||||||
|
auth_called = true
|
||||||
|
vim.schedule(function()
|
||||||
|
on_complete(true)
|
||||||
|
end)
|
||||||
|
end
|
||||||
|
local received_token
|
||||||
|
oauth.with_token(c, 'test_auth', function(token)
|
||||||
|
received_token = token
|
||||||
|
end)
|
||||||
|
vim.wait(1000, function()
|
||||||
|
return received_token ~= nil
|
||||||
|
end)
|
||||||
|
assert.is_true(auth_called)
|
||||||
|
assert.equals('new-token', received_token)
|
||||||
|
end)
|
||||||
|
|
||||||
|
it('bails on bundled credentials without calling auth', function()
|
||||||
|
local c = oauth.new({ name = 'test_bail', scope = 'x', port = 0, config_key = 'gtasks' })
|
||||||
|
c.get_access_token = function()
|
||||||
|
return nil
|
||||||
|
end
|
||||||
|
c.resolve_credentials = function()
|
||||||
|
return { client_id = oauth.BUNDLED_CLIENT_ID, client_secret = 'x' }
|
||||||
|
end
|
||||||
|
local auth_called = false
|
||||||
|
c.auth = function()
|
||||||
|
auth_called = true
|
||||||
|
end
|
||||||
|
local callback_called = false
|
||||||
|
oauth.with_token(c, 'test_bail', function()
|
||||||
|
callback_called = true
|
||||||
|
end)
|
||||||
|
vim.wait(500, function()
|
||||||
|
return false
|
||||||
|
end)
|
||||||
|
assert.is_false(auth_called)
|
||||||
|
assert.is_false(callback_called)
|
||||||
|
end)
|
||||||
|
|
||||||
|
it('stops when auth fails', function()
|
||||||
|
local c = oauth.new({ name = 'test_fail', scope = 'x', port = 0, config_key = 'gtasks' })
|
||||||
|
c.get_access_token = function()
|
||||||
|
return nil
|
||||||
|
end
|
||||||
|
c.resolve_credentials = function()
|
||||||
|
return { client_id = 'real-id', client_secret = 'real-secret' }
|
||||||
|
end
|
||||||
|
c.auth = function(_, on_complete)
|
||||||
|
vim.schedule(function()
|
||||||
|
on_complete(false)
|
||||||
|
end)
|
||||||
|
end
|
||||||
|
local callback_called = false
|
||||||
|
oauth.with_token(c, 'test_fail', function()
|
||||||
|
callback_called = true
|
||||||
|
end)
|
||||||
|
vim.wait(500, function()
|
||||||
|
return false
|
||||||
|
end)
|
||||||
|
assert.is_false(callback_called)
|
||||||
|
end)
|
||||||
|
|
||||||
|
it('proceeds directly when already authenticated', function()
|
||||||
|
local c = oauth.new({ name = 'test_ok', scope = 'x', port = 0, config_key = 'gtasks' })
|
||||||
|
c.get_access_token = function()
|
||||||
|
return 'existing-token'
|
||||||
|
end
|
||||||
|
local received_token
|
||||||
|
oauth.with_token(c, 'test_ok', function(token)
|
||||||
|
received_token = token
|
||||||
|
end)
|
||||||
|
vim.wait(1000, function()
|
||||||
|
return received_token ~= nil
|
||||||
|
end)
|
||||||
|
assert.equals('existing-token', received_token)
|
||||||
|
end)
|
||||||
|
end)
|
||||||
end)
|
end)
|
||||||
|
|
|
||||||
|
|
@ -374,6 +374,64 @@ describe('s3', function()
|
||||||
end)
|
end)
|
||||||
end)
|
end)
|
||||||
|
|
||||||
|
describe('ensure_credentials', function()
|
||||||
|
it('returns true on valid credentials', function()
|
||||||
|
util.system = function(args)
|
||||||
|
if vim.tbl_contains(args, 'get-caller-identity') then
|
||||||
|
return { code = 0, stdout = '{"Account":"123"}', stderr = '' }
|
||||||
|
end
|
||||||
|
return { code = 0, stdout = '', stderr = '' }
|
||||||
|
end
|
||||||
|
assert.is_true(s3._ensure_credentials())
|
||||||
|
end)
|
||||||
|
|
||||||
|
it('returns false on missing credentials', function()
|
||||||
|
util.system = function()
|
||||||
|
return { code = 1, stdout = '', stderr = 'Unable to locate credentials' }
|
||||||
|
end
|
||||||
|
local msg
|
||||||
|
local orig_notify = vim.notify
|
||||||
|
vim.notify = function(m, level)
|
||||||
|
if level == vim.log.levels.ERROR then
|
||||||
|
msg = m
|
||||||
|
end
|
||||||
|
end
|
||||||
|
assert.is_false(s3._ensure_credentials())
|
||||||
|
vim.notify = orig_notify
|
||||||
|
assert.truthy(msg and msg:find('no AWS credentials'))
|
||||||
|
end)
|
||||||
|
|
||||||
|
it('retries SSO login on expired session', function()
|
||||||
|
local calls = {}
|
||||||
|
util.system = function(args)
|
||||||
|
if vim.tbl_contains(args, 'get-caller-identity') then
|
||||||
|
return { code = 1, stdout = '', stderr = 'Error: SSO session expired' }
|
||||||
|
end
|
||||||
|
if vim.tbl_contains(args, 'sso') then
|
||||||
|
table.insert(calls, 'sso-login')
|
||||||
|
return { code = 0, stdout = '', stderr = '' }
|
||||||
|
end
|
||||||
|
return { code = 0, stdout = '', stderr = '' }
|
||||||
|
end
|
||||||
|
assert.is_true(s3._ensure_credentials())
|
||||||
|
assert.equals(1, #calls)
|
||||||
|
assert.equals('sso-login', calls[1])
|
||||||
|
end)
|
||||||
|
|
||||||
|
it('returns false when SSO login fails', function()
|
||||||
|
util.system = function(args)
|
||||||
|
if vim.tbl_contains(args, 'get-caller-identity') then
|
||||||
|
return { code = 1, stdout = '', stderr = 'SSO token expired' }
|
||||||
|
end
|
||||||
|
if vim.tbl_contains(args, 'sso') then
|
||||||
|
return { code = 1, stdout = '', stderr = 'login failed' }
|
||||||
|
end
|
||||||
|
return { code = 0, stdout = '', stderr = '' }
|
||||||
|
end
|
||||||
|
assert.is_false(s3._ensure_credentials())
|
||||||
|
end)
|
||||||
|
end)
|
||||||
|
|
||||||
describe('push', function()
|
describe('push', function()
|
||||||
it('uploads store to S3', function()
|
it('uploads store to S3', function()
|
||||||
local s = pending.store()
|
local s = pending.store()
|
||||||
|
|
@ -383,6 +441,9 @@ describe('s3', function()
|
||||||
|
|
||||||
local captured_args
|
local captured_args
|
||||||
util.system = function(args)
|
util.system = function(args)
|
||||||
|
if vim.tbl_contains(args, 'get-caller-identity') then
|
||||||
|
return { code = 0, stdout = '{"Account":"123"}', stderr = '' }
|
||||||
|
end
|
||||||
if vim.tbl_contains(args, 's3') then
|
if vim.tbl_contains(args, 's3') then
|
||||||
captured_args = args
|
captured_args = args
|
||||||
return { code = 0, stdout = '', stderr = '' }
|
return { code = 0, stdout = '', stderr = '' }
|
||||||
|
|
@ -405,6 +466,13 @@ describe('s3', function()
|
||||||
pending = require('pending')
|
pending = require('pending')
|
||||||
s3 = require('pending.sync.s3')
|
s3 = require('pending.sync.s3')
|
||||||
|
|
||||||
|
util.system = function(args)
|
||||||
|
if vim.tbl_contains(args, 'get-caller-identity') then
|
||||||
|
return { code = 0, stdout = '{"Account":"123"}', stderr = '' }
|
||||||
|
end
|
||||||
|
return { code = 0, stdout = '', stderr = '' }
|
||||||
|
end
|
||||||
|
|
||||||
local msg
|
local msg
|
||||||
local orig_notify = vim.notify
|
local orig_notify = vim.notify
|
||||||
vim.notify = function(m, level)
|
vim.notify = function(m, level)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue