diff --git a/src/aiTaskWorker.js b/src/aiTaskWorker.js index 6f191b7..9781ae0 100644 --- a/src/aiTaskWorker.js +++ b/src/aiTaskWorker.js @@ -152,6 +152,15 @@ async function clearPollerState(taskDbId) { await pool.query("DELETE FROM generation_task_pollers WHERE task_id = $1", [taskDbId]); } +async function getPersistedLeaseToken(taskDbId) { + await ensureTaskPollerStore(); + const { rows } = await pool.query( + "SELECT lease_token FROM generation_task_pollers WHERE task_id = $1 LIMIT 1", + [taskDbId], + ); + return rows[0]?.lease_token || null; +} + async function getLeaseKey(leaseToken) { if (!leaseToken) return null; const { rows } = await pool.query( @@ -880,6 +889,29 @@ function stopPolling(taskDbId) { clearPollerState(taskDbId).catch(() => {}); } +async function cancelTaskRuntimeState(taskDbId, keyManager) { + const poller = activePollers.get(taskDbId); + if (poller) { + clearInterval(poller.interval); + activePollers.delete(taskDbId); + } + + const leaseToken = poller?.leaseToken || await getPersistedLeaseToken(taskDbId).catch(() => null); + await clearPollerState(taskDbId).catch(() => {}); + if (leaseToken && keyManager) { + await keyManager.releaseKey(leaseToken).catch((err) => { + console.error(`[aiTaskWorker] failed to release lease for cancelled task ${taskDbId}:`, err.message); + }); + } + await publishTaskEvent({ + taskId: taskDbId, + status: "cancelled", + progress: 100, + resultUrl: null, + error: "任务已取消", + }); +} + function getActiveCount() { return activePollers.size; } @@ -1064,6 +1096,7 @@ function stopPollerRecovery() { module.exports = { startPolling, stopPolling, + cancelTaskRuntimeState, updateTaskInDb, getActiveCount, extractProviderTaskId, diff --git a/src/routes/ai.js b/src/routes/ai.js index 5a6cf78..3c5e012 100644 --- a/src/routes/ai.js +++ b/src/routes/ai.js @@ -16,6 +16,7 @@ const { } = require("../enterpriseVideoBilling"); const { startPolling, + cancelTaskRuntimeState, updateTaskInDb, extractProviderTaskId, extractImageUrl, @@ -1742,6 +1743,35 @@ function registerAiRoutes(router) { } }); + const streamTaskStatusPoll = async (taskId, userId, emit) => { + const { rows } = await pool.query( + "SELECT * FROM generation_tasks WHERE id = $1 AND user_id = $2", + [taskId, userId], + ); + const row = rows[0]; + if (!row) return { found: false, terminal: true }; + + if (row.status === "pending" || row.status === "running") { + pool.query( + "UPDATE generation_tasks SET last_poll_at = NOW() WHERE id = $1", + [taskId], + ).catch(() => {}); + } + + const event = { + taskId: row.id, + status: row.status, + progress: Number(row.progress || 0), + resultUrl: row.result_url || null, + error: row.error || null, + }; + emit(event); + return { + found: true, + terminal: ["completed", "failed", "cancelled"].includes(row.status), + }; + }; + router.get("/ai/tasks/:taskId/stream", requireAuth, async (req, res) => { const { taskId } = req.params; try { @@ -1773,16 +1803,43 @@ function registerAiRoutes(router) { return; } + let closed = false; + let lastSnapshot = JSON.stringify(initial); + let dbPollTimer = null; + const endStream = () => { + if (closed) return; + closed = true; + if (dbPollTimer) clearInterval(dbPollTimer); + taskEvents.off(`task:${taskId}`, onUpdate); + res.end(); + }; + const emitIfChanged = (evt) => { + if (closed) return; + const snapshot = JSON.stringify(evt); + if (snapshot === lastSnapshot) return; + lastSnapshot = snapshot; + res.write(`data: ${snapshot}\n\n`); + }; const onUpdate = (evt) => { - res.write(`data: ${JSON.stringify(evt)}\n\n`); + emitIfChanged(evt); if (["completed", "failed", "cancelled"].includes(evt.status)) { - res.end(); + endStream(); } }; taskEvents.on(`task:${taskId}`, onUpdate); + dbPollTimer = setInterval(() => { + streamTaskStatusPoll(taskId, req.user.id, emitIfChanged) + .then((result) => { + if (!result.found || result.terminal) endStream(); + }) + .catch((pollErr) => { + console.error(`[ai/task-stream] db poll failed for task ${taskId}:`, pollErr.message); + }); + }, 3000); + req.on("close", () => { - taskEvents.off(`task:${taskId}`, onUpdate); + endStream(); }); } catch (err) { if (!res.headersSent) res.status(err.name === "AbortError" ? 504 : 500).json({ error: err.name === "AbortError" ? "AI 上游响应超时,请重试" : err.message }); @@ -1799,6 +1856,7 @@ function registerAiRoutes(router) { [taskId, req.user.id], ); if (rows.length === 0) return res.status(404).json({ error: "Task not found or not in active state" }); + await cancelTaskRuntimeState(taskId, keyManager); res.json({ id: rows[0].id, status: rows[0].status }); } catch (err) { console.error("[ai/task-cancel] error:", err.message);