From 2249f7ce36d72eb9c45556cb65735f4e17f0a0aa Mon Sep 17 00:00:00 2001 From: Stringadmin Date: Thu, 4 Jun 2026 21:08:14 +0800 Subject: [PATCH] Harden generation task lifecycle --- src/aiTaskWorker.js | 334 ++++++++++++++++++++++++++++++++++++-------- src/dbSetup.js | 13 ++ src/keyManager.js | 5 +- src/routes/ai.js | 10 +- 4 files changed, 299 insertions(+), 63 deletions(-) diff --git a/src/aiTaskWorker.js b/src/aiTaskWorker.js index ded2557..208c027 100644 --- a/src/aiTaskWorker.js +++ b/src/aiTaskWorker.js @@ -5,6 +5,8 @@ const { EventEmitter } = require("node:events"); const { pool } = require("./db"); const { refundTaskBillingOnFailure } = require("./billing"); const { putObject, isOssConfigured } = require("./ossClient"); +const keyManager = require("./keyManager"); +const { resolveImageProviderCandidates, resolveVideoProvider } = require("./aiProviderRouter"); const taskEvents = new EventEmitter(); taskEvents.setMaxListeners(200); @@ -13,6 +15,10 @@ const activePollers = new Map(); const POLL_INTERVAL_MS = 3000; const MAX_POLL_ATTEMPTS = 120; const GRS_IMAGE_MAX_POLL_ATTEMPTS = Number(process.env.GRSAI_IMAGE_MAX_POLL_ATTEMPTS || 60); +const STALE_TASK_TIMEOUT_MINUTES = Math.max(10, Number(process.env.STALE_GENERATION_TASK_MINUTES || 120)); +const RESULT_PERSIST_RETRY_LIMIT = Math.max(1, Number(process.env.RESULT_PERSIST_RETRY_LIMIT || 5)); +const RESULT_PERSIST_RETRY_BATCH_SIZE = Math.max(1, Number(process.env.RESULT_PERSIST_RETRY_BATCH_SIZE || 25)); +const TASK_STARTUP_RECOVERY_LIMIT = Math.max(1, Number(process.env.TASK_STARTUP_RECOVERY_LIMIT || 50)); function normalizeTaskProgress(value) { const numeric = Number(value); @@ -92,8 +98,9 @@ async function updateTaskInDb(taskId, updates) { if (fields.length === 0) return; values.push(taskId); + const protectCancelled = nextUpdates.status !== "cancelled" ? " AND status <> 'cancelled'" : ""; const { rows } = await pool.query( - `UPDATE generation_tasks SET ${fields.join(", ")} WHERE id = $${idx} RETURNING *`, + `UPDATE generation_tasks SET ${fields.join(", ")} WHERE id = $${idx}${protectCancelled} RETURNING *`, values, ); let updatedTask = rows[0]; @@ -124,20 +131,66 @@ function persistTaskResultUrlToOssInBackground(task) { Promise.resolve() .then(async () => { - const durableUrl = await persistResultUrlToOss(task); - if (!durableUrl || durableUrl === task.result_url) return; - - await pool.query( - "UPDATE generation_tasks SET result_url = $1, updated_at = NOW() WHERE id = $2 AND result_url = $3", - [durableUrl, task.id, task.result_url], - ); - console.info(`[aiTaskWorker] task ${task.id} result persisted to OSS after completion`); + await persistTaskResultUrlToOss(task); }) .catch((error) => { console.warn(`[aiTaskWorker] background result persistence failed for task ${task.id}:`, error.message); }); } +async function markResultPersistence(taskId, status, error = null, durableUrl = null, previousUrl = null) { + const fields = [ + "result_persist_status = $1", + "result_persist_attempts = result_persist_attempts + 1", + "result_persist_error = $2", + "updated_at = NOW()", + ]; + const values = [status, error ? String(error).slice(0, 1000) : null]; + let idx = values.length + 1; + + if (status === "succeeded") { + fields.push("result_persisted_at = NOW()"); + } + if (durableUrl) { + fields.push(`result_url = $${idx++}`); + values.push(durableUrl); + } + + values.push(taskId); + let where = `id = $${idx}`; + if (previousUrl) { + idx += 1; + values.push(previousUrl); + where += ` AND result_url = $${idx}`; + } + + await pool.query(`UPDATE generation_tasks SET ${fields.join(", ")} WHERE ${where}`, values); +} + +async function persistTaskResultUrlToOss(task) { + if (!task?.id || !task?.result_url) return null; + + if (isOwnPersistedResultUrl(task.result_url)) { + await markResultPersistence(task.id, "succeeded", null, null); + return task.result_url; + } + + if (!isOssConfigured()) { + await markResultPersistence(task.id, "failed", "OSS is not configured"); + return null; + } + + const durableUrl = await persistResultUrlToOss(task); + if (!durableUrl) { + await markResultPersistence(task.id, "failed", "Result URL could not be copied to OSS"); + return null; + } + + await markResultPersistence(task.id, "succeeded", null, durableUrl, task.result_url); + console.info(`[aiTaskWorker] task ${task.id} result persisted to OSS after completion`); + return durableUrl; +} + function asObject(value) { return value && typeof value === "object" && !Array.isArray(value) ? value : undefined; } @@ -636,35 +689,41 @@ function getMaxPollAttempts(type, providerConfig) { return MAX_POLL_ATTEMPTS; } +async function releasePollingLease(poller) { + if (!poller?.leaseToken || !poller?.keyManager) return; + await poller.keyManager.releaseKey(poller.leaseToken).catch((err) => { + console.warn(`[aiTaskWorker] release lease failed for task ${poller.taskDbId}:`, err.message); + }); +} + function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig, leaseToken, keyManager, onTaskFailed }) { if (activePollers.has(taskDbId)) return; let attempts = 0; + let polling = false; const maxPollAttempts = getMaxPollAttempts(type, providerConfig); const interval = setInterval(async () => { + if (polling) return; + polling = true; attempts++; - if (attempts > maxPollAttempts) { - clearInterval(interval); - activePollers.delete(taskDbId); - if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {}); - if (typeof onTaskFailed === "function") { - const handled = await onTaskFailed("Task timed out").catch((fallbackErr) => { - console.error(`[aiTaskWorker] fallback error for task ${taskDbId}:`, fallbackErr.message); - return false; - }); - if (handled) return; - } - await updateTaskInDb(taskDbId, { status: "failed", error: "Task timed out" }); - return; - } try { - // Check if task was cancelled by user + if (attempts > maxPollAttempts) { + await stopPolling(taskDbId, { releaseLease: true }); + if (typeof onTaskFailed === "function") { + const handled = await onTaskFailed("Task timed out").catch((fallbackErr) => { + console.error(`[aiTaskWorker] fallback error for task ${taskDbId}:`, fallbackErr.message); + return false; + }); + if (handled) return; + } + await updateTaskInDb(taskDbId, { status: "failed", error: "Task timed out" }); + return; + } + const { rows: [taskRow] } = await pool.query("SELECT status FROM generation_tasks WHERE id = $1", [taskDbId]); if (!taskRow || taskRow.status === "cancelled") { - clearInterval(interval); - activePollers.delete(taskDbId); - if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {}); + await stopPolling(taskDbId, { releaseLease: true }); return; } @@ -680,9 +739,7 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig, } if (result.status === "completed" || result.status === "failed") { - clearInterval(interval); - activePollers.delete(taskDbId); - if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {}); + await stopPolling(taskDbId, { releaseLease: true }); if (result.status === "failed" && typeof onTaskFailed === "function") { const handled = await onTaskFailed(result.error || "Task failed").catch((fallbackErr) => { console.error(`[aiTaskWorker] fallback error for task ${taskDbId}:`, fallbackErr.message); @@ -695,68 +752,226 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig, await updateTaskInDb(taskDbId, result); } catch (err) { console.error(`[aiTaskWorker] poll error for task ${taskDbId}:`, err.message); + } finally { + polling = false; } }, POLL_INTERVAL_MS); - activePollers.set(taskDbId, { interval, leaseToken }); + activePollers.set(taskDbId, { taskDbId, interval, leaseToken, keyManager }); } -function stopPolling(taskDbId) { +async function stopPolling(taskDbId, options = {}) { const poller = activePollers.get(taskDbId); - if (poller) { - clearInterval(poller.interval); - activePollers.delete(taskDbId); + if (!poller) return; + + clearInterval(poller.interval); + activePollers.delete(taskDbId); + if (options.releaseLease) { + await releasePollingLease(poller); } } +async function cancelTask(taskId, userId) { + const { rows } = await pool.query( + `UPDATE generation_tasks + SET status = 'cancelled', completed_at = NOW(), updated_at = NOW() + WHERE id = $1 AND user_id = $2 AND status IN ('pending', 'running') + RETURNING *`, + [taskId, userId], + ); + const task = rows[0]; + if (!task) return null; + + await stopPolling(task.id, { releaseLease: true }); + taskEvents.emit(`task:${task.id}`, formatTaskEvent(task)); + return task; +} + function getActiveCount() { return activePollers.size; } -// --- Periodic stale task cleanup --- -// Runs every 5 minutes, marks tasks stuck in 'pending'/'running' for too long as 'failed'. -// This catches cases where the worker crashed, the provider API never responded, -// or the cancel request failed silently on the client side. const STALE_TASK_CLEANUP_INTERVAL_MS = 5 * 60 * 1000; let staleTaskCleanupTimer = null; +const TASK_RESULT_PERSIST_RETRY_INTERVAL_MS = 5 * 60 * 1000; +let taskResultPersistenceRetryTimer = null; +let taskStartupRecoveryTimer = null; +let taskStaleCleanupRunning = false; +let taskResultPersistenceRetryRunning = false; +let taskStartupRecoveryRunning = false; + +function parseTaskParams(paramsJson) { + if (!paramsJson) return {}; + if (typeof paramsJson === "object") return paramsJson; + try { + const parsed = JSON.parse(paramsJson); + return parsed && typeof parsed === "object" ? parsed : {}; + } catch { + return {}; + } +} + +function resolveProviderConfigForRecovery(task) { + const params = parseTaskParams(task.params_json); + + if (task.type === "video") { + if (params.model === "video-style-transform" || params.operation === "video-style-super-resolution") { + return { provider: "dashscope", protocol: "wan-i2v", baseUrl: "https://dashscope.aliyuncs.com" }; + } + if (params.model === "aliyun-video-super-resolve" || params.model === "aliyun-erase-subtitles") { + return null; + } + return resolveVideoProvider(params.model); + } + + if (task.type === "image") { + if (params.operation === "image-super-resolution" || params.operation === "image-edit") { + return { provider: "dashscope", transport: "dashscope-image" }; + } + const candidates = resolveImageProviderCandidates(params.model); + return candidates[0] || null; + } + + return null; +} + +function normalizeRecoveryUser(task) { + return { + id: task.user_id, + enterpriseId: task.enterprise_id ?? null, + accountType: task.enterprise_id ? "enterprise" : "personal", + }; +} async function runStaleTaskCleanup() { + if (taskStaleCleanupRunning) return; + taskStaleCleanupRunning = true; try { const { rows } = await pool.query( `UPDATE generation_tasks - SET status = 'failed', error = '任务超时自动释放', updated_at = NOW() + SET status = 'failed', + error = 'Task timed out and was released automatically', + completed_at = NOW(), + updated_at = NOW() WHERE status IN ('pending', 'running') - AND GREATEST(updated_at, COALESCE(last_poll_at, created_at)) < NOW() - INTERVAL '10 minutes' - RETURNING id`, + AND GREATEST(updated_at, COALESCE(last_poll_at, created_at)) < NOW() - ($1::int * INTERVAL '1 minute') + RETURNING *`, + [STALE_TASK_TIMEOUT_MINUTES], ); + for (const row of rows) { - taskEvents.emit(`task:${row.id}`, { - taskId: row.id, - status: "failed", - progress: null, - resultUrl: null, - error: "任务超时自动释放", + await stopPolling(row.id, { releaseLease: true }); + taskEvents.emit(`task:${row.id}`, formatTaskEvent(row)); + await refundTaskBillingOnFailure(row.id).catch((err) => { + console.error(`[aiTaskWorker] stale task refund error for task ${row.id}:`, err.message); }); - // Also stop any active poller for this task - const poller = activePollers.get(row.id); - if (poller) { - clearInterval(poller.timer); - activePollers.delete(row.id); - } } + if (rows.length > 0) { console.log(`[aiTaskWorker] Cleaned up ${rows.length} stale task(s)`); } } catch (err) { console.error("[aiTaskWorker] Stale task cleanup failed:", err.message); + } finally { + taskStaleCleanupRunning = false; + } +} + +async function runResultPersistenceRetry() { + if (taskResultPersistenceRetryRunning) return; + taskResultPersistenceRetryRunning = true; + try { + const { rows } = await pool.query( + `SELECT * + FROM generation_tasks + WHERE status = 'completed' + AND result_url IS NOT NULL + AND result_url ~* '^https?://' + AND result_url !~* '/users/[^/]+/generation-results/' + AND result_persist_status IN ('pending', 'failed') + AND result_persist_attempts < $1 + ORDER BY updated_at ASC + LIMIT $2`, + [RESULT_PERSIST_RETRY_LIMIT, RESULT_PERSIST_RETRY_BATCH_SIZE], + ); + + for (const row of rows) { + await persistTaskResultUrlToOss(row); + } + + if (rows.length > 0) { + console.log(`[aiTaskWorker] Retried OSS result persistence for ${rows.length} task(s)`); + } + } catch (err) { + console.error("[aiTaskWorker] Result persistence retry failed:", err.message); + } finally { + taskResultPersistenceRetryRunning = false; + } +} + +async function runTaskStartupRecovery() { + if (taskStartupRecoveryRunning) return; + taskStartupRecoveryRunning = true; + try { + const { rows } = await pool.query( + `SELECT gt.*, u.enterprise_id + FROM generation_tasks gt + JOIN users u ON u.id = gt.user_id + WHERE gt.status = 'running' + AND gt.provider_task_id IS NOT NULL + AND GREATEST(gt.updated_at, COALESCE(gt.last_poll_at, gt.created_at)) >= NOW() - ($1::int * INTERVAL '1 minute') + ORDER BY gt.updated_at DESC + LIMIT $2`, + [STALE_TASK_TIMEOUT_MINUTES, TASK_STARTUP_RECOVERY_LIMIT], + ); + + let recovered = 0; + for (const task of rows) { + if (activePollers.has(task.id)) continue; + + let providerConfig; + try { + providerConfig = resolveProviderConfigForRecovery(task); + } catch (err) { + console.warn(`[aiTaskWorker] task ${task.id} recovery skipped: ${err.message}`); + continue; + } + + if (!providerConfig?.provider) continue; + const slotResult = await keyManager.acquireKey(providerConfig.provider, normalizeRecoveryUser(task), null, { waitTimeoutMs: 0 }); + if (!slotResult) { + console.warn(`[aiTaskWorker] task ${task.id} recovery waiting for provider capacity: ${providerConfig.provider}`); + continue; + } + + startPolling(task.id, { + providerTaskId: task.provider_task_id, + apiKey: slotResult.apiKey, + type: task.type, + providerConfig, + leaseToken: slotResult.leaseToken, + keyManager, + }); + recovered += 1; + } + + if (recovered > 0) { + console.log(`[aiTaskWorker] Recovered ${recovered} running task poller(s) after startup`); + } + } catch (err) { + console.error("[aiTaskWorker] Startup task recovery failed:", err.message); + } finally { + taskStartupRecoveryRunning = false; } } function startStaleTaskCleanup() { if (staleTaskCleanupTimer) return; staleTaskCleanupTimer = setInterval(runStaleTaskCleanup, STALE_TASK_CLEANUP_INTERVAL_MS); - // Run once shortly after startup + taskResultPersistenceRetryTimer = setInterval(runResultPersistenceRetry, TASK_RESULT_PERSIST_RETRY_INTERVAL_MS); + taskStartupRecoveryTimer = setTimeout(runTaskStartupRecovery, 5_000); setTimeout(runStaleTaskCleanup, 10_000); + setTimeout(runResultPersistenceRetry, 15_000); } function stopStaleTaskCleanup() { @@ -764,11 +979,20 @@ function stopStaleTaskCleanup() { clearInterval(staleTaskCleanupTimer); staleTaskCleanupTimer = null; } + if (taskResultPersistenceRetryTimer) { + clearInterval(taskResultPersistenceRetryTimer); + taskResultPersistenceRetryTimer = null; + } + if (taskStartupRecoveryTimer) { + clearTimeout(taskStartupRecoveryTimer); + taskStartupRecoveryTimer = null; + } } module.exports = { startPolling, stopPolling, + cancelTask, updateTaskInDb, getActiveCount, extractProviderTaskId, diff --git a/src/dbSetup.js b/src/dbSetup.js index c1ee1c8..19f73d1 100644 --- a/src/dbSetup.js +++ b/src/dbSetup.js @@ -353,6 +353,18 @@ async function migrateGenerationTasksBillingColumns(client) { ); } +async function migrateGenerationTaskResultPersistence(client) { + await addColumnIfMissing("generation_tasks", "result_persist_status TEXT NOT NULL DEFAULT 'pending'"); + await addColumnIfMissing("generation_tasks", "result_persist_attempts INTEGER NOT NULL DEFAULT 0"); + await addColumnIfMissing("generation_tasks", "result_persist_error TEXT"); + await addColumnIfMissing("generation_tasks", "result_persisted_at TIMESTAMPTZ"); + await client.query(` + CREATE INDEX IF NOT EXISTS idx_generation_tasks_result_persist_retry + ON generation_tasks(result_persist_status, updated_at) + WHERE status = 'completed' AND result_url IS NOT NULL + `); +} + async function ensureModelPriceSeed() { const columns = await getColumnNames("model_prices"); const useMills = columns.includes("input_price_mills"); @@ -959,6 +971,7 @@ async function ensureSchema() { await runMigration("030_generation_tasks_user_status_index", migrateGenerationTasksUserStatusIndex); await runMigration("031_generation_tasks_billing_columns", migrateGenerationTasksBillingColumns); await runMigration("032_ecommerce_video_history", migrateEcommerceVideoHistorySchema); + await runMigration("033_generation_task_result_persistence", migrateGenerationTaskResultPersistence); await ensureModelPriceSeed(); } diff --git a/src/keyManager.js b/src/keyManager.js index 6f9b49e..5ca70e4 100644 --- a/src/keyManager.js +++ b/src/keyManager.js @@ -284,7 +284,7 @@ async function releaseLeaseInternal(leaseToken, user, options = {}) { const { rows } = await client.query( ` WITH candidate AS ( - SELECT l.id, l.key_id, k.provider + SELECT l.id, l.key_id, l.user_id, l.enterprise_id, k.provider FROM key_leases l JOIN api_keys k ON k.id = l.key_id WHERE l.lease_token = $1 AND l.released_at IS NULL @@ -298,6 +298,7 @@ async function releaseLeaseInternal(leaseToken, user, options = {}) { RETURNING id, key_id ) SELECT r.id, r.key_id, c.provider + , c.user_id, c.enterprise_id FROM released r JOIN candidate c ON c.key_id = r.key_id `, @@ -339,7 +340,7 @@ async function releaseLeaseInternal(leaseToken, user, options = {}) { INSERT INTO usage_logs (user_id, enterprise_id, provider, key_id, action) VALUES ($1, $2, (SELECT provider FROM api_keys WHERE id = $3), $4, $5) `, - [userId, enterpriseId, lease.key_id, lease.key_id, "release"], + [userId ?? lease.user_id, enterpriseId ?? lease.enterprise_id, lease.key_id, lease.key_id, "release"], ); return { diff --git a/src/routes/ai.js b/src/routes/ai.js index b8f57c1..c04967e 100644 --- a/src/routes/ai.js +++ b/src/routes/ai.js @@ -16,6 +16,7 @@ const { } = require("../enterpriseVideoBilling"); const { startPolling, + cancelTask, updateTaskInDb, extractProviderTaskId, extractImageUrl, @@ -1770,12 +1771,9 @@ function registerAiRoutes(router) { if (!Number.isFinite(taskId)) return res.status(400).json({ error: "Invalid task id" }); try { - const { rows } = await pool.query( - "UPDATE generation_tasks SET status = 'cancelled', updated_at = NOW() WHERE id = $1 AND user_id = $2 AND status IN ('pending', 'running') RETURNING id, status", - [taskId, req.user.id], - ); - if (rows.length === 0) return res.status(404).json({ error: "Task not found or not in active state" }); - res.json({ id: rows[0].id, status: rows[0].status }); + const task = await cancelTask(taskId, req.user.id); + if (!task) return res.status(404).json({ error: "Task not found or not in active state" }); + res.json({ id: task.id, status: task.status }); } catch (err) { console.error("[ai/task-cancel] error:", err.message); res.status(500).json({ error: "取消任务失败" });