fix: harden provider polling recovery
This commit is contained in:
+159
-38
@@ -5,6 +5,7 @@ const { EventEmitter } = require("node:events");
|
||||
const { pool } = require("./db");
|
||||
const { refundTaskBillingOnFailure } = require("./billing");
|
||||
const { putObject, isOssConfigured } = require("./ossClient");
|
||||
const { withProviderPollSlot } = require("./providerPollLimiter");
|
||||
|
||||
const taskEvents = new EventEmitter();
|
||||
taskEvents.setMaxListeners(200);
|
||||
@@ -18,10 +19,12 @@ const TASK_EVENT_ORIGIN = `${process.pid}-${crypto.randomUUID()}`;
|
||||
const POLLER_OWNER_ID = `${process.pid}-${crypto.randomUUID()}`;
|
||||
const POLLER_OWNER_STALE_MS = Number(process.env.TASK_POLLER_OWNER_STALE_MS || 20_000);
|
||||
const POLLER_RECOVERY_INTERVAL_MS = Number(process.env.TASK_POLLER_RECOVERY_INTERVAL_MS || 30_000);
|
||||
const PROVIDER_POLL_REQUEST_TIMEOUT_MS = Number(process.env.TASK_PROVIDER_POLL_REQUEST_TIMEOUT_MS || 25_000);
|
||||
let taskEventListenerClient = null;
|
||||
let taskEventListenerStarting = null;
|
||||
let pollerStoreReady = null;
|
||||
let pollerRecoveryTimer = null;
|
||||
let staleTaskCleanupStartupTimer = null;
|
||||
|
||||
function normalizeTaskProgress(value) {
|
||||
const numeric = Number(value);
|
||||
@@ -152,6 +155,14 @@ async function clearPollerState(taskDbId) {
|
||||
await pool.query("DELETE FROM generation_task_pollers WHERE task_id = $1", [taskDbId]);
|
||||
}
|
||||
|
||||
async function orphanOwnedPollerState() {
|
||||
await ensureTaskPollerStore();
|
||||
await pool.query(
|
||||
"UPDATE generation_task_pollers SET owner_id = NULL, owner_heartbeat_at = NULL, updated_at = NOW() WHERE owner_id = $1",
|
||||
[POLLER_OWNER_ID],
|
||||
);
|
||||
}
|
||||
|
||||
async function getPersistedLeaseToken(taskDbId) {
|
||||
await ensureTaskPollerStore();
|
||||
const { rows } = await pool.query(
|
||||
@@ -280,6 +291,12 @@ async function updateTaskInDb(taskId, updates) {
|
||||
});
|
||||
}
|
||||
|
||||
if (nextUpdates.status === "completed") {
|
||||
await markTaskBillingAccepted(taskId).catch((err) => {
|
||||
console.error(`[aiTaskWorker] billing accept error for task ${taskId}:`, err.message);
|
||||
});
|
||||
}
|
||||
|
||||
if (nextUpdates.status === "failed") {
|
||||
await refundTaskBillingOnFailure(taskId).catch((err) => {
|
||||
console.error(`[aiTaskWorker] refund error for task ${taskId}:`, err.message);
|
||||
@@ -287,6 +304,13 @@ async function updateTaskInDb(taskId, updates) {
|
||||
}
|
||||
}
|
||||
|
||||
async function markTaskBillingAccepted(taskId) {
|
||||
await pool.query(
|
||||
"UPDATE credit_ledger SET status = 'charged', updated_at = NOW() WHERE task_id = $1 AND status = 'reserved'",
|
||||
[taskId],
|
||||
);
|
||||
}
|
||||
|
||||
function persistTaskResultUrlToOssInBackground(task) {
|
||||
if (!task?.id || !task?.result_url) return;
|
||||
|
||||
@@ -641,9 +665,22 @@ function extractErrorMessage(json, fallback) {
|
||||
}
|
||||
|
||||
async function fetchJson(url, headers) {
|
||||
const res = await fetch(url, { method: "GET", headers });
|
||||
if (!res.ok) return { ok: false, json: null };
|
||||
return { ok: true, json: await res.json() };
|
||||
const controller = new AbortController();
|
||||
const timeoutMs = Number.isFinite(PROVIDER_POLL_REQUEST_TIMEOUT_MS) && PROVIDER_POLL_REQUEST_TIMEOUT_MS > 0
|
||||
? PROVIDER_POLL_REQUEST_TIMEOUT_MS
|
||||
: 25_000;
|
||||
const timer = setTimeout(() => controller.abort(), timeoutMs);
|
||||
timer.unref?.();
|
||||
|
||||
try {
|
||||
const res = await fetch(url, { method: "GET", headers, signal: controller.signal });
|
||||
if (!res.ok) return { ok: false, json: null };
|
||||
return { ok: true, json: await res.json() };
|
||||
} catch (err) {
|
||||
return { ok: false, json: null, error: err };
|
||||
} finally {
|
||||
clearTimeout(timer);
|
||||
}
|
||||
}
|
||||
|
||||
async function pollGrsaiImage(_taskId, providerTaskId, apiKey, baseUrl, resultEndpoint) {
|
||||
@@ -813,26 +850,31 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig,
|
||||
}
|
||||
|
||||
let attempts = 0;
|
||||
let polling = false;
|
||||
let skippedPolls = 0;
|
||||
const maxPollAttempts = getMaxPollAttempts(type, providerConfig);
|
||||
const interval = setInterval(async () => {
|
||||
attempts++;
|
||||
if (attempts > maxPollAttempts) {
|
||||
clearInterval(interval);
|
||||
activePollers.delete(taskDbId);
|
||||
if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {});
|
||||
if (typeof onTaskFailed === "function") {
|
||||
const handled = await onTaskFailed("Task timed out").catch((fallbackErr) => {
|
||||
console.error(`[aiTaskWorker] fallback error for task ${taskDbId}:`, fallbackErr.message);
|
||||
return false;
|
||||
});
|
||||
if (handled) return;
|
||||
}
|
||||
await updateTaskInDb(taskDbId, { status: "failed", error: "Task timed out" });
|
||||
await clearPollerState(taskDbId).catch(() => {});
|
||||
return;
|
||||
}
|
||||
if (polling) return;
|
||||
polling = true;
|
||||
|
||||
try {
|
||||
if (attempts >= maxPollAttempts) {
|
||||
clearInterval(interval);
|
||||
activePollers.delete(taskDbId);
|
||||
if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {});
|
||||
if (typeof onTaskFailed === "function") {
|
||||
await clearPollerState(taskDbId).catch(() => {});
|
||||
const handled = await onTaskFailed("Task timed out").catch((fallbackErr) => {
|
||||
console.error(`[aiTaskWorker] fallback error for task ${taskDbId}:`, fallbackErr.message);
|
||||
return false;
|
||||
});
|
||||
if (handled) return;
|
||||
}
|
||||
await updateTaskInDb(taskDbId, { status: "failed", error: "Task timed out" });
|
||||
await clearPollerState(taskDbId).catch(() => {});
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if task was cancelled by user
|
||||
const { rows: [taskRow] } = await pool.query("SELECT status FROM generation_tasks WHERE id = $1", [taskDbId]);
|
||||
if (!taskRow || taskRow.status === "cancelled") {
|
||||
@@ -844,15 +886,29 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig,
|
||||
}
|
||||
await refreshPollerHeartbeat(taskDbId).catch(() => {});
|
||||
|
||||
let result;
|
||||
if (type === "image") {
|
||||
if (providerConfig.transport === "dashscope-image") {
|
||||
result = await pollDashscopeImage(taskDbId, providerTaskId, apiKey);
|
||||
} else {
|
||||
result = await pollGrsaiImage(taskDbId, providerTaskId, apiKey, providerConfig.baseUrl, providerConfig.resultEndpoint || "/result");
|
||||
const pollOutcome = await withProviderPollSlot(taskDbId, async () => {
|
||||
attempts++;
|
||||
if (type === "image") {
|
||||
if (providerConfig.transport === "dashscope-image") {
|
||||
return pollDashscopeImage(taskDbId, providerTaskId, apiKey);
|
||||
}
|
||||
return pollGrsaiImage(taskDbId, providerTaskId, apiKey, providerConfig.baseUrl, providerConfig.resultEndpoint || "/result");
|
||||
}
|
||||
} else {
|
||||
result = await pollVideoTask(taskDbId, providerTaskId, apiKey, providerConfig);
|
||||
return pollVideoTask(taskDbId, providerTaskId, apiKey, providerConfig);
|
||||
});
|
||||
|
||||
if (!pollOutcome.acquired) {
|
||||
skippedPolls++;
|
||||
if (skippedPolls % 20 === 0) {
|
||||
console.info(`[aiTaskWorker] task ${taskDbId} waiting for provider poll slot (skipped=${skippedPolls})`);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
skippedPolls = 0;
|
||||
const result = pollOutcome.value;
|
||||
if (!result) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (result.status === "completed" || result.status === "failed") {
|
||||
@@ -860,6 +916,7 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig,
|
||||
activePollers.delete(taskDbId);
|
||||
if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {});
|
||||
if (result.status === "failed" && typeof onTaskFailed === "function") {
|
||||
await clearPollerState(taskDbId).catch(() => {});
|
||||
const handled = await onTaskFailed(result.error || "Task failed").catch((fallbackErr) => {
|
||||
console.error(`[aiTaskWorker] fallback error for task ${taskDbId}:`, fallbackErr.message);
|
||||
return false;
|
||||
@@ -874,6 +931,8 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig,
|
||||
}
|
||||
} catch (err) {
|
||||
console.error(`[aiTaskWorker] poll error for task ${taskDbId}:`, err.message);
|
||||
} finally {
|
||||
polling = false;
|
||||
}
|
||||
}, POLL_INTERVAL_MS);
|
||||
|
||||
@@ -921,7 +980,7 @@ async function recoverRunnablePollers() {
|
||||
const staleInterval = `${Math.max(5, Math.ceil(POLLER_OWNER_STALE_MS / 1000))} seconds`;
|
||||
const { rows } = await pool.query(
|
||||
`
|
||||
SELECT p.task_id
|
||||
SELECT p.task_id, p.updated_at
|
||||
FROM generation_task_pollers p
|
||||
JOIN generation_tasks t ON t.id = p.task_id
|
||||
WHERE t.status IN ('pending', 'running')
|
||||
@@ -944,6 +1003,7 @@ async function recoverRunnablePollers() {
|
||||
const apiKey = await getLeaseKey(poller.lease_token);
|
||||
if (apiKey == null) {
|
||||
console.warn(`[aiTaskWorker] cannot recover task ${taskId}: active lease not found`);
|
||||
await releaseUnrecoverableTask(taskId, "任务执行状态已失效,已自动释放");
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -955,11 +1015,51 @@ async function recoverRunnablePollers() {
|
||||
providerConfig: parseProviderConfig(poller.provider_config_json),
|
||||
leaseToken: poller.lease_token,
|
||||
keyManager: require("./keyManager"),
|
||||
onTaskFailed: async (failureMessage) => {
|
||||
await updateTaskInDb(taskId, { status: "failed", error: failureMessage || "Task failed" });
|
||||
return true;
|
||||
},
|
||||
skipPersist: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async function releaseUnrecoverableTask(taskId, message) {
|
||||
const { rows } = await pool.query(
|
||||
`
|
||||
UPDATE generation_tasks t
|
||||
SET status = 'failed', error = $2, completed_at = NOW(), updated_at = NOW()
|
||||
FROM generation_task_pollers p
|
||||
WHERE t.id = $1
|
||||
AND p.task_id = t.id
|
||||
AND p.owner_id = $3
|
||||
AND t.status IN ('pending', 'running')
|
||||
RETURNING t.*
|
||||
`,
|
||||
[taskId, message, POLLER_OWNER_ID],
|
||||
);
|
||||
|
||||
const task = rows[0];
|
||||
if (!task) return false;
|
||||
|
||||
const leaseToken = await getPersistedLeaseToken(taskId).catch(() => null);
|
||||
await clearPollerState(taskId).catch(() => {});
|
||||
if (leaseToken) {
|
||||
await require("./keyManager").releaseKey(leaseToken).catch((err) => {
|
||||
console.error(`[aiTaskWorker] failed to release lease for unrecoverable task ${taskId}:`, err.message);
|
||||
});
|
||||
}
|
||||
await publishTaskEvent(formatTaskEvent(task));
|
||||
await createTaskLifecycleNotification(task).catch((err) => {
|
||||
console.error(`[aiTaskWorker] notification error for unrecoverable task ${taskId}:`, err.message);
|
||||
});
|
||||
await refundTaskBillingOnFailure(taskId).catch((err) => {
|
||||
console.error(`[aiTaskWorker] refund error for unrecoverable task ${taskId}:`, err.message);
|
||||
});
|
||||
console.warn(`[aiTaskWorker] released unrecoverable task ${taskId}: ${message}`);
|
||||
return true;
|
||||
}
|
||||
|
||||
// --- Periodic stale task cleanup ---
|
||||
// Runs every 5 minutes, marks tasks stuck in 'pending'/'running' for too long as 'failed'.
|
||||
// This catches cases where the worker crashed, the provider API never responded,
|
||||
@@ -971,26 +1071,32 @@ async function runStaleTaskCleanup() {
|
||||
try {
|
||||
const { rows } = await pool.query(
|
||||
`UPDATE generation_tasks
|
||||
SET status = 'failed', error = '任务超时自动释放', updated_at = NOW()
|
||||
SET status = 'failed', error = '任务超时自动释放', completed_at = NOW(), updated_at = NOW()
|
||||
WHERE status IN ('pending', 'running')
|
||||
AND GREATEST(updated_at, COALESCE(last_poll_at, created_at)) < NOW() - INTERVAL '10 minutes'
|
||||
RETURNING id`,
|
||||
RETURNING *`,
|
||||
);
|
||||
for (const row of rows) {
|
||||
await publishTaskEvent({
|
||||
taskId: row.id,
|
||||
status: "failed",
|
||||
progress: null,
|
||||
resultUrl: null,
|
||||
error: "任务超时自动释放",
|
||||
});
|
||||
// Also stop any active poller for this task
|
||||
const poller = activePollers.get(row.id);
|
||||
if (poller) {
|
||||
clearInterval(poller.interval);
|
||||
activePollers.delete(row.id);
|
||||
}
|
||||
const leaseToken = poller?.leaseToken || await getPersistedLeaseToken(row.id).catch(() => null);
|
||||
await clearPollerState(row.id).catch(() => {});
|
||||
if (leaseToken) {
|
||||
await require("./keyManager").releaseKey(leaseToken).catch((err) => {
|
||||
console.error(`[aiTaskWorker] failed to release lease for stale task ${row.id}:`, err.message);
|
||||
});
|
||||
}
|
||||
await publishTaskEvent(formatTaskEvent(row));
|
||||
await createTaskLifecycleNotification(row).catch((err) => {
|
||||
console.error(`[aiTaskWorker] notification error for stale task ${row.id}:`, err.message);
|
||||
});
|
||||
await refundTaskBillingOnFailure(row.id).catch((err) => {
|
||||
console.error(`[aiTaskWorker] refund error for stale task ${row.id}:`, err.message);
|
||||
});
|
||||
}
|
||||
if (rows.length > 0) {
|
||||
console.log(`[aiTaskWorker] Cleaned up ${rows.length} stale task(s)`);
|
||||
@@ -1064,10 +1170,14 @@ function startStaleTaskCleanup() {
|
||||
if (staleTaskCleanupTimer) return;
|
||||
staleTaskCleanupTimer = setInterval(runStaleTaskCleanup, STALE_TASK_CLEANUP_INTERVAL_MS);
|
||||
// Run once shortly after startup
|
||||
setTimeout(runStaleTaskCleanup, 10_000);
|
||||
staleTaskCleanupStartupTimer = setTimeout(runStaleTaskCleanup, 10_000);
|
||||
}
|
||||
|
||||
function stopStaleTaskCleanup() {
|
||||
if (staleTaskCleanupStartupTimer) {
|
||||
clearTimeout(staleTaskCleanupStartupTimer);
|
||||
staleTaskCleanupStartupTimer = null;
|
||||
}
|
||||
if (staleTaskCleanupTimer) {
|
||||
clearInterval(staleTaskCleanupTimer);
|
||||
staleTaskCleanupTimer = null;
|
||||
@@ -1093,9 +1203,20 @@ function stopPollerRecovery() {
|
||||
}
|
||||
}
|
||||
|
||||
async function stopAllPollers() {
|
||||
for (const [taskId, poller] of activePollers.entries()) {
|
||||
clearInterval(poller.interval);
|
||||
activePollers.delete(taskId);
|
||||
}
|
||||
await orphanOwnedPollerState().catch((err) => {
|
||||
console.error("[aiTaskWorker] failed to orphan owned poller state:", err.message);
|
||||
});
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
startPolling,
|
||||
stopPolling,
|
||||
stopAllPollers,
|
||||
cancelTaskRuntimeState,
|
||||
updateTaskInDb,
|
||||
getActiveCount,
|
||||
|
||||
+50
-25
@@ -3,8 +3,17 @@ const express = require('express')
|
||||
const rateLimit = require('express-rate-limit')
|
||||
const cors = require('cors')
|
||||
const helmet = require('helmet')
|
||||
const { startSettlementWorker } = require('./settlementWorker')
|
||||
const { startProviderHealthMonitor } = require('./providerHealthMonitor')
|
||||
const { startSettlementWorker, stopSettlementWorker } = require('./settlementWorker')
|
||||
const { startProviderHealthMonitor, stopProviderHealthMonitor } = require('./providerHealthMonitor')
|
||||
const {
|
||||
startStaleTaskCleanup,
|
||||
startTaskEventListener,
|
||||
startPollerRecovery,
|
||||
stopStaleTaskCleanup,
|
||||
stopTaskEventListener,
|
||||
stopPollerRecovery,
|
||||
stopAllPollers,
|
||||
} = require('./aiTaskWorker')
|
||||
const { ensureDatabase } = require('./dbSetup')
|
||||
const { assertRuntimeSecurityConfig } = require('./securityConfig')
|
||||
const { loadPriceCache } = require('./pricing')
|
||||
@@ -17,6 +26,7 @@ const PORT = Number(process.env.PORT) || 3600
|
||||
const HOST = process.env.HOST || '0.0.0.0'
|
||||
const IS_PRODUCTION = process.env.NODE_ENV === 'production'
|
||||
let server = null
|
||||
let staleLeaseCleanupTimer = null
|
||||
|
||||
// CORS: in production, require explicit allowlist; in dev, allow all with credentials
|
||||
function buildCorsOptions() {
|
||||
@@ -133,18 +143,18 @@ async function main() {
|
||||
|
||||
// Periodic stale lease cleanup (every 5 min)
|
||||
const { cleanStaleLeases } = require('./keyManager')
|
||||
setInterval(() => {
|
||||
staleLeaseCleanupTimer = setInterval(() => {
|
||||
cleanStaleLeases().then((cleaned) => {
|
||||
if (cleaned > 0) console.log(`[cleanup] Released ${cleaned} stale lease(s)`)
|
||||
}).catch((err) => {
|
||||
console.error('[cleanup] error:', err)
|
||||
})
|
||||
}, 5 * 60 * 1000)
|
||||
if (staleLeaseCleanupTimer.unref) staleLeaseCleanupTimer.unref()
|
||||
|
||||
startSettlementWorker()
|
||||
startProviderHealthMonitor()
|
||||
|
||||
const { startStaleTaskCleanup, startTaskEventListener, startPollerRecovery } = require('./aiTaskWorker')
|
||||
await startTaskEventListener()
|
||||
startPollerRecovery()
|
||||
startStaleTaskCleanup()
|
||||
@@ -175,32 +185,47 @@ process.on('uncaughtException', (err) => {
|
||||
// ── Graceful shutdown ───────────────────────────────────────────────────
|
||||
let shuttingDown = false
|
||||
|
||||
function gracefulShutdown(signal) {
|
||||
async function shutdownRuntimeState() {
|
||||
if (staleLeaseCleanupTimer) {
|
||||
clearInterval(staleLeaseCleanupTimer)
|
||||
staleLeaseCleanupTimer = null
|
||||
}
|
||||
stopSettlementWorker()
|
||||
stopProviderHealthMonitor()
|
||||
stopPollerRecovery()
|
||||
stopStaleTaskCleanup()
|
||||
await Promise.allSettled([stopTaskEventListener(), stopAllPollers()])
|
||||
}
|
||||
|
||||
function closeServer() {
|
||||
if (!server || !server.listening) return Promise.resolve()
|
||||
return new Promise((resolve) => {
|
||||
server.close(() => {
|
||||
console.log('[shutdown] Server closed, cleaning up...')
|
||||
resolve()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
async function gracefulShutdown(signal) {
|
||||
if (shuttingDown) return
|
||||
shuttingDown = true
|
||||
console.log('[shutdown] Received ' + signal + ', draining connections...')
|
||||
|
||||
if (server && server.listening) {
|
||||
server.close(() => {
|
||||
console.log('[shutdown] Server closed, cleaning up...')
|
||||
const { stopProviderHealthMonitor } = require('./providerHealthMonitor')
|
||||
stopProviderHealthMonitor()
|
||||
const { stopTaskEventListener, stopPollerRecovery } = require('./aiTaskWorker')
|
||||
stopPollerRecovery()
|
||||
void stopTaskEventListener()
|
||||
const { pool } = require('./db')
|
||||
pool.end().then(() => {
|
||||
console.log('[shutdown] Database pool closed')
|
||||
process.exit(0)
|
||||
}).catch(() => process.exit(0))
|
||||
})
|
||||
setTimeout(() => {
|
||||
console.error('[shutdown] Forced exit after timeout')
|
||||
process.exit(1)
|
||||
}, 15000).unref()
|
||||
|
||||
// Force exit after timeout
|
||||
setTimeout(() => {
|
||||
console.error('[shutdown] Forced exit after timeout')
|
||||
process.exit(1)
|
||||
}, 15000).unref()
|
||||
} else {
|
||||
try {
|
||||
await shutdownRuntimeState()
|
||||
await closeServer()
|
||||
const { pool } = require('./db')
|
||||
await pool.end()
|
||||
console.log('[shutdown] Database pool closed')
|
||||
process.exit(0)
|
||||
} catch (err) {
|
||||
console.error('[shutdown] error:', err)
|
||||
process.exit(0)
|
||||
}
|
||||
}
|
||||
|
||||
+9
-3
@@ -284,7 +284,7 @@ async function releaseLeaseInternal(leaseToken, user, options = {}) {
|
||||
const { rows } = await client.query(
|
||||
`
|
||||
WITH candidate AS (
|
||||
SELECT l.id, l.key_id, k.provider
|
||||
SELECT l.id, l.key_id, l.user_id, l.enterprise_id, k.provider
|
||||
FROM key_leases l
|
||||
JOIN api_keys k ON k.id = l.key_id
|
||||
WHERE l.lease_token = $1 AND l.released_at IS NULL
|
||||
@@ -297,7 +297,7 @@ async function releaseLeaseInternal(leaseToken, user, options = {}) {
|
||||
WHERE id = (SELECT id FROM candidate)
|
||||
RETURNING id, key_id
|
||||
)
|
||||
SELECT r.id, r.key_id, c.provider
|
||||
SELECT r.id, r.key_id, c.user_id AS lease_user_id, c.enterprise_id AS lease_enterprise_id, c.provider
|
||||
FROM released r
|
||||
JOIN candidate c ON c.key_id = r.key_id
|
||||
`,
|
||||
@@ -339,7 +339,13 @@ async function releaseLeaseInternal(leaseToken, user, options = {}) {
|
||||
INSERT INTO usage_logs (user_id, enterprise_id, provider, key_id, action)
|
||||
VALUES ($1, $2, (SELECT provider FROM api_keys WHERE id = $3), $4, $5)
|
||||
`,
|
||||
[userId, enterpriseId, lease.key_id, lease.key_id, "release"],
|
||||
[
|
||||
userId || lease.lease_user_id,
|
||||
enterpriseId || lease.lease_enterprise_id,
|
||||
lease.key_id,
|
||||
lease.key_id,
|
||||
"release",
|
||||
],
|
||||
);
|
||||
|
||||
return {
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
"use strict";
|
||||
|
||||
const crypto = require("node:crypto");
|
||||
const { pool } = require("./db");
|
||||
|
||||
const DEFAULT_MAX_CONCURRENCY = 8;
|
||||
const DEFAULT_SLOT_TTL_MS = 30_000;
|
||||
const POLL_SCOPE = "generation-provider-poll:global";
|
||||
const OWNER_ID = `${process.pid}-${crypto.randomUUID()}`;
|
||||
|
||||
let storeReady = null;
|
||||
|
||||
function normalizePositiveInteger(value, fallback) {
|
||||
const numeric = Number(value);
|
||||
if (!Number.isFinite(numeric) || numeric <= 0) return fallback;
|
||||
return Math.max(1, Math.trunc(numeric));
|
||||
}
|
||||
|
||||
function getMaxConcurrency() {
|
||||
return normalizePositiveInteger(process.env.TASK_PROVIDER_POLL_MAX_CONCURRENCY, DEFAULT_MAX_CONCURRENCY);
|
||||
}
|
||||
|
||||
function getSlotTtlInterval() {
|
||||
const ttlMs = normalizePositiveInteger(process.env.TASK_PROVIDER_POLL_SLOT_TTL_MS, DEFAULT_SLOT_TTL_MS);
|
||||
return `${Math.max(1, Math.ceil(ttlMs / 1000))} seconds`;
|
||||
}
|
||||
|
||||
async function ensureProviderPollLimiterStore() {
|
||||
if (storeReady) return storeReady;
|
||||
storeReady = pool.query(`
|
||||
CREATE TABLE IF NOT EXISTS generation_provider_poll_slots (
|
||||
scope TEXT NOT NULL,
|
||||
slot_no INTEGER NOT NULL,
|
||||
owner_id TEXT NOT NULL,
|
||||
task_id INTEGER,
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
acquired_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
PRIMARY KEY (scope, slot_no)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_generation_provider_poll_slots_expires
|
||||
ON generation_provider_poll_slots(expires_at);
|
||||
`).catch((err) => {
|
||||
storeReady = null;
|
||||
throw err;
|
||||
});
|
||||
return storeReady;
|
||||
}
|
||||
|
||||
async function acquireProviderPollSlot(taskId, options = {}) {
|
||||
await ensureProviderPollLimiterStore();
|
||||
|
||||
const scope = options.scope || POLL_SCOPE;
|
||||
const maxConcurrency = normalizePositiveInteger(options.maxConcurrency, getMaxConcurrency());
|
||||
const ttlInterval = options.ttlInterval || getSlotTtlInterval();
|
||||
const { rows } = await pool.query(
|
||||
`
|
||||
WITH candidate AS (
|
||||
SELECT s.slot_no
|
||||
FROM generate_series(1, $2::integer) AS s(slot_no)
|
||||
LEFT JOIN generation_provider_poll_slots l
|
||||
ON l.scope = $1 AND l.slot_no = s.slot_no
|
||||
WHERE l.scope IS NULL OR l.expires_at < NOW()
|
||||
ORDER BY s.slot_no ASC
|
||||
LIMIT 1
|
||||
),
|
||||
claimed AS (
|
||||
INSERT INTO generation_provider_poll_slots (
|
||||
scope, slot_no, owner_id, task_id, expires_at, acquired_at, updated_at
|
||||
)
|
||||
SELECT $1, slot_no, $3, $4, NOW() + ($5::text)::interval, NOW(), NOW()
|
||||
FROM candidate
|
||||
ON CONFLICT (scope, slot_no) DO UPDATE SET
|
||||
owner_id = EXCLUDED.owner_id,
|
||||
task_id = EXCLUDED.task_id,
|
||||
expires_at = EXCLUDED.expires_at,
|
||||
acquired_at = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE generation_provider_poll_slots.expires_at < NOW()
|
||||
RETURNING scope, slot_no
|
||||
)
|
||||
SELECT scope, slot_no FROM claimed
|
||||
`,
|
||||
[scope, maxConcurrency, OWNER_ID, taskId || null, ttlInterval],
|
||||
);
|
||||
|
||||
const slot = rows[0];
|
||||
return slot ? { scope: slot.scope, slotNo: slot.slot_no, ownerId: OWNER_ID } : null;
|
||||
}
|
||||
|
||||
async function releaseProviderPollSlot(slot) {
|
||||
if (!slot?.scope || !slot?.slotNo) return;
|
||||
await ensureProviderPollLimiterStore();
|
||||
await pool.query(
|
||||
"DELETE FROM generation_provider_poll_slots WHERE scope = $1 AND slot_no = $2 AND owner_id = $3",
|
||||
[slot.scope, slot.slotNo, slot.ownerId || OWNER_ID],
|
||||
);
|
||||
}
|
||||
|
||||
async function withProviderPollSlot(taskId, fn, options = {}) {
|
||||
const slot = await acquireProviderPollSlot(taskId, options);
|
||||
if (!slot) return { acquired: false, value: undefined };
|
||||
|
||||
try {
|
||||
return { acquired: true, value: await fn() };
|
||||
} finally {
|
||||
await releaseProviderPollSlot(slot).catch((err) => {
|
||||
console.error(`[providerPollLimiter] failed to release poll slot ${slot.scope}:${slot.slotNo}:`, err.message);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
acquireProviderPollSlot,
|
||||
ensureProviderPollLimiterStore,
|
||||
getMaxConcurrency,
|
||||
normalizePositiveInteger,
|
||||
releaseProviderPollSlot,
|
||||
withProviderPollSlot,
|
||||
};
|
||||
+22
-2
@@ -1072,6 +1072,16 @@ function registerAiRoutes(router) {
|
||||
error.costCents = billingResult.costCents;
|
||||
throw error;
|
||||
}
|
||||
if (billingResult.costCents > 0) {
|
||||
await client.query(
|
||||
"UPDATE generation_tasks SET cost_cents = $1, billing_target = $2, billing_refunded = 0, updated_at = NOW() WHERE id = $3",
|
||||
[
|
||||
billingResult.costCents,
|
||||
billingResult.deductionType === "enterprise_image_flat" ? "enterprise_image" : "user",
|
||||
nextTaskRow.id,
|
||||
],
|
||||
);
|
||||
}
|
||||
return { taskRow: nextTaskRow, imageBilling: billingResult };
|
||||
});
|
||||
const preauth = { authorized: true, estimatedCostCents: 0, billingMode: imageBilling.deductionType };
|
||||
@@ -1086,9 +1096,11 @@ function registerAiRoutes(router) {
|
||||
},
|
||||
providerDebug: buildImageProviderDebug(model),
|
||||
});
|
||||
submitImageWithProviderFallback(taskRow.id, providerCandidates, req.user, preauth, params).catch((err) => {
|
||||
submitImageWithProviderFallback(taskRow.id, providerCandidates, req.user, preauth, params).catch(async (err) => {
|
||||
console.error("[ai/image] submit error:", err.message);
|
||||
updateTaskInDb(taskRow.id, { status: "failed", error: err.message });
|
||||
await updateTaskInDb(taskRow.id, { status: "failed", error: err.message }).catch((updateErr) => {
|
||||
console.error(`[ai/image] failed to persist task ${taskRow.id} failure:`, updateErr.message);
|
||||
});
|
||||
});
|
||||
} catch (err) {
|
||||
console.error("[ai/image] error:", err.message);
|
||||
@@ -1200,6 +1212,10 @@ function registerAiRoutes(router) {
|
||||
...enterpriseBilling,
|
||||
taskId: nextTaskRow.id,
|
||||
});
|
||||
await client.query(
|
||||
"UPDATE generation_tasks SET cost_cents = $1, billing_target = 'enterprise_video', billing_refunded = 0, updated_at = NOW() WHERE id = $2",
|
||||
[nextBilling.amountCents, nextTaskRow.id],
|
||||
);
|
||||
return { taskRow: nextTaskRow, reservedBilling: nextBilling, regularBilling: null };
|
||||
}
|
||||
// Regular user: deduct from personal balance
|
||||
@@ -1222,6 +1238,10 @@ function registerAiRoutes(router) {
|
||||
"INSERT INTO transactions (user_id, type, amount_cents, balance_after_cents, description) VALUES ($1, 'deduct', $2, $3, $4)",
|
||||
[req.user.id, -costCents, deducted.balance_cents, `视频生成扣费 ${credits} 积分`],
|
||||
);
|
||||
await client.query(
|
||||
"UPDATE generation_tasks SET cost_cents = $1, billing_target = 'user', billing_refunded = 0, updated_at = NOW() WHERE id = $2",
|
||||
[costCents, nextTaskRow.id],
|
||||
);
|
||||
return { taskRow: nextTaskRow, reservedBilling: null, regularBilling: { costCents, balanceAfterCents: deducted.balance_cents, credits } };
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user