From f9da506017b64c4436329d2a570c30fc22ab8120 Mon Sep 17 00:00:00 2001 From: Stringadmin Date: Tue, 9 Jun 2026 11:32:53 +0800 Subject: [PATCH] fix: harden provider polling recovery --- .env.example | 6 + package.json | 5 +- scripts/keyManagerReleaseContract.test.js | 73 ++++++++ scripts/providerPollLimiterContract.test.js | 96 ++++++++++ src/aiTaskWorker.js | 197 ++++++++++++++++---- src/index.js | 75 +++++--- src/keyManager.js | 12 +- src/providerPollLimiter.js | 120 ++++++++++++ src/routes/ai.js | 24 ++- 9 files changed, 539 insertions(+), 69 deletions(-) create mode 100644 scripts/keyManagerReleaseContract.test.js create mode 100644 scripts/providerPollLimiterContract.test.js create mode 100644 src/providerPollLimiter.js diff --git a/.env.example b/.env.example index 19642c2..03cf184 100644 --- a/.env.example +++ b/.env.example @@ -20,6 +20,12 @@ JWT_EXPIRES_IN=7d # Connection pool PG_POOL_MAX=10 +# Provider polling reliability +# Shared across PM2 workers through Postgres-backed poll slots. +TASK_PROVIDER_POLL_MAX_CONCURRENCY=8 +TASK_PROVIDER_POLL_SLOT_TTL_MS=30000 +TASK_PROVIDER_POLL_REQUEST_TIMEOUT_MS=25000 + # CORS (comma separated allowed origins, * for all) CORS_ORIGINS=* diff --git a/package.json b/package.json index 4d6c1b5..b0ffbcd 100644 --- a/package.json +++ b/package.json @@ -14,7 +14,10 @@ "audit-routes": "node src/cli/auditModelRoutes.js", "import-config": "node src/cli/importConfig.js", "init-pools": "node src/cli/initPools.js", - "test:community-routes": "node scripts/communityRouteContract.test.js" + "test:community-routes": "node scripts/communityRouteContract.test.js", + "test:key-manager": "node scripts/keyManagerReleaseContract.test.js", + "test:provider-poll-limiter": "node scripts/providerPollLimiterContract.test.js", + "test": "npm run test:community-routes && npm run test:key-manager && npm run test:provider-poll-limiter" }, "dependencies": { "alipay-sdk": "^4.14.0", diff --git a/scripts/keyManagerReleaseContract.test.js b/scripts/keyManagerReleaseContract.test.js new file mode 100644 index 0000000..a8d515e --- /dev/null +++ b/scripts/keyManagerReleaseContract.test.js @@ -0,0 +1,73 @@ +const assert = require("node:assert/strict"); +const { createRequire } = require("node:module"); + +const nodeRequire = createRequire(__filename); + +function loadKeyManagerWithPool(pool) { + const dbPath = nodeRequire.resolve("../src/db"); + const keyManagerPath = nodeRequire.resolve("../src/keyManager"); + const originalDbModule = nodeRequire.cache[dbPath]; + const originalKeyManagerModule = nodeRequire.cache[keyManagerPath]; + + delete nodeRequire.cache[keyManagerPath]; + nodeRequire.cache[dbPath] = { + id: dbPath, + filename: dbPath, + loaded: true, + exports: { + pool, + withTransaction: async (fn) => fn(pool), + }, + }; + + return { + keyManager: nodeRequire("../src/keyManager"), + restore() { + delete nodeRequire.cache[keyManagerPath]; + if (originalKeyManagerModule) nodeRequire.cache[keyManagerPath] = originalKeyManagerModule; + if (originalDbModule) nodeRequire.cache[dbPath] = originalDbModule; + else delete nodeRequire.cache[dbPath]; + }, + }; +} + +function createReleasePool() { + const calls = []; + return { + calls, + async query(sql, params) { + calls.push({ sql, params }); + if (/WITH candidate AS/i.test(sql)) { + return { + rows: [{ + id: 10, + key_id: 20, + lease_user_id: 30, + lease_enterprise_id: 40, + provider: "dashscope", + }], + }; + } + if (/UPDATE api_keys SET active_count/i.test(sql)) return { rows: [] }; + if (/INSERT INTO usage_logs/i.test(sql)) return { rows: [] }; + throw new Error(`Unexpected SQL: ${sql}`); + }, + }; +} + +(async () => { + const pool = createReleasePool(); + const { keyManager, restore } = loadKeyManagerWithPool(pool); + try { + const result = await keyManager.releaseKey("lease-token-without-user-context"); + + assert.equal(result.released, true); + const usageLogCall = pool.calls.find((call) => /INSERT INTO usage_logs/i.test(call.sql)); + assert.deepEqual(usageLogCall.params, [30, 40, 20, 20, "release"]); + } finally { + restore(); + } +})().catch((error) => { + console.error(error); + process.exitCode = 1; +}); diff --git a/scripts/providerPollLimiterContract.test.js b/scripts/providerPollLimiterContract.test.js new file mode 100644 index 0000000..a215257 --- /dev/null +++ b/scripts/providerPollLimiterContract.test.js @@ -0,0 +1,96 @@ +const assert = require("node:assert/strict"); +const { createRequire } = require("node:module"); + +const nodeRequire = createRequire(__filename); + +function loadLimiterWithPool(pool) { + const dbPath = nodeRequire.resolve("../src/db"); + const limiterPath = nodeRequire.resolve("../src/providerPollLimiter"); + const originalDbModule = nodeRequire.cache[dbPath]; + const originalLimiterModule = nodeRequire.cache[limiterPath]; + + delete nodeRequire.cache[limiterPath]; + nodeRequire.cache[dbPath] = { + id: dbPath, + filename: dbPath, + loaded: true, + exports: { pool }, + }; + + return { + limiter: nodeRequire("../src/providerPollLimiter"), + restore() { + delete nodeRequire.cache[limiterPath]; + if (originalLimiterModule) nodeRequire.cache[limiterPath] = originalLimiterModule; + if (originalDbModule) nodeRequire.cache[dbPath] = originalDbModule; + else delete nodeRequire.cache[dbPath]; + }, + }; +} + +function createPool(options = {}) { + const calls = []; + return { + calls, + async query(sql, params = []) { + calls.push({ sql, params }); + if (/CREATE TABLE IF NOT EXISTS generation_provider_poll_slots/i.test(sql)) return { rows: [] }; + if (/WITH candidate AS/i.test(sql)) { + if (options.noAvailableSlot) return { rows: [] }; + return { rows: [{ scope: params[0], slot_no: 2 }] }; + } + if (/DELETE FROM generation_provider_poll_slots/i.test(sql)) return { rows: [] }; + throw new Error(`Unexpected SQL: ${sql}`); + }, + }; +} + +(async () => { + const previousLimit = process.env.TASK_PROVIDER_POLL_MAX_CONCURRENCY; + process.env.TASK_PROVIDER_POLL_MAX_CONCURRENCY = "3"; + + const pool = createPool(); + const { limiter, restore } = loadLimiterWithPool(pool); + try { + const outcome = await limiter.withProviderPollSlot(101, async () => "polled"); + + assert.equal(outcome.acquired, true); + assert.equal(outcome.value, "polled"); + + const acquireCall = pool.calls.find((call) => /WITH candidate AS/i.test(call.sql)); + assert.equal(acquireCall.params[1], 3); + assert.equal(acquireCall.params[3], 101); + + const releaseCall = pool.calls.find((call) => /DELETE FROM generation_provider_poll_slots/i.test(call.sql)); + assert.equal(releaseCall.params[0], acquireCall.params[0]); + assert.equal(releaseCall.params[1], 2); + assert.equal(releaseCall.params[2], acquireCall.params[2]); + } finally { + if (previousLimit === undefined) delete process.env.TASK_PROVIDER_POLL_MAX_CONCURRENCY; + else process.env.TASK_PROVIDER_POLL_MAX_CONCURRENCY = previousLimit; + restore(); + } + + const saturatedPool = createPool({ noAvailableSlot: true }); + const { limiter: saturatedLimiter, restore: restoreSaturated } = loadLimiterWithPool(saturatedPool); + try { + let called = false; + const outcome = await saturatedLimiter.withProviderPollSlot(202, async () => { + called = true; + return "should-not-run"; + }); + + assert.equal(outcome.acquired, false); + assert.equal(outcome.value, undefined); + assert.equal(called, false); + assert.equal( + saturatedPool.calls.some((call) => /DELETE FROM generation_provider_poll_slots/i.test(call.sql)), + false, + ); + } finally { + restoreSaturated(); + } +})().catch((error) => { + console.error(error); + process.exitCode = 1; +}); diff --git a/src/aiTaskWorker.js b/src/aiTaskWorker.js index 9781ae0..f061ed8 100644 --- a/src/aiTaskWorker.js +++ b/src/aiTaskWorker.js @@ -5,6 +5,7 @@ const { EventEmitter } = require("node:events"); const { pool } = require("./db"); const { refundTaskBillingOnFailure } = require("./billing"); const { putObject, isOssConfigured } = require("./ossClient"); +const { withProviderPollSlot } = require("./providerPollLimiter"); const taskEvents = new EventEmitter(); taskEvents.setMaxListeners(200); @@ -18,10 +19,12 @@ const TASK_EVENT_ORIGIN = `${process.pid}-${crypto.randomUUID()}`; const POLLER_OWNER_ID = `${process.pid}-${crypto.randomUUID()}`; const POLLER_OWNER_STALE_MS = Number(process.env.TASK_POLLER_OWNER_STALE_MS || 20_000); const POLLER_RECOVERY_INTERVAL_MS = Number(process.env.TASK_POLLER_RECOVERY_INTERVAL_MS || 30_000); +const PROVIDER_POLL_REQUEST_TIMEOUT_MS = Number(process.env.TASK_PROVIDER_POLL_REQUEST_TIMEOUT_MS || 25_000); let taskEventListenerClient = null; let taskEventListenerStarting = null; let pollerStoreReady = null; let pollerRecoveryTimer = null; +let staleTaskCleanupStartupTimer = null; function normalizeTaskProgress(value) { const numeric = Number(value); @@ -152,6 +155,14 @@ async function clearPollerState(taskDbId) { await pool.query("DELETE FROM generation_task_pollers WHERE task_id = $1", [taskDbId]); } +async function orphanOwnedPollerState() { + await ensureTaskPollerStore(); + await pool.query( + "UPDATE generation_task_pollers SET owner_id = NULL, owner_heartbeat_at = NULL, updated_at = NOW() WHERE owner_id = $1", + [POLLER_OWNER_ID], + ); +} + async function getPersistedLeaseToken(taskDbId) { await ensureTaskPollerStore(); const { rows } = await pool.query( @@ -280,6 +291,12 @@ async function updateTaskInDb(taskId, updates) { }); } + if (nextUpdates.status === "completed") { + await markTaskBillingAccepted(taskId).catch((err) => { + console.error(`[aiTaskWorker] billing accept error for task ${taskId}:`, err.message); + }); + } + if (nextUpdates.status === "failed") { await refundTaskBillingOnFailure(taskId).catch((err) => { console.error(`[aiTaskWorker] refund error for task ${taskId}:`, err.message); @@ -287,6 +304,13 @@ async function updateTaskInDb(taskId, updates) { } } +async function markTaskBillingAccepted(taskId) { + await pool.query( + "UPDATE credit_ledger SET status = 'charged', updated_at = NOW() WHERE task_id = $1 AND status = 'reserved'", + [taskId], + ); +} + function persistTaskResultUrlToOssInBackground(task) { if (!task?.id || !task?.result_url) return; @@ -641,9 +665,22 @@ function extractErrorMessage(json, fallback) { } async function fetchJson(url, headers) { - const res = await fetch(url, { method: "GET", headers }); - if (!res.ok) return { ok: false, json: null }; - return { ok: true, json: await res.json() }; + const controller = new AbortController(); + const timeoutMs = Number.isFinite(PROVIDER_POLL_REQUEST_TIMEOUT_MS) && PROVIDER_POLL_REQUEST_TIMEOUT_MS > 0 + ? PROVIDER_POLL_REQUEST_TIMEOUT_MS + : 25_000; + const timer = setTimeout(() => controller.abort(), timeoutMs); + timer.unref?.(); + + try { + const res = await fetch(url, { method: "GET", headers, signal: controller.signal }); + if (!res.ok) return { ok: false, json: null }; + return { ok: true, json: await res.json() }; + } catch (err) { + return { ok: false, json: null, error: err }; + } finally { + clearTimeout(timer); + } } async function pollGrsaiImage(_taskId, providerTaskId, apiKey, baseUrl, resultEndpoint) { @@ -813,26 +850,31 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig, } let attempts = 0; + let polling = false; + let skippedPolls = 0; const maxPollAttempts = getMaxPollAttempts(type, providerConfig); const interval = setInterval(async () => { - attempts++; - if (attempts > maxPollAttempts) { - clearInterval(interval); - activePollers.delete(taskDbId); - if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {}); - if (typeof onTaskFailed === "function") { - const handled = await onTaskFailed("Task timed out").catch((fallbackErr) => { - console.error(`[aiTaskWorker] fallback error for task ${taskDbId}:`, fallbackErr.message); - return false; - }); - if (handled) return; - } - await updateTaskInDb(taskDbId, { status: "failed", error: "Task timed out" }); - await clearPollerState(taskDbId).catch(() => {}); - return; - } + if (polling) return; + polling = true; try { + if (attempts >= maxPollAttempts) { + clearInterval(interval); + activePollers.delete(taskDbId); + if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {}); + if (typeof onTaskFailed === "function") { + await clearPollerState(taskDbId).catch(() => {}); + const handled = await onTaskFailed("Task timed out").catch((fallbackErr) => { + console.error(`[aiTaskWorker] fallback error for task ${taskDbId}:`, fallbackErr.message); + return false; + }); + if (handled) return; + } + await updateTaskInDb(taskDbId, { status: "failed", error: "Task timed out" }); + await clearPollerState(taskDbId).catch(() => {}); + return; + } + // Check if task was cancelled by user const { rows: [taskRow] } = await pool.query("SELECT status FROM generation_tasks WHERE id = $1", [taskDbId]); if (!taskRow || taskRow.status === "cancelled") { @@ -844,15 +886,29 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig, } await refreshPollerHeartbeat(taskDbId).catch(() => {}); - let result; - if (type === "image") { - if (providerConfig.transport === "dashscope-image") { - result = await pollDashscopeImage(taskDbId, providerTaskId, apiKey); - } else { - result = await pollGrsaiImage(taskDbId, providerTaskId, apiKey, providerConfig.baseUrl, providerConfig.resultEndpoint || "/result"); + const pollOutcome = await withProviderPollSlot(taskDbId, async () => { + attempts++; + if (type === "image") { + if (providerConfig.transport === "dashscope-image") { + return pollDashscopeImage(taskDbId, providerTaskId, apiKey); + } + return pollGrsaiImage(taskDbId, providerTaskId, apiKey, providerConfig.baseUrl, providerConfig.resultEndpoint || "/result"); } - } else { - result = await pollVideoTask(taskDbId, providerTaskId, apiKey, providerConfig); + return pollVideoTask(taskDbId, providerTaskId, apiKey, providerConfig); + }); + + if (!pollOutcome.acquired) { + skippedPolls++; + if (skippedPolls % 20 === 0) { + console.info(`[aiTaskWorker] task ${taskDbId} waiting for provider poll slot (skipped=${skippedPolls})`); + } + return; + } + + skippedPolls = 0; + const result = pollOutcome.value; + if (!result) { + return; } if (result.status === "completed" || result.status === "failed") { @@ -860,6 +916,7 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig, activePollers.delete(taskDbId); if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {}); if (result.status === "failed" && typeof onTaskFailed === "function") { + await clearPollerState(taskDbId).catch(() => {}); const handled = await onTaskFailed(result.error || "Task failed").catch((fallbackErr) => { console.error(`[aiTaskWorker] fallback error for task ${taskDbId}:`, fallbackErr.message); return false; @@ -874,6 +931,8 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig, } } catch (err) { console.error(`[aiTaskWorker] poll error for task ${taskDbId}:`, err.message); + } finally { + polling = false; } }, POLL_INTERVAL_MS); @@ -921,7 +980,7 @@ async function recoverRunnablePollers() { const staleInterval = `${Math.max(5, Math.ceil(POLLER_OWNER_STALE_MS / 1000))} seconds`; const { rows } = await pool.query( ` - SELECT p.task_id + SELECT p.task_id, p.updated_at FROM generation_task_pollers p JOIN generation_tasks t ON t.id = p.task_id WHERE t.status IN ('pending', 'running') @@ -944,6 +1003,7 @@ async function recoverRunnablePollers() { const apiKey = await getLeaseKey(poller.lease_token); if (apiKey == null) { console.warn(`[aiTaskWorker] cannot recover task ${taskId}: active lease not found`); + await releaseUnrecoverableTask(taskId, "任务执行状态已失效,已自动释放"); continue; } @@ -955,11 +1015,51 @@ async function recoverRunnablePollers() { providerConfig: parseProviderConfig(poller.provider_config_json), leaseToken: poller.lease_token, keyManager: require("./keyManager"), + onTaskFailed: async (failureMessage) => { + await updateTaskInDb(taskId, { status: "failed", error: failureMessage || "Task failed" }); + return true; + }, skipPersist: true, }); } } +async function releaseUnrecoverableTask(taskId, message) { + const { rows } = await pool.query( + ` + UPDATE generation_tasks t + SET status = 'failed', error = $2, completed_at = NOW(), updated_at = NOW() + FROM generation_task_pollers p + WHERE t.id = $1 + AND p.task_id = t.id + AND p.owner_id = $3 + AND t.status IN ('pending', 'running') + RETURNING t.* + `, + [taskId, message, POLLER_OWNER_ID], + ); + + const task = rows[0]; + if (!task) return false; + + const leaseToken = await getPersistedLeaseToken(taskId).catch(() => null); + await clearPollerState(taskId).catch(() => {}); + if (leaseToken) { + await require("./keyManager").releaseKey(leaseToken).catch((err) => { + console.error(`[aiTaskWorker] failed to release lease for unrecoverable task ${taskId}:`, err.message); + }); + } + await publishTaskEvent(formatTaskEvent(task)); + await createTaskLifecycleNotification(task).catch((err) => { + console.error(`[aiTaskWorker] notification error for unrecoverable task ${taskId}:`, err.message); + }); + await refundTaskBillingOnFailure(taskId).catch((err) => { + console.error(`[aiTaskWorker] refund error for unrecoverable task ${taskId}:`, err.message); + }); + console.warn(`[aiTaskWorker] released unrecoverable task ${taskId}: ${message}`); + return true; +} + // --- Periodic stale task cleanup --- // Runs every 5 minutes, marks tasks stuck in 'pending'/'running' for too long as 'failed'. // This catches cases where the worker crashed, the provider API never responded, @@ -971,26 +1071,32 @@ async function runStaleTaskCleanup() { try { const { rows } = await pool.query( `UPDATE generation_tasks - SET status = 'failed', error = '任务超时自动释放', updated_at = NOW() + SET status = 'failed', error = '任务超时自动释放', completed_at = NOW(), updated_at = NOW() WHERE status IN ('pending', 'running') AND GREATEST(updated_at, COALESCE(last_poll_at, created_at)) < NOW() - INTERVAL '10 minutes' - RETURNING id`, + RETURNING *`, ); for (const row of rows) { - await publishTaskEvent({ - taskId: row.id, - status: "failed", - progress: null, - resultUrl: null, - error: "任务超时自动释放", - }); // Also stop any active poller for this task const poller = activePollers.get(row.id); if (poller) { clearInterval(poller.interval); activePollers.delete(row.id); } + const leaseToken = poller?.leaseToken || await getPersistedLeaseToken(row.id).catch(() => null); await clearPollerState(row.id).catch(() => {}); + if (leaseToken) { + await require("./keyManager").releaseKey(leaseToken).catch((err) => { + console.error(`[aiTaskWorker] failed to release lease for stale task ${row.id}:`, err.message); + }); + } + await publishTaskEvent(formatTaskEvent(row)); + await createTaskLifecycleNotification(row).catch((err) => { + console.error(`[aiTaskWorker] notification error for stale task ${row.id}:`, err.message); + }); + await refundTaskBillingOnFailure(row.id).catch((err) => { + console.error(`[aiTaskWorker] refund error for stale task ${row.id}:`, err.message); + }); } if (rows.length > 0) { console.log(`[aiTaskWorker] Cleaned up ${rows.length} stale task(s)`); @@ -1064,10 +1170,14 @@ function startStaleTaskCleanup() { if (staleTaskCleanupTimer) return; staleTaskCleanupTimer = setInterval(runStaleTaskCleanup, STALE_TASK_CLEANUP_INTERVAL_MS); // Run once shortly after startup - setTimeout(runStaleTaskCleanup, 10_000); + staleTaskCleanupStartupTimer = setTimeout(runStaleTaskCleanup, 10_000); } function stopStaleTaskCleanup() { + if (staleTaskCleanupStartupTimer) { + clearTimeout(staleTaskCleanupStartupTimer); + staleTaskCleanupStartupTimer = null; + } if (staleTaskCleanupTimer) { clearInterval(staleTaskCleanupTimer); staleTaskCleanupTimer = null; @@ -1093,9 +1203,20 @@ function stopPollerRecovery() { } } +async function stopAllPollers() { + for (const [taskId, poller] of activePollers.entries()) { + clearInterval(poller.interval); + activePollers.delete(taskId); + } + await orphanOwnedPollerState().catch((err) => { + console.error("[aiTaskWorker] failed to orphan owned poller state:", err.message); + }); +} + module.exports = { startPolling, stopPolling, + stopAllPollers, cancelTaskRuntimeState, updateTaskInDb, getActiveCount, diff --git a/src/index.js b/src/index.js index bc6b3cd..b01200e 100644 --- a/src/index.js +++ b/src/index.js @@ -3,8 +3,17 @@ const express = require('express') const rateLimit = require('express-rate-limit') const cors = require('cors') const helmet = require('helmet') -const { startSettlementWorker } = require('./settlementWorker') -const { startProviderHealthMonitor } = require('./providerHealthMonitor') +const { startSettlementWorker, stopSettlementWorker } = require('./settlementWorker') +const { startProviderHealthMonitor, stopProviderHealthMonitor } = require('./providerHealthMonitor') +const { + startStaleTaskCleanup, + startTaskEventListener, + startPollerRecovery, + stopStaleTaskCleanup, + stopTaskEventListener, + stopPollerRecovery, + stopAllPollers, +} = require('./aiTaskWorker') const { ensureDatabase } = require('./dbSetup') const { assertRuntimeSecurityConfig } = require('./securityConfig') const { loadPriceCache } = require('./pricing') @@ -17,6 +26,7 @@ const PORT = Number(process.env.PORT) || 3600 const HOST = process.env.HOST || '0.0.0.0' const IS_PRODUCTION = process.env.NODE_ENV === 'production' let server = null +let staleLeaseCleanupTimer = null // CORS: in production, require explicit allowlist; in dev, allow all with credentials function buildCorsOptions() { @@ -133,18 +143,18 @@ async function main() { // Periodic stale lease cleanup (every 5 min) const { cleanStaleLeases } = require('./keyManager') - setInterval(() => { + staleLeaseCleanupTimer = setInterval(() => { cleanStaleLeases().then((cleaned) => { if (cleaned > 0) console.log(`[cleanup] Released ${cleaned} stale lease(s)`) }).catch((err) => { console.error('[cleanup] error:', err) }) }, 5 * 60 * 1000) + if (staleLeaseCleanupTimer.unref) staleLeaseCleanupTimer.unref() startSettlementWorker() startProviderHealthMonitor() - const { startStaleTaskCleanup, startTaskEventListener, startPollerRecovery } = require('./aiTaskWorker') await startTaskEventListener() startPollerRecovery() startStaleTaskCleanup() @@ -175,32 +185,47 @@ process.on('uncaughtException', (err) => { // ── Graceful shutdown ─────────────────────────────────────────────────── let shuttingDown = false -function gracefulShutdown(signal) { +async function shutdownRuntimeState() { + if (staleLeaseCleanupTimer) { + clearInterval(staleLeaseCleanupTimer) + staleLeaseCleanupTimer = null + } + stopSettlementWorker() + stopProviderHealthMonitor() + stopPollerRecovery() + stopStaleTaskCleanup() + await Promise.allSettled([stopTaskEventListener(), stopAllPollers()]) +} + +function closeServer() { + if (!server || !server.listening) return Promise.resolve() + return new Promise((resolve) => { + server.close(() => { + console.log('[shutdown] Server closed, cleaning up...') + resolve() + }) + }) +} + +async function gracefulShutdown(signal) { if (shuttingDown) return shuttingDown = true console.log('[shutdown] Received ' + signal + ', draining connections...') - if (server && server.listening) { - server.close(() => { - console.log('[shutdown] Server closed, cleaning up...') - const { stopProviderHealthMonitor } = require('./providerHealthMonitor') - stopProviderHealthMonitor() - const { stopTaskEventListener, stopPollerRecovery } = require('./aiTaskWorker') - stopPollerRecovery() - void stopTaskEventListener() - const { pool } = require('./db') - pool.end().then(() => { - console.log('[shutdown] Database pool closed') - process.exit(0) - }).catch(() => process.exit(0)) - }) + setTimeout(() => { + console.error('[shutdown] Forced exit after timeout') + process.exit(1) + }, 15000).unref() - // Force exit after timeout - setTimeout(() => { - console.error('[shutdown] Forced exit after timeout') - process.exit(1) - }, 15000).unref() - } else { + try { + await shutdownRuntimeState() + await closeServer() + const { pool } = require('./db') + await pool.end() + console.log('[shutdown] Database pool closed') + process.exit(0) + } catch (err) { + console.error('[shutdown] error:', err) process.exit(0) } } diff --git a/src/keyManager.js b/src/keyManager.js index 6f9b49e..4d30cbd 100644 --- a/src/keyManager.js +++ b/src/keyManager.js @@ -284,7 +284,7 @@ async function releaseLeaseInternal(leaseToken, user, options = {}) { const { rows } = await client.query( ` WITH candidate AS ( - SELECT l.id, l.key_id, k.provider + SELECT l.id, l.key_id, l.user_id, l.enterprise_id, k.provider FROM key_leases l JOIN api_keys k ON k.id = l.key_id WHERE l.lease_token = $1 AND l.released_at IS NULL @@ -297,7 +297,7 @@ async function releaseLeaseInternal(leaseToken, user, options = {}) { WHERE id = (SELECT id FROM candidate) RETURNING id, key_id ) - SELECT r.id, r.key_id, c.provider + SELECT r.id, r.key_id, c.user_id AS lease_user_id, c.enterprise_id AS lease_enterprise_id, c.provider FROM released r JOIN candidate c ON c.key_id = r.key_id `, @@ -339,7 +339,13 @@ async function releaseLeaseInternal(leaseToken, user, options = {}) { INSERT INTO usage_logs (user_id, enterprise_id, provider, key_id, action) VALUES ($1, $2, (SELECT provider FROM api_keys WHERE id = $3), $4, $5) `, - [userId, enterpriseId, lease.key_id, lease.key_id, "release"], + [ + userId || lease.lease_user_id, + enterpriseId || lease.lease_enterprise_id, + lease.key_id, + lease.key_id, + "release", + ], ); return { diff --git a/src/providerPollLimiter.js b/src/providerPollLimiter.js new file mode 100644 index 0000000..2de5a8f --- /dev/null +++ b/src/providerPollLimiter.js @@ -0,0 +1,120 @@ +"use strict"; + +const crypto = require("node:crypto"); +const { pool } = require("./db"); + +const DEFAULT_MAX_CONCURRENCY = 8; +const DEFAULT_SLOT_TTL_MS = 30_000; +const POLL_SCOPE = "generation-provider-poll:global"; +const OWNER_ID = `${process.pid}-${crypto.randomUUID()}`; + +let storeReady = null; + +function normalizePositiveInteger(value, fallback) { + const numeric = Number(value); + if (!Number.isFinite(numeric) || numeric <= 0) return fallback; + return Math.max(1, Math.trunc(numeric)); +} + +function getMaxConcurrency() { + return normalizePositiveInteger(process.env.TASK_PROVIDER_POLL_MAX_CONCURRENCY, DEFAULT_MAX_CONCURRENCY); +} + +function getSlotTtlInterval() { + const ttlMs = normalizePositiveInteger(process.env.TASK_PROVIDER_POLL_SLOT_TTL_MS, DEFAULT_SLOT_TTL_MS); + return `${Math.max(1, Math.ceil(ttlMs / 1000))} seconds`; +} + +async function ensureProviderPollLimiterStore() { + if (storeReady) return storeReady; + storeReady = pool.query(` + CREATE TABLE IF NOT EXISTS generation_provider_poll_slots ( + scope TEXT NOT NULL, + slot_no INTEGER NOT NULL, + owner_id TEXT NOT NULL, + task_id INTEGER, + expires_at TIMESTAMPTZ NOT NULL, + acquired_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + PRIMARY KEY (scope, slot_no) + ); + CREATE INDEX IF NOT EXISTS idx_generation_provider_poll_slots_expires + ON generation_provider_poll_slots(expires_at); + `).catch((err) => { + storeReady = null; + throw err; + }); + return storeReady; +} + +async function acquireProviderPollSlot(taskId, options = {}) { + await ensureProviderPollLimiterStore(); + + const scope = options.scope || POLL_SCOPE; + const maxConcurrency = normalizePositiveInteger(options.maxConcurrency, getMaxConcurrency()); + const ttlInterval = options.ttlInterval || getSlotTtlInterval(); + const { rows } = await pool.query( + ` + WITH candidate AS ( + SELECT s.slot_no + FROM generate_series(1, $2::integer) AS s(slot_no) + LEFT JOIN generation_provider_poll_slots l + ON l.scope = $1 AND l.slot_no = s.slot_no + WHERE l.scope IS NULL OR l.expires_at < NOW() + ORDER BY s.slot_no ASC + LIMIT 1 + ), + claimed AS ( + INSERT INTO generation_provider_poll_slots ( + scope, slot_no, owner_id, task_id, expires_at, acquired_at, updated_at + ) + SELECT $1, slot_no, $3, $4, NOW() + ($5::text)::interval, NOW(), NOW() + FROM candidate + ON CONFLICT (scope, slot_no) DO UPDATE SET + owner_id = EXCLUDED.owner_id, + task_id = EXCLUDED.task_id, + expires_at = EXCLUDED.expires_at, + acquired_at = NOW(), + updated_at = NOW() + WHERE generation_provider_poll_slots.expires_at < NOW() + RETURNING scope, slot_no + ) + SELECT scope, slot_no FROM claimed + `, + [scope, maxConcurrency, OWNER_ID, taskId || null, ttlInterval], + ); + + const slot = rows[0]; + return slot ? { scope: slot.scope, slotNo: slot.slot_no, ownerId: OWNER_ID } : null; +} + +async function releaseProviderPollSlot(slot) { + if (!slot?.scope || !slot?.slotNo) return; + await ensureProviderPollLimiterStore(); + await pool.query( + "DELETE FROM generation_provider_poll_slots WHERE scope = $1 AND slot_no = $2 AND owner_id = $3", + [slot.scope, slot.slotNo, slot.ownerId || OWNER_ID], + ); +} + +async function withProviderPollSlot(taskId, fn, options = {}) { + const slot = await acquireProviderPollSlot(taskId, options); + if (!slot) return { acquired: false, value: undefined }; + + try { + return { acquired: true, value: await fn() }; + } finally { + await releaseProviderPollSlot(slot).catch((err) => { + console.error(`[providerPollLimiter] failed to release poll slot ${slot.scope}:${slot.slotNo}:`, err.message); + }); + } +} + +module.exports = { + acquireProviderPollSlot, + ensureProviderPollLimiterStore, + getMaxConcurrency, + normalizePositiveInteger, + releaseProviderPollSlot, + withProviderPollSlot, +}; diff --git a/src/routes/ai.js b/src/routes/ai.js index 3c5e012..d71cf4e 100644 --- a/src/routes/ai.js +++ b/src/routes/ai.js @@ -1072,6 +1072,16 @@ function registerAiRoutes(router) { error.costCents = billingResult.costCents; throw error; } + if (billingResult.costCents > 0) { + await client.query( + "UPDATE generation_tasks SET cost_cents = $1, billing_target = $2, billing_refunded = 0, updated_at = NOW() WHERE id = $3", + [ + billingResult.costCents, + billingResult.deductionType === "enterprise_image_flat" ? "enterprise_image" : "user", + nextTaskRow.id, + ], + ); + } return { taskRow: nextTaskRow, imageBilling: billingResult }; }); const preauth = { authorized: true, estimatedCostCents: 0, billingMode: imageBilling.deductionType }; @@ -1086,9 +1096,11 @@ function registerAiRoutes(router) { }, providerDebug: buildImageProviderDebug(model), }); - submitImageWithProviderFallback(taskRow.id, providerCandidates, req.user, preauth, params).catch((err) => { + submitImageWithProviderFallback(taskRow.id, providerCandidates, req.user, preauth, params).catch(async (err) => { console.error("[ai/image] submit error:", err.message); - updateTaskInDb(taskRow.id, { status: "failed", error: err.message }); + await updateTaskInDb(taskRow.id, { status: "failed", error: err.message }).catch((updateErr) => { + console.error(`[ai/image] failed to persist task ${taskRow.id} failure:`, updateErr.message); + }); }); } catch (err) { console.error("[ai/image] error:", err.message); @@ -1200,6 +1212,10 @@ function registerAiRoutes(router) { ...enterpriseBilling, taskId: nextTaskRow.id, }); + await client.query( + "UPDATE generation_tasks SET cost_cents = $1, billing_target = 'enterprise_video', billing_refunded = 0, updated_at = NOW() WHERE id = $2", + [nextBilling.amountCents, nextTaskRow.id], + ); return { taskRow: nextTaskRow, reservedBilling: nextBilling, regularBilling: null }; } // Regular user: deduct from personal balance @@ -1222,6 +1238,10 @@ function registerAiRoutes(router) { "INSERT INTO transactions (user_id, type, amount_cents, balance_after_cents, description) VALUES ($1, 'deduct', $2, $3, $4)", [req.user.id, -costCents, deducted.balance_cents, `视频生成扣费 ${credits} 积分`], ); + await client.query( + "UPDATE generation_tasks SET cost_cents = $1, billing_target = 'user', billing_refunded = 0, updated_at = NOW() WHERE id = $2", + [costCents, nextTaskRow.id], + ); return { taskRow: nextTaskRow, reservedBilling: null, regularBilling: { costCents, balanceAfterCents: deducted.balance_cents, credits } }; });