fix: harden provider polling recovery

This commit is contained in:
2026-06-09 11:32:53 +08:00
parent 1166811ee4
commit f9da506017
9 changed files with 539 additions and 69 deletions
+6
View File
@@ -20,6 +20,12 @@ JWT_EXPIRES_IN=7d
# Connection pool # Connection pool
PG_POOL_MAX=10 PG_POOL_MAX=10
# Provider polling reliability
# Shared across PM2 workers through Postgres-backed poll slots.
TASK_PROVIDER_POLL_MAX_CONCURRENCY=8
TASK_PROVIDER_POLL_SLOT_TTL_MS=30000
TASK_PROVIDER_POLL_REQUEST_TIMEOUT_MS=25000
# CORS (comma separated allowed origins, * for all) # CORS (comma separated allowed origins, * for all)
CORS_ORIGINS=* CORS_ORIGINS=*
+4 -1
View File
@@ -14,7 +14,10 @@
"audit-routes": "node src/cli/auditModelRoutes.js", "audit-routes": "node src/cli/auditModelRoutes.js",
"import-config": "node src/cli/importConfig.js", "import-config": "node src/cli/importConfig.js",
"init-pools": "node src/cli/initPools.js", "init-pools": "node src/cli/initPools.js",
"test:community-routes": "node scripts/communityRouteContract.test.js" "test:community-routes": "node scripts/communityRouteContract.test.js",
"test:key-manager": "node scripts/keyManagerReleaseContract.test.js",
"test:provider-poll-limiter": "node scripts/providerPollLimiterContract.test.js",
"test": "npm run test:community-routes && npm run test:key-manager && npm run test:provider-poll-limiter"
}, },
"dependencies": { "dependencies": {
"alipay-sdk": "^4.14.0", "alipay-sdk": "^4.14.0",
+73
View File
@@ -0,0 +1,73 @@
const assert = require("node:assert/strict");
const { createRequire } = require("node:module");
const nodeRequire = createRequire(__filename);
function loadKeyManagerWithPool(pool) {
const dbPath = nodeRequire.resolve("../src/db");
const keyManagerPath = nodeRequire.resolve("../src/keyManager");
const originalDbModule = nodeRequire.cache[dbPath];
const originalKeyManagerModule = nodeRequire.cache[keyManagerPath];
delete nodeRequire.cache[keyManagerPath];
nodeRequire.cache[dbPath] = {
id: dbPath,
filename: dbPath,
loaded: true,
exports: {
pool,
withTransaction: async (fn) => fn(pool),
},
};
return {
keyManager: nodeRequire("../src/keyManager"),
restore() {
delete nodeRequire.cache[keyManagerPath];
if (originalKeyManagerModule) nodeRequire.cache[keyManagerPath] = originalKeyManagerModule;
if (originalDbModule) nodeRequire.cache[dbPath] = originalDbModule;
else delete nodeRequire.cache[dbPath];
},
};
}
function createReleasePool() {
const calls = [];
return {
calls,
async query(sql, params) {
calls.push({ sql, params });
if (/WITH candidate AS/i.test(sql)) {
return {
rows: [{
id: 10,
key_id: 20,
lease_user_id: 30,
lease_enterprise_id: 40,
provider: "dashscope",
}],
};
}
if (/UPDATE api_keys SET active_count/i.test(sql)) return { rows: [] };
if (/INSERT INTO usage_logs/i.test(sql)) return { rows: [] };
throw new Error(`Unexpected SQL: ${sql}`);
},
};
}
(async () => {
const pool = createReleasePool();
const { keyManager, restore } = loadKeyManagerWithPool(pool);
try {
const result = await keyManager.releaseKey("lease-token-without-user-context");
assert.equal(result.released, true);
const usageLogCall = pool.calls.find((call) => /INSERT INTO usage_logs/i.test(call.sql));
assert.deepEqual(usageLogCall.params, [30, 40, 20, 20, "release"]);
} finally {
restore();
}
})().catch((error) => {
console.error(error);
process.exitCode = 1;
});
@@ -0,0 +1,96 @@
const assert = require("node:assert/strict");
const { createRequire } = require("node:module");
const nodeRequire = createRequire(__filename);
function loadLimiterWithPool(pool) {
const dbPath = nodeRequire.resolve("../src/db");
const limiterPath = nodeRequire.resolve("../src/providerPollLimiter");
const originalDbModule = nodeRequire.cache[dbPath];
const originalLimiterModule = nodeRequire.cache[limiterPath];
delete nodeRequire.cache[limiterPath];
nodeRequire.cache[dbPath] = {
id: dbPath,
filename: dbPath,
loaded: true,
exports: { pool },
};
return {
limiter: nodeRequire("../src/providerPollLimiter"),
restore() {
delete nodeRequire.cache[limiterPath];
if (originalLimiterModule) nodeRequire.cache[limiterPath] = originalLimiterModule;
if (originalDbModule) nodeRequire.cache[dbPath] = originalDbModule;
else delete nodeRequire.cache[dbPath];
},
};
}
function createPool(options = {}) {
const calls = [];
return {
calls,
async query(sql, params = []) {
calls.push({ sql, params });
if (/CREATE TABLE IF NOT EXISTS generation_provider_poll_slots/i.test(sql)) return { rows: [] };
if (/WITH candidate AS/i.test(sql)) {
if (options.noAvailableSlot) return { rows: [] };
return { rows: [{ scope: params[0], slot_no: 2 }] };
}
if (/DELETE FROM generation_provider_poll_slots/i.test(sql)) return { rows: [] };
throw new Error(`Unexpected SQL: ${sql}`);
},
};
}
(async () => {
const previousLimit = process.env.TASK_PROVIDER_POLL_MAX_CONCURRENCY;
process.env.TASK_PROVIDER_POLL_MAX_CONCURRENCY = "3";
const pool = createPool();
const { limiter, restore } = loadLimiterWithPool(pool);
try {
const outcome = await limiter.withProviderPollSlot(101, async () => "polled");
assert.equal(outcome.acquired, true);
assert.equal(outcome.value, "polled");
const acquireCall = pool.calls.find((call) => /WITH candidate AS/i.test(call.sql));
assert.equal(acquireCall.params[1], 3);
assert.equal(acquireCall.params[3], 101);
const releaseCall = pool.calls.find((call) => /DELETE FROM generation_provider_poll_slots/i.test(call.sql));
assert.equal(releaseCall.params[0], acquireCall.params[0]);
assert.equal(releaseCall.params[1], 2);
assert.equal(releaseCall.params[2], acquireCall.params[2]);
} finally {
if (previousLimit === undefined) delete process.env.TASK_PROVIDER_POLL_MAX_CONCURRENCY;
else process.env.TASK_PROVIDER_POLL_MAX_CONCURRENCY = previousLimit;
restore();
}
const saturatedPool = createPool({ noAvailableSlot: true });
const { limiter: saturatedLimiter, restore: restoreSaturated } = loadLimiterWithPool(saturatedPool);
try {
let called = false;
const outcome = await saturatedLimiter.withProviderPollSlot(202, async () => {
called = true;
return "should-not-run";
});
assert.equal(outcome.acquired, false);
assert.equal(outcome.value, undefined);
assert.equal(called, false);
assert.equal(
saturatedPool.calls.some((call) => /DELETE FROM generation_provider_poll_slots/i.test(call.sql)),
false,
);
} finally {
restoreSaturated();
}
})().catch((error) => {
console.error(error);
process.exitCode = 1;
});
+142 -21
View File
@@ -5,6 +5,7 @@ const { EventEmitter } = require("node:events");
const { pool } = require("./db"); const { pool } = require("./db");
const { refundTaskBillingOnFailure } = require("./billing"); const { refundTaskBillingOnFailure } = require("./billing");
const { putObject, isOssConfigured } = require("./ossClient"); const { putObject, isOssConfigured } = require("./ossClient");
const { withProviderPollSlot } = require("./providerPollLimiter");
const taskEvents = new EventEmitter(); const taskEvents = new EventEmitter();
taskEvents.setMaxListeners(200); taskEvents.setMaxListeners(200);
@@ -18,10 +19,12 @@ const TASK_EVENT_ORIGIN = `${process.pid}-${crypto.randomUUID()}`;
const POLLER_OWNER_ID = `${process.pid}-${crypto.randomUUID()}`; const POLLER_OWNER_ID = `${process.pid}-${crypto.randomUUID()}`;
const POLLER_OWNER_STALE_MS = Number(process.env.TASK_POLLER_OWNER_STALE_MS || 20_000); const POLLER_OWNER_STALE_MS = Number(process.env.TASK_POLLER_OWNER_STALE_MS || 20_000);
const POLLER_RECOVERY_INTERVAL_MS = Number(process.env.TASK_POLLER_RECOVERY_INTERVAL_MS || 30_000); const POLLER_RECOVERY_INTERVAL_MS = Number(process.env.TASK_POLLER_RECOVERY_INTERVAL_MS || 30_000);
const PROVIDER_POLL_REQUEST_TIMEOUT_MS = Number(process.env.TASK_PROVIDER_POLL_REQUEST_TIMEOUT_MS || 25_000);
let taskEventListenerClient = null; let taskEventListenerClient = null;
let taskEventListenerStarting = null; let taskEventListenerStarting = null;
let pollerStoreReady = null; let pollerStoreReady = null;
let pollerRecoveryTimer = null; let pollerRecoveryTimer = null;
let staleTaskCleanupStartupTimer = null;
function normalizeTaskProgress(value) { function normalizeTaskProgress(value) {
const numeric = Number(value); const numeric = Number(value);
@@ -152,6 +155,14 @@ async function clearPollerState(taskDbId) {
await pool.query("DELETE FROM generation_task_pollers WHERE task_id = $1", [taskDbId]); await pool.query("DELETE FROM generation_task_pollers WHERE task_id = $1", [taskDbId]);
} }
async function orphanOwnedPollerState() {
await ensureTaskPollerStore();
await pool.query(
"UPDATE generation_task_pollers SET owner_id = NULL, owner_heartbeat_at = NULL, updated_at = NOW() WHERE owner_id = $1",
[POLLER_OWNER_ID],
);
}
async function getPersistedLeaseToken(taskDbId) { async function getPersistedLeaseToken(taskDbId) {
await ensureTaskPollerStore(); await ensureTaskPollerStore();
const { rows } = await pool.query( const { rows } = await pool.query(
@@ -280,6 +291,12 @@ async function updateTaskInDb(taskId, updates) {
}); });
} }
if (nextUpdates.status === "completed") {
await markTaskBillingAccepted(taskId).catch((err) => {
console.error(`[aiTaskWorker] billing accept error for task ${taskId}:`, err.message);
});
}
if (nextUpdates.status === "failed") { if (nextUpdates.status === "failed") {
await refundTaskBillingOnFailure(taskId).catch((err) => { await refundTaskBillingOnFailure(taskId).catch((err) => {
console.error(`[aiTaskWorker] refund error for task ${taskId}:`, err.message); console.error(`[aiTaskWorker] refund error for task ${taskId}:`, err.message);
@@ -287,6 +304,13 @@ async function updateTaskInDb(taskId, updates) {
} }
} }
async function markTaskBillingAccepted(taskId) {
await pool.query(
"UPDATE credit_ledger SET status = 'charged', updated_at = NOW() WHERE task_id = $1 AND status = 'reserved'",
[taskId],
);
}
function persistTaskResultUrlToOssInBackground(task) { function persistTaskResultUrlToOssInBackground(task) {
if (!task?.id || !task?.result_url) return; if (!task?.id || !task?.result_url) return;
@@ -641,9 +665,22 @@ function extractErrorMessage(json, fallback) {
} }
async function fetchJson(url, headers) { async function fetchJson(url, headers) {
const res = await fetch(url, { method: "GET", headers }); const controller = new AbortController();
const timeoutMs = Number.isFinite(PROVIDER_POLL_REQUEST_TIMEOUT_MS) && PROVIDER_POLL_REQUEST_TIMEOUT_MS > 0
? PROVIDER_POLL_REQUEST_TIMEOUT_MS
: 25_000;
const timer = setTimeout(() => controller.abort(), timeoutMs);
timer.unref?.();
try {
const res = await fetch(url, { method: "GET", headers, signal: controller.signal });
if (!res.ok) return { ok: false, json: null }; if (!res.ok) return { ok: false, json: null };
return { ok: true, json: await res.json() }; return { ok: true, json: await res.json() };
} catch (err) {
return { ok: false, json: null, error: err };
} finally {
clearTimeout(timer);
}
} }
async function pollGrsaiImage(_taskId, providerTaskId, apiKey, baseUrl, resultEndpoint) { async function pollGrsaiImage(_taskId, providerTaskId, apiKey, baseUrl, resultEndpoint) {
@@ -813,14 +850,20 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig,
} }
let attempts = 0; let attempts = 0;
let polling = false;
let skippedPolls = 0;
const maxPollAttempts = getMaxPollAttempts(type, providerConfig); const maxPollAttempts = getMaxPollAttempts(type, providerConfig);
const interval = setInterval(async () => { const interval = setInterval(async () => {
attempts++; if (polling) return;
if (attempts > maxPollAttempts) { polling = true;
try {
if (attempts >= maxPollAttempts) {
clearInterval(interval); clearInterval(interval);
activePollers.delete(taskDbId); activePollers.delete(taskDbId);
if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {}); if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {});
if (typeof onTaskFailed === "function") { if (typeof onTaskFailed === "function") {
await clearPollerState(taskDbId).catch(() => {});
const handled = await onTaskFailed("Task timed out").catch((fallbackErr) => { const handled = await onTaskFailed("Task timed out").catch((fallbackErr) => {
console.error(`[aiTaskWorker] fallback error for task ${taskDbId}:`, fallbackErr.message); console.error(`[aiTaskWorker] fallback error for task ${taskDbId}:`, fallbackErr.message);
return false; return false;
@@ -832,7 +875,6 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig,
return; return;
} }
try {
// Check if task was cancelled by user // Check if task was cancelled by user
const { rows: [taskRow] } = await pool.query("SELECT status FROM generation_tasks WHERE id = $1", [taskDbId]); const { rows: [taskRow] } = await pool.query("SELECT status FROM generation_tasks WHERE id = $1", [taskDbId]);
if (!taskRow || taskRow.status === "cancelled") { if (!taskRow || taskRow.status === "cancelled") {
@@ -844,15 +886,29 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig,
} }
await refreshPollerHeartbeat(taskDbId).catch(() => {}); await refreshPollerHeartbeat(taskDbId).catch(() => {});
let result; const pollOutcome = await withProviderPollSlot(taskDbId, async () => {
attempts++;
if (type === "image") { if (type === "image") {
if (providerConfig.transport === "dashscope-image") { if (providerConfig.transport === "dashscope-image") {
result = await pollDashscopeImage(taskDbId, providerTaskId, apiKey); return pollDashscopeImage(taskDbId, providerTaskId, apiKey);
} else {
result = await pollGrsaiImage(taskDbId, providerTaskId, apiKey, providerConfig.baseUrl, providerConfig.resultEndpoint || "/result");
} }
} else { return pollGrsaiImage(taskDbId, providerTaskId, apiKey, providerConfig.baseUrl, providerConfig.resultEndpoint || "/result");
result = await pollVideoTask(taskDbId, providerTaskId, apiKey, providerConfig); }
return pollVideoTask(taskDbId, providerTaskId, apiKey, providerConfig);
});
if (!pollOutcome.acquired) {
skippedPolls++;
if (skippedPolls % 20 === 0) {
console.info(`[aiTaskWorker] task ${taskDbId} waiting for provider poll slot (skipped=${skippedPolls})`);
}
return;
}
skippedPolls = 0;
const result = pollOutcome.value;
if (!result) {
return;
} }
if (result.status === "completed" || result.status === "failed") { if (result.status === "completed" || result.status === "failed") {
@@ -860,6 +916,7 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig,
activePollers.delete(taskDbId); activePollers.delete(taskDbId);
if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {}); if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {});
if (result.status === "failed" && typeof onTaskFailed === "function") { if (result.status === "failed" && typeof onTaskFailed === "function") {
await clearPollerState(taskDbId).catch(() => {});
const handled = await onTaskFailed(result.error || "Task failed").catch((fallbackErr) => { const handled = await onTaskFailed(result.error || "Task failed").catch((fallbackErr) => {
console.error(`[aiTaskWorker] fallback error for task ${taskDbId}:`, fallbackErr.message); console.error(`[aiTaskWorker] fallback error for task ${taskDbId}:`, fallbackErr.message);
return false; return false;
@@ -874,6 +931,8 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig,
} }
} catch (err) { } catch (err) {
console.error(`[aiTaskWorker] poll error for task ${taskDbId}:`, err.message); console.error(`[aiTaskWorker] poll error for task ${taskDbId}:`, err.message);
} finally {
polling = false;
} }
}, POLL_INTERVAL_MS); }, POLL_INTERVAL_MS);
@@ -921,7 +980,7 @@ async function recoverRunnablePollers() {
const staleInterval = `${Math.max(5, Math.ceil(POLLER_OWNER_STALE_MS / 1000))} seconds`; const staleInterval = `${Math.max(5, Math.ceil(POLLER_OWNER_STALE_MS / 1000))} seconds`;
const { rows } = await pool.query( const { rows } = await pool.query(
` `
SELECT p.task_id SELECT p.task_id, p.updated_at
FROM generation_task_pollers p FROM generation_task_pollers p
JOIN generation_tasks t ON t.id = p.task_id JOIN generation_tasks t ON t.id = p.task_id
WHERE t.status IN ('pending', 'running') WHERE t.status IN ('pending', 'running')
@@ -944,6 +1003,7 @@ async function recoverRunnablePollers() {
const apiKey = await getLeaseKey(poller.lease_token); const apiKey = await getLeaseKey(poller.lease_token);
if (apiKey == null) { if (apiKey == null) {
console.warn(`[aiTaskWorker] cannot recover task ${taskId}: active lease not found`); console.warn(`[aiTaskWorker] cannot recover task ${taskId}: active lease not found`);
await releaseUnrecoverableTask(taskId, "任务执行状态已失效,已自动释放");
continue; continue;
} }
@@ -955,11 +1015,51 @@ async function recoverRunnablePollers() {
providerConfig: parseProviderConfig(poller.provider_config_json), providerConfig: parseProviderConfig(poller.provider_config_json),
leaseToken: poller.lease_token, leaseToken: poller.lease_token,
keyManager: require("./keyManager"), keyManager: require("./keyManager"),
onTaskFailed: async (failureMessage) => {
await updateTaskInDb(taskId, { status: "failed", error: failureMessage || "Task failed" });
return true;
},
skipPersist: true, skipPersist: true,
}); });
} }
} }
async function releaseUnrecoverableTask(taskId, message) {
const { rows } = await pool.query(
`
UPDATE generation_tasks t
SET status = 'failed', error = $2, completed_at = NOW(), updated_at = NOW()
FROM generation_task_pollers p
WHERE t.id = $1
AND p.task_id = t.id
AND p.owner_id = $3
AND t.status IN ('pending', 'running')
RETURNING t.*
`,
[taskId, message, POLLER_OWNER_ID],
);
const task = rows[0];
if (!task) return false;
const leaseToken = await getPersistedLeaseToken(taskId).catch(() => null);
await clearPollerState(taskId).catch(() => {});
if (leaseToken) {
await require("./keyManager").releaseKey(leaseToken).catch((err) => {
console.error(`[aiTaskWorker] failed to release lease for unrecoverable task ${taskId}:`, err.message);
});
}
await publishTaskEvent(formatTaskEvent(task));
await createTaskLifecycleNotification(task).catch((err) => {
console.error(`[aiTaskWorker] notification error for unrecoverable task ${taskId}:`, err.message);
});
await refundTaskBillingOnFailure(taskId).catch((err) => {
console.error(`[aiTaskWorker] refund error for unrecoverable task ${taskId}:`, err.message);
});
console.warn(`[aiTaskWorker] released unrecoverable task ${taskId}: ${message}`);
return true;
}
// --- Periodic stale task cleanup --- // --- Periodic stale task cleanup ---
// Runs every 5 minutes, marks tasks stuck in 'pending'/'running' for too long as 'failed'. // Runs every 5 minutes, marks tasks stuck in 'pending'/'running' for too long as 'failed'.
// This catches cases where the worker crashed, the provider API never responded, // This catches cases where the worker crashed, the provider API never responded,
@@ -971,26 +1071,32 @@ async function runStaleTaskCleanup() {
try { try {
const { rows } = await pool.query( const { rows } = await pool.query(
`UPDATE generation_tasks `UPDATE generation_tasks
SET status = 'failed', error = '任务超时自动释放', updated_at = NOW() SET status = 'failed', error = '任务超时自动释放', completed_at = NOW(), updated_at = NOW()
WHERE status IN ('pending', 'running') WHERE status IN ('pending', 'running')
AND GREATEST(updated_at, COALESCE(last_poll_at, created_at)) < NOW() - INTERVAL '10 minutes' AND GREATEST(updated_at, COALESCE(last_poll_at, created_at)) < NOW() - INTERVAL '10 minutes'
RETURNING id`, RETURNING *`,
); );
for (const row of rows) { for (const row of rows) {
await publishTaskEvent({
taskId: row.id,
status: "failed",
progress: null,
resultUrl: null,
error: "任务超时自动释放",
});
// Also stop any active poller for this task // Also stop any active poller for this task
const poller = activePollers.get(row.id); const poller = activePollers.get(row.id);
if (poller) { if (poller) {
clearInterval(poller.interval); clearInterval(poller.interval);
activePollers.delete(row.id); activePollers.delete(row.id);
} }
const leaseToken = poller?.leaseToken || await getPersistedLeaseToken(row.id).catch(() => null);
await clearPollerState(row.id).catch(() => {}); await clearPollerState(row.id).catch(() => {});
if (leaseToken) {
await require("./keyManager").releaseKey(leaseToken).catch((err) => {
console.error(`[aiTaskWorker] failed to release lease for stale task ${row.id}:`, err.message);
});
}
await publishTaskEvent(formatTaskEvent(row));
await createTaskLifecycleNotification(row).catch((err) => {
console.error(`[aiTaskWorker] notification error for stale task ${row.id}:`, err.message);
});
await refundTaskBillingOnFailure(row.id).catch((err) => {
console.error(`[aiTaskWorker] refund error for stale task ${row.id}:`, err.message);
});
} }
if (rows.length > 0) { if (rows.length > 0) {
console.log(`[aiTaskWorker] Cleaned up ${rows.length} stale task(s)`); console.log(`[aiTaskWorker] Cleaned up ${rows.length} stale task(s)`);
@@ -1064,10 +1170,14 @@ function startStaleTaskCleanup() {
if (staleTaskCleanupTimer) return; if (staleTaskCleanupTimer) return;
staleTaskCleanupTimer = setInterval(runStaleTaskCleanup, STALE_TASK_CLEANUP_INTERVAL_MS); staleTaskCleanupTimer = setInterval(runStaleTaskCleanup, STALE_TASK_CLEANUP_INTERVAL_MS);
// Run once shortly after startup // Run once shortly after startup
setTimeout(runStaleTaskCleanup, 10_000); staleTaskCleanupStartupTimer = setTimeout(runStaleTaskCleanup, 10_000);
} }
function stopStaleTaskCleanup() { function stopStaleTaskCleanup() {
if (staleTaskCleanupStartupTimer) {
clearTimeout(staleTaskCleanupStartupTimer);
staleTaskCleanupStartupTimer = null;
}
if (staleTaskCleanupTimer) { if (staleTaskCleanupTimer) {
clearInterval(staleTaskCleanupTimer); clearInterval(staleTaskCleanupTimer);
staleTaskCleanupTimer = null; staleTaskCleanupTimer = null;
@@ -1093,9 +1203,20 @@ function stopPollerRecovery() {
} }
} }
async function stopAllPollers() {
for (const [taskId, poller] of activePollers.entries()) {
clearInterval(poller.interval);
activePollers.delete(taskId);
}
await orphanOwnedPollerState().catch((err) => {
console.error("[aiTaskWorker] failed to orphan owned poller state:", err.message);
});
}
module.exports = { module.exports = {
startPolling, startPolling,
stopPolling, stopPolling,
stopAllPollers,
cancelTaskRuntimeState, cancelTaskRuntimeState,
updateTaskInDb, updateTaskInDb,
getActiveCount, getActiveCount,
+47 -22
View File
@@ -3,8 +3,17 @@ const express = require('express')
const rateLimit = require('express-rate-limit') const rateLimit = require('express-rate-limit')
const cors = require('cors') const cors = require('cors')
const helmet = require('helmet') const helmet = require('helmet')
const { startSettlementWorker } = require('./settlementWorker') const { startSettlementWorker, stopSettlementWorker } = require('./settlementWorker')
const { startProviderHealthMonitor } = require('./providerHealthMonitor') const { startProviderHealthMonitor, stopProviderHealthMonitor } = require('./providerHealthMonitor')
const {
startStaleTaskCleanup,
startTaskEventListener,
startPollerRecovery,
stopStaleTaskCleanup,
stopTaskEventListener,
stopPollerRecovery,
stopAllPollers,
} = require('./aiTaskWorker')
const { ensureDatabase } = require('./dbSetup') const { ensureDatabase } = require('./dbSetup')
const { assertRuntimeSecurityConfig } = require('./securityConfig') const { assertRuntimeSecurityConfig } = require('./securityConfig')
const { loadPriceCache } = require('./pricing') const { loadPriceCache } = require('./pricing')
@@ -17,6 +26,7 @@ const PORT = Number(process.env.PORT) || 3600
const HOST = process.env.HOST || '0.0.0.0' const HOST = process.env.HOST || '0.0.0.0'
const IS_PRODUCTION = process.env.NODE_ENV === 'production' const IS_PRODUCTION = process.env.NODE_ENV === 'production'
let server = null let server = null
let staleLeaseCleanupTimer = null
// CORS: in production, require explicit allowlist; in dev, allow all with credentials // CORS: in production, require explicit allowlist; in dev, allow all with credentials
function buildCorsOptions() { function buildCorsOptions() {
@@ -133,18 +143,18 @@ async function main() {
// Periodic stale lease cleanup (every 5 min) // Periodic stale lease cleanup (every 5 min)
const { cleanStaleLeases } = require('./keyManager') const { cleanStaleLeases } = require('./keyManager')
setInterval(() => { staleLeaseCleanupTimer = setInterval(() => {
cleanStaleLeases().then((cleaned) => { cleanStaleLeases().then((cleaned) => {
if (cleaned > 0) console.log(`[cleanup] Released ${cleaned} stale lease(s)`) if (cleaned > 0) console.log(`[cleanup] Released ${cleaned} stale lease(s)`)
}).catch((err) => { }).catch((err) => {
console.error('[cleanup] error:', err) console.error('[cleanup] error:', err)
}) })
}, 5 * 60 * 1000) }, 5 * 60 * 1000)
if (staleLeaseCleanupTimer.unref) staleLeaseCleanupTimer.unref()
startSettlementWorker() startSettlementWorker()
startProviderHealthMonitor() startProviderHealthMonitor()
const { startStaleTaskCleanup, startTaskEventListener, startPollerRecovery } = require('./aiTaskWorker')
await startTaskEventListener() await startTaskEventListener()
startPollerRecovery() startPollerRecovery()
startStaleTaskCleanup() startStaleTaskCleanup()
@@ -175,32 +185,47 @@ process.on('uncaughtException', (err) => {
// ── Graceful shutdown ─────────────────────────────────────────────────── // ── Graceful shutdown ───────────────────────────────────────────────────
let shuttingDown = false let shuttingDown = false
function gracefulShutdown(signal) { async function shutdownRuntimeState() {
if (staleLeaseCleanupTimer) {
clearInterval(staleLeaseCleanupTimer)
staleLeaseCleanupTimer = null
}
stopSettlementWorker()
stopProviderHealthMonitor()
stopPollerRecovery()
stopStaleTaskCleanup()
await Promise.allSettled([stopTaskEventListener(), stopAllPollers()])
}
function closeServer() {
if (!server || !server.listening) return Promise.resolve()
return new Promise((resolve) => {
server.close(() => {
console.log('[shutdown] Server closed, cleaning up...')
resolve()
})
})
}
async function gracefulShutdown(signal) {
if (shuttingDown) return if (shuttingDown) return
shuttingDown = true shuttingDown = true
console.log('[shutdown] Received ' + signal + ', draining connections...') console.log('[shutdown] Received ' + signal + ', draining connections...')
if (server && server.listening) {
server.close(() => {
console.log('[shutdown] Server closed, cleaning up...')
const { stopProviderHealthMonitor } = require('./providerHealthMonitor')
stopProviderHealthMonitor()
const { stopTaskEventListener, stopPollerRecovery } = require('./aiTaskWorker')
stopPollerRecovery()
void stopTaskEventListener()
const { pool } = require('./db')
pool.end().then(() => {
console.log('[shutdown] Database pool closed')
process.exit(0)
}).catch(() => process.exit(0))
})
// Force exit after timeout
setTimeout(() => { setTimeout(() => {
console.error('[shutdown] Forced exit after timeout') console.error('[shutdown] Forced exit after timeout')
process.exit(1) process.exit(1)
}, 15000).unref() }, 15000).unref()
} else {
try {
await shutdownRuntimeState()
await closeServer()
const { pool } = require('./db')
await pool.end()
console.log('[shutdown] Database pool closed')
process.exit(0)
} catch (err) {
console.error('[shutdown] error:', err)
process.exit(0) process.exit(0)
} }
} }
+9 -3
View File
@@ -284,7 +284,7 @@ async function releaseLeaseInternal(leaseToken, user, options = {}) {
const { rows } = await client.query( const { rows } = await client.query(
` `
WITH candidate AS ( WITH candidate AS (
SELECT l.id, l.key_id, k.provider SELECT l.id, l.key_id, l.user_id, l.enterprise_id, k.provider
FROM key_leases l FROM key_leases l
JOIN api_keys k ON k.id = l.key_id JOIN api_keys k ON k.id = l.key_id
WHERE l.lease_token = $1 AND l.released_at IS NULL WHERE l.lease_token = $1 AND l.released_at IS NULL
@@ -297,7 +297,7 @@ async function releaseLeaseInternal(leaseToken, user, options = {}) {
WHERE id = (SELECT id FROM candidate) WHERE id = (SELECT id FROM candidate)
RETURNING id, key_id RETURNING id, key_id
) )
SELECT r.id, r.key_id, c.provider SELECT r.id, r.key_id, c.user_id AS lease_user_id, c.enterprise_id AS lease_enterprise_id, c.provider
FROM released r FROM released r
JOIN candidate c ON c.key_id = r.key_id JOIN candidate c ON c.key_id = r.key_id
`, `,
@@ -339,7 +339,13 @@ async function releaseLeaseInternal(leaseToken, user, options = {}) {
INSERT INTO usage_logs (user_id, enterprise_id, provider, key_id, action) INSERT INTO usage_logs (user_id, enterprise_id, provider, key_id, action)
VALUES ($1, $2, (SELECT provider FROM api_keys WHERE id = $3), $4, $5) VALUES ($1, $2, (SELECT provider FROM api_keys WHERE id = $3), $4, $5)
`, `,
[userId, enterpriseId, lease.key_id, lease.key_id, "release"], [
userId || lease.lease_user_id,
enterpriseId || lease.lease_enterprise_id,
lease.key_id,
lease.key_id,
"release",
],
); );
return { return {
+120
View File
@@ -0,0 +1,120 @@
"use strict";
const crypto = require("node:crypto");
const { pool } = require("./db");
const DEFAULT_MAX_CONCURRENCY = 8;
const DEFAULT_SLOT_TTL_MS = 30_000;
const POLL_SCOPE = "generation-provider-poll:global";
const OWNER_ID = `${process.pid}-${crypto.randomUUID()}`;
let storeReady = null;
function normalizePositiveInteger(value, fallback) {
const numeric = Number(value);
if (!Number.isFinite(numeric) || numeric <= 0) return fallback;
return Math.max(1, Math.trunc(numeric));
}
function getMaxConcurrency() {
return normalizePositiveInteger(process.env.TASK_PROVIDER_POLL_MAX_CONCURRENCY, DEFAULT_MAX_CONCURRENCY);
}
function getSlotTtlInterval() {
const ttlMs = normalizePositiveInteger(process.env.TASK_PROVIDER_POLL_SLOT_TTL_MS, DEFAULT_SLOT_TTL_MS);
return `${Math.max(1, Math.ceil(ttlMs / 1000))} seconds`;
}
async function ensureProviderPollLimiterStore() {
if (storeReady) return storeReady;
storeReady = pool.query(`
CREATE TABLE IF NOT EXISTS generation_provider_poll_slots (
scope TEXT NOT NULL,
slot_no INTEGER NOT NULL,
owner_id TEXT NOT NULL,
task_id INTEGER,
expires_at TIMESTAMPTZ NOT NULL,
acquired_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (scope, slot_no)
);
CREATE INDEX IF NOT EXISTS idx_generation_provider_poll_slots_expires
ON generation_provider_poll_slots(expires_at);
`).catch((err) => {
storeReady = null;
throw err;
});
return storeReady;
}
async function acquireProviderPollSlot(taskId, options = {}) {
await ensureProviderPollLimiterStore();
const scope = options.scope || POLL_SCOPE;
const maxConcurrency = normalizePositiveInteger(options.maxConcurrency, getMaxConcurrency());
const ttlInterval = options.ttlInterval || getSlotTtlInterval();
const { rows } = await pool.query(
`
WITH candidate AS (
SELECT s.slot_no
FROM generate_series(1, $2::integer) AS s(slot_no)
LEFT JOIN generation_provider_poll_slots l
ON l.scope = $1 AND l.slot_no = s.slot_no
WHERE l.scope IS NULL OR l.expires_at < NOW()
ORDER BY s.slot_no ASC
LIMIT 1
),
claimed AS (
INSERT INTO generation_provider_poll_slots (
scope, slot_no, owner_id, task_id, expires_at, acquired_at, updated_at
)
SELECT $1, slot_no, $3, $4, NOW() + ($5::text)::interval, NOW(), NOW()
FROM candidate
ON CONFLICT (scope, slot_no) DO UPDATE SET
owner_id = EXCLUDED.owner_id,
task_id = EXCLUDED.task_id,
expires_at = EXCLUDED.expires_at,
acquired_at = NOW(),
updated_at = NOW()
WHERE generation_provider_poll_slots.expires_at < NOW()
RETURNING scope, slot_no
)
SELECT scope, slot_no FROM claimed
`,
[scope, maxConcurrency, OWNER_ID, taskId || null, ttlInterval],
);
const slot = rows[0];
return slot ? { scope: slot.scope, slotNo: slot.slot_no, ownerId: OWNER_ID } : null;
}
async function releaseProviderPollSlot(slot) {
if (!slot?.scope || !slot?.slotNo) return;
await ensureProviderPollLimiterStore();
await pool.query(
"DELETE FROM generation_provider_poll_slots WHERE scope = $1 AND slot_no = $2 AND owner_id = $3",
[slot.scope, slot.slotNo, slot.ownerId || OWNER_ID],
);
}
async function withProviderPollSlot(taskId, fn, options = {}) {
const slot = await acquireProviderPollSlot(taskId, options);
if (!slot) return { acquired: false, value: undefined };
try {
return { acquired: true, value: await fn() };
} finally {
await releaseProviderPollSlot(slot).catch((err) => {
console.error(`[providerPollLimiter] failed to release poll slot ${slot.scope}:${slot.slotNo}:`, err.message);
});
}
}
module.exports = {
acquireProviderPollSlot,
ensureProviderPollLimiterStore,
getMaxConcurrency,
normalizePositiveInteger,
releaseProviderPollSlot,
withProviderPollSlot,
};
+22 -2
View File
@@ -1072,6 +1072,16 @@ function registerAiRoutes(router) {
error.costCents = billingResult.costCents; error.costCents = billingResult.costCents;
throw error; throw error;
} }
if (billingResult.costCents > 0) {
await client.query(
"UPDATE generation_tasks SET cost_cents = $1, billing_target = $2, billing_refunded = 0, updated_at = NOW() WHERE id = $3",
[
billingResult.costCents,
billingResult.deductionType === "enterprise_image_flat" ? "enterprise_image" : "user",
nextTaskRow.id,
],
);
}
return { taskRow: nextTaskRow, imageBilling: billingResult }; return { taskRow: nextTaskRow, imageBilling: billingResult };
}); });
const preauth = { authorized: true, estimatedCostCents: 0, billingMode: imageBilling.deductionType }; const preauth = { authorized: true, estimatedCostCents: 0, billingMode: imageBilling.deductionType };
@@ -1086,9 +1096,11 @@ function registerAiRoutes(router) {
}, },
providerDebug: buildImageProviderDebug(model), providerDebug: buildImageProviderDebug(model),
}); });
submitImageWithProviderFallback(taskRow.id, providerCandidates, req.user, preauth, params).catch((err) => { submitImageWithProviderFallback(taskRow.id, providerCandidates, req.user, preauth, params).catch(async (err) => {
console.error("[ai/image] submit error:", err.message); console.error("[ai/image] submit error:", err.message);
updateTaskInDb(taskRow.id, { status: "failed", error: err.message }); await updateTaskInDb(taskRow.id, { status: "failed", error: err.message }).catch((updateErr) => {
console.error(`[ai/image] failed to persist task ${taskRow.id} failure:`, updateErr.message);
});
}); });
} catch (err) { } catch (err) {
console.error("[ai/image] error:", err.message); console.error("[ai/image] error:", err.message);
@@ -1200,6 +1212,10 @@ function registerAiRoutes(router) {
...enterpriseBilling, ...enterpriseBilling,
taskId: nextTaskRow.id, taskId: nextTaskRow.id,
}); });
await client.query(
"UPDATE generation_tasks SET cost_cents = $1, billing_target = 'enterprise_video', billing_refunded = 0, updated_at = NOW() WHERE id = $2",
[nextBilling.amountCents, nextTaskRow.id],
);
return { taskRow: nextTaskRow, reservedBilling: nextBilling, regularBilling: null }; return { taskRow: nextTaskRow, reservedBilling: nextBilling, regularBilling: null };
} }
// Regular user: deduct from personal balance // Regular user: deduct from personal balance
@@ -1222,6 +1238,10 @@ function registerAiRoutes(router) {
"INSERT INTO transactions (user_id, type, amount_cents, balance_after_cents, description) VALUES ($1, 'deduct', $2, $3, $4)", "INSERT INTO transactions (user_id, type, amount_cents, balance_after_cents, description) VALUES ($1, 'deduct', $2, $3, $4)",
[req.user.id, -costCents, deducted.balance_cents, `视频生成扣费 ${credits} 积分`], [req.user.id, -costCents, deducted.balance_cents, `视频生成扣费 ${credits} 积分`],
); );
await client.query(
"UPDATE generation_tasks SET cost_cents = $1, billing_target = 'user', billing_refunded = 0, updated_at = NOW() WHERE id = $2",
[costCents, nextTaskRow.id],
);
return { taskRow: nextTaskRow, reservedBilling: null, regularBilling: { costCents, balanceAfterCents: deducted.balance_cents, credits } }; return { taskRow: nextTaskRow, reservedBilling: null, regularBilling: { costCents, balanceAfterCents: deducted.balance_cents, credits } };
}); });