1 Commits

Author SHA1 Message Date
stringadmin 2249f7ce36 Harden generation task lifecycle 2026-06-04 21:08:14 +08:00
15 changed files with 350 additions and 853 deletions
+276 -351
View File
@@ -5,6 +5,8 @@ const { EventEmitter } = require("node:events");
const { pool } = require("./db"); const { pool } = require("./db");
const { refundTaskBillingOnFailure } = require("./billing"); const { refundTaskBillingOnFailure } = require("./billing");
const { putObject, isOssConfigured } = require("./ossClient"); const { putObject, isOssConfigured } = require("./ossClient");
const keyManager = require("./keyManager");
const { resolveImageProviderCandidates, resolveVideoProvider } = require("./aiProviderRouter");
const taskEvents = new EventEmitter(); const taskEvents = new EventEmitter();
taskEvents.setMaxListeners(200); taskEvents.setMaxListeners(200);
@@ -13,15 +15,10 @@ const activePollers = new Map();
const POLL_INTERVAL_MS = 3000; const POLL_INTERVAL_MS = 3000;
const MAX_POLL_ATTEMPTS = 120; const MAX_POLL_ATTEMPTS = 120;
const GRS_IMAGE_MAX_POLL_ATTEMPTS = Number(process.env.GRSAI_IMAGE_MAX_POLL_ATTEMPTS || 60); const GRS_IMAGE_MAX_POLL_ATTEMPTS = Number(process.env.GRSAI_IMAGE_MAX_POLL_ATTEMPTS || 60);
const TASK_EVENT_CHANNEL = "generation_task_events"; const STALE_TASK_TIMEOUT_MINUTES = Math.max(10, Number(process.env.STALE_GENERATION_TASK_MINUTES || 120));
const TASK_EVENT_ORIGIN = `${process.pid}-${crypto.randomUUID()}`; const RESULT_PERSIST_RETRY_LIMIT = Math.max(1, Number(process.env.RESULT_PERSIST_RETRY_LIMIT || 5));
const POLLER_OWNER_ID = `${process.pid}-${crypto.randomUUID()}`; const RESULT_PERSIST_RETRY_BATCH_SIZE = Math.max(1, Number(process.env.RESULT_PERSIST_RETRY_BATCH_SIZE || 25));
const POLLER_OWNER_STALE_MS = Number(process.env.TASK_POLLER_OWNER_STALE_MS || 20_000); const TASK_STARTUP_RECOVERY_LIMIT = Math.max(1, Number(process.env.TASK_STARTUP_RECOVERY_LIMIT || 50));
const POLLER_RECOVERY_INTERVAL_MS = Number(process.env.TASK_POLLER_RECOVERY_INTERVAL_MS || 30_000);
let taskEventListenerClient = null;
let taskEventListenerStarting = null;
let pollerStoreReady = null;
let pollerRecoveryTimer = null;
function normalizeTaskProgress(value) { function normalizeTaskProgress(value) {
const numeric = Number(value); const numeric = Number(value);
@@ -39,156 +36,6 @@ function formatTaskEvent(row) {
}; };
} }
function emitTaskEvent(event) {
if (!event?.taskId) return;
taskEvents.emit(`task:${event.taskId}`, event);
}
async function publishTaskEvent(event) {
if (!event?.taskId) return;
emitTaskEvent(event);
try {
await pool.query("SELECT pg_notify($1, $2)", [
TASK_EVENT_CHANNEL,
JSON.stringify({ origin: TASK_EVENT_ORIGIN, event }),
]);
} catch (err) {
console.error(`[aiTaskWorker] task event publish failed for task ${event.taskId}:`, err.message);
}
}
function serializeProviderConfig(providerConfig) {
if (!providerConfig || typeof providerConfig !== "object") return {};
const allowedKeys = [
"provider",
"transport",
"protocol",
"baseUrl",
"endpoint",
"resultEndpoint",
"model",
"requestedModel",
];
const result = {};
for (const key of allowedKeys) {
if (providerConfig[key] !== undefined) result[key] = providerConfig[key];
}
return result;
}
function parseProviderConfig(value) {
if (!value) return {};
if (typeof value === "object") return value;
try {
const parsed = JSON.parse(value);
return parsed && typeof parsed === "object" ? parsed : {};
} catch {
return {};
}
}
async function ensureTaskPollerStore() {
if (pollerStoreReady) return pollerStoreReady;
pollerStoreReady = pool.query(`
CREATE TABLE IF NOT EXISTS generation_task_pollers (
task_id INTEGER PRIMARY KEY REFERENCES generation_tasks(id) ON DELETE CASCADE,
provider_task_id TEXT NOT NULL,
task_type TEXT NOT NULL,
provider_config_json TEXT NOT NULL,
lease_token TEXT,
owner_id TEXT,
owner_heartbeat_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_generation_task_pollers_owner
ON generation_task_pollers(owner_heartbeat_at);
`).catch((err) => {
pollerStoreReady = null;
throw err;
});
return pollerStoreReady;
}
async function persistPollerState(taskDbId, { providerTaskId, type, providerConfig, leaseToken }) {
await ensureTaskPollerStore();
await pool.query(
`
INSERT INTO generation_task_pollers (
task_id, provider_task_id, task_type, provider_config_json, lease_token,
owner_id, owner_heartbeat_at, updated_at
)
VALUES ($1, $2, $3, $4, $5, $6, NOW(), NOW())
ON CONFLICT (task_id) DO UPDATE SET
provider_task_id = EXCLUDED.provider_task_id,
task_type = EXCLUDED.task_type,
provider_config_json = EXCLUDED.provider_config_json,
lease_token = EXCLUDED.lease_token,
owner_id = EXCLUDED.owner_id,
owner_heartbeat_at = NOW(),
updated_at = NOW()
`,
[
taskDbId,
providerTaskId,
type,
JSON.stringify(serializeProviderConfig(providerConfig)),
leaseToken || null,
POLLER_OWNER_ID,
],
);
}
async function refreshPollerHeartbeat(taskDbId) {
await ensureTaskPollerStore();
await pool.query(
"UPDATE generation_task_pollers SET owner_id = $1, owner_heartbeat_at = NOW(), updated_at = NOW() WHERE task_id = $2",
[POLLER_OWNER_ID, taskDbId],
);
}
async function clearPollerState(taskDbId) {
await ensureTaskPollerStore();
await pool.query("DELETE FROM generation_task_pollers WHERE task_id = $1", [taskDbId]);
}
async function getLeaseKey(leaseToken) {
if (!leaseToken) return null;
const { rows } = await pool.query(
`
SELECT k.api_key
FROM key_leases l
JOIN api_keys k ON k.id = l.key_id
WHERE l.lease_token = $1
AND l.released_at IS NULL
AND k.enabled = 1
LIMIT 1
`,
[leaseToken],
);
const apiKey = rows[0]?.api_key;
return apiKey === "pool-slot" ? "" : apiKey || null;
}
async function claimPoller(taskId) {
await ensureTaskPollerStore();
const staleInterval = `${Math.max(5, Math.ceil(POLLER_OWNER_STALE_MS / 1000))} seconds`;
const { rows } = await pool.query(
`
UPDATE generation_task_pollers
SET owner_id = $1, owner_heartbeat_at = NOW(), updated_at = NOW()
WHERE task_id = $2
AND (
owner_heartbeat_at IS NULL
OR owner_heartbeat_at < NOW() - ($3::text)::interval
)
RETURNING *
`,
[POLLER_OWNER_ID, taskId, staleInterval],
);
return rows[0] || null;
}
async function createTaskLifecycleNotification(task) { async function createTaskLifecycleNotification(task) {
if (!task || !task.user_id || !task.id) return; if (!task || !task.user_id || !task.id) return;
@@ -251,14 +98,15 @@ async function updateTaskInDb(taskId, updates) {
if (fields.length === 0) return; if (fields.length === 0) return;
values.push(taskId); values.push(taskId);
const protectCancelled = nextUpdates.status !== "cancelled" ? " AND status <> 'cancelled'" : "";
const { rows } = await pool.query( const { rows } = await pool.query(
`UPDATE generation_tasks SET ${fields.join(", ")} WHERE id = $${idx} RETURNING *`, `UPDATE generation_tasks SET ${fields.join(", ")} WHERE id = $${idx}${protectCancelled} RETURNING *`,
values, values,
); );
let updatedTask = rows[0]; let updatedTask = rows[0];
if (updatedTask) { if (updatedTask) {
await publishTaskEvent(formatTaskEvent(updatedTask)); taskEvents.emit(`task:${taskId}`, formatTaskEvent(updatedTask));
} }
if (nextUpdates.status === "completed" && updatedTask?.result_url) { if (nextUpdates.status === "completed" && updatedTask?.result_url) {
@@ -283,20 +131,66 @@ function persistTaskResultUrlToOssInBackground(task) {
Promise.resolve() Promise.resolve()
.then(async () => { .then(async () => {
const durableUrl = await persistResultUrlToOss(task); await persistTaskResultUrlToOss(task);
if (!durableUrl || durableUrl === task.result_url) return;
await pool.query(
"UPDATE generation_tasks SET result_url = $1, updated_at = NOW() WHERE id = $2 AND result_url = $3",
[durableUrl, task.id, task.result_url],
);
console.info(`[aiTaskWorker] task ${task.id} result persisted to OSS after completion`);
}) })
.catch((error) => { .catch((error) => {
console.warn(`[aiTaskWorker] background result persistence failed for task ${task.id}:`, error.message); console.warn(`[aiTaskWorker] background result persistence failed for task ${task.id}:`, error.message);
}); });
} }
async function markResultPersistence(taskId, status, error = null, durableUrl = null, previousUrl = null) {
const fields = [
"result_persist_status = $1",
"result_persist_attempts = result_persist_attempts + 1",
"result_persist_error = $2",
"updated_at = NOW()",
];
const values = [status, error ? String(error).slice(0, 1000) : null];
let idx = values.length + 1;
if (status === "succeeded") {
fields.push("result_persisted_at = NOW()");
}
if (durableUrl) {
fields.push(`result_url = $${idx++}`);
values.push(durableUrl);
}
values.push(taskId);
let where = `id = $${idx}`;
if (previousUrl) {
idx += 1;
values.push(previousUrl);
where += ` AND result_url = $${idx}`;
}
await pool.query(`UPDATE generation_tasks SET ${fields.join(", ")} WHERE ${where}`, values);
}
async function persistTaskResultUrlToOss(task) {
if (!task?.id || !task?.result_url) return null;
if (isOwnPersistedResultUrl(task.result_url)) {
await markResultPersistence(task.id, "succeeded", null, null);
return task.result_url;
}
if (!isOssConfigured()) {
await markResultPersistence(task.id, "failed", "OSS is not configured");
return null;
}
const durableUrl = await persistResultUrlToOss(task);
if (!durableUrl) {
await markResultPersistence(task.id, "failed", "Result URL could not be copied to OSS");
return null;
}
await markResultPersistence(task.id, "succeeded", null, durableUrl, task.result_url);
console.info(`[aiTaskWorker] task ${task.id} result persisted to OSS after completion`);
return durableUrl;
}
function asObject(value) { function asObject(value) {
return value && typeof value === "object" && !Array.isArray(value) ? value : undefined; return value && typeof value === "object" && !Array.isArray(value) ? value : undefined;
} }
@@ -795,45 +689,43 @@ function getMaxPollAttempts(type, providerConfig) {
return MAX_POLL_ATTEMPTS; return MAX_POLL_ATTEMPTS;
} }
function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig, leaseToken, keyManager, onTaskFailed, skipPersist = false }) { async function releasePollingLease(poller) {
if (!poller?.leaseToken || !poller?.keyManager) return;
await poller.keyManager.releaseKey(poller.leaseToken).catch((err) => {
console.warn(`[aiTaskWorker] release lease failed for task ${poller.taskDbId}:`, err.message);
});
}
function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig, leaseToken, keyManager, onTaskFailed }) {
if (activePollers.has(taskDbId)) return; if (activePollers.has(taskDbId)) return;
if (!skipPersist) {
persistPollerState(taskDbId, { providerTaskId, type, providerConfig, leaseToken }).catch((err) => {
console.error(`[aiTaskWorker] failed to persist poller state for task ${taskDbId}:`, err.message);
});
}
let attempts = 0; let attempts = 0;
let polling = false;
const maxPollAttempts = getMaxPollAttempts(type, providerConfig); const maxPollAttempts = getMaxPollAttempts(type, providerConfig);
const interval = setInterval(async () => { const interval = setInterval(async () => {
if (polling) return;
polling = true;
attempts++; attempts++;
if (attempts > maxPollAttempts) {
clearInterval(interval);
activePollers.delete(taskDbId);
if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {});
if (typeof onTaskFailed === "function") {
const handled = await onTaskFailed("Task timed out").catch((fallbackErr) => {
console.error(`[aiTaskWorker] fallback error for task ${taskDbId}:`, fallbackErr.message);
return false;
});
if (handled) return;
}
await updateTaskInDb(taskDbId, { status: "failed", error: "Task timed out" });
await clearPollerState(taskDbId).catch(() => {});
return;
}
try { try {
// Check if task was cancelled by user if (attempts > maxPollAttempts) {
const { rows: [taskRow] } = await pool.query("SELECT status FROM generation_tasks WHERE id = $1", [taskDbId]); await stopPolling(taskDbId, { releaseLease: true });
if (!taskRow || taskRow.status === "cancelled") { if (typeof onTaskFailed === "function") {
clearInterval(interval); const handled = await onTaskFailed("Task timed out").catch((fallbackErr) => {
activePollers.delete(taskDbId); console.error(`[aiTaskWorker] fallback error for task ${taskDbId}:`, fallbackErr.message);
await clearPollerState(taskDbId).catch(() => {}); return false;
if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {}); });
if (handled) return;
}
await updateTaskInDb(taskDbId, { status: "failed", error: "Task timed out" });
return;
}
const { rows: [taskRow] } = await pool.query("SELECT status FROM generation_tasks WHERE id = $1", [taskDbId]);
if (!taskRow || taskRow.status === "cancelled") {
await stopPolling(taskDbId, { releaseLease: true });
return; return;
} }
await refreshPollerHeartbeat(taskDbId).catch(() => {});
let result; let result;
if (type === "image") { if (type === "image") {
@@ -847,9 +739,7 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig,
} }
if (result.status === "completed" || result.status === "failed") { if (result.status === "completed" || result.status === "failed") {
clearInterval(interval); await stopPolling(taskDbId, { releaseLease: true });
activePollers.delete(taskDbId);
if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {});
if (result.status === "failed" && typeof onTaskFailed === "function") { if (result.status === "failed" && typeof onTaskFailed === "function") {
const handled = await onTaskFailed(result.error || "Task failed").catch((fallbackErr) => { const handled = await onTaskFailed(result.error || "Task failed").catch((fallbackErr) => {
console.error(`[aiTaskWorker] fallback error for task ${taskDbId}:`, fallbackErr.message); console.error(`[aiTaskWorker] fallback error for task ${taskDbId}:`, fallbackErr.message);
@@ -860,179 +750,228 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig,
} }
await updateTaskInDb(taskDbId, result); await updateTaskInDb(taskDbId, result);
if (result.status === "completed" || result.status === "failed") {
await clearPollerState(taskDbId).catch(() => {});
}
} catch (err) { } catch (err) {
console.error(`[aiTaskWorker] poll error for task ${taskDbId}:`, err.message); console.error(`[aiTaskWorker] poll error for task ${taskDbId}:`, err.message);
} finally {
polling = false;
} }
}, POLL_INTERVAL_MS); }, POLL_INTERVAL_MS);
activePollers.set(taskDbId, { interval, leaseToken }); activePollers.set(taskDbId, { taskDbId, interval, leaseToken, keyManager });
} }
function stopPolling(taskDbId) { async function stopPolling(taskDbId, options = {}) {
const poller = activePollers.get(taskDbId); const poller = activePollers.get(taskDbId);
if (poller) { if (!poller) return;
clearInterval(poller.interval);
activePollers.delete(taskDbId); clearInterval(poller.interval);
activePollers.delete(taskDbId);
if (options.releaseLease) {
await releasePollingLease(poller);
} }
clearPollerState(taskDbId).catch(() => {}); }
async function cancelTask(taskId, userId) {
const { rows } = await pool.query(
`UPDATE generation_tasks
SET status = 'cancelled', completed_at = NOW(), updated_at = NOW()
WHERE id = $1 AND user_id = $2 AND status IN ('pending', 'running')
RETURNING *`,
[taskId, userId],
);
const task = rows[0];
if (!task) return null;
await stopPolling(task.id, { releaseLease: true });
taskEvents.emit(`task:${task.id}`, formatTaskEvent(task));
return task;
} }
function getActiveCount() { function getActiveCount() {
return activePollers.size; return activePollers.size;
} }
async function recoverRunnablePollers() { const STALE_TASK_CLEANUP_INTERVAL_MS = 5 * 60 * 1000;
await ensureTaskPollerStore(); let staleTaskCleanupTimer = null;
const staleInterval = `${Math.max(5, Math.ceil(POLLER_OWNER_STALE_MS / 1000))} seconds`; const TASK_RESULT_PERSIST_RETRY_INTERVAL_MS = 5 * 60 * 1000;
const { rows } = await pool.query( let taskResultPersistenceRetryTimer = null;
` let taskStartupRecoveryTimer = null;
SELECT p.task_id let taskStaleCleanupRunning = false;
FROM generation_task_pollers p let taskResultPersistenceRetryRunning = false;
JOIN generation_tasks t ON t.id = p.task_id let taskStartupRecoveryRunning = false;
WHERE t.status IN ('pending', 'running')
AND (
p.owner_heartbeat_at IS NULL
OR p.owner_heartbeat_at < NOW() - ($1::text)::interval
)
ORDER BY p.owner_heartbeat_at NULLS FIRST, p.updated_at ASC
LIMIT 20
`,
[staleInterval],
);
for (const row of rows) { function parseTaskParams(paramsJson) {
const taskId = row.task_id; if (!paramsJson) return {};
if (activePollers.has(taskId)) continue; if (typeof paramsJson === "object") return paramsJson;
const poller = await claimPoller(taskId); try {
if (!poller || activePollers.has(taskId)) continue; const parsed = JSON.parse(paramsJson);
return parsed && typeof parsed === "object" ? parsed : {};
const apiKey = await getLeaseKey(poller.lease_token); } catch {
if (apiKey == null) { return {};
console.warn(`[aiTaskWorker] cannot recover task ${taskId}: active lease not found`);
continue;
}
console.info(`[aiTaskWorker] recovering poller for task ${taskId}`);
startPolling(taskId, {
providerTaskId: poller.provider_task_id,
apiKey,
type: poller.task_type,
providerConfig: parseProviderConfig(poller.provider_config_json),
leaseToken: poller.lease_token,
keyManager: require("./keyManager"),
skipPersist: true,
});
} }
} }
// --- Periodic stale task cleanup --- function resolveProviderConfigForRecovery(task) {
// Runs every 5 minutes, marks tasks stuck in 'pending'/'running' for too long as 'failed'. const params = parseTaskParams(task.params_json);
// This catches cases where the worker crashed, the provider API never responded,
// or the cancel request failed silently on the client side. if (task.type === "video") {
const STALE_TASK_CLEANUP_INTERVAL_MS = 5 * 60 * 1000; if (params.model === "video-style-transform" || params.operation === "video-style-super-resolution") {
let staleTaskCleanupTimer = null; return { provider: "dashscope", protocol: "wan-i2v", baseUrl: "https://dashscope.aliyuncs.com" };
}
if (params.model === "aliyun-video-super-resolve" || params.model === "aliyun-erase-subtitles") {
return null;
}
return resolveVideoProvider(params.model);
}
if (task.type === "image") {
if (params.operation === "image-super-resolution" || params.operation === "image-edit") {
return { provider: "dashscope", transport: "dashscope-image" };
}
const candidates = resolveImageProviderCandidates(params.model);
return candidates[0] || null;
}
return null;
}
function normalizeRecoveryUser(task) {
return {
id: task.user_id,
enterpriseId: task.enterprise_id ?? null,
accountType: task.enterprise_id ? "enterprise" : "personal",
};
}
async function runStaleTaskCleanup() { async function runStaleTaskCleanup() {
if (taskStaleCleanupRunning) return;
taskStaleCleanupRunning = true;
try { try {
const { rows } = await pool.query( const { rows } = await pool.query(
`UPDATE generation_tasks `UPDATE generation_tasks
SET status = 'failed', error = '任务超时自动释放', updated_at = NOW() SET status = 'failed',
error = 'Task timed out and was released automatically',
completed_at = NOW(),
updated_at = NOW()
WHERE status IN ('pending', 'running') WHERE status IN ('pending', 'running')
AND GREATEST(updated_at, COALESCE(last_poll_at, created_at)) < NOW() - INTERVAL '10 minutes' AND GREATEST(updated_at, COALESCE(last_poll_at, created_at)) < NOW() - ($1::int * INTERVAL '1 minute')
RETURNING id`, RETURNING *`,
[STALE_TASK_TIMEOUT_MINUTES],
); );
for (const row of rows) { for (const row of rows) {
await publishTaskEvent({ await stopPolling(row.id, { releaseLease: true });
taskId: row.id, taskEvents.emit(`task:${row.id}`, formatTaskEvent(row));
status: "failed", await refundTaskBillingOnFailure(row.id).catch((err) => {
progress: null, console.error(`[aiTaskWorker] stale task refund error for task ${row.id}:`, err.message);
resultUrl: null,
error: "任务超时自动释放",
}); });
// Also stop any active poller for this task
const poller = activePollers.get(row.id);
if (poller) {
clearInterval(poller.interval);
activePollers.delete(row.id);
}
await clearPollerState(row.id).catch(() => {});
} }
if (rows.length > 0) { if (rows.length > 0) {
console.log(`[aiTaskWorker] Cleaned up ${rows.length} stale task(s)`); console.log(`[aiTaskWorker] Cleaned up ${rows.length} stale task(s)`);
} }
} catch (err) { } catch (err) {
console.error("[aiTaskWorker] Stale task cleanup failed:", err.message); console.error("[aiTaskWorker] Stale task cleanup failed:", err.message);
}
}
async function startTaskEventListener() {
if (taskEventListenerClient) return;
if (taskEventListenerStarting) return taskEventListenerStarting;
taskEventListenerStarting = (async () => {
const client = await pool.connect();
let released = false;
const releaseClient = () => {
if (released) return;
released = true;
taskEventListenerClient = null;
try {
client.release();
} catch {}
};
client.on("notification", (message) => {
if (message.channel !== TASK_EVENT_CHANNEL || !message.payload) return;
try {
const payload = JSON.parse(message.payload);
if (payload?.origin === TASK_EVENT_ORIGIN) return;
emitTaskEvent(payload?.event || payload);
} catch (err) {
console.error("[aiTaskWorker] task event notification parse failed:", err.message);
}
});
client.on("error", (err) => {
console.error("[aiTaskWorker] task event listener error:", err.message);
releaseClient();
setTimeout(() => {
startTaskEventListener().catch((restartErr) => {
console.error("[aiTaskWorker] task event listener restart failed:", restartErr.message);
});
}, 5000).unref?.();
});
await client.query(`LISTEN ${TASK_EVENT_CHANNEL}`);
taskEventListenerClient = client;
console.log(`[aiTaskWorker] listening for task events on ${TASK_EVENT_CHANNEL}`);
})();
try {
await taskEventListenerStarting;
} finally { } finally {
taskEventListenerStarting = null; taskStaleCleanupRunning = false;
} }
} }
async function stopTaskEventListener() { async function runResultPersistenceRetry() {
const client = taskEventListenerClient; if (taskResultPersistenceRetryRunning) return;
taskEventListenerClient = null; taskResultPersistenceRetryRunning = true;
if (!client) return;
try { try {
await client.query(`UNLISTEN ${TASK_EVENT_CHANNEL}`); const { rows } = await pool.query(
} catch {} `SELECT *
client.release(); FROM generation_tasks
WHERE status = 'completed'
AND result_url IS NOT NULL
AND result_url ~* '^https?://'
AND result_url !~* '/users/[^/]+/generation-results/'
AND result_persist_status IN ('pending', 'failed')
AND result_persist_attempts < $1
ORDER BY updated_at ASC
LIMIT $2`,
[RESULT_PERSIST_RETRY_LIMIT, RESULT_PERSIST_RETRY_BATCH_SIZE],
);
for (const row of rows) {
await persistTaskResultUrlToOss(row);
}
if (rows.length > 0) {
console.log(`[aiTaskWorker] Retried OSS result persistence for ${rows.length} task(s)`);
}
} catch (err) {
console.error("[aiTaskWorker] Result persistence retry failed:", err.message);
} finally {
taskResultPersistenceRetryRunning = false;
}
}
async function runTaskStartupRecovery() {
if (taskStartupRecoveryRunning) return;
taskStartupRecoveryRunning = true;
try {
const { rows } = await pool.query(
`SELECT gt.*, u.enterprise_id
FROM generation_tasks gt
JOIN users u ON u.id = gt.user_id
WHERE gt.status = 'running'
AND gt.provider_task_id IS NOT NULL
AND GREATEST(gt.updated_at, COALESCE(gt.last_poll_at, gt.created_at)) >= NOW() - ($1::int * INTERVAL '1 minute')
ORDER BY gt.updated_at DESC
LIMIT $2`,
[STALE_TASK_TIMEOUT_MINUTES, TASK_STARTUP_RECOVERY_LIMIT],
);
let recovered = 0;
for (const task of rows) {
if (activePollers.has(task.id)) continue;
let providerConfig;
try {
providerConfig = resolveProviderConfigForRecovery(task);
} catch (err) {
console.warn(`[aiTaskWorker] task ${task.id} recovery skipped: ${err.message}`);
continue;
}
if (!providerConfig?.provider) continue;
const slotResult = await keyManager.acquireKey(providerConfig.provider, normalizeRecoveryUser(task), null, { waitTimeoutMs: 0 });
if (!slotResult) {
console.warn(`[aiTaskWorker] task ${task.id} recovery waiting for provider capacity: ${providerConfig.provider}`);
continue;
}
startPolling(task.id, {
providerTaskId: task.provider_task_id,
apiKey: slotResult.apiKey,
type: task.type,
providerConfig,
leaseToken: slotResult.leaseToken,
keyManager,
});
recovered += 1;
}
if (recovered > 0) {
console.log(`[aiTaskWorker] Recovered ${recovered} running task poller(s) after startup`);
}
} catch (err) {
console.error("[aiTaskWorker] Startup task recovery failed:", err.message);
} finally {
taskStartupRecoveryRunning = false;
}
} }
function startStaleTaskCleanup() { function startStaleTaskCleanup() {
if (staleTaskCleanupTimer) return; if (staleTaskCleanupTimer) return;
staleTaskCleanupTimer = setInterval(runStaleTaskCleanup, STALE_TASK_CLEANUP_INTERVAL_MS); staleTaskCleanupTimer = setInterval(runStaleTaskCleanup, STALE_TASK_CLEANUP_INTERVAL_MS);
// Run once shortly after startup taskResultPersistenceRetryTimer = setInterval(runResultPersistenceRetry, TASK_RESULT_PERSIST_RETRY_INTERVAL_MS);
taskStartupRecoveryTimer = setTimeout(runTaskStartupRecovery, 5_000);
setTimeout(runStaleTaskCleanup, 10_000); setTimeout(runStaleTaskCleanup, 10_000);
setTimeout(runResultPersistenceRetry, 15_000);
} }
function stopStaleTaskCleanup() { function stopStaleTaskCleanup() {
@@ -1040,30 +979,20 @@ function stopStaleTaskCleanup() {
clearInterval(staleTaskCleanupTimer); clearInterval(staleTaskCleanupTimer);
staleTaskCleanupTimer = null; staleTaskCleanupTimer = null;
} }
} if (taskResultPersistenceRetryTimer) {
clearInterval(taskResultPersistenceRetryTimer);
function startPollerRecovery() { taskResultPersistenceRetryTimer = null;
if (pollerRecoveryTimer) return; }
ensureTaskPollerStore() if (taskStartupRecoveryTimer) {
.then(() => recoverRunnablePollers()) clearTimeout(taskStartupRecoveryTimer);
.catch((err) => console.error("[aiTaskWorker] initial poller recovery failed:", err.message)); taskStartupRecoveryTimer = null;
pollerRecoveryTimer = setInterval(() => {
recoverRunnablePollers().catch((err) => {
console.error("[aiTaskWorker] poller recovery failed:", err.message);
});
}, POLLER_RECOVERY_INTERVAL_MS);
}
function stopPollerRecovery() {
if (pollerRecoveryTimer) {
clearInterval(pollerRecoveryTimer);
pollerRecoveryTimer = null;
} }
} }
module.exports = { module.exports = {
startPolling, startPolling,
stopPolling, stopPolling,
cancelTask,
updateTaskInDb, updateTaskInDb,
getActiveCount, getActiveCount,
extractProviderTaskId, extractProviderTaskId,
@@ -1073,10 +1002,6 @@ module.exports = {
parseKlingCredential, parseKlingCredential,
createKlingJwt, createKlingJwt,
taskEvents, taskEvents,
startTaskEventListener,
stopTaskEventListener,
startPollerRecovery,
stopPollerRecovery,
startStaleTaskCleanup, startStaleTaskCleanup,
stopStaleTaskCleanup, stopStaleTaskCleanup,
}; };
+25 -38
View File
@@ -1,44 +1,29 @@
/** /**
* Billing module — handles balance deduction, package quotas, * Billing module — handles balance deduction (in cents), package quotas,
* transactions, and key-lease pre-authorization. * transactions, and key-lease pre-authorization.
* *
* Unit conventions: * Money conventions:
* - payment_orders.amount_cents / packages.price_cents: cash amount in CNY cents. * - balance: cents (分, 1/100 CNY) — stored in users.balance_cents and enterprises.balance_cents
* - users.balance_cents / enterprises.balance_cents / transactions.amount_cents: * - prices: mills (厘, 1/1000 CNY) — stored in model_prices.*_mills
* credit units, where 100 units = 1 platform credit. * - cost calculation: mills → convert to cents at deduction time (divide by 10, floor)
* - model_prices.*_mills: CNY mills. 1 CNY = 100 credits, so 1 mill = 10 credit units. * - transactions: cents — amount_cents, balance_after_cents
* *
* Flow: * Flow:
* - Enterprise admin recharges enterprise pool → distributes to employee users * - Enterprise admin recharges enterprise pool → distributes to employee users
* - API deductions come from users.balance_cents (per-user credit balance) * - API deductions come from users.balance_cents (per-user)
* - Personal users recharge their own users.balance_cents directly * - Personal users recharge their own users.balance_cents directly
*/ */
const { pool, withTransaction } = require("./db"); const { pool, withTransaction } = require("./db");
const { calculateCostMills, getModelPrice } = require("./pricing"); const { calculateCostMills, getModelPrice } = require("./pricing");
const CREDIT_UNITS_PER_CREDIT = 100; const IMAGE_GENERATION_FLAT_COST_CENTS = 20;
const CREDIT_UNITS_PER_CNY_CENT = 100;
const CREDIT_UNITS_PER_CNY_MILL = 10;
const IMAGE_GENERATION_FLAT_COST_CENTS = 20 * CREDIT_UNITS_PER_CREDIT;
function creditsToCreditUnits(credits) {
return Math.max(0, Math.round(Number(credits || 0) * CREDIT_UNITS_PER_CREDIT));
}
function formatCreditsFromCents(amountCents) { function formatCreditsFromCents(amountCents) {
const value = Number(amountCents || 0) / CREDIT_UNITS_PER_CREDIT; const value = Number(amountCents || 0) / 100;
return Number.isInteger(value) ? String(value) : String(Number(value.toFixed(2))); return Number.isInteger(value) ? String(value) : String(Number(value.toFixed(2)));
} }
function cashCentsToCreditUnits(amountCents) {
return Math.max(0, Math.round(Number(amountCents || 0) * CREDIT_UNITS_PER_CNY_CENT));
}
function millsToCreditUnits(mills) {
return Math.max(0, Math.round(Number(mills || 0) * CREDIT_UNITS_PER_CNY_MILL));
}
async function recordEnterpriseCreditLedger(client, entry) { async function recordEnterpriseCreditLedger(client, entry) {
const enterpriseId = entry?.enterpriseId || null; const enterpriseId = entry?.enterpriseId || null;
const userId = entry?.userId || null; const userId = entry?.userId || null;
@@ -129,6 +114,10 @@ async function getEnterpriseName(enterpriseId) {
return rows[0] ? rows[0].name : null; return rows[0] ? rows[0].name : null;
} }
function millsToCents(mills) {
return Math.floor(mills / 10);
}
// ── Atomic balance helpers ─────────────────────────────────────────── // ── Atomic balance helpers ───────────────────────────────────────────
async function atomicDeductUserBalance(client, userId, amountCents) { async function atomicDeductUserBalance(client, userId, amountCents) {
@@ -178,7 +167,7 @@ async function preauthorizeCall(userId, provider) {
const { rows } = await pool.query( const { rows } = await pool.query(
` `
SELECT COALESCE(CAST(ROUND(AVG(cost_estimate * 10000)::numeric) AS INTEGER), 0) AS avg_cents SELECT COALESCE(CAST(ROUND(AVG(cost_estimate * 100)::numeric) AS INTEGER), 0) AS avg_cents
FROM api_call_logs FROM api_call_logs
WHERE provider = $1 WHERE provider = $1
AND status = 'success' AND status = 'success'
@@ -196,9 +185,10 @@ async function preauthorizeCall(userId, provider) {
const bufferedEstimate = Math.ceil(estimatedCostCents * 1.2); const bufferedEstimate = Math.ceil(estimatedCostCents * 1.2);
if (balanceCents < bufferedEstimate) { if (balanceCents < bufferedEstimate) {
const credits = Math.floor(balanceCents / 100);
return { return {
authorized: false, authorized: false,
message: `账户积分不足,请充值 (当前 ${formatCreditsFromCents(balanceCents)} 积分,预估需要 ${formatCreditsFromCents(bufferedEstimate)} 积分)`, message: `账户积分不足,请充值 (当前 ${credits} 积分,预估需要 ${Math.ceil(bufferedEstimate / 100)} 积分)`,
}; };
} }
@@ -215,9 +205,9 @@ async function deductForApiCall(userId, model, promptTokens, completionTokens) {
return { success: true, costCents: 0, deductionType: "none", message: "No pricing" }; return { success: true, costCents: 0, deductionType: "none", message: "No pricing" };
} }
const costCents = millsToCreditUnits(costMills); const costCents = millsToCents(costMills);
if (costCents <= 0) { if (costCents <= 0) {
return { success: true, costCents: 0, deductionType: "none", message: "Cost below minimum credit unit" }; return { success: true, costCents: 0, deductionType: "none", message: "Cost below 1 cent" };
} }
const billingState = await getUserBillingState(userId); const billingState = await getUserBillingState(userId);
@@ -418,7 +408,7 @@ async function tryDeductFromUserBalance(userId, enterpriseId, amountCents, ledge
userId, userId,
-amountCents, -amountCents,
newBal, newBal,
`API 调用扣费 ${formatCreditsFromCents(amountCents)} 积分`, `API 调用扣费 ${Math.ceil(amountCents / 100)} 积分`,
], ],
); );
@@ -439,15 +429,16 @@ async function tryDeductFromUserBalance(userId, enterpriseId, amountCents, ledge
if (newBalanceCents == null) { if (newBalanceCents == null) {
const currentBalance = await getUserBalanceCents(userId); const currentBalance = await getUserBalanceCents(userId);
const credits = Math.floor((currentBalance || 0) / 100);
return { return {
success: false, success: false,
message: `积分不足 (当前 ${formatCreditsFromCents(currentBalance || 0)} 积分,需要 ${formatCreditsFromCents(amountCents)} 积分)`, message: `积分不足 (当前 ${credits} 积分,需要 ${Math.ceil(amountCents / 100)} 积分)`,
}; };
} }
return { return {
success: true, success: true,
message: `Deducted ${formatCreditsFromCents(amountCents)} credits, balance: ${formatCreditsFromCents(newBalanceCents)} credits`, message: `Deducted ${Math.ceil(amountCents / 100)} credits, balance: ${Math.floor(newBalanceCents / 100)} credits`,
}; };
} }
@@ -493,7 +484,7 @@ async function settleLease(leaseId, actualCostCents) {
userId, userId,
-diffCents, -diffCents,
newBal, newBal,
`API 预估差额扣费 ${formatCreditsFromCents(diffCents)} 积分`, `API 预估差额扣费 ${Math.ceil(diffCents / 100)} 积分`,
], ],
); );
} }
@@ -512,7 +503,7 @@ async function settleLease(leaseId, actualCostCents) {
userId, userId,
refundCents, refundCents,
newBal, newBal,
`API 预估差额退回 ${formatCreditsFromCents(refundCents)} 积分`, `API 预估差额退回 ${Math.ceil(refundCents / 100)} 积分`,
], ],
); );
} }
@@ -637,7 +628,7 @@ async function distributeCredits(enterpriseId, targetUserId, amountCents, adminU
targetUserId, targetUserId,
amountCents, amountCents,
newUserBal, newUserBal,
`从企业池获得 ${formatCreditsFromCents(amountCents)} 积分`, `从企业池获得 ${Math.floor(amountCents / 100)} 积分`,
adminUserId, adminUserId,
], ],
); );
@@ -770,8 +761,4 @@ module.exports = {
preauthorizeCall, preauthorizeCall,
settleLease, settleLease,
forceSettleLease, forceSettleLease,
creditsToCreditUnits,
cashCentsToCreditUnits,
millsToCreditUnits,
formatCreditsFromCents,
}; };
+13
View File
@@ -353,6 +353,18 @@ async function migrateGenerationTasksBillingColumns(client) {
); );
} }
async function migrateGenerationTaskResultPersistence(client) {
await addColumnIfMissing("generation_tasks", "result_persist_status TEXT NOT NULL DEFAULT 'pending'");
await addColumnIfMissing("generation_tasks", "result_persist_attempts INTEGER NOT NULL DEFAULT 0");
await addColumnIfMissing("generation_tasks", "result_persist_error TEXT");
await addColumnIfMissing("generation_tasks", "result_persisted_at TIMESTAMPTZ");
await client.query(`
CREATE INDEX IF NOT EXISTS idx_generation_tasks_result_persist_retry
ON generation_tasks(result_persist_status, updated_at)
WHERE status = 'completed' AND result_url IS NOT NULL
`);
}
async function ensureModelPriceSeed() { async function ensureModelPriceSeed() {
const columns = await getColumnNames("model_prices"); const columns = await getColumnNames("model_prices");
const useMills = columns.includes("input_price_mills"); const useMills = columns.includes("input_price_mills");
@@ -959,6 +971,7 @@ async function ensureSchema() {
await runMigration("030_generation_tasks_user_status_index", migrateGenerationTasksUserStatusIndex); await runMigration("030_generation_tasks_user_status_index", migrateGenerationTasksUserStatusIndex);
await runMigration("031_generation_tasks_billing_columns", migrateGenerationTasksBillingColumns); await runMigration("031_generation_tasks_billing_columns", migrateGenerationTasksBillingColumns);
await runMigration("032_ecommerce_video_history", migrateEcommerceVideoHistorySchema); await runMigration("032_ecommerce_video_history", migrateEcommerceVideoHistorySchema);
await runMigration("033_generation_task_result_persistence", migrateGenerationTaskResultPersistence);
await ensureModelPriceSeed(); await ensureModelPriceSeed();
} }
+1 -6
View File
@@ -144,9 +144,7 @@ async function main() {
startSettlementWorker() startSettlementWorker()
startProviderHealthMonitor() startProviderHealthMonitor()
const { startStaleTaskCleanup, startTaskEventListener, startPollerRecovery } = require('./aiTaskWorker') const { startStaleTaskCleanup } = require('./aiTaskWorker')
await startTaskEventListener()
startPollerRecovery()
startStaleTaskCleanup() startStaleTaskCleanup()
server = app.listen(PORT, HOST, () => { server = app.listen(PORT, HOST, () => {
@@ -185,9 +183,6 @@ function gracefulShutdown(signal) {
console.log('[shutdown] Server closed, cleaning up...') console.log('[shutdown] Server closed, cleaning up...')
const { stopProviderHealthMonitor } = require('./providerHealthMonitor') const { stopProviderHealthMonitor } = require('./providerHealthMonitor')
stopProviderHealthMonitor() stopProviderHealthMonitor()
const { stopTaskEventListener, stopPollerRecovery } = require('./aiTaskWorker')
stopPollerRecovery()
void stopTaskEventListener()
const { pool } = require('./db') const { pool } = require('./db')
pool.end().then(() => { pool.end().then(() => {
console.log('[shutdown] Database pool closed') console.log('[shutdown] Database pool closed')
+3 -2
View File
@@ -284,7 +284,7 @@ async function releaseLeaseInternal(leaseToken, user, options = {}) {
const { rows } = await client.query( const { rows } = await client.query(
` `
WITH candidate AS ( WITH candidate AS (
SELECT l.id, l.key_id, k.provider SELECT l.id, l.key_id, l.user_id, l.enterprise_id, k.provider
FROM key_leases l FROM key_leases l
JOIN api_keys k ON k.id = l.key_id JOIN api_keys k ON k.id = l.key_id
WHERE l.lease_token = $1 AND l.released_at IS NULL WHERE l.lease_token = $1 AND l.released_at IS NULL
@@ -298,6 +298,7 @@ async function releaseLeaseInternal(leaseToken, user, options = {}) {
RETURNING id, key_id RETURNING id, key_id
) )
SELECT r.id, r.key_id, c.provider SELECT r.id, r.key_id, c.provider
, c.user_id, c.enterprise_id
FROM released r FROM released r
JOIN candidate c ON c.key_id = r.key_id JOIN candidate c ON c.key_id = r.key_id
`, `,
@@ -339,7 +340,7 @@ async function releaseLeaseInternal(leaseToken, user, options = {}) {
INSERT INTO usage_logs (user_id, enterprise_id, provider, key_id, action) INSERT INTO usage_logs (user_id, enterprise_id, provider, key_id, action)
VALUES ($1, $2, (SELECT provider FROM api_keys WHERE id = $3), $4, $5) VALUES ($1, $2, (SELECT provider FROM api_keys WHERE id = $3), $4, $5)
`, `,
[userId, enterpriseId, lease.key_id, lease.key_id, "release"], [userId ?? lease.user_id, enterpriseId ?? lease.enterprise_id, lease.key_id, lease.key_id, "release"],
); );
return { return {
+5 -7
View File
@@ -1,7 +1,7 @@
const fs = require("node:fs"); const fs = require("node:fs");
const { AlipaySdk } = require("alipay-sdk"); const { AlipaySdk } = require("alipay-sdk");
const { pool, withTransaction } = require("./db"); const { pool, withTransaction } = require("./db");
const { cashCentsToCreditUnits, creditBalance, creditUserBalance, activatePackage, formatCreditsFromCents } = require("./billing"); const { creditBalance, creditUserBalance, activatePackage } = require("./billing");
let alipayInstance = null; let alipayInstance = null;
@@ -130,19 +130,17 @@ async function handlePaymentSuccess(orderNo, tradeNo) {
); );
if (order.type === "personal_recharge" && order.user_id) { if (order.type === "personal_recharge" && order.user_id) {
const creditUnits = cashCentsToCreditUnits(order.amount_cents);
await creditUserBalance( await creditUserBalance(
order.user_id, order.user_id,
creditUnits, order.amount_cents,
`支付宝充值 ${formatCreditsFromCents(creditUnits)} 积分`, `支付宝充值 ${Math.floor(order.amount_cents / 100)} 积分`,
orderNo, orderNo,
); );
} else if (order.type === "recharge") { } else if (order.type === "recharge") {
const creditUnits = cashCentsToCreditUnits(order.amount_cents);
await creditBalance( await creditBalance(
order.enterprise_id, order.enterprise_id,
creditUnits, order.amount_cents,
`支付宝充值 ${formatCreditsFromCents(creditUnits)} 积分`, `支付宝充值 ${Math.floor(order.amount_cents / 100)} 积分`,
orderNo, orderNo,
); );
} else if (order.type === "package" && order.package_id) { } else if (order.type === "package" && order.package_id) {
+5 -7
View File
@@ -2,7 +2,7 @@ const _crypto = require("node:crypto");
const fs = require("node:fs"); const fs = require("node:fs");
const WxPay = require("wechatpay-node-v3"); const WxPay = require("wechatpay-node-v3");
const { pool, withTransaction } = require("./db"); const { pool, withTransaction } = require("./db");
const { cashCentsToCreditUnits, creditBalance, creditUserBalance, activatePackage, formatCreditsFromCents } = require("./billing"); const { creditBalance, creditUserBalance, activatePackage } = require("./billing");
let wxPayInstance = null; let wxPayInstance = null;
@@ -140,19 +140,17 @@ async function handlePaymentSuccess(orderNo, transactionId) {
); );
if (order.type === "personal_recharge" && order.user_id) { if (order.type === "personal_recharge" && order.user_id) {
const creditUnits = cashCentsToCreditUnits(order.amount_cents);
await creditUserBalance( await creditUserBalance(
order.user_id, order.user_id,
creditUnits, order.amount_cents,
`微信充值 ${formatCreditsFromCents(creditUnits)} 积分`, `微信充值 ${Math.floor(order.amount_cents / 100)} 积分`,
orderNo, orderNo,
); );
} else if (order.type === "recharge") { } else if (order.type === "recharge") {
const creditUnits = cashCentsToCreditUnits(order.amount_cents);
await creditBalance( await creditBalance(
order.enterprise_id, order.enterprise_id,
creditUnits, order.amount_cents,
`微信充值 ${formatCreditsFromCents(creditUnits)} 积分`, `微信充值 ${Math.floor(order.amount_cents / 100)} 积分`,
orderNo, orderNo,
); );
} else if (order.type === "package" && order.package_id) { } else if (order.type === "package" && order.package_id) {
+1 -1
View File
@@ -176,7 +176,7 @@ async function getAverageCostCents(provider) {
const { rows } = await pool.query( const { rows } = await pool.query(
` `
SELECT CAST(ROUND(AVG(CASE SELECT CAST(ROUND(AVG(CASE
WHEN cost_estimate IS NOT NULL THEN cost_estimate * 10000 WHEN cost_estimate IS NOT NULL THEN cost_estimate * 100
ELSE 0 ELSE 0
END)::numeric) AS INTEGER) AS avg_cents END)::numeric) AS INTEGER) AS avg_cents
FROM api_call_logs FROM api_call_logs
+6 -21
View File
@@ -6,8 +6,6 @@ const {
listModelPrices, listModelPrices,
loadPriceCache, loadPriceCache,
creditUserBalance, creditUserBalance,
creditsToCreditUnits,
formatCreditsFromCents,
pool, pool,
validateUsername, validateUsername,
validatePassword, validatePassword,
@@ -158,18 +156,14 @@ function registerAdminRoutes(router) {
router.post("/admin/users/:id/credit", requireAuth, requireAdmin, async (req, res) => { router.post("/admin/users/:id/credit", requireAuth, requireAdmin, async (req, res) => {
const targetUserId = Number(req.params.id); const targetUserId = Number(req.params.id);
const { amountCredits, amountCents } = req.body; const { amountCents } = req.body;
const creditUnits = if (!amountCents || amountCents <= 0) return res.status(400).json({ error: "积分必须大于 0" });
amountCredits !== undefined && amountCredits !== null && amountCredits !== ""
? creditsToCreditUnits(amountCredits)
: Number(amountCents);
if (!creditUnits || creditUnits <= 0) return res.status(400).json({ error: "积分必须大于 0" });
try { try {
const newBalance = await creditUserBalance( const newBalance = await creditUserBalance(
targetUserId, targetUserId,
creditUnits, amountCents,
`管理员 ${req.user.username} 发放 ${formatCreditsFromCents(creditUnits)} 积分`, `管理员 ${req.user.username} 发放 ${Math.floor(amountCents / 100)} 积分`,
); );
res.json({ success: true, newBalanceCents: newBalance }); res.json({ success: true, newBalanceCents: newBalance });
} catch (err) { } catch (err) {
@@ -553,8 +547,6 @@ function registerAdminRoutes(router) {
name, name,
description = "", description = "",
priceCents, priceCents,
credits,
amountCredits,
creditsCents = 0, creditsCents = 0,
imageQuota = 0, imageQuota = 0,
videoQuota = 0, videoQuota = 0,
@@ -580,9 +572,7 @@ function registerAdminRoutes(router) {
name, name,
description, description,
Number(priceCents), Number(priceCents),
credits !== undefined || amountCredits !== undefined Number(creditsCents || 0),
? creditsToCreditUnits(credits ?? amountCredits)
: Number(creditsCents || 0),
Number(imageQuota || 0), Number(imageQuota || 0),
Number(videoQuota || 0), Number(videoQuota || 0),
Number(textQuota || 0), Number(textQuota || 0),
@@ -609,8 +599,6 @@ function registerAdminRoutes(router) {
name, name,
description, description,
priceCents, priceCents,
credits,
amountCredits,
creditsCents, creditsCents,
imageQuota, imageQuota,
videoQuota, videoQuota,
@@ -635,10 +623,7 @@ function registerAdminRoutes(router) {
updates.push(`price_cents = $${idx++}`); updates.push(`price_cents = $${idx++}`);
params.push(Number(priceCents)); params.push(Number(priceCents));
} }
if (credits !== undefined || amountCredits !== undefined) { if (creditsCents !== undefined) {
updates.push(`credits_cents = $${idx++}`);
params.push(creditsToCreditUnits(credits ?? amountCredits));
} else if (creditsCents !== undefined) {
updates.push(`credits_cents = $${idx++}`); updates.push(`credits_cents = $${idx++}`);
params.push(Number(creditsCents)); params.push(Number(creditsCents));
} }
+4 -6
View File
@@ -16,6 +16,7 @@ const {
} = require("../enterpriseVideoBilling"); } = require("../enterpriseVideoBilling");
const { const {
startPolling, startPolling,
cancelTask,
updateTaskInDb, updateTaskInDb,
extractProviderTaskId, extractProviderTaskId,
extractImageUrl, extractImageUrl,
@@ -1770,12 +1771,9 @@ function registerAiRoutes(router) {
if (!Number.isFinite(taskId)) return res.status(400).json({ error: "Invalid task id" }); if (!Number.isFinite(taskId)) return res.status(400).json({ error: "Invalid task id" });
try { try {
const { rows } = await pool.query( const task = await cancelTask(taskId, req.user.id);
"UPDATE generation_tasks SET status = 'cancelled', updated_at = NOW() WHERE id = $1 AND user_id = $2 AND status IN ('pending', 'running') RETURNING id, status", if (!task) return res.status(404).json({ error: "Task not found or not in active state" });
[taskId, req.user.id], res.json({ id: task.id, status: task.status });
);
if (rows.length === 0) return res.status(404).json({ error: "Task not found or not in active state" });
res.json({ id: rows[0].id, status: rows[0].status });
} catch (err) { } catch (err) {
console.error("[ai/task-cancel] error:", err.message); console.error("[ai/task-cancel] error:", err.message);
res.status(500).json({ error: "取消任务失败" }); res.status(500).json({ error: "取消任务失败" });
-388
View File
@@ -1,388 +0,0 @@
"use strict";
const { getUserContextById, verifyToken } = require("../auth");
const { pool, withTransaction } = require("../db");
const { loadBetaInviteCodes, normalizeBetaInviteCode } = require("../betaInviteCodes");
const REVIEW_USERNAMES = new Set(["xqy1912"]);
function cleanText(value, maxLength) {
return String(value || "").trim().slice(0, maxLength);
}
function cleanTextArray(value, maxItems = 20, maxLength = 200) {
if (!Array.isArray(value)) return [];
return value.map((item) => cleanText(item, maxLength)).filter(Boolean).slice(0, maxItems);
}
function parseJson(value, fallback) {
if (!value || typeof value !== "string") return fallback;
try {
return JSON.parse(value);
} catch {
return fallback;
}
}
function safeJsonString(value, fallback) {
try {
return JSON.stringify(value ?? fallback);
} catch {
return JSON.stringify(fallback);
}
}
function getRequestIp(req) {
const forwardedFor = String(req.headers["x-forwarded-for"] || "").split(",")[0].trim();
return forwardedFor || req.socket?.remoteAddress || "";
}
async function optionalAuth(req, _res, next) {
const authHeader = req.headers.authorization;
if (!authHeader?.startsWith("Bearer ")) {
next();
return;
}
try {
const payload = verifyToken(authHeader.slice(7));
const user = await getUserContextById(payload.userId);
if (user?.enabled) req.user = user;
} catch {
// Public application submission should still work without a valid session.
}
next();
}
function canReviewBetaApplications(user) {
if (!user) return false;
const role = String(user.role || "").trim().toLowerCase();
const username = String(user.username || "").trim().toLowerCase();
return role === "admin" || REVIEW_USERNAMES.has(username);
}
function requireBetaApplicationReviewer(req, res, next) {
if (!canReviewBetaApplications(req.user)) {
return res.status(403).json({ error: "无权审核内测申请" });
}
next();
}
async function ensureBetaApplicationSchema() {
await pool.query(`
CREATE TABLE IF NOT EXISTS beta_applications (
id SERIAL PRIMARY KEY,
user_id INTEGER REFERENCES users(id) ON DELETE SET NULL,
name TEXT,
phone TEXT,
wechat TEXT,
industry TEXT,
company TEXT,
city TEXT,
ai_tools TEXT,
ai_duration TEXT,
ai_track TEXT,
ai_direction_json TEXT NOT NULL DEFAULT '[]',
weekly_usage TEXT,
feedback_willing TEXT,
want_feature_json TEXT NOT NULL DEFAULT '[]',
self_statement TEXT,
signature TEXT,
agree_rules INTEGER NOT NULL DEFAULT 0,
status TEXT NOT NULL DEFAULT 'pending',
invite_code TEXT,
review_note TEXT,
reviewed_by INTEGER REFERENCES users(id) ON DELETE SET NULL,
reviewed_at TIMESTAMPTZ,
ip_address TEXT,
user_agent TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_beta_applications_status_created
ON beta_applications(status, created_at DESC);
CREATE INDEX IF NOT EXISTS idx_beta_applications_user_created
ON beta_applications(user_id, created_at DESC);
`);
}
function normalizeApplicationBody(body) {
return {
name: cleanText(body?.name, 120),
phone: cleanText(body?.phone, 60),
wechat: cleanText(body?.wechat, 120),
industry: cleanText(body?.industry, 160),
company: cleanText(body?.company, 200),
city: cleanText(body?.city, 120),
aiTools: cleanText(body?.aiTools ?? body?.ai_tools, 1000),
aiDuration: cleanText(body?.aiDuration ?? body?.ai_duration, 120),
aiTrack: cleanText(body?.aiTrack ?? body?.ai_track, 160),
aiDirection: cleanTextArray(body?.aiDirection ?? body?.ai_direction),
weeklyUsage: cleanText(body?.weeklyUsage ?? body?.weekly_usage, 120),
feedbackWilling: cleanText(body?.feedbackWilling ?? body?.feedback_willing, 160),
wantFeature: cleanTextArray(body?.wantFeature ?? body?.want_feature),
selfStatement: cleanText(body?.selfStatement ?? body?.self_statement, 5000),
signature: cleanText(body?.signature, 120),
agreeRules: body?.agreeRules === true || body?.agree_rules === true || body?.agreeRules === 1 || body?.agree_rules === 1,
};
}
function formatApplication(row) {
return {
id: Number(row.id),
userId: row.user_id == null ? null : Number(row.user_id),
username: row.username || null,
name: row.name || "",
phone: row.phone || "",
wechat: row.wechat || "",
industry: row.industry || "",
company: row.company || "",
city: row.city || "",
aiTools: row.ai_tools || "",
aiDuration: row.ai_duration || "",
aiTrack: row.ai_track || "",
aiDirection: parseJson(row.ai_direction_json, []),
weeklyUsage: row.weekly_usage || "",
feedbackWilling: row.feedback_willing || "",
wantFeature: parseJson(row.want_feature_json, []),
selfStatement: row.self_statement || "",
signature: row.signature || "",
agreeRules: Boolean(row.agree_rules),
status: row.status || "pending",
inviteCode: row.invite_code || null,
reviewNote: row.review_note || null,
reviewedBy: row.reviewed_by == null ? null : Number(row.reviewed_by),
reviewerUsername: row.reviewer_username || null,
reviewedAt: row.reviewed_at || null,
ipAddress: row.ip_address || null,
userAgent: row.user_agent || null,
createdAt: row.created_at,
updatedAt: row.updated_at,
};
}
async function selectApplicationById(client, id) {
const { rows } = await client.query(
`
SELECT a.*, u.username, reviewer.username AS reviewer_username
FROM beta_applications a
LEFT JOIN users u ON u.id = a.user_id
LEFT JOIN users reviewer ON reviewer.id = a.reviewed_by
WHERE a.id = $1
LIMIT 1
`,
[id],
);
return rows[0] || null;
}
async function issueNextBetaInviteCode(client) {
const codes = Array.from(loadBetaInviteCodes()).map(normalizeBetaInviteCode).filter(Boolean).sort();
for (const code of codes) {
const { rows } = await client.query(
`
SELECT 1
FROM beta_invite_code_uses
WHERE code = $1
UNION ALL
SELECT 1
FROM beta_applications
WHERE invite_code = $1 AND status = 'approved'
LIMIT 1
`,
[code],
);
if (rows.length === 0) return code;
}
return null;
}
async function createNotification(client, userId, input) {
if (!userId) return;
await client.query(
`
INSERT INTO web_notifications (
user_id, type, title, description, target_type, target_id, metadata_json
)
VALUES ($1, $2, $3, $4, $5, $6, $7)
`,
[
userId,
input.type,
input.title,
input.description || null,
input.targetType || "beta_application",
input.targetId ? String(input.targetId) : null,
safeJsonString(input.metadata, {}),
],
);
}
function registerBetaApplicationRoutes(router) {
router.post("/beta-applications", optionalAuth, async (req, res) => {
try {
await ensureBetaApplicationSchema();
const app = normalizeApplicationBody(req.body);
if (!app.name || !app.phone || !app.wechat || !app.selfStatement || !app.signature || !app.agreeRules) {
return res.status(400).json({ error: "请填写姓名、手机号、微信、申请自述、签名并同意内测规则" });
}
const { rows } = await pool.query(
`
INSERT INTO beta_applications (
user_id, name, phone, wechat, industry, company, city,
ai_tools, ai_duration, ai_track, ai_direction_json,
weekly_usage, feedback_willing, want_feature_json,
self_statement, signature, agree_rules, ip_address, user_agent
)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
RETURNING id, status, created_at
`,
[
req.user?.id || null,
app.name,
app.phone,
app.wechat,
app.industry || null,
app.company || null,
app.city || null,
app.aiTools || null,
app.aiDuration || null,
app.aiTrack || null,
safeJsonString(app.aiDirection, []),
app.weeklyUsage || null,
app.feedbackWilling || null,
safeJsonString(app.wantFeature, []),
app.selfStatement,
app.signature,
app.agreeRules ? 1 : 0,
getRequestIp(req),
cleanText(req.headers["user-agent"], 1000) || null,
],
);
res.status(201).json({
application: {
id: rows[0].id,
status: rows[0].status,
createdAt: rows[0].created_at,
},
});
} catch (err) {
console.error("[beta-applications] create failed:", err.message);
res.status(500).json({ error: "提交内测申请失败" });
}
});
router.get("/admin/beta-applications", requireBetaApplicationReviewer, async (req, res) => {
try {
await ensureBetaApplicationSchema();
const status = cleanText(req.query.status, 32);
const params = [];
const where = [];
if (status) {
params.push(status);
where.push(`a.status = $${params.length}`);
}
const { rows } = await pool.query(
`
SELECT a.*, u.username, reviewer.username AS reviewer_username
FROM beta_applications a
LEFT JOIN users u ON u.id = a.user_id
LEFT JOIN users reviewer ON reviewer.id = a.reviewed_by
${where.length ? `WHERE ${where.join(" AND ")}` : ""}
ORDER BY
CASE a.status WHEN 'pending' THEN 0 WHEN 'approved' THEN 1 ELSE 2 END,
a.created_at DESC
LIMIT 300
`,
params,
);
res.json({ applications: rows.map(formatApplication) });
} catch (err) {
console.error("[admin/beta-applications] list failed:", err.message);
res.status(500).json({ error: "读取内测申请失败" });
}
});
router.patch("/admin/beta-applications/:id", requireBetaApplicationReviewer, async (req, res) => {
const id = Number(req.params.id);
const action = cleanText(req.body?.action, 32);
const reviewNote = cleanText(req.body?.reviewNote ?? req.body?.review_note, 1000) || null;
if (!Number.isFinite(id)) return res.status(400).json({ error: "申请 ID 不正确" });
if (action !== "approve" && action !== "reject") return res.status(400).json({ error: "审核动作不正确" });
try {
await ensureBetaApplicationSchema();
const application = await withTransaction(async (client) => {
const current = await selectApplicationById(client, id);
if (!current) {
const err = new Error("申请不存在");
err.status = 404;
throw err;
}
if (current.status !== "pending") {
const err = new Error("该申请已审核");
err.status = 409;
throw err;
}
let inviteCode = null;
if (action === "approve") {
inviteCode = await issueNextBetaInviteCode(client);
if (!inviteCode) {
const err = new Error("暂无可用内测码,请先补充内测码");
err.status = 409;
throw err;
}
}
const { rows } = await client.query(
`
UPDATE beta_applications
SET status = $1,
invite_code = $2,
review_note = $3,
reviewed_by = $4,
reviewed_at = NOW(),
updated_at = NOW()
WHERE id = $5
RETURNING *
`,
[action === "approve" ? "approved" : "rejected", inviteCode, reviewNote, req.user.id, id],
);
const updated = rows[0];
if (updated.user_id) {
if (action === "approve") {
await createNotification(client, updated.user_id, {
type: "review_passed",
title: "内测申请已通过",
description: `您的内测申请已通过,内测码:${inviteCode}`,
targetId: updated.id,
metadata: { inviteCode },
});
} else {
await createNotification(client, updated.user_id, {
type: "review_rejected",
title: "您未通过内测申请",
description: reviewNote || "很遗憾,您的内测申请暂未通过。",
targetId: updated.id,
});
}
}
return selectApplicationById(client, id);
});
res.json({ application: formatApplication(application) });
} catch (err) {
const status = Number(err.status || 500);
if (status >= 400 && status < 500) return res.status(status).json({ error: err.message });
console.error("[admin/beta-applications] review failed:", err.message);
res.status(500).json({ error: "审核内测申请失败" });
}
});
}
module.exports = { registerBetaApplicationRoutes, canReviewBetaApplications };
-4
View File
@@ -32,8 +32,6 @@ const {
getUserEnterpriseId, getUserEnterpriseId,
getEnterpriseName, getEnterpriseName,
preauthorizeCall, preauthorizeCall,
creditsToCreditUnits,
formatCreditsFromCents,
} = require("../billing"); } = require("../billing");
const wechatPay = require("../paymentWechat"); const wechatPay = require("../paymentWechat");
const alipay = require("../paymentAlipay"); const alipay = require("../paymentAlipay");
@@ -795,8 +793,6 @@ module.exports = {
getUserEnterpriseId, getUserEnterpriseId,
getEnterpriseName, getEnterpriseName,
preauthorizeCall, preauthorizeCall,
creditsToCreditUnits,
formatCreditsFromCents,
wechatPay, wechatPay,
alipay, alipay,
crypto, crypto,
+8 -17
View File
@@ -2,7 +2,6 @@ const {
requireAuth, requireAuth,
requireEnterpriseAdmin, requireEnterpriseAdmin,
distributeCredits, distributeCredits,
creditsToCreditUnits,
getEnterpriseFinancials, getEnterpriseFinancials,
getEnterpriseName, getEnterpriseName,
pool, pool,
@@ -303,33 +302,25 @@ function registerEnterpriseRoutes(router) {
}); });
router.post("/enterprise/distribute", requireAuth, requireEnterpriseAdmin, async (req, res) => { router.post("/enterprise/distribute", requireAuth, requireEnterpriseAdmin, async (req, res) => {
const { userId, amountCredits, amountCents, distributions } = req.body; const { userId, amountCents, distributions } = req.body;
try { try {
if (distributions && Array.isArray(distributions)) { if (distributions && Array.isArray(distributions)) {
for (const d of distributions) { for (const d of distributions) {
const creditUnits = if (!d.userId || !d.amountCents || d.amountCents <= 0) {
d.amountCredits !== undefined && d.amountCredits !== null && d.amountCredits !== ""
? creditsToCreditUnits(d.amountCredits)
: Number(d.amountCents);
if (!d.userId || !creditUnits || creditUnits <= 0) {
return res return res
.status(400) .status(400)
.json({ error: "每条分发记录必须包含有效的 userId 和 amountCredits" }); .json({ error: "每条分发记录必须包含有效的 userId 和 amountCents" });
} }
await distributeCredits(req.user.enterpriseId, d.userId, creditUnits, req.user.id); await distributeCredits(req.user.enterpriseId, d.userId, d.amountCents, req.user.id);
} }
res.json({ success: true, count: distributions.length }); res.json({ success: true, count: distributions.length });
} else if (userId && (amountCredits || amountCents)) { } else if (userId && amountCents) {
const creditUnits = if (amountCents <= 0) return res.status(400).json({ error: "分发积分必须大于0" });
amountCredits !== undefined && amountCredits !== null && amountCredits !== ""
? creditsToCreditUnits(amountCredits)
: Number(amountCents);
if (!creditUnits || creditUnits <= 0) return res.status(400).json({ error: "分发积分必须大于0" });
const result = await distributeCredits( const result = await distributeCredits(
req.user.enterpriseId, req.user.enterpriseId,
userId, userId,
creditUnits, amountCents,
req.user.id, req.user.id,
); );
res.json({ success: true, ...result }); res.json({ success: true, ...result });
@@ -358,7 +349,7 @@ function registerEnterpriseRoutes(router) {
u.username, u.username,
u.balance_cents AS current_balance_cents, u.balance_cents AS current_balance_cents,
COUNT(acl.id) AS total_calls, COUNT(acl.id) AS total_calls,
COALESCE(SUM(CASE WHEN acl.cost_estimate IS NOT NULL THEN CAST(ROUND((acl.cost_estimate * 10000)::numeric) AS INTEGER) ELSE 0 END), 0) AS total_cost_cents, COALESCE(SUM(CASE WHEN acl.cost_estimate IS NOT NULL THEN CAST(ROUND((acl.cost_estimate * 100)::numeric) AS INTEGER) ELSE 0 END), 0) AS total_cost_cents,
MAX(acl.created_at) AS last_active MAX(acl.created_at) AS last_active
FROM users u FROM users u
LEFT JOIN api_call_logs acl ON acl.user_id = u.id AND acl.status = 'success' LEFT JOIN api_call_logs acl ON acl.user_id = u.id AND acl.status = 'success'
-2
View File
@@ -17,7 +17,6 @@ const { registerConversationRoutes } = require('./conversations')
const { registerReportRoutes } = require('./reports') const { registerReportRoutes } = require('./reports')
const { registerAssetRoutes } = require('./assets') const { registerAssetRoutes } = require('./assets')
const { registerNotificationRoutes } = require('./notifications') const { registerNotificationRoutes } = require('./notifications')
const { registerBetaApplicationRoutes } = require('./betaApplications')
const { registerDraftRoutes } = require('./drafts'); const { registerDraftRoutes } = require('./drafts');
const { registerFileExtractRoutes } = require('./fileExtract'); const { registerFileExtractRoutes } = require('./fileExtract');
const mountClientErrorRoutes = require('./clientErrors') const mountClientErrorRoutes = require('./clientErrors')
@@ -49,7 +48,6 @@ registerConversationRoutes(router)
registerReportRoutes(router) registerReportRoutes(router)
registerAssetRoutes(router) registerAssetRoutes(router)
registerNotificationRoutes(router) registerNotificationRoutes(router)
registerBetaApplicationRoutes(router)
registerDraftRoutes(router) registerDraftRoutes(router)
registerFileExtractRoutes(router) registerFileExtractRoutes(router)
registerHealthRoutes(router) registerHealthRoutes(router)
+3 -3
View File
@@ -136,7 +136,7 @@ function registerUserRoutes(router) {
CASE CASE
WHEN billing_refunded = 1 THEN 0 WHEN billing_refunded = 1 THEN 0
WHEN cost_cents > 0 THEN cost_cents WHEN cost_cents > 0 THEN cost_cents
WHEN status = 'completed' AND type = 'image' THEN 2000 WHEN status = 'completed' AND type = 'image' THEN 20
WHEN status = 'completed' AND type = 'video' THEN 500 WHEN status = 'completed' AND type = 'video' THEN 500
ELSE 0 ELSE 0
END END
@@ -162,7 +162,7 @@ function registerUserRoutes(router) {
resolution = params.resolution || params.quality || params.ratio || null; resolution = params.resolution || params.quality || params.ratio || null;
if (row.status === "completed") { if (row.status === "completed") {
if (row.type === "image") { if (row.type === "image") {
estimatedCents = 2000; estimatedCents = 20;
} else if (row.type === "video") { } else if (row.type === "video") {
const dur = params.duration || 5; const dur = params.duration || 5;
const res = String(params.resolution || params.quality || "").toUpperCase(); const res = String(params.resolution || params.quality || "").toUpperCase();
@@ -209,7 +209,7 @@ function registerUserRoutes(router) {
CASE CASE
WHEN billing_refunded = 1 THEN 0 WHEN billing_refunded = 1 THEN 0
WHEN cost_cents > 0 THEN cost_cents WHEN cost_cents > 0 THEN cost_cents
WHEN status = 'completed' AND type = 'image' THEN 2000 WHEN status = 'completed' AND type = 'image' THEN 20
WHEN status = 'completed' AND type = 'video' THEN 500 WHEN status = 'completed' AND type = 'video' THEN 500
ELSE 0 ELSE 0
END END