diff --git a/src/aiTaskWorker.js b/src/aiTaskWorker.js index ded2557..6f191b7 100644 --- a/src/aiTaskWorker.js +++ b/src/aiTaskWorker.js @@ -13,6 +13,15 @@ const activePollers = new Map(); const POLL_INTERVAL_MS = 3000; const MAX_POLL_ATTEMPTS = 120; const GRS_IMAGE_MAX_POLL_ATTEMPTS = Number(process.env.GRSAI_IMAGE_MAX_POLL_ATTEMPTS || 60); +const TASK_EVENT_CHANNEL = "generation_task_events"; +const TASK_EVENT_ORIGIN = `${process.pid}-${crypto.randomUUID()}`; +const POLLER_OWNER_ID = `${process.pid}-${crypto.randomUUID()}`; +const POLLER_OWNER_STALE_MS = Number(process.env.TASK_POLLER_OWNER_STALE_MS || 20_000); +const POLLER_RECOVERY_INTERVAL_MS = Number(process.env.TASK_POLLER_RECOVERY_INTERVAL_MS || 30_000); +let taskEventListenerClient = null; +let taskEventListenerStarting = null; +let pollerStoreReady = null; +let pollerRecoveryTimer = null; function normalizeTaskProgress(value) { const numeric = Number(value); @@ -30,6 +39,156 @@ function formatTaskEvent(row) { }; } +function emitTaskEvent(event) { + if (!event?.taskId) return; + taskEvents.emit(`task:${event.taskId}`, event); +} + +async function publishTaskEvent(event) { + if (!event?.taskId) return; + emitTaskEvent(event); + try { + await pool.query("SELECT pg_notify($1, $2)", [ + TASK_EVENT_CHANNEL, + JSON.stringify({ origin: TASK_EVENT_ORIGIN, event }), + ]); + } catch (err) { + console.error(`[aiTaskWorker] task event publish failed for task ${event.taskId}:`, err.message); + } +} + +function serializeProviderConfig(providerConfig) { + if (!providerConfig || typeof providerConfig !== "object") return {}; + const allowedKeys = [ + "provider", + "transport", + "protocol", + "baseUrl", + "endpoint", + "resultEndpoint", + "model", + "requestedModel", + ]; + const result = {}; + for (const key of allowedKeys) { + if (providerConfig[key] !== undefined) result[key] = providerConfig[key]; + } + return result; +} + +function parseProviderConfig(value) { + if (!value) return {}; + if (typeof value === "object") return value; + try { + const parsed = JSON.parse(value); + return parsed && typeof parsed === "object" ? parsed : {}; + } catch { + return {}; + } +} + +async function ensureTaskPollerStore() { + if (pollerStoreReady) return pollerStoreReady; + pollerStoreReady = pool.query(` + CREATE TABLE IF NOT EXISTS generation_task_pollers ( + task_id INTEGER PRIMARY KEY REFERENCES generation_tasks(id) ON DELETE CASCADE, + provider_task_id TEXT NOT NULL, + task_type TEXT NOT NULL, + provider_config_json TEXT NOT NULL, + lease_token TEXT, + owner_id TEXT, + owner_heartbeat_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + ); + CREATE INDEX IF NOT EXISTS idx_generation_task_pollers_owner + ON generation_task_pollers(owner_heartbeat_at); + `).catch((err) => { + pollerStoreReady = null; + throw err; + }); + return pollerStoreReady; +} + +async function persistPollerState(taskDbId, { providerTaskId, type, providerConfig, leaseToken }) { + await ensureTaskPollerStore(); + await pool.query( + ` + INSERT INTO generation_task_pollers ( + task_id, provider_task_id, task_type, provider_config_json, lease_token, + owner_id, owner_heartbeat_at, updated_at + ) + VALUES ($1, $2, $3, $4, $5, $6, NOW(), NOW()) + ON CONFLICT (task_id) DO UPDATE SET + provider_task_id = EXCLUDED.provider_task_id, + task_type = EXCLUDED.task_type, + provider_config_json = EXCLUDED.provider_config_json, + lease_token = EXCLUDED.lease_token, + owner_id = EXCLUDED.owner_id, + owner_heartbeat_at = NOW(), + updated_at = NOW() + `, + [ + taskDbId, + providerTaskId, + type, + JSON.stringify(serializeProviderConfig(providerConfig)), + leaseToken || null, + POLLER_OWNER_ID, + ], + ); +} + +async function refreshPollerHeartbeat(taskDbId) { + await ensureTaskPollerStore(); + await pool.query( + "UPDATE generation_task_pollers SET owner_id = $1, owner_heartbeat_at = NOW(), updated_at = NOW() WHERE task_id = $2", + [POLLER_OWNER_ID, taskDbId], + ); +} + +async function clearPollerState(taskDbId) { + await ensureTaskPollerStore(); + await pool.query("DELETE FROM generation_task_pollers WHERE task_id = $1", [taskDbId]); +} + +async function getLeaseKey(leaseToken) { + if (!leaseToken) return null; + const { rows } = await pool.query( + ` + SELECT k.api_key + FROM key_leases l + JOIN api_keys k ON k.id = l.key_id + WHERE l.lease_token = $1 + AND l.released_at IS NULL + AND k.enabled = 1 + LIMIT 1 + `, + [leaseToken], + ); + const apiKey = rows[0]?.api_key; + return apiKey === "pool-slot" ? "" : apiKey || null; +} + +async function claimPoller(taskId) { + await ensureTaskPollerStore(); + const staleInterval = `${Math.max(5, Math.ceil(POLLER_OWNER_STALE_MS / 1000))} seconds`; + const { rows } = await pool.query( + ` + UPDATE generation_task_pollers + SET owner_id = $1, owner_heartbeat_at = NOW(), updated_at = NOW() + WHERE task_id = $2 + AND ( + owner_heartbeat_at IS NULL + OR owner_heartbeat_at < NOW() - ($3::text)::interval + ) + RETURNING * + `, + [POLLER_OWNER_ID, taskId, staleInterval], + ); + return rows[0] || null; +} + async function createTaskLifecycleNotification(task) { if (!task || !task.user_id || !task.id) return; @@ -99,7 +258,7 @@ async function updateTaskInDb(taskId, updates) { let updatedTask = rows[0]; if (updatedTask) { - taskEvents.emit(`task:${taskId}`, formatTaskEvent(updatedTask)); + await publishTaskEvent(formatTaskEvent(updatedTask)); } if (nextUpdates.status === "completed" && updatedTask?.result_url) { @@ -636,8 +795,13 @@ function getMaxPollAttempts(type, providerConfig) { return MAX_POLL_ATTEMPTS; } -function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig, leaseToken, keyManager, onTaskFailed }) { +function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig, leaseToken, keyManager, onTaskFailed, skipPersist = false }) { if (activePollers.has(taskDbId)) return; + if (!skipPersist) { + persistPollerState(taskDbId, { providerTaskId, type, providerConfig, leaseToken }).catch((err) => { + console.error(`[aiTaskWorker] failed to persist poller state for task ${taskDbId}:`, err.message); + }); + } let attempts = 0; const maxPollAttempts = getMaxPollAttempts(type, providerConfig); @@ -655,6 +819,7 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig, if (handled) return; } await updateTaskInDb(taskDbId, { status: "failed", error: "Task timed out" }); + await clearPollerState(taskDbId).catch(() => {}); return; } @@ -664,9 +829,11 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig, if (!taskRow || taskRow.status === "cancelled") { clearInterval(interval); activePollers.delete(taskDbId); + await clearPollerState(taskDbId).catch(() => {}); if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {}); return; } + await refreshPollerHeartbeat(taskDbId).catch(() => {}); let result; if (type === "image") { @@ -693,6 +860,9 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig, } await updateTaskInDb(taskDbId, result); + if (result.status === "completed" || result.status === "failed") { + await clearPollerState(taskDbId).catch(() => {}); + } } catch (err) { console.error(`[aiTaskWorker] poll error for task ${taskDbId}:`, err.message); } @@ -707,12 +877,57 @@ function stopPolling(taskDbId) { clearInterval(poller.interval); activePollers.delete(taskDbId); } + clearPollerState(taskDbId).catch(() => {}); } function getActiveCount() { return activePollers.size; } +async function recoverRunnablePollers() { + await ensureTaskPollerStore(); + const staleInterval = `${Math.max(5, Math.ceil(POLLER_OWNER_STALE_MS / 1000))} seconds`; + const { rows } = await pool.query( + ` + SELECT p.task_id + FROM generation_task_pollers p + JOIN generation_tasks t ON t.id = p.task_id + WHERE t.status IN ('pending', 'running') + AND ( + p.owner_heartbeat_at IS NULL + OR p.owner_heartbeat_at < NOW() - ($1::text)::interval + ) + ORDER BY p.owner_heartbeat_at NULLS FIRST, p.updated_at ASC + LIMIT 20 + `, + [staleInterval], + ); + + for (const row of rows) { + const taskId = row.task_id; + if (activePollers.has(taskId)) continue; + const poller = await claimPoller(taskId); + if (!poller || activePollers.has(taskId)) continue; + + const apiKey = await getLeaseKey(poller.lease_token); + if (apiKey == null) { + console.warn(`[aiTaskWorker] cannot recover task ${taskId}: active lease not found`); + continue; + } + + console.info(`[aiTaskWorker] recovering poller for task ${taskId}`); + startPolling(taskId, { + providerTaskId: poller.provider_task_id, + apiKey, + type: poller.task_type, + providerConfig: parseProviderConfig(poller.provider_config_json), + leaseToken: poller.lease_token, + keyManager: require("./keyManager"), + skipPersist: true, + }); + } +} + // --- Periodic stale task cleanup --- // Runs every 5 minutes, marks tasks stuck in 'pending'/'running' for too long as 'failed'. // This catches cases where the worker crashed, the provider API never responded, @@ -730,7 +945,7 @@ async function runStaleTaskCleanup() { RETURNING id`, ); for (const row of rows) { - taskEvents.emit(`task:${row.id}`, { + await publishTaskEvent({ taskId: row.id, status: "failed", progress: null, @@ -740,9 +955,10 @@ async function runStaleTaskCleanup() { // Also stop any active poller for this task const poller = activePollers.get(row.id); if (poller) { - clearInterval(poller.timer); + clearInterval(poller.interval); activePollers.delete(row.id); } + await clearPollerState(row.id).catch(() => {}); } if (rows.length > 0) { console.log(`[aiTaskWorker] Cleaned up ${rows.length} stale task(s)`); @@ -752,6 +968,66 @@ async function runStaleTaskCleanup() { } } +async function startTaskEventListener() { + if (taskEventListenerClient) return; + if (taskEventListenerStarting) return taskEventListenerStarting; + + taskEventListenerStarting = (async () => { + const client = await pool.connect(); + let released = false; + + const releaseClient = () => { + if (released) return; + released = true; + taskEventListenerClient = null; + try { + client.release(); + } catch {} + }; + + client.on("notification", (message) => { + if (message.channel !== TASK_EVENT_CHANNEL || !message.payload) return; + try { + const payload = JSON.parse(message.payload); + if (payload?.origin === TASK_EVENT_ORIGIN) return; + emitTaskEvent(payload?.event || payload); + } catch (err) { + console.error("[aiTaskWorker] task event notification parse failed:", err.message); + } + }); + + client.on("error", (err) => { + console.error("[aiTaskWorker] task event listener error:", err.message); + releaseClient(); + setTimeout(() => { + startTaskEventListener().catch((restartErr) => { + console.error("[aiTaskWorker] task event listener restart failed:", restartErr.message); + }); + }, 5000).unref?.(); + }); + + await client.query(`LISTEN ${TASK_EVENT_CHANNEL}`); + taskEventListenerClient = client; + console.log(`[aiTaskWorker] listening for task events on ${TASK_EVENT_CHANNEL}`); + })(); + + try { + await taskEventListenerStarting; + } finally { + taskEventListenerStarting = null; + } +} + +async function stopTaskEventListener() { + const client = taskEventListenerClient; + taskEventListenerClient = null; + if (!client) return; + try { + await client.query(`UNLISTEN ${TASK_EVENT_CHANNEL}`); + } catch {} + client.release(); +} + function startStaleTaskCleanup() { if (staleTaskCleanupTimer) return; staleTaskCleanupTimer = setInterval(runStaleTaskCleanup, STALE_TASK_CLEANUP_INTERVAL_MS); @@ -766,6 +1042,25 @@ function stopStaleTaskCleanup() { } } +function startPollerRecovery() { + if (pollerRecoveryTimer) return; + ensureTaskPollerStore() + .then(() => recoverRunnablePollers()) + .catch((err) => console.error("[aiTaskWorker] initial poller recovery failed:", err.message)); + pollerRecoveryTimer = setInterval(() => { + recoverRunnablePollers().catch((err) => { + console.error("[aiTaskWorker] poller recovery failed:", err.message); + }); + }, POLLER_RECOVERY_INTERVAL_MS); +} + +function stopPollerRecovery() { + if (pollerRecoveryTimer) { + clearInterval(pollerRecoveryTimer); + pollerRecoveryTimer = null; + } +} + module.exports = { startPolling, stopPolling, @@ -778,6 +1073,10 @@ module.exports = { parseKlingCredential, createKlingJwt, taskEvents, + startTaskEventListener, + stopTaskEventListener, + startPollerRecovery, + stopPollerRecovery, startStaleTaskCleanup, stopStaleTaskCleanup, }; diff --git a/src/index.js b/src/index.js index 0775847..bc6b3cd 100644 --- a/src/index.js +++ b/src/index.js @@ -144,7 +144,9 @@ async function main() { startSettlementWorker() startProviderHealthMonitor() - const { startStaleTaskCleanup } = require('./aiTaskWorker') + const { startStaleTaskCleanup, startTaskEventListener, startPollerRecovery } = require('./aiTaskWorker') + await startTaskEventListener() + startPollerRecovery() startStaleTaskCleanup() server = app.listen(PORT, HOST, () => { @@ -183,6 +185,9 @@ function gracefulShutdown(signal) { console.log('[shutdown] Server closed, cleaning up...') const { stopProviderHealthMonitor } = require('./providerHealthMonitor') stopProviderHealthMonitor() + const { stopTaskEventListener, stopPollerRecovery } = require('./aiTaskWorker') + stopPollerRecovery() + void stopTaskEventListener() const { pool } = require('./db') pool.end().then(() => { console.log('[shutdown] Database pool closed')