2 Commits

20 changed files with 164 additions and 1214 deletions
-7
View File
@@ -20,13 +20,6 @@ JWT_EXPIRES_IN=7d
# Connection pool # Connection pool
PG_POOL_MAX=10 PG_POOL_MAX=10
# Provider polling reliability
# Shared across PM2 workers through Postgres-backed poll slots.
TASK_PROVIDER_POLL_MAX_CONCURRENCY=8
TASK_PROVIDER_POLL_SLOT_TTL_MS=30000
TASK_PROVIDER_POLL_REQUEST_TIMEOUT_MS=25000
GRSAI_IMAGE_SUBMIT_TIMEOUT_MS=30000
# CORS (comma separated allowed origins, * for all) # CORS (comma separated allowed origins, * for all)
CORS_ORIGINS=* CORS_ORIGINS=*
+1 -5
View File
@@ -14,11 +14,7 @@
"audit-routes": "node src/cli/auditModelRoutes.js", "audit-routes": "node src/cli/auditModelRoutes.js",
"import-config": "node src/cli/importConfig.js", "import-config": "node src/cli/importConfig.js",
"init-pools": "node src/cli/initPools.js", "init-pools": "node src/cli/initPools.js",
"test:community-routes": "node scripts/communityRouteContract.test.js", "test:community-routes": "node scripts/communityRouteContract.test.js"
"test:enterprise-video-pricing": "node scripts/enterpriseVideoPricingContract.test.js",
"test:key-manager": "node scripts/keyManagerReleaseContract.test.js",
"test:provider-poll-limiter": "node scripts/providerPollLimiterContract.test.js",
"test": "npm run test:community-routes && npm run test:enterprise-video-pricing && npm run test:key-manager && npm run test:provider-poll-limiter"
}, },
"dependencies": { "dependencies": {
"alipay-sdk": "^4.14.0", "alipay-sdk": "^4.14.0",
@@ -1,54 +0,0 @@
const assert = require("node:assert/strict");
const {
calculateEnterpriseVideoCredits,
getEnterpriseVideoCreditRate,
getEnterpriseVideoPricingConfig,
} = require("../src/enterpriseVideoBilling");
function getRule(config, id) {
const rule = config.rules.find((item) => item.id === id);
assert(rule, `missing enterprise video pricing rule: ${id}`);
return rule;
}
const config = getEnterpriseVideoPricingConfig();
assert.equal(config.currency, "CNY");
assert.equal(config.creditsPerCny, 100);
assert.equal(config.billingUnit, "per_second");
assert.deepEqual(config.resolutions, ["720P", "1080P"]);
assert.equal(getRule(config, "happyhorse").rates["720P"], 0.72);
assert.equal(getRule(config, "happyhorse").rates["1080P"], 1.28);
assert.equal(getRule(config, "wanxiang-i2v").rates["720P"], 0.6);
assert.equal(getRule(config, "kling-muted").rates["1080P"], 0.8);
assert.equal(
getEnterpriseVideoCreditRate({
model: "happyhorse-1.0",
resolution: "1080P",
}),
getRule(config, "happyhorse").rates["1080P"],
);
assert.equal(
getEnterpriseVideoCreditRate({
model: "kling-3.0-dashscope",
resolution: "720P",
muted: true,
hasReferenceVideo: false,
}),
getRule(config, "kling-muted").rates["720P"],
);
assert.equal(
calculateEnterpriseVideoCredits({
model: "vidu-q3-turbo",
resolution: "1080P",
durationSeconds: 5,
}),
500,
);
console.log("enterprise video pricing contract tests passed");
-73
View File
@@ -1,73 +0,0 @@
const assert = require("node:assert/strict");
const { createRequire } = require("node:module");
const nodeRequire = createRequire(__filename);
function loadKeyManagerWithPool(pool) {
const dbPath = nodeRequire.resolve("../src/db");
const keyManagerPath = nodeRequire.resolve("../src/keyManager");
const originalDbModule = nodeRequire.cache[dbPath];
const originalKeyManagerModule = nodeRequire.cache[keyManagerPath];
delete nodeRequire.cache[keyManagerPath];
nodeRequire.cache[dbPath] = {
id: dbPath,
filename: dbPath,
loaded: true,
exports: {
pool,
withTransaction: async (fn) => fn(pool),
},
};
return {
keyManager: nodeRequire("../src/keyManager"),
restore() {
delete nodeRequire.cache[keyManagerPath];
if (originalKeyManagerModule) nodeRequire.cache[keyManagerPath] = originalKeyManagerModule;
if (originalDbModule) nodeRequire.cache[dbPath] = originalDbModule;
else delete nodeRequire.cache[dbPath];
},
};
}
function createReleasePool() {
const calls = [];
return {
calls,
async query(sql, params) {
calls.push({ sql, params });
if (/WITH candidate AS/i.test(sql)) {
return {
rows: [{
id: 10,
key_id: 20,
lease_user_id: 30,
lease_enterprise_id: 40,
provider: "dashscope",
}],
};
}
if (/UPDATE api_keys SET active_count/i.test(sql)) return { rows: [] };
if (/INSERT INTO usage_logs/i.test(sql)) return { rows: [] };
throw new Error(`Unexpected SQL: ${sql}`);
},
};
}
(async () => {
const pool = createReleasePool();
const { keyManager, restore } = loadKeyManagerWithPool(pool);
try {
const result = await keyManager.releaseKey("lease-token-without-user-context");
assert.equal(result.released, true);
const usageLogCall = pool.calls.find((call) => /INSERT INTO usage_logs/i.test(call.sql));
assert.deepEqual(usageLogCall.params, [30, 40, 20, 20, "release"]);
} finally {
restore();
}
})().catch((error) => {
console.error(error);
process.exitCode = 1;
});
@@ -1,96 +0,0 @@
const assert = require("node:assert/strict");
const { createRequire } = require("node:module");
const nodeRequire = createRequire(__filename);
function loadLimiterWithPool(pool) {
const dbPath = nodeRequire.resolve("../src/db");
const limiterPath = nodeRequire.resolve("../src/providerPollLimiter");
const originalDbModule = nodeRequire.cache[dbPath];
const originalLimiterModule = nodeRequire.cache[limiterPath];
delete nodeRequire.cache[limiterPath];
nodeRequire.cache[dbPath] = {
id: dbPath,
filename: dbPath,
loaded: true,
exports: { pool },
};
return {
limiter: nodeRequire("../src/providerPollLimiter"),
restore() {
delete nodeRequire.cache[limiterPath];
if (originalLimiterModule) nodeRequire.cache[limiterPath] = originalLimiterModule;
if (originalDbModule) nodeRequire.cache[dbPath] = originalDbModule;
else delete nodeRequire.cache[dbPath];
},
};
}
function createPool(options = {}) {
const calls = [];
return {
calls,
async query(sql, params = []) {
calls.push({ sql, params });
if (/CREATE TABLE IF NOT EXISTS generation_provider_poll_slots/i.test(sql)) return { rows: [] };
if (/WITH candidate AS/i.test(sql)) {
if (options.noAvailableSlot) return { rows: [] };
return { rows: [{ scope: params[0], slot_no: 2 }] };
}
if (/DELETE FROM generation_provider_poll_slots/i.test(sql)) return { rows: [] };
throw new Error(`Unexpected SQL: ${sql}`);
},
};
}
(async () => {
const previousLimit = process.env.TASK_PROVIDER_POLL_MAX_CONCURRENCY;
process.env.TASK_PROVIDER_POLL_MAX_CONCURRENCY = "3";
const pool = createPool();
const { limiter, restore } = loadLimiterWithPool(pool);
try {
const outcome = await limiter.withProviderPollSlot(101, async () => "polled");
assert.equal(outcome.acquired, true);
assert.equal(outcome.value, "polled");
const acquireCall = pool.calls.find((call) => /WITH candidate AS/i.test(call.sql));
assert.equal(acquireCall.params[1], 3);
assert.equal(acquireCall.params[3], 101);
const releaseCall = pool.calls.find((call) => /DELETE FROM generation_provider_poll_slots/i.test(call.sql));
assert.equal(releaseCall.params[0], acquireCall.params[0]);
assert.equal(releaseCall.params[1], 2);
assert.equal(releaseCall.params[2], acquireCall.params[2]);
} finally {
if (previousLimit === undefined) delete process.env.TASK_PROVIDER_POLL_MAX_CONCURRENCY;
else process.env.TASK_PROVIDER_POLL_MAX_CONCURRENCY = previousLimit;
restore();
}
const saturatedPool = createPool({ noAvailableSlot: true });
const { limiter: saturatedLimiter, restore: restoreSaturated } = loadLimiterWithPool(saturatedPool);
try {
let called = false;
const outcome = await saturatedLimiter.withProviderPollSlot(202, async () => {
called = true;
return "should-not-run";
});
assert.equal(outcome.acquired, false);
assert.equal(outcome.value, undefined);
assert.equal(called, false);
assert.equal(
saturatedPool.calls.some((call) => /DELETE FROM generation_provider_poll_slots/i.test(call.sql)),
false,
);
} finally {
restoreSaturated();
}
})().catch((error) => {
console.error(error);
process.exitCode = 1;
});
+38 -192
View File
@@ -5,7 +5,6 @@ const { EventEmitter } = require("node:events");
const { pool } = require("./db"); const { pool } = require("./db");
const { refundTaskBillingOnFailure } = require("./billing"); const { refundTaskBillingOnFailure } = require("./billing");
const { putObject, isOssConfigured } = require("./ossClient"); const { putObject, isOssConfigured } = require("./ossClient");
const { withProviderPollSlot } = require("./providerPollLimiter");
const taskEvents = new EventEmitter(); const taskEvents = new EventEmitter();
taskEvents.setMaxListeners(200); taskEvents.setMaxListeners(200);
@@ -19,12 +18,10 @@ const TASK_EVENT_ORIGIN = `${process.pid}-${crypto.randomUUID()}`;
const POLLER_OWNER_ID = `${process.pid}-${crypto.randomUUID()}`; const POLLER_OWNER_ID = `${process.pid}-${crypto.randomUUID()}`;
const POLLER_OWNER_STALE_MS = Number(process.env.TASK_POLLER_OWNER_STALE_MS || 20_000); const POLLER_OWNER_STALE_MS = Number(process.env.TASK_POLLER_OWNER_STALE_MS || 20_000);
const POLLER_RECOVERY_INTERVAL_MS = Number(process.env.TASK_POLLER_RECOVERY_INTERVAL_MS || 30_000); const POLLER_RECOVERY_INTERVAL_MS = Number(process.env.TASK_POLLER_RECOVERY_INTERVAL_MS || 30_000);
const PROVIDER_POLL_REQUEST_TIMEOUT_MS = Number(process.env.TASK_PROVIDER_POLL_REQUEST_TIMEOUT_MS || 25_000);
let taskEventListenerClient = null; let taskEventListenerClient = null;
let taskEventListenerStarting = null; let taskEventListenerStarting = null;
let pollerStoreReady = null; let pollerStoreReady = null;
let pollerRecoveryTimer = null; let pollerRecoveryTimer = null;
let staleTaskCleanupStartupTimer = null;
function normalizeTaskProgress(value) { function normalizeTaskProgress(value) {
const numeric = Number(value); const numeric = Number(value);
@@ -155,23 +152,6 @@ async function clearPollerState(taskDbId) {
await pool.query("DELETE FROM generation_task_pollers WHERE task_id = $1", [taskDbId]); await pool.query("DELETE FROM generation_task_pollers WHERE task_id = $1", [taskDbId]);
} }
async function orphanOwnedPollerState() {
await ensureTaskPollerStore();
await pool.query(
"UPDATE generation_task_pollers SET owner_id = NULL, owner_heartbeat_at = NULL, updated_at = NOW() WHERE owner_id = $1",
[POLLER_OWNER_ID],
);
}
async function getPersistedLeaseToken(taskDbId) {
await ensureTaskPollerStore();
const { rows } = await pool.query(
"SELECT lease_token FROM generation_task_pollers WHERE task_id = $1 LIMIT 1",
[taskDbId],
);
return rows[0]?.lease_token || null;
}
async function getLeaseKey(leaseToken) { async function getLeaseKey(leaseToken) {
if (!leaseToken) return null; if (!leaseToken) return null;
const { rows } = await pool.query( const { rows } = await pool.query(
@@ -291,12 +271,6 @@ async function updateTaskInDb(taskId, updates) {
}); });
} }
if (nextUpdates.status === "completed") {
await markTaskBillingAccepted(taskId).catch((err) => {
console.error(`[aiTaskWorker] billing accept error for task ${taskId}:`, err.message);
});
}
if (nextUpdates.status === "failed") { if (nextUpdates.status === "failed") {
await refundTaskBillingOnFailure(taskId).catch((err) => { await refundTaskBillingOnFailure(taskId).catch((err) => {
console.error(`[aiTaskWorker] refund error for task ${taskId}:`, err.message); console.error(`[aiTaskWorker] refund error for task ${taskId}:`, err.message);
@@ -304,13 +278,6 @@ async function updateTaskInDb(taskId, updates) {
} }
} }
async function markTaskBillingAccepted(taskId) {
await pool.query(
"UPDATE credit_ledger SET status = 'charged', updated_at = NOW() WHERE task_id = $1 AND status = 'reserved'",
[taskId],
);
}
function persistTaskResultUrlToOssInBackground(task) { function persistTaskResultUrlToOssInBackground(task) {
if (!task?.id || !task?.result_url) return; if (!task?.id || !task?.result_url) return;
@@ -665,22 +632,9 @@ function extractErrorMessage(json, fallback) {
} }
async function fetchJson(url, headers) { async function fetchJson(url, headers) {
const controller = new AbortController(); const res = await fetch(url, { method: "GET", headers });
const timeoutMs = Number.isFinite(PROVIDER_POLL_REQUEST_TIMEOUT_MS) && PROVIDER_POLL_REQUEST_TIMEOUT_MS > 0 if (!res.ok) return { ok: false, json: null };
? PROVIDER_POLL_REQUEST_TIMEOUT_MS return { ok: true, json: await res.json() };
: 25_000;
const timer = setTimeout(() => controller.abort(), timeoutMs);
timer.unref?.();
try {
const res = await fetch(url, { method: "GET", headers, signal: controller.signal });
if (!res.ok) return { ok: false, json: null };
return { ok: true, json: await res.json() };
} catch (err) {
return { ok: false, json: null, error: err };
} finally {
clearTimeout(timer);
}
} }
async function pollGrsaiImage(_taskId, providerTaskId, apiKey, baseUrl, resultEndpoint) { async function pollGrsaiImage(_taskId, providerTaskId, apiKey, baseUrl, resultEndpoint) {
@@ -850,31 +804,26 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig,
} }
let attempts = 0; let attempts = 0;
let polling = false;
let skippedPolls = 0;
const maxPollAttempts = getMaxPollAttempts(type, providerConfig); const maxPollAttempts = getMaxPollAttempts(type, providerConfig);
const interval = setInterval(async () => { const interval = setInterval(async () => {
if (polling) return; attempts++;
polling = true; if (attempts > maxPollAttempts) {
clearInterval(interval);
activePollers.delete(taskDbId);
if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {});
if (typeof onTaskFailed === "function") {
const handled = await onTaskFailed("Task timed out").catch((fallbackErr) => {
console.error(`[aiTaskWorker] fallback error for task ${taskDbId}:`, fallbackErr.message);
return false;
});
if (handled) return;
}
await updateTaskInDb(taskDbId, { status: "failed", error: "Task timed out" });
await clearPollerState(taskDbId).catch(() => {});
return;
}
try { try {
if (attempts >= maxPollAttempts) {
clearInterval(interval);
activePollers.delete(taskDbId);
if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {});
if (typeof onTaskFailed === "function") {
await clearPollerState(taskDbId).catch(() => {});
const handled = await onTaskFailed("Task timed out").catch((fallbackErr) => {
console.error(`[aiTaskWorker] fallback error for task ${taskDbId}:`, fallbackErr.message);
return false;
});
if (handled) return;
}
await updateTaskInDb(taskDbId, { status: "failed", error: "Task timed out" });
await clearPollerState(taskDbId).catch(() => {});
return;
}
// Check if task was cancelled by user // Check if task was cancelled by user
const { rows: [taskRow] } = await pool.query("SELECT status FROM generation_tasks WHERE id = $1", [taskDbId]); const { rows: [taskRow] } = await pool.query("SELECT status FROM generation_tasks WHERE id = $1", [taskDbId]);
if (!taskRow || taskRow.status === "cancelled") { if (!taskRow || taskRow.status === "cancelled") {
@@ -886,29 +835,15 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig,
} }
await refreshPollerHeartbeat(taskDbId).catch(() => {}); await refreshPollerHeartbeat(taskDbId).catch(() => {});
const pollOutcome = await withProviderPollSlot(taskDbId, async () => { let result;
attempts++; if (type === "image") {
if (type === "image") { if (providerConfig.transport === "dashscope-image") {
if (providerConfig.transport === "dashscope-image") { result = await pollDashscopeImage(taskDbId, providerTaskId, apiKey);
return pollDashscopeImage(taskDbId, providerTaskId, apiKey); } else {
} result = await pollGrsaiImage(taskDbId, providerTaskId, apiKey, providerConfig.baseUrl, providerConfig.resultEndpoint || "/result");
return pollGrsaiImage(taskDbId, providerTaskId, apiKey, providerConfig.baseUrl, providerConfig.resultEndpoint || "/result");
} }
return pollVideoTask(taskDbId, providerTaskId, apiKey, providerConfig); } else {
}); result = await pollVideoTask(taskDbId, providerTaskId, apiKey, providerConfig);
if (!pollOutcome.acquired) {
skippedPolls++;
if (skippedPolls % 20 === 0) {
console.info(`[aiTaskWorker] task ${taskDbId} waiting for provider poll slot (skipped=${skippedPolls})`);
}
return;
}
skippedPolls = 0;
const result = pollOutcome.value;
if (!result) {
return;
} }
if (result.status === "completed" || result.status === "failed") { if (result.status === "completed" || result.status === "failed") {
@@ -916,7 +851,6 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig,
activePollers.delete(taskDbId); activePollers.delete(taskDbId);
if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {}); if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {});
if (result.status === "failed" && typeof onTaskFailed === "function") { if (result.status === "failed" && typeof onTaskFailed === "function") {
await clearPollerState(taskDbId).catch(() => {});
const handled = await onTaskFailed(result.error || "Task failed").catch((fallbackErr) => { const handled = await onTaskFailed(result.error || "Task failed").catch((fallbackErr) => {
console.error(`[aiTaskWorker] fallback error for task ${taskDbId}:`, fallbackErr.message); console.error(`[aiTaskWorker] fallback error for task ${taskDbId}:`, fallbackErr.message);
return false; return false;
@@ -931,8 +865,6 @@ function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig,
} }
} catch (err) { } catch (err) {
console.error(`[aiTaskWorker] poll error for task ${taskDbId}:`, err.message); console.error(`[aiTaskWorker] poll error for task ${taskDbId}:`, err.message);
} finally {
polling = false;
} }
}, POLL_INTERVAL_MS); }, POLL_INTERVAL_MS);
@@ -948,29 +880,6 @@ function stopPolling(taskDbId) {
clearPollerState(taskDbId).catch(() => {}); clearPollerState(taskDbId).catch(() => {});
} }
async function cancelTaskRuntimeState(taskDbId, keyManager) {
const poller = activePollers.get(taskDbId);
if (poller) {
clearInterval(poller.interval);
activePollers.delete(taskDbId);
}
const leaseToken = poller?.leaseToken || await getPersistedLeaseToken(taskDbId).catch(() => null);
await clearPollerState(taskDbId).catch(() => {});
if (leaseToken && keyManager) {
await keyManager.releaseKey(leaseToken).catch((err) => {
console.error(`[aiTaskWorker] failed to release lease for cancelled task ${taskDbId}:`, err.message);
});
}
await publishTaskEvent({
taskId: taskDbId,
status: "cancelled",
progress: 100,
resultUrl: null,
error: "任务已取消",
});
}
function getActiveCount() { function getActiveCount() {
return activePollers.size; return activePollers.size;
} }
@@ -980,7 +889,7 @@ async function recoverRunnablePollers() {
const staleInterval = `${Math.max(5, Math.ceil(POLLER_OWNER_STALE_MS / 1000))} seconds`; const staleInterval = `${Math.max(5, Math.ceil(POLLER_OWNER_STALE_MS / 1000))} seconds`;
const { rows } = await pool.query( const { rows } = await pool.query(
` `
SELECT p.task_id, p.updated_at SELECT p.task_id
FROM generation_task_pollers p FROM generation_task_pollers p
JOIN generation_tasks t ON t.id = p.task_id JOIN generation_tasks t ON t.id = p.task_id
WHERE t.status IN ('pending', 'running') WHERE t.status IN ('pending', 'running')
@@ -1003,7 +912,6 @@ async function recoverRunnablePollers() {
const apiKey = await getLeaseKey(poller.lease_token); const apiKey = await getLeaseKey(poller.lease_token);
if (apiKey == null) { if (apiKey == null) {
console.warn(`[aiTaskWorker] cannot recover task ${taskId}: active lease not found`); console.warn(`[aiTaskWorker] cannot recover task ${taskId}: active lease not found`);
await releaseUnrecoverableTask(taskId, "任务执行状态已失效,已自动释放");
continue; continue;
} }
@@ -1015,51 +923,11 @@ async function recoverRunnablePollers() {
providerConfig: parseProviderConfig(poller.provider_config_json), providerConfig: parseProviderConfig(poller.provider_config_json),
leaseToken: poller.lease_token, leaseToken: poller.lease_token,
keyManager: require("./keyManager"), keyManager: require("./keyManager"),
onTaskFailed: async (failureMessage) => {
await updateTaskInDb(taskId, { status: "failed", error: failureMessage || "Task failed" });
return true;
},
skipPersist: true, skipPersist: true,
}); });
} }
} }
async function releaseUnrecoverableTask(taskId, message) {
const { rows } = await pool.query(
`
UPDATE generation_tasks t
SET status = 'failed', error = $2, completed_at = NOW(), updated_at = NOW()
FROM generation_task_pollers p
WHERE t.id = $1
AND p.task_id = t.id
AND p.owner_id = $3
AND t.status IN ('pending', 'running')
RETURNING t.*
`,
[taskId, message, POLLER_OWNER_ID],
);
const task = rows[0];
if (!task) return false;
const leaseToken = await getPersistedLeaseToken(taskId).catch(() => null);
await clearPollerState(taskId).catch(() => {});
if (leaseToken) {
await require("./keyManager").releaseKey(leaseToken).catch((err) => {
console.error(`[aiTaskWorker] failed to release lease for unrecoverable task ${taskId}:`, err.message);
});
}
await publishTaskEvent(formatTaskEvent(task));
await createTaskLifecycleNotification(task).catch((err) => {
console.error(`[aiTaskWorker] notification error for unrecoverable task ${taskId}:`, err.message);
});
await refundTaskBillingOnFailure(taskId).catch((err) => {
console.error(`[aiTaskWorker] refund error for unrecoverable task ${taskId}:`, err.message);
});
console.warn(`[aiTaskWorker] released unrecoverable task ${taskId}: ${message}`);
return true;
}
// --- Periodic stale task cleanup --- // --- Periodic stale task cleanup ---
// Runs every 5 minutes, marks tasks stuck in 'pending'/'running' for too long as 'failed'. // Runs every 5 minutes, marks tasks stuck in 'pending'/'running' for too long as 'failed'.
// This catches cases where the worker crashed, the provider API never responded, // This catches cases where the worker crashed, the provider API never responded,
@@ -1071,32 +939,26 @@ async function runStaleTaskCleanup() {
try { try {
const { rows } = await pool.query( const { rows } = await pool.query(
`UPDATE generation_tasks `UPDATE generation_tasks
SET status = 'failed', error = '任务超时自动释放', completed_at = NOW(), updated_at = NOW() SET status = 'failed', error = '任务超时自动释放', updated_at = NOW()
WHERE status IN ('pending', 'running') WHERE status IN ('pending', 'running')
AND GREATEST(updated_at, COALESCE(last_poll_at, created_at)) < NOW() - INTERVAL '10 minutes' AND GREATEST(updated_at, COALESCE(last_poll_at, created_at)) < NOW() - INTERVAL '10 minutes'
RETURNING *`, RETURNING id`,
); );
for (const row of rows) { for (const row of rows) {
await publishTaskEvent({
taskId: row.id,
status: "failed",
progress: null,
resultUrl: null,
error: "任务超时自动释放",
});
// Also stop any active poller for this task // Also stop any active poller for this task
const poller = activePollers.get(row.id); const poller = activePollers.get(row.id);
if (poller) { if (poller) {
clearInterval(poller.interval); clearInterval(poller.interval);
activePollers.delete(row.id); activePollers.delete(row.id);
} }
const leaseToken = poller?.leaseToken || await getPersistedLeaseToken(row.id).catch(() => null);
await clearPollerState(row.id).catch(() => {}); await clearPollerState(row.id).catch(() => {});
if (leaseToken) {
await require("./keyManager").releaseKey(leaseToken).catch((err) => {
console.error(`[aiTaskWorker] failed to release lease for stale task ${row.id}:`, err.message);
});
}
await publishTaskEvent(formatTaskEvent(row));
await createTaskLifecycleNotification(row).catch((err) => {
console.error(`[aiTaskWorker] notification error for stale task ${row.id}:`, err.message);
});
await refundTaskBillingOnFailure(row.id).catch((err) => {
console.error(`[aiTaskWorker] refund error for stale task ${row.id}:`, err.message);
});
} }
if (rows.length > 0) { if (rows.length > 0) {
console.log(`[aiTaskWorker] Cleaned up ${rows.length} stale task(s)`); console.log(`[aiTaskWorker] Cleaned up ${rows.length} stale task(s)`);
@@ -1170,14 +1032,10 @@ function startStaleTaskCleanup() {
if (staleTaskCleanupTimer) return; if (staleTaskCleanupTimer) return;
staleTaskCleanupTimer = setInterval(runStaleTaskCleanup, STALE_TASK_CLEANUP_INTERVAL_MS); staleTaskCleanupTimer = setInterval(runStaleTaskCleanup, STALE_TASK_CLEANUP_INTERVAL_MS);
// Run once shortly after startup // Run once shortly after startup
staleTaskCleanupStartupTimer = setTimeout(runStaleTaskCleanup, 10_000); setTimeout(runStaleTaskCleanup, 10_000);
} }
function stopStaleTaskCleanup() { function stopStaleTaskCleanup() {
if (staleTaskCleanupStartupTimer) {
clearTimeout(staleTaskCleanupStartupTimer);
staleTaskCleanupStartupTimer = null;
}
if (staleTaskCleanupTimer) { if (staleTaskCleanupTimer) {
clearInterval(staleTaskCleanupTimer); clearInterval(staleTaskCleanupTimer);
staleTaskCleanupTimer = null; staleTaskCleanupTimer = null;
@@ -1203,21 +1061,9 @@ function stopPollerRecovery() {
} }
} }
async function stopAllPollers() {
for (const [taskId, poller] of activePollers.entries()) {
clearInterval(poller.interval);
activePollers.delete(taskId);
}
await orphanOwnedPollerState().catch((err) => {
console.error("[aiTaskWorker] failed to orphan owned poller state:", err.message);
});
}
module.exports = { module.exports = {
startPolling, startPolling,
stopPolling, stopPolling,
stopAllPollers,
cancelTaskRuntimeState,
updateTaskInDb, updateTaskInDb,
getActiveCount, getActiveCount,
extractProviderTaskId, extractProviderTaskId,
+20 -20
View File
@@ -6,6 +6,7 @@ const { getJwtSecret } = require("./securityConfig");
const JWT_SECRET = getJwtSecret(); const JWT_SECRET = getJwtSecret();
const JWT_EXPIRES_IN = process.env.JWT_EXPIRES_IN || "7d"; const JWT_EXPIRES_IN = process.env.JWT_EXPIRES_IN || "7d";
const MAX_CONCURRENT_SESSIONS = 2;
const USER_CONTEXT_SELECT = ` const USER_CONTEXT_SELECT = `
SELECT SELECT
@@ -169,26 +170,25 @@ function verifyToken(token) {
async function startUserSession(userId, userAgent) { async function startUserSession(userId, userAgent) {
const sessionId = crypto.randomUUID(); const sessionId = crypto.randomUUID();
const client = await pool.connect(); await pool.query(
try { "INSERT INTO user_sessions (id, user_id, user_agent, created_at) VALUES ($1, $2, $3, NOW())",
await client.query("BEGIN"); [sessionId, userId, userAgent || null],
await client.query("SELECT id FROM users WHERE id = $1 FOR UPDATE", [userId]); );
await client.query("DELETE FROM user_sessions WHERE user_id = $1", [userId]); await pool.query(
await client.query( `DELETE FROM user_sessions
"INSERT INTO user_sessions (id, user_id, user_agent, created_at) VALUES ($1, $2, $3, NOW())", WHERE user_id = $1
[sessionId, userId, userAgent || null], AND id NOT IN (
); SELECT id FROM user_sessions
await client.query( WHERE user_id = $1
"UPDATE users SET current_session_id = $1, current_session_started_at = NOW(), updated_at = NOW() WHERE id = $2", ORDER BY created_at DESC
[sessionId, userId], LIMIT $2
); )`,
await client.query("COMMIT"); [userId, MAX_CONCURRENT_SESSIONS],
} catch (error) { );
await client.query("ROLLBACK"); await pool.query(
throw error; "UPDATE users SET current_session_id = $1, current_session_started_at = NOW(), updated_at = NOW() WHERE id = $2",
} finally { [sessionId, userId],
client.release(); );
}
return sessionId; return sessionId;
} }
+1 -3
View File
@@ -18,11 +18,9 @@ const { pool, withTransaction } = require("./db");
const { calculateCostMills, getModelPrice } = require("./pricing"); const { calculateCostMills, getModelPrice } = require("./pricing");
const CREDIT_UNITS_PER_CREDIT = 100; const CREDIT_UNITS_PER_CREDIT = 100;
const CREDITS_PER_CNY = 100;
const CREDIT_UNITS_PER_CNY_CENT = 100; const CREDIT_UNITS_PER_CNY_CENT = 100;
const CREDIT_UNITS_PER_CNY_MILL = 10; const CREDIT_UNITS_PER_CNY_MILL = 10;
const IMAGE_GENERATION_FLAT_COST_CREDITS = 20; const IMAGE_GENERATION_FLAT_COST_CENTS = 20 * CREDIT_UNITS_PER_CREDIT;
const IMAGE_GENERATION_FLAT_COST_CENTS = IMAGE_GENERATION_FLAT_COST_CREDITS * CREDIT_UNITS_PER_CREDIT;
function creditsToCreditUnits(credits) { function creditsToCreditUnits(credits) {
return Math.max(0, Math.round(Number(credits || 0) * CREDIT_UNITS_PER_CREDIT)); return Math.max(0, Math.round(Number(credits || 0) * CREDIT_UNITS_PER_CREDIT));
+29 -88
View File
@@ -19,56 +19,6 @@ const ENTERPRISE_VIDEO_ALLOWED_MODELS = new Set([
"pixverse-c1-i2v", "pixverse-c1-i2v",
]); ]);
const CREDITS_PER_CNY = 100;
const CREDIT_UNITS_PER_CREDIT = 100;
const ENTERPRISE_VIDEO_RESOLUTIONS = ["720P", "1080P"];
const ENTERPRISE_VIDEO_DEFAULT_RESOLUTION = "1080P";
const ENTERPRISE_VIDEO_PRICING_RULES = [
{
id: "happyhorse",
modelIncludes: ["happyhorse"],
rates: { "720P": 0.72, "1080P": 1.28 },
},
{
id: "wanxiang-i2v",
modelIncludes: ["wan2.7-i2v", "wanxiang"],
rates: { "720P": 0.6, "1080P": 1 },
},
{
id: "wan-animate-s2v",
modelIncludes: ["animate-mix", "s2v"],
rates: { "720P": 0.6, "1080P": 1 },
},
{
id: "kling-muted-reference",
modelIncludes: ["kling"],
when: { muted: true, hasReferenceVideo: true },
rates: { "720P": 0.9, "1080P": 1.2 },
},
{
id: "kling-muted",
modelIncludes: ["kling"],
when: { muted: true, hasReferenceVideo: false },
rates: { "720P": 0.6, "1080P": 0.8 },
},
{
id: "kling-default",
modelIncludes: ["kling"],
rates: { "720P": 0.9, "1080P": 1.2 },
},
{
id: "vidu",
modelIncludes: ["vidu"],
rates: { "720P": 0.6, "1080P": 1 },
},
{
id: "pixverse",
modelIncludes: ["pixverse"],
rates: { "720P": 0.6, "1080P": 1 },
},
];
function normalizeModel(value) { function normalizeModel(value) {
return String(value || "").trim().toLowerCase(); return String(value || "").trim().toLowerCase();
} }
@@ -83,21 +33,6 @@ function normalizeEnterpriseVideoDuration(value) {
return Math.max(1, Math.ceil(numeric)); return Math.max(1, Math.ceil(numeric));
} }
function enterpriseVideoPricingRuleMatches(rule, input, model) {
if (!rule.modelIncludes.some((pattern) => model.includes(pattern))) return false;
if (!rule.when) return true;
if (Object.prototype.hasOwnProperty.call(rule.when, "muted") && Boolean(input.muted) !== rule.when.muted) {
return false;
}
if (
Object.prototype.hasOwnProperty.call(rule.when, "hasReferenceVideo") &&
Boolean(input.hasReferenceVideo) !== rule.when.hasReferenceVideo
) {
return false;
}
return true;
}
function isEnterpriseVideoBillingUser(user) { function isEnterpriseVideoBillingUser(user) {
return Boolean(user?.enterpriseId); return Boolean(user?.enterpriseId);
} }
@@ -131,10 +66,33 @@ function getEnterpriseVideoCreditRate(input) {
const resolution = normalizeEnterpriseVideoResolution(input.resolution || input.quality); const resolution = normalizeEnterpriseVideoResolution(input.resolution || input.quality);
const model = normalizeModel(input.model || input.requestedModel); const model = normalizeModel(input.model || input.requestedModel);
const rule = ENTERPRISE_VIDEO_PRICING_RULES.find((candidate) => if (model.includes("happyhorse")) {
enterpriseVideoPricingRuleMatches(candidate, input, model), return resolution === "720P" ? 0.72 : 1.28;
); }
if (rule) return rule.rates[resolution] ?? rule.rates[ENTERPRISE_VIDEO_DEFAULT_RESOLUTION];
if (model.includes("wan2.7-i2v") || model.includes("wanxiang")) {
return resolution === "720P" ? 0.6 : 1;
}
if (model.includes("animate-mix") || model.includes("s2v")) {
return resolution === "720P" ? 0.6 : 1;
}
if (model.includes("kling")) {
if (input.muted) {
if (input.hasReferenceVideo) return resolution === "720P" ? 0.9 : 1.2;
return resolution === "720P" ? 0.6 : 0.8;
}
return resolution === "720P" ? 0.9 : 1.2;
}
if (model.includes("vidu")) {
return resolution === "720P" ? 0.6 : 1.0;
}
if (model.includes("pixverse")) {
return resolution === "720P" ? 0.6 : 1.0;
}
const error = new Error(`Unsupported enterprise video model: ${input.model || input.requestedModel}`); const error = new Error(`Unsupported enterprise video model: ${input.model || input.requestedModel}`);
error.status = 403; error.status = 403;
@@ -142,25 +100,9 @@ function getEnterpriseVideoCreditRate(input) {
throw error; throw error;
} }
function getEnterpriseVideoPricingConfig() {
return {
currency: "CNY",
creditsPerCny: CREDITS_PER_CNY,
billingUnit: "per_second",
defaultResolution: ENTERPRISE_VIDEO_DEFAULT_RESOLUTION,
resolutions: [...ENTERPRISE_VIDEO_RESOLUTIONS],
rules: ENTERPRISE_VIDEO_PRICING_RULES.map((rule) => ({
id: rule.id,
modelIncludes: [...rule.modelIncludes],
when: rule.when ? { ...rule.when } : undefined,
rates: { ...rule.rates },
})),
};
}
function calculateEnterpriseVideoCredits(input) { function calculateEnterpriseVideoCredits(input) {
const duration = normalizeEnterpriseVideoDuration(input.durationSeconds || input.duration); const duration = normalizeEnterpriseVideoDuration(input.durationSeconds || input.duration);
return Number((getEnterpriseVideoCreditRate(input) * duration * CREDITS_PER_CNY).toFixed(2)); return Number((getEnterpriseVideoCreditRate(input) * duration).toFixed(2));
} }
function calculateEnterpriseVideoCost(input) { function calculateEnterpriseVideoCost(input) {
@@ -171,7 +113,7 @@ function calculateEnterpriseVideoCost(input) {
resolution, resolution,
durationSeconds, durationSeconds,
}); });
const rateCentsPerSecond = Math.round(rateCreditsPerSecond * CREDITS_PER_CNY * CREDIT_UNITS_PER_CREDIT); const rateCentsPerSecond = Math.round(rateCreditsPerSecond * 100);
return { return {
resolution, resolution,
durationSeconds, durationSeconds,
@@ -288,7 +230,6 @@ module.exports = {
assertEnterpriseVideoModelAllowed, assertEnterpriseVideoModelAllowed,
calculateEnterpriseVideoCost, calculateEnterpriseVideoCost,
calculateEnterpriseVideoCredits, calculateEnterpriseVideoCredits,
getEnterpriseVideoPricingConfig,
getEnterpriseVideoCreditRate, getEnterpriseVideoCreditRate,
isEnterpriseVideoBillingUser, isEnterpriseVideoBillingUser,
isEnterpriseVideoModelAllowed, isEnterpriseVideoModelAllowed,
+25 -50
View File
@@ -3,17 +3,8 @@ const express = require('express')
const rateLimit = require('express-rate-limit') const rateLimit = require('express-rate-limit')
const cors = require('cors') const cors = require('cors')
const helmet = require('helmet') const helmet = require('helmet')
const { startSettlementWorker, stopSettlementWorker } = require('./settlementWorker') const { startSettlementWorker } = require('./settlementWorker')
const { startProviderHealthMonitor, stopProviderHealthMonitor } = require('./providerHealthMonitor') const { startProviderHealthMonitor } = require('./providerHealthMonitor')
const {
startStaleTaskCleanup,
startTaskEventListener,
startPollerRecovery,
stopStaleTaskCleanup,
stopTaskEventListener,
stopPollerRecovery,
stopAllPollers,
} = require('./aiTaskWorker')
const { ensureDatabase } = require('./dbSetup') const { ensureDatabase } = require('./dbSetup')
const { assertRuntimeSecurityConfig } = require('./securityConfig') const { assertRuntimeSecurityConfig } = require('./securityConfig')
const { loadPriceCache } = require('./pricing') const { loadPriceCache } = require('./pricing')
@@ -26,7 +17,6 @@ const PORT = Number(process.env.PORT) || 3600
const HOST = process.env.HOST || '0.0.0.0' const HOST = process.env.HOST || '0.0.0.0'
const IS_PRODUCTION = process.env.NODE_ENV === 'production' const IS_PRODUCTION = process.env.NODE_ENV === 'production'
let server = null let server = null
let staleLeaseCleanupTimer = null
// CORS: in production, require explicit allowlist; in dev, allow all with credentials // CORS: in production, require explicit allowlist; in dev, allow all with credentials
function buildCorsOptions() { function buildCorsOptions() {
@@ -143,18 +133,18 @@ async function main() {
// Periodic stale lease cleanup (every 5 min) // Periodic stale lease cleanup (every 5 min)
const { cleanStaleLeases } = require('./keyManager') const { cleanStaleLeases } = require('./keyManager')
staleLeaseCleanupTimer = setInterval(() => { setInterval(() => {
cleanStaleLeases().then((cleaned) => { cleanStaleLeases().then((cleaned) => {
if (cleaned > 0) console.log(`[cleanup] Released ${cleaned} stale lease(s)`) if (cleaned > 0) console.log(`[cleanup] Released ${cleaned} stale lease(s)`)
}).catch((err) => { }).catch((err) => {
console.error('[cleanup] error:', err) console.error('[cleanup] error:', err)
}) })
}, 5 * 60 * 1000) }, 5 * 60 * 1000)
if (staleLeaseCleanupTimer.unref) staleLeaseCleanupTimer.unref()
startSettlementWorker() startSettlementWorker()
startProviderHealthMonitor() startProviderHealthMonitor()
const { startStaleTaskCleanup, startTaskEventListener, startPollerRecovery } = require('./aiTaskWorker')
await startTaskEventListener() await startTaskEventListener()
startPollerRecovery() startPollerRecovery()
startStaleTaskCleanup() startStaleTaskCleanup()
@@ -185,47 +175,32 @@ process.on('uncaughtException', (err) => {
// ── Graceful shutdown ─────────────────────────────────────────────────── // ── Graceful shutdown ───────────────────────────────────────────────────
let shuttingDown = false let shuttingDown = false
async function shutdownRuntimeState() { function gracefulShutdown(signal) {
if (staleLeaseCleanupTimer) {
clearInterval(staleLeaseCleanupTimer)
staleLeaseCleanupTimer = null
}
stopSettlementWorker()
stopProviderHealthMonitor()
stopPollerRecovery()
stopStaleTaskCleanup()
await Promise.allSettled([stopTaskEventListener(), stopAllPollers()])
}
function closeServer() {
if (!server || !server.listening) return Promise.resolve()
return new Promise((resolve) => {
server.close(() => {
console.log('[shutdown] Server closed, cleaning up...')
resolve()
})
})
}
async function gracefulShutdown(signal) {
if (shuttingDown) return if (shuttingDown) return
shuttingDown = true shuttingDown = true
console.log('[shutdown] Received ' + signal + ', draining connections...') console.log('[shutdown] Received ' + signal + ', draining connections...')
setTimeout(() => { if (server && server.listening) {
console.error('[shutdown] Forced exit after timeout') server.close(() => {
process.exit(1) console.log('[shutdown] Server closed, cleaning up...')
}, 15000).unref() const { stopProviderHealthMonitor } = require('./providerHealthMonitor')
stopProviderHealthMonitor()
const { stopTaskEventListener, stopPollerRecovery } = require('./aiTaskWorker')
stopPollerRecovery()
void stopTaskEventListener()
const { pool } = require('./db')
pool.end().then(() => {
console.log('[shutdown] Database pool closed')
process.exit(0)
}).catch(() => process.exit(0))
})
try { // Force exit after timeout
await shutdownRuntimeState() setTimeout(() => {
await closeServer() console.error('[shutdown] Forced exit after timeout')
const { pool } = require('./db') process.exit(1)
await pool.end() }, 15000).unref()
console.log('[shutdown] Database pool closed') } else {
process.exit(0)
} catch (err) {
console.error('[shutdown] error:', err)
process.exit(0) process.exit(0)
} }
} }
+3 -9
View File
@@ -284,7 +284,7 @@ async function releaseLeaseInternal(leaseToken, user, options = {}) {
const { rows } = await client.query( const { rows } = await client.query(
` `
WITH candidate AS ( WITH candidate AS (
SELECT l.id, l.key_id, l.user_id, l.enterprise_id, k.provider SELECT l.id, l.key_id, k.provider
FROM key_leases l FROM key_leases l
JOIN api_keys k ON k.id = l.key_id JOIN api_keys k ON k.id = l.key_id
WHERE l.lease_token = $1 AND l.released_at IS NULL WHERE l.lease_token = $1 AND l.released_at IS NULL
@@ -297,7 +297,7 @@ async function releaseLeaseInternal(leaseToken, user, options = {}) {
WHERE id = (SELECT id FROM candidate) WHERE id = (SELECT id FROM candidate)
RETURNING id, key_id RETURNING id, key_id
) )
SELECT r.id, r.key_id, c.user_id AS lease_user_id, c.enterprise_id AS lease_enterprise_id, c.provider SELECT r.id, r.key_id, c.provider
FROM released r FROM released r
JOIN candidate c ON c.key_id = r.key_id JOIN candidate c ON c.key_id = r.key_id
`, `,
@@ -339,13 +339,7 @@ async function releaseLeaseInternal(leaseToken, user, options = {}) {
INSERT INTO usage_logs (user_id, enterprise_id, provider, key_id, action) INSERT INTO usage_logs (user_id, enterprise_id, provider, key_id, action)
VALUES ($1, $2, (SELECT provider FROM api_keys WHERE id = $3), $4, $5) VALUES ($1, $2, (SELECT provider FROM api_keys WHERE id = $3), $4, $5)
`, `,
[ [userId, enterpriseId, lease.key_id, lease.key_id, "release"],
userId || lease.lease_user_id,
enterpriseId || lease.lease_enterprise_id,
lease.key_id,
lease.key_id,
"release",
],
); );
return { return {
+4 -17
View File
@@ -7,7 +7,6 @@
*/ */
const { pool } = require("./db"); const { pool } = require("./db");
const { recordProviderSuccess, recordProviderFailure, getAllBreakerStats } = require("./providerCircuitBreaker");
const CHECK_INTERVAL_MS = 5 * 60 * 1000; const CHECK_INTERVAL_MS = 5 * 60 * 1000;
const DASHSCOPE_TEST_MODEL = "qwen-max"; const DASHSCOPE_TEST_MODEL = "qwen-max";
@@ -22,15 +21,6 @@ const providerHealthCache = {
grsai: { status: "unknown", lastCheck: null, lastError: null, details: null }, grsai: { status: "unknown", lastCheck: null, lastError: null, details: null },
}; };
function recordProbeOutcome(provider, result, latencyMs) {
if (!provider) return;
if (result?.ok) {
recordProviderSuccess(provider, latencyMs);
} else {
recordProviderFailure(provider);
}
}
async function getDashScopeKey() { async function getDashScopeKey() {
const { rows } = await pool.query( const { rows } = await pool.query(
"SELECT id, api_key FROM api_keys WHERE provider LIKE '%dashscope%' AND enabled = 1 ORDER BY id LIMIT 1" "SELECT id, api_key FROM api_keys WHERE provider LIKE '%dashscope%' AND enabled = 1 ORDER BY id LIMIT 1"
@@ -130,10 +120,8 @@ async function runHealthCheck() {
// ── DashScope ── // ── DashScope ──
const dashKey = await getDashScopeKey(); const dashKey = await getDashScopeKey();
if (dashKey) { if (dashKey) {
const startedAt = Date.now();
try { try {
const result = await probeDashScope(dashKey); const result = await probeDashScope(dashKey);
recordProbeOutcome("dashscope", result, Date.now() - startedAt);
const prev = providerHealthCache.dashscope.status; const prev = providerHealthCache.dashscope.status;
providerHealthCache.dashscope = { providerHealthCache.dashscope = {
status: result.status, status: result.status,
@@ -156,7 +144,6 @@ async function runHealthCheck() {
} }
} }
} catch (err) { } catch (err) {
recordProviderFailure("dashscope");
providerHealthCache.dashscope = { providerHealthCache.dashscope = {
status: "timeout", status: "timeout",
lastCheck: new Date().toISOString(), lastCheck: new Date().toISOString(),
@@ -177,10 +164,8 @@ async function runHealthCheck() {
// ── GrsAI ── // ── GrsAI ──
const grsaiKey = await getGrsaiKey(); const grsaiKey = await getGrsaiKey();
if (grsaiKey) { if (grsaiKey) {
const startedAt = Date.now();
try { try {
const result = await probeGrsai(grsaiKey); const result = await probeGrsai(grsaiKey);
recordProbeOutcome("grsai", result, Date.now() - startedAt);
const prev = providerHealthCache.grsai.status; const prev = providerHealthCache.grsai.status;
providerHealthCache.grsai = { providerHealthCache.grsai = {
status: result.status, status: result.status,
@@ -201,7 +186,6 @@ async function runHealthCheck() {
} }
} }
} catch (err) { } catch (err) {
recordProviderFailure("grsai");
providerHealthCache.grsai = { providerHealthCache.grsai = {
status: "timeout", status: "timeout",
lastCheck: new Date().toISOString(), lastCheck: new Date().toISOString(),
@@ -220,7 +204,10 @@ async function runHealthCheck() {
} }
// ── Circuit breaker summary ── // ── Circuit breaker summary ──
providerHealthCache.circuitBreaker = getAllBreakerStats(); try {
const cb = require("./providerCircuitBreaker");
providerHealthCache.circuitBreaker = cb.getProviderStatusMap ? cb.getProviderStatusMap() : null;
} catch {}
// ── Admin low-balance alert ── // ── Admin low-balance alert ──
try { try {
-120
View File
@@ -1,120 +0,0 @@
"use strict";
const crypto = require("node:crypto");
const { pool } = require("./db");
const DEFAULT_MAX_CONCURRENCY = 8;
const DEFAULT_SLOT_TTL_MS = 30_000;
const POLL_SCOPE = "generation-provider-poll:global";
const OWNER_ID = `${process.pid}-${crypto.randomUUID()}`;
let storeReady = null;
function normalizePositiveInteger(value, fallback) {
const numeric = Number(value);
if (!Number.isFinite(numeric) || numeric <= 0) return fallback;
return Math.max(1, Math.trunc(numeric));
}
function getMaxConcurrency() {
return normalizePositiveInteger(process.env.TASK_PROVIDER_POLL_MAX_CONCURRENCY, DEFAULT_MAX_CONCURRENCY);
}
function getSlotTtlInterval() {
const ttlMs = normalizePositiveInteger(process.env.TASK_PROVIDER_POLL_SLOT_TTL_MS, DEFAULT_SLOT_TTL_MS);
return `${Math.max(1, Math.ceil(ttlMs / 1000))} seconds`;
}
async function ensureProviderPollLimiterStore() {
if (storeReady) return storeReady;
storeReady = pool.query(`
CREATE TABLE IF NOT EXISTS generation_provider_poll_slots (
scope TEXT NOT NULL,
slot_no INTEGER NOT NULL,
owner_id TEXT NOT NULL,
task_id INTEGER,
expires_at TIMESTAMPTZ NOT NULL,
acquired_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (scope, slot_no)
);
CREATE INDEX IF NOT EXISTS idx_generation_provider_poll_slots_expires
ON generation_provider_poll_slots(expires_at);
`).catch((err) => {
storeReady = null;
throw err;
});
return storeReady;
}
async function acquireProviderPollSlot(taskId, options = {}) {
await ensureProviderPollLimiterStore();
const scope = options.scope || POLL_SCOPE;
const maxConcurrency = normalizePositiveInteger(options.maxConcurrency, getMaxConcurrency());
const ttlInterval = options.ttlInterval || getSlotTtlInterval();
const { rows } = await pool.query(
`
WITH candidate AS (
SELECT s.slot_no
FROM generate_series(1, $2::integer) AS s(slot_no)
LEFT JOIN generation_provider_poll_slots l
ON l.scope = $1 AND l.slot_no = s.slot_no
WHERE l.scope IS NULL OR l.expires_at < NOW()
ORDER BY s.slot_no ASC
LIMIT 1
),
claimed AS (
INSERT INTO generation_provider_poll_slots (
scope, slot_no, owner_id, task_id, expires_at, acquired_at, updated_at
)
SELECT $1, slot_no, $3, $4, NOW() + ($5::text)::interval, NOW(), NOW()
FROM candidate
ON CONFLICT (scope, slot_no) DO UPDATE SET
owner_id = EXCLUDED.owner_id,
task_id = EXCLUDED.task_id,
expires_at = EXCLUDED.expires_at,
acquired_at = NOW(),
updated_at = NOW()
WHERE generation_provider_poll_slots.expires_at < NOW()
RETURNING scope, slot_no
)
SELECT scope, slot_no FROM claimed
`,
[scope, maxConcurrency, OWNER_ID, taskId || null, ttlInterval],
);
const slot = rows[0];
return slot ? { scope: slot.scope, slotNo: slot.slot_no, ownerId: OWNER_ID } : null;
}
async function releaseProviderPollSlot(slot) {
if (!slot?.scope || !slot?.slotNo) return;
await ensureProviderPollLimiterStore();
await pool.query(
"DELETE FROM generation_provider_poll_slots WHERE scope = $1 AND slot_no = $2 AND owner_id = $3",
[slot.scope, slot.slotNo, slot.ownerId || OWNER_ID],
);
}
async function withProviderPollSlot(taskId, fn, options = {}) {
const slot = await acquireProviderPollSlot(taskId, options);
if (!slot) return { acquired: false, value: undefined };
try {
return { acquired: true, value: await fn() };
} finally {
await releaseProviderPollSlot(slot).catch((err) => {
console.error(`[providerPollLimiter] failed to release poll slot ${slot.scope}:${slot.slotNo}:`, err.message);
});
}
}
module.exports = {
acquireProviderPollSlot,
ensureProviderPollLimiterStore,
getMaxConcurrency,
normalizePositiveInteger,
releaseProviderPollSlot,
withProviderPollSlot,
};
+16 -124
View File
@@ -4,7 +4,7 @@ const crypto = require("node:crypto");
const { requireAuth, keyManager, preauthorizeCall, pool, withTransaction, deductImageGenerationCredits } = require("./context"); const { requireAuth, keyManager, preauthorizeCall, pool, withTransaction, deductImageGenerationCredits } = require("./context");
const { putObject, isOssConfigured } = require("../ossClient"); const { putObject, isOssConfigured } = require("../ossClient");
const { buildImageProviderDebug, resolveImageProviderCandidates, resolveVideoProvider, resolveTextProvider, getPostUrl } = require("../aiProviderRouter"); const { buildImageProviderDebug, resolveImageProviderCandidates, resolveVideoProvider, resolveTextProvider, getPostUrl } = require("../aiProviderRouter");
const { shouldSkipProvider, recordProviderSuccess, recordProviderFailure, getAdaptiveTimeout } = require("../providerCircuitBreaker"); const { shouldSkipProvider, recordProviderSuccess, recordProviderFailure } = require("../providerCircuitBreaker");
const { const {
isEnterpriseVideoBillingUser, isEnterpriseVideoBillingUser,
markEnterpriseVideoCreditsAccepted, markEnterpriseVideoCreditsAccepted,
@@ -16,7 +16,6 @@ const {
} = require("../enterpriseVideoBilling"); } = require("../enterpriseVideoBilling");
const { const {
startPolling, startPolling,
cancelTaskRuntimeState,
updateTaskInDb, updateTaskInDb,
extractProviderTaskId, extractProviderTaskId,
extractImageUrl, extractImageUrl,
@@ -60,7 +59,6 @@ function toViapiAccessibleUrl(url) {
const SUPER_RESOLVE_POLL_INTERVAL_MS = 3000; const SUPER_RESOLVE_POLL_INTERVAL_MS = 3000;
const SUPER_RESOLVE_MAX_POLL_ATTEMPTS = 200; const SUPER_RESOLVE_MAX_POLL_ATTEMPTS = 200;
const IMAGE_PROVIDER_SUBMIT_TIMEOUT_MS = 90_000; const IMAGE_PROVIDER_SUBMIT_TIMEOUT_MS = 90_000;
const GRSAI_IMAGE_SUBMIT_TIMEOUT_MS = Number(process.env.GRSAI_IMAGE_SUBMIT_TIMEOUT_MS || 30_000);
const GEMINI_IMAGE_SUBMIT_TIMEOUT_MS = 180_000; const GEMINI_IMAGE_SUBMIT_TIMEOUT_MS = 180_000;
const DASHSCOPE_VIDEO_STYLE_ENDPOINT = "https://dashscope.aliyuncs.com/api/v1/services/aigc/video-generation/video-synthesis"; const DASHSCOPE_VIDEO_STYLE_ENDPOINT = "https://dashscope.aliyuncs.com/api/v1/services/aigc/video-generation/video-synthesis";
const DASHSCOPE_IMAGE_EDIT_ENDPOINT = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2image/image-synthesis"; const DASHSCOPE_IMAGE_EDIT_ENDPOINT = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2image/image-synthesis";
@@ -99,18 +97,6 @@ function clampImageQualityForModel(model, quality) {
return normalized; return normalized;
} }
function isDashscopeWan27Limited2KScene(params) {
const model = String(params?.model || "").toLowerCase();
if (model !== "wan2.7-image-pro") return false;
const hasReferenceImages = Array.isArray(params.referenceUrls) && params.referenceUrls.some(Boolean);
return hasReferenceImages || getGridCount(params.gridMode) > 1;
}
function resolveDashscopeImageQuality(params) {
const quality = clampImageQualityForModel(params.model, params.quality);
return isDashscopeWan27Limited2KScene(params) && quality === "4K" ? "2K" : quality;
}
function clampGrsaiImageQualityForModel(model, quality) { function clampGrsaiImageQualityForModel(model, quality) {
const normalized = normalizeQuality(quality, "1K"); const normalized = normalizeQuality(quality, "1K");
const maxQuality = GRSAI_IMAGE_MAX_QUALITY.get(String(model || "").toLowerCase()); const maxQuality = GRSAI_IMAGE_MAX_QUALITY.get(String(model || "").toLowerCase());
@@ -348,25 +334,18 @@ async function assertUserGenerationConcurrencyLimit(userId, client = pool) {
[userId], [userId],
); );
const { rows: limitRows } = await client.query(
"SELECT max_concurrency FROM users WHERE id = $1",
[userId],
);
const rawLimit = Number(limitRows[0]?.max_concurrency);
const concurrencyLimit = Number.isFinite(rawLimit) && rawLimit > 0 ? rawLimit : MAX_USER_ACTIVE_GENERATION_TASKS;
const { rows } = await client.query( const { rows } = await client.query(
"SELECT COUNT(*)::int AS active_count FROM generation_tasks WHERE user_id = $1 AND status IN ('pending', 'running')", "SELECT COUNT(*)::int AS active_count FROM generation_tasks WHERE user_id = $1 AND status IN ('pending', 'running')",
[userId], [userId],
); );
const activeCount = Number(rows[0]?.active_count ?? rows[0]?.count ?? 0); const activeCount = Number(rows[0]?.active_count ?? rows[0]?.count ?? 0);
if (activeCount < concurrencyLimit) return; if (activeCount < MAX_USER_ACTIVE_GENERATION_TASKS) return;
const error = new Error(`最多只能同时进行${concurrencyLimit}个任务`); const error = new Error(GENERATION_CONCURRENCY_LIMIT_MESSAGE);
error.status = 429; error.status = 429;
error.code = "GENERATION_CONCURRENCY_LIMIT"; error.code = "GENERATION_CONCURRENCY_LIMIT";
error.activeCount = activeCount; error.activeCount = activeCount;
error.maxActiveTasks = concurrencyLimit; error.maxActiveTasks = MAX_USER_ACTIVE_GENERATION_TASKS;
throw error; throw error;
} }
@@ -490,22 +469,17 @@ function buildDashscopeImageBody(params) {
if (url) content.push({ image: url }); if (url) content.push({ image: url });
} }
content.push({ text: params.prompt }); content.push({ text: params.prompt });
const quality = resolveDashscopeImageQuality(params); const quality = clampImageQualityForModel(params.model, params.quality);
const gridCount = getGridCount(params.gridMode);
const parameters = {
size: mapAspectRatioToDashscopeSize(params.ratio, quality),
n: gridCount,
watermark: false,
};
if (gridCount > 1) {
parameters.enable_sequential = true;
}
return { return {
model: params.model, model: params.model,
input: { input: {
messages: [{ role: "user", content }], messages: [{ role: "user", content }],
}, },
parameters, parameters: {
size: mapAspectRatioToDashscopeSize(params.ratio, quality),
n: params.gridMode === "grid-4" ? 4 : params.gridMode === "grid-9" ? 9 : 1,
watermark: false,
},
}; };
} }
@@ -1073,16 +1047,6 @@ function registerAiRoutes(router) {
error.costCents = billingResult.costCents; error.costCents = billingResult.costCents;
throw error; throw error;
} }
if (billingResult.costCents > 0) {
await client.query(
"UPDATE generation_tasks SET cost_cents = $1, billing_target = $2, billing_refunded = 0, updated_at = NOW() WHERE id = $3",
[
billingResult.costCents,
billingResult.deductionType === "enterprise_image_flat" ? "enterprise_image" : "user",
nextTaskRow.id,
],
);
}
return { taskRow: nextTaskRow, imageBilling: billingResult }; return { taskRow: nextTaskRow, imageBilling: billingResult };
}); });
const preauth = { authorized: true, estimatedCostCents: 0, billingMode: imageBilling.deductionType }; const preauth = { authorized: true, estimatedCostCents: 0, billingMode: imageBilling.deductionType };
@@ -1097,11 +1061,9 @@ function registerAiRoutes(router) {
}, },
providerDebug: buildImageProviderDebug(model), providerDebug: buildImageProviderDebug(model),
}); });
submitImageWithProviderFallback(taskRow.id, providerCandidates, req.user, preauth, params).catch(async (err) => { submitImageWithProviderFallback(taskRow.id, providerCandidates, req.user, preauth, params).catch((err) => {
console.error("[ai/image] submit error:", err.message); console.error("[ai/image] submit error:", err.message);
await updateTaskInDb(taskRow.id, { status: "failed", error: err.message }).catch((updateErr) => { updateTaskInDb(taskRow.id, { status: "failed", error: err.message });
console.error(`[ai/image] failed to persist task ${taskRow.id} failure:`, updateErr.message);
});
}); });
} catch (err) { } catch (err) {
console.error("[ai/image] error:", err.message); console.error("[ai/image] error:", err.message);
@@ -1213,10 +1175,6 @@ function registerAiRoutes(router) {
...enterpriseBilling, ...enterpriseBilling,
taskId: nextTaskRow.id, taskId: nextTaskRow.id,
}); });
await client.query(
"UPDATE generation_tasks SET cost_cents = $1, billing_target = 'enterprise_video', billing_refunded = 0, updated_at = NOW() WHERE id = $2",
[nextBilling.amountCents, nextTaskRow.id],
);
return { taskRow: nextTaskRow, reservedBilling: nextBilling, regularBilling: null }; return { taskRow: nextTaskRow, reservedBilling: nextBilling, regularBilling: null };
} }
// Regular user: deduct from personal balance // Regular user: deduct from personal balance
@@ -1239,10 +1197,6 @@ function registerAiRoutes(router) {
"INSERT INTO transactions (user_id, type, amount_cents, balance_after_cents, description) VALUES ($1, 'deduct', $2, $3, $4)", "INSERT INTO transactions (user_id, type, amount_cents, balance_after_cents, description) VALUES ($1, 'deduct', $2, $3, $4)",
[req.user.id, -costCents, deducted.balance_cents, `视频生成扣费 ${credits} 积分`], [req.user.id, -costCents, deducted.balance_cents, `视频生成扣费 ${credits} 积分`],
); );
await client.query(
"UPDATE generation_tasks SET cost_cents = $1, billing_target = 'user', billing_refunded = 0, updated_at = NOW() WHERE id = $2",
[costCents, nextTaskRow.id],
);
return { taskRow: nextTaskRow, reservedBilling: null, regularBilling: { costCents, balanceAfterCents: deducted.balance_cents, credits } }; return { taskRow: nextTaskRow, reservedBilling: null, regularBilling: { costCents, balanceAfterCents: deducted.balance_cents, credits } };
}); });
@@ -1764,35 +1718,6 @@ function registerAiRoutes(router) {
} }
}); });
const streamTaskStatusPoll = async (taskId, userId, emit) => {
const { rows } = await pool.query(
"SELECT * FROM generation_tasks WHERE id = $1 AND user_id = $2",
[taskId, userId],
);
const row = rows[0];
if (!row) return { found: false, terminal: true };
if (row.status === "pending" || row.status === "running") {
pool.query(
"UPDATE generation_tasks SET last_poll_at = NOW() WHERE id = $1",
[taskId],
).catch(() => {});
}
const event = {
taskId: row.id,
status: row.status,
progress: Number(row.progress || 0),
resultUrl: row.result_url || null,
error: row.error || null,
};
emit(event);
return {
found: true,
terminal: ["completed", "failed", "cancelled"].includes(row.status),
};
};
router.get("/ai/tasks/:taskId/stream", requireAuth, async (req, res) => { router.get("/ai/tasks/:taskId/stream", requireAuth, async (req, res) => {
const { taskId } = req.params; const { taskId } = req.params;
try { try {
@@ -1824,43 +1749,16 @@ function registerAiRoutes(router) {
return; return;
} }
let closed = false;
let lastSnapshot = JSON.stringify(initial);
let dbPollTimer = null;
const endStream = () => {
if (closed) return;
closed = true;
if (dbPollTimer) clearInterval(dbPollTimer);
taskEvents.off(`task:${taskId}`, onUpdate);
res.end();
};
const emitIfChanged = (evt) => {
if (closed) return;
const snapshot = JSON.stringify(evt);
if (snapshot === lastSnapshot) return;
lastSnapshot = snapshot;
res.write(`data: ${snapshot}\n\n`);
};
const onUpdate = (evt) => { const onUpdate = (evt) => {
emitIfChanged(evt); res.write(`data: ${JSON.stringify(evt)}\n\n`);
if (["completed", "failed", "cancelled"].includes(evt.status)) { if (["completed", "failed", "cancelled"].includes(evt.status)) {
endStream(); res.end();
} }
}; };
taskEvents.on(`task:${taskId}`, onUpdate); taskEvents.on(`task:${taskId}`, onUpdate);
dbPollTimer = setInterval(() => {
streamTaskStatusPoll(taskId, req.user.id, emitIfChanged)
.then((result) => {
if (!result.found || result.terminal) endStream();
})
.catch((pollErr) => {
console.error(`[ai/task-stream] db poll failed for task ${taskId}:`, pollErr.message);
});
}, 3000);
req.on("close", () => { req.on("close", () => {
endStream(); taskEvents.off(`task:${taskId}`, onUpdate);
}); });
} catch (err) { } catch (err) {
if (!res.headersSent) res.status(err.name === "AbortError" ? 504 : 500).json({ error: err.name === "AbortError" ? "AI 上游响应超时,请重试" : err.message }); if (!res.headersSent) res.status(err.name === "AbortError" ? 504 : 500).json({ error: err.name === "AbortError" ? "AI 上游响应超时,请重试" : err.message });
@@ -1877,7 +1775,6 @@ function registerAiRoutes(router) {
[taskId, req.user.id], [taskId, req.user.id],
); );
if (rows.length === 0) return res.status(404).json({ error: "Task not found or not in active state" }); if (rows.length === 0) return res.status(404).json({ error: "Task not found or not in active state" });
await cancelTaskRuntimeState(taskId, keyManager);
res.json({ id: rows[0].id, status: rows[0].status }); res.json({ id: rows[0].id, status: rows[0].status });
} catch (err) { } catch (err) {
console.error("[ai/task-cancel] error:", err.message); console.error("[ai/task-cancel] error:", err.message);
@@ -2036,12 +1933,7 @@ async function submitImageToProvider(taskDbId, providerConfig, slotResult, param
const { headers, body } = buildImageRequest(providerConfig, params, slotResult.apiKey); const { headers, body } = buildImageRequest(providerConfig, params, slotResult.apiKey);
await updateTaskInDb(taskDbId, { status: "running", progress: 10 }); await updateTaskInDb(taskDbId, { status: "running", progress: 10 });
const defaultSubmitTimeout = providerConfig.transport === "gemini-image" const submitTimeout = providerConfig.transport === "gemini-image" ? GEMINI_IMAGE_SUBMIT_TIMEOUT_MS : IMAGE_PROVIDER_SUBMIT_TIMEOUT_MS;
? GEMINI_IMAGE_SUBMIT_TIMEOUT_MS
: providerConfig.transport === "grsai-image"
? GRSAI_IMAGE_SUBMIT_TIMEOUT_MS
: IMAGE_PROVIDER_SUBMIT_TIMEOUT_MS;
const submitTimeout = getAdaptiveTimeout(providerConfig.provider, defaultSubmitTimeout);
const response = await fetchWithTimeout(url, { method: "POST", headers, body: JSON.stringify(body) }, submitTimeout); const response = await fetchWithTimeout(url, { method: "POST", headers, body: JSON.stringify(body) }, submitTimeout);
if (!response.ok) { if (!response.ok) {
const errText = await response.text().catch(() => "provider error"); const errText = await response.text().catch(() => "provider error");
+8 -139
View File
@@ -1,11 +1,10 @@
"use strict"; "use strict";
const { getUserContextById, requireAuth, verifyToken } = require("../auth"); const { getUserContextById, verifyToken } = require("../auth");
const { pool, withTransaction } = require("../db"); const { pool, withTransaction } = require("../db");
const { loadBetaInviteCodes, normalizeBetaInviteCode } = require("../betaInviteCodes"); const { loadBetaInviteCodes, normalizeBetaInviteCode } = require("../betaInviteCodes");
const REVIEW_USERNAMES = new Set(["xqy1912"]); const REVIEW_USERNAMES = new Set(["xqy1912"]);
const EMAIL_PATTERN = /^[^\s@]+@[^\s@]+\.[^\s@]+$/;
function cleanText(value, maxLength) { function cleanText(value, maxLength) {
return String(value || "").trim().slice(0, maxLength); return String(value || "").trim().slice(0, maxLength);
@@ -16,17 +15,6 @@ function cleanTextArray(value, maxItems = 20, maxLength = 200) {
return value.map((item) => cleanText(item, maxLength)).filter(Boolean).slice(0, maxItems); return value.map((item) => cleanText(item, maxLength)).filter(Boolean).slice(0, maxItems);
} }
function normalizeEmail(email) {
return String(email || "").trim().toLowerCase();
}
function validateEmail(email) {
const normalized = normalizeEmail(email);
if (!normalized) return "请填写用于接收内测码的邮箱";
if (!EMAIL_PATTERN.test(normalized)) return "邮箱格式不正确";
return null;
}
function parseJson(value, fallback) { function parseJson(value, fallback) {
if (!value || typeof value !== "string") return fallback; if (!value || typeof value !== "string") return fallback;
try { try {
@@ -44,27 +32,6 @@ function safeJsonString(value, fallback) {
} }
} }
function buildSmtpTransportOptions(scope) {
const prefix = scope ? `${scope}_` : "";
return {
host: process.env[`${prefix}SMTP_HOST`] || process.env.SMTP_HOST,
port: Number(process.env[`${prefix}SMTP_PORT`] || process.env.SMTP_PORT) || 587,
secure: String(process.env[`${prefix}SMTP_SECURE`] || process.env.SMTP_SECURE || "") === "1",
auth: {
user: process.env[`${prefix}SMTP_USER`] || process.env.SMTP_USER,
pass: process.env[`${prefix}SMTP_PASS`] || process.env.SMTP_PASS,
},
};
}
function formatEmailAddress(address, displayName) {
const email = String(address || "").trim();
const name = String(displayName || "").trim();
if (!name) return email;
const escapedName = name.replace(/"/g, '\\"');
return `"${escapedName}" <${email}>`;
}
function getRequestIp(req) { function getRequestIp(req) {
const forwardedFor = String(req.headers["x-forwarded-for"] || "").split(",")[0].trim(); const forwardedFor = String(req.headers["x-forwarded-for"] || "").split(",")[0].trim();
return forwardedFor || req.socket?.remoteAddress || ""; return forwardedFor || req.socket?.remoteAddress || "";
@@ -107,7 +74,6 @@ async function ensureBetaApplicationSchema() {
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
user_id INTEGER REFERENCES users(id) ON DELETE SET NULL, user_id INTEGER REFERENCES users(id) ON DELETE SET NULL,
name TEXT, name TEXT,
email TEXT,
phone TEXT, phone TEXT,
wechat TEXT, wechat TEXT,
industry TEXT, industry TEXT,
@@ -122,7 +88,6 @@ async function ensureBetaApplicationSchema() {
want_feature_json TEXT NOT NULL DEFAULT '[]', want_feature_json TEXT NOT NULL DEFAULT '[]',
self_statement TEXT, self_statement TEXT,
signature TEXT, signature TEXT,
application_date TEXT,
agree_rules INTEGER NOT NULL DEFAULT 0, agree_rules INTEGER NOT NULL DEFAULT 0,
status TEXT NOT NULL DEFAULT 'pending', status TEXT NOT NULL DEFAULT 'pending',
invite_code TEXT, invite_code TEXT,
@@ -138,19 +103,12 @@ async function ensureBetaApplicationSchema() {
ON beta_applications(status, created_at DESC); ON beta_applications(status, created_at DESC);
CREATE INDEX IF NOT EXISTS idx_beta_applications_user_created CREATE INDEX IF NOT EXISTS idx_beta_applications_user_created
ON beta_applications(user_id, created_at DESC); ON beta_applications(user_id, created_at DESC);
ALTER TABLE beta_applications
ADD COLUMN IF NOT EXISTS email TEXT;
ALTER TABLE beta_applications
ADD COLUMN IF NOT EXISTS application_date TEXT;
CREATE INDEX IF NOT EXISTS idx_beta_applications_email
ON beta_applications(LOWER(email));
`); `);
} }
function normalizeApplicationBody(body) { function normalizeApplicationBody(body) {
return { return {
name: cleanText(body?.name, 120), name: cleanText(body?.name, 120),
email: normalizeEmail(body?.email),
phone: cleanText(body?.phone, 60), phone: cleanText(body?.phone, 60),
wechat: cleanText(body?.wechat, 120), wechat: cleanText(body?.wechat, 120),
industry: cleanText(body?.industry, 160), industry: cleanText(body?.industry, 160),
@@ -165,7 +123,6 @@ function normalizeApplicationBody(body) {
wantFeature: cleanTextArray(body?.wantFeature ?? body?.want_feature), wantFeature: cleanTextArray(body?.wantFeature ?? body?.want_feature),
selfStatement: cleanText(body?.selfStatement ?? body?.self_statement, 5000), selfStatement: cleanText(body?.selfStatement ?? body?.self_statement, 5000),
signature: cleanText(body?.signature, 120), signature: cleanText(body?.signature, 120),
applicationDate: cleanText(body?.applicationDate ?? body?.application_date, 120),
agreeRules: body?.agreeRules === true || body?.agree_rules === true || body?.agreeRules === 1 || body?.agree_rules === 1, agreeRules: body?.agreeRules === true || body?.agree_rules === true || body?.agreeRules === 1 || body?.agree_rules === 1,
}; };
} }
@@ -176,7 +133,6 @@ function formatApplication(row) {
userId: row.user_id == null ? null : Number(row.user_id), userId: row.user_id == null ? null : Number(row.user_id),
username: row.username || null, username: row.username || null,
name: row.name || "", name: row.name || "",
email: row.email || "",
phone: row.phone || "", phone: row.phone || "",
wechat: row.wechat || "", wechat: row.wechat || "",
industry: row.industry || "", industry: row.industry || "",
@@ -191,7 +147,6 @@ function formatApplication(row) {
wantFeature: parseJson(row.want_feature_json, []), wantFeature: parseJson(row.want_feature_json, []),
selfStatement: row.self_statement || "", selfStatement: row.self_statement || "",
signature: row.signature || "", signature: row.signature || "",
applicationDate: row.application_date || "",
agreeRules: Boolean(row.agree_rules), agreeRules: Boolean(row.agree_rules),
status: row.status || "pending", status: row.status || "pending",
inviteCode: row.invite_code || null, inviteCode: row.invite_code || null,
@@ -263,112 +218,29 @@ async function createNotification(client, userId, input) {
); );
} }
function buildReviewEmailContent(application, action, inviteCode, reviewNote) {
const name = application.name || "内测申请人";
if (action === "approve") {
const text = [
`${name},您好:`,
"",
"您的 OmniAI 内测申请已通过。",
`内测码:${inviteCode}`,
"",
"请在注册页面填写该内测码完成账号注册。内测码仅限本人使用,请勿转发。",
"",
"OmniAI 团队",
].join("\n");
const html = `
<div style="font-family:Arial,'Microsoft YaHei',sans-serif;max-width:560px;margin:0 auto;padding:24px;color:#222">
<h2 style="margin:0 0 16px;color:#166534">OmniAI 内测申请已通过</h2>
<p>${name},您好:</p>
<p>您的 OmniAI 内测申请已通过。</p>
<p style="padding:14px 16px;background:#f0fdf4;border:1px solid #bbf7d0;border-radius:8px;font-size:20px;font-weight:700;letter-spacing:1px;color:#166534">内测码:${inviteCode}</p>
<p>请在注册页面填写该内测码完成账号注册。内测码仅限本人使用,请勿转发。</p>
<p style="margin-top:24px;color:#666">OmniAI 团队</p>
</div>
`;
return { subject: "[OmniAI] 内测申请已通过", text, html };
}
const reason = reviewNote || "很遗憾,您的内测申请暂未通过。";
const text = [
`${name},您好:`,
"",
"您未通过 OmniAI 内测申请。",
`审核备注:${reason}`,
"",
"感谢您的关注。",
"",
"OmniAI 团队",
].join("\n");
const html = `
<div style="font-family:Arial,'Microsoft YaHei',sans-serif;max-width:560px;margin:0 auto;padding:24px;color:#222">
<h2 style="margin:0 0 16px;color:#991b1b">OmniAI 内测申请未通过</h2>
<p>${name},您好:</p>
<p>您未通过 OmniAI 内测申请。</p>
<p style="padding:12px 14px;background:#fef2f2;border:1px solid #fecaca;border-radius:8px;color:#7f1d1d">审核备注:${reason}</p>
<p>感谢您的关注。</p>
<p style="margin-top:24px;color:#666">OmniAI 团队</p>
</div>
`;
return { subject: "[OmniAI] 内测申请未通过", text, html };
}
async function sendBetaApplicationReviewEmail(application, action, inviteCode, reviewNote) {
const email = normalizeEmail(application.email);
const emailError = validateEmail(email);
if (emailError) {
const err = new Error(`申请邮箱无效,无法发送审核结果:${emailError}`);
err.status = 409;
throw err;
}
const provider = String(process.env.EMAIL_PROVIDER || "mock").trim().toLowerCase();
const content = buildReviewEmailContent(application, action, inviteCode, reviewNote);
if (provider === "smtp") {
const nodemailer = require("nodemailer");
const smtpOptions = buildSmtpTransportOptions("BETA");
const transporter = nodemailer.createTransport(smtpOptions);
const fromAddress = process.env.BETA_SMTP_FROM || process.env.SMTP_FROM || smtpOptions.auth.user;
const fromName = process.env.BETA_SMTP_FROM_NAME || process.env.SMTP_FROM_NAME || "万物可爱";
await transporter.sendMail({
from: formatEmailAddress(fromAddress, fromName),
to: email,
subject: content.subject,
text: content.text,
html: content.html,
});
return { provider: "smtp" };
}
console.log(`[beta-application-email:${action}] ${email} ${content.subject}`);
return { provider: "mock" };
}
function registerBetaApplicationRoutes(router) { function registerBetaApplicationRoutes(router) {
router.post("/beta-applications", optionalAuth, async (req, res) => { router.post("/beta-applications", optionalAuth, async (req, res) => {
try { try {
await ensureBetaApplicationSchema(); await ensureBetaApplicationSchema();
const app = normalizeApplicationBody(req.body); const app = normalizeApplicationBody(req.body);
const emailError = validateEmail(app.email); if (!app.name || !app.phone || !app.wechat || !app.selfStatement || !app.signature || !app.agreeRules) {
if (!app.name || emailError || !app.phone || !app.wechat || !app.selfStatement || !app.signature || !app.applicationDate || !app.agreeRules) { return res.status(400).json({ error: "请填写姓名、手机号、微信、申请自述、签名并同意内测规则" });
return res.status(400).json({ error: emailError || "请填写姓名、手机号、微信、申请自述、签名、申请日期并同意内测规则" });
} }
const { rows } = await pool.query( const { rows } = await pool.query(
` `
INSERT INTO beta_applications ( INSERT INTO beta_applications (
user_id, name, email, phone, wechat, industry, company, city, user_id, name, phone, wechat, industry, company, city,
ai_tools, ai_duration, ai_track, ai_direction_json, ai_tools, ai_duration, ai_track, ai_direction_json,
weekly_usage, feedback_willing, want_feature_json, weekly_usage, feedback_willing, want_feature_json,
self_statement, signature, application_date, agree_rules, ip_address, user_agent self_statement, signature, agree_rules, ip_address, user_agent
) )
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
RETURNING id, status, created_at RETURNING id, status, created_at
`, `,
[ [
req.user?.id || null, req.user?.id || null,
app.name, app.name,
app.email,
app.phone, app.phone,
app.wechat, app.wechat,
app.industry || null, app.industry || null,
@@ -383,7 +255,6 @@ function registerBetaApplicationRoutes(router) {
safeJsonString(app.wantFeature, []), safeJsonString(app.wantFeature, []),
app.selfStatement, app.selfStatement,
app.signature, app.signature,
app.applicationDate,
app.agreeRules ? 1 : 0, app.agreeRules ? 1 : 0,
getRequestIp(req), getRequestIp(req),
cleanText(req.headers["user-agent"], 1000) || null, cleanText(req.headers["user-agent"], 1000) || null,
@@ -403,7 +274,7 @@ function registerBetaApplicationRoutes(router) {
} }
}); });
router.get("/admin/beta-applications", requireAuth, requireBetaApplicationReviewer, async (req, res) => { router.get("/admin/beta-applications", requireBetaApplicationReviewer, async (req, res) => {
try { try {
await ensureBetaApplicationSchema(); await ensureBetaApplicationSchema();
const status = cleanText(req.query.status, 32); const status = cleanText(req.query.status, 32);
@@ -434,7 +305,7 @@ function registerBetaApplicationRoutes(router) {
} }
}); });
router.patch("/admin/beta-applications/:id", requireAuth, requireBetaApplicationReviewer, async (req, res) => { router.patch("/admin/beta-applications/:id", requireBetaApplicationReviewer, async (req, res) => {
const id = Number(req.params.id); const id = Number(req.params.id);
const action = cleanText(req.body?.action, 32); const action = cleanText(req.body?.action, 32);
const reviewNote = cleanText(req.body?.reviewNote ?? req.body?.review_note, 1000) || null; const reviewNote = cleanText(req.body?.reviewNote ?? req.body?.review_note, 1000) || null;
@@ -482,8 +353,6 @@ function registerBetaApplicationRoutes(router) {
); );
const updated = rows[0]; const updated = rows[0];
await sendBetaApplicationReviewEmail(updated, action, inviteCode, reviewNote);
if (updated.user_id) { if (updated.user_id) {
if (action === "approve") { if (action === "approve") {
await createNotification(client, updated.user_id, { await createNotification(client, updated.user_id, {
-96
View File
@@ -1,96 +0,0 @@
const express = require("express");
const { requireAuth, requireAdmin } = require("../auth");
const { pool } = require("../db");
const { creditUserBalance } = require("../billing");
const router = express.Router();
router.post("/bug-feedback", requireAuth, async (req, res) => {
const userId = req.user.id;
const { title, description, screenshotUrl } = req.body;
if (!title || String(title).trim().length === 0) return res.status(400).json({ error: "标题不能为空" });
if (!description || String(description).trim().length === 0) return res.status(400).json({ error: "描述不能为空" });
if (String(title).length > 200) return res.status(400).json({ error: "标题不能超过200字" });
if (String(description).length > 5000) return res.status(400).json({ error: "描述不能超过5000字" });
try {
const result = await pool.query(
`INSERT INTO bug_feedback (user_id, title, description, screenshot_url) VALUES ($1, $2, $3, $4) RETURNING id, status, created_at`,
[userId, String(title).trim(), String(description).trim(), screenshotUrl || null]
);
res.json({ feedback: { id: result.rows[0].id, status: result.rows[0].status, createdAt: result.rows[0].created_at } });
} catch (err) {
console.error("[bug-feedback] submit failed:", err.message);
res.status(500).json({ error: "提交失败,请稍后重试" });
}
});
router.get("/bug-feedback/mine", requireAuth, async (req, res) => {
const userId = req.user.id;
try {
const result = await pool.query(
`SELECT id, title, description, screenshot_url, status, admin_note, created_at FROM bug_feedback WHERE user_id = $1 ORDER BY created_at DESC LIMIT 50`,
[userId]
);
res.json({ feedbacks: result.rows.map(r => ({ id: r.id, title: r.title, description: r.description, screenshotUrl: r.screenshot_url, status: r.status, adminNote: r.admin_note, createdAt: r.created_at })) });
} catch (err) {
console.error("[bug-feedback] list mine failed:", err.message);
res.status(500).json({ error: "获取反馈列表失败" });
}
});
router.get("/admin/bug-feedback", requireAuth, requireAdmin, async (req, res) => {
const status = req.query.status || null;
const limit = Math.min(Number(req.query.limit) || 20, 100);
const offset = Number(req.query.offset) || 0;
try {
const where = status ? "WHERE bf.status = $1" : "";
const params = status ? [status, limit, offset] : [limit, offset];
const countWhere = status ? "WHERE status = $1" : "";
const countParams = status ? [status] : [];
const [dataRes, countRes] = await Promise.all([
pool.query(`SELECT bf.id, bf.title, bf.description, bf.screenshot_url, bf.status, bf.admin_note, bf.reward_credited, bf.created_at, u.username FROM bug_feedback bf JOIN users u ON u.id = bf.user_id ${where} ORDER BY bf.created_at DESC LIMIT $${status ? 2 : 1} OFFSET $${status ? 3 : 2}`, params),
pool.query(`SELECT COUNT(*)::int AS total FROM bug_feedback ${countWhere}`, countParams),
]);
res.json({
feedbacks: dataRes.rows.map(r => ({ id: r.id, title: r.title, description: r.description, screenshotUrl: r.screenshot_url, status: r.status, adminNote: r.admin_note, rewardCredited: r.reward_credited, username: r.username, createdAt: r.created_at })),
total: countRes.rows[0].total,
});
} catch (err) {
console.error("[admin/bug-feedback] list failed:", err.message);
res.status(500).json({ error: "获取反馈列表失败" });
}
});
router.patch("/admin/bug-feedback/:id", requireAuth, requireAdmin, async (req, res) => {
const feedbackId = Number(req.params.id);
const { status, adminNote } = req.body;
if (!["approved", "rejected"].includes(status)) return res.status(400).json({ error: "状态只能是 approved 或 rejected" });
const client = await pool.connect();
try {
await client.query("BEGIN");
const existing = await client.query("SELECT id, user_id, status, reward_credited FROM bug_feedback WHERE id = $1 FOR UPDATE", [feedbackId]);
if (existing.rows.length === 0) { await client.query("ROLLBACK"); return res.status(404).json({ error: "反馈不存在" }); }
const row = existing.rows[0];
await client.query("UPDATE bug_feedback SET status = $1, admin_note = $2, updated_at = NOW() WHERE id = $3", [status, adminNote || null, feedbackId]);
let rewardCredited = row.reward_credited;
if (status === "approved" && !row.reward_credited) {
await creditUserBalance(row.user_id, 100, "Bug反馈奖励 1 积分");
await client.query("UPDATE bug_feedback SET reward_credited = TRUE WHERE id = $1", [feedbackId]);
rewardCredited = true;
}
await client.query("COMMIT");
res.json({ success: true, rewardCredited });
} catch (err) {
await client.query("ROLLBACK");
console.error("[admin/bug-feedback] patch failed:", err.message);
res.status(500).json({ error: "操作失败" });
} finally {
client.release();
}
});
module.exports = router;
+14 -109
View File
@@ -212,123 +212,28 @@ function hashEmailCode(email, code) {
return crypto.createHash("sha256").update(email + ":" + code + ":" + secret).digest("hex"); return crypto.createHash("sha256").update(email + ":" + code + ":" + secret).digest("hex");
} }
function buildSmtpTransportOptions(scope) {
const prefix = scope ? `${scope}_` : "";
return {
host: process.env[`${prefix}SMTP_HOST`] || process.env.SMTP_HOST,
port: Number(process.env[`${prefix}SMTP_PORT`] || process.env.SMTP_PORT) || 587,
secure: String(process.env[`${prefix}SMTP_SECURE`] || process.env.SMTP_SECURE || "") === "1",
auth: {
user: process.env[`${prefix}SMTP_USER`] || process.env.SMTP_USER,
pass: process.env[`${prefix}SMTP_PASS`] || process.env.SMTP_PASS,
},
};
}
function formatEmailAddress(address, displayName) {
const email = String(address || "").trim();
const name = String(displayName || "").trim();
if (!name) return email;
const escapedName = name.replace(/"/g, '\\"');
return `"${escapedName}" <${email}>`;
}
function escapeEmailHtml(value) {
return String(value || "")
.replace(/&/g, "&amp;")
.replace(/</g, "&lt;")
.replace(/>/g, "&gt;")
.replace(/"/g, "&quot;");
}
function buildEmailCodeContent(code, purpose) {
const purposeText = purpose === "register" ? "注册" : purpose === "login" ? "登录" : "重置密码";
const ttlText = String(EMAIL_CODE_TTL_MINUTES);
const safeCode = escapeEmailHtml(code);
const safePurposeText = escapeEmailHtml(purposeText);
const preheader = `您的 OmniAI ${purposeText}验证码是 ${code}${ttlText} 分钟内有效。`;
return {
subject: "[OmniAI] 邮箱验证码",
text:
`您的验证码是:${code}\n` +
`用途:${purposeText}\n` +
`有效期:${ttlText} 分钟\n` +
"请勿将验证码转发给他人。如非本人操作,请忽略此邮件。",
html: `<!doctype html>
<html lang="zh-CN">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>OmniAI 邮箱验证码</title>
</head>
<body style="margin:0;padding:0;background:#f4f7fb;color:#1f2937;font-family:-apple-system,BlinkMacSystemFont,'Segoe UI','PingFang SC','Microsoft YaHei',Arial,sans-serif;">
<div style="display:none;max-height:0;overflow:hidden;opacity:0;color:transparent;">${escapeEmailHtml(preheader)}</div>
<table role="presentation" width="100%" cellspacing="0" cellpadding="0" style="width:100%;background:#f4f7fb;margin:0;padding:28px 12px;">
<tr>
<td align="center">
<table role="presentation" width="100%" cellspacing="0" cellpadding="0" style="width:100%;max-width:560px;background:#ffffff;border-radius:16px;overflow:hidden;border:1px solid #e5ebf3;box-shadow:0 18px 45px rgba(31,41,55,0.08);">
<tr>
<td style="padding:28px 28px 20px;background:#101827;color:#ffffff;">
<div style="font-size:13px;letter-spacing:2px;text-transform:uppercase;color:#a7f3d0;font-weight:700;">OmniAI</div>
<h1 style="margin:10px 0 0;font-size:24px;line-height:1.35;font-weight:800;color:#ffffff;">万物可爱邮箱验证</h1>
<p style="margin:10px 0 0;font-size:14px;line-height:1.8;color:#cbd5e1;">请使用下方验证码完成${safePurposeText}操作。</p>
</td>
</tr>
<tr>
<td style="padding:28px;">
<div style="border:1px solid #dbe6f4;background:#f8fbff;border-radius:14px;padding:22px 18px;text-align:center;">
<div style="font-size:13px;color:#64748b;margin-bottom:10px;">验证码</div>
<div style="font-size:38px;line-height:1.2;letter-spacing:8px;font-weight:800;color:#0f766e;font-family:'SFMono-Regular',Consolas,'Liberation Mono',monospace;">${safeCode}</div>
<div style="font-size:13px;color:#64748b;margin-top:14px;">${ttlText} 分钟内有效</div>
</div>
<table role="presentation" width="100%" cellspacing="0" cellpadding="0" style="margin-top:22px;border-collapse:collapse;">
<tr>
<td style="padding:12px 0;border-bottom:1px solid #edf2f7;color:#64748b;font-size:14px;">用途</td>
<td align="right" style="padding:12px 0;border-bottom:1px solid #edf2f7;color:#111827;font-size:14px;font-weight:700;">${safePurposeText}</td>
</tr>
<tr>
<td style="padding:12px 0;border-bottom:1px solid #edf2f7;color:#64748b;font-size:14px;">有效期</td>
<td align="right" style="padding:12px 0;border-bottom:1px solid #edf2f7;color:#111827;font-size:14px;font-weight:700;">${ttlText} 分钟</td>
</tr>
</table>
<div style="margin-top:22px;padding:14px 16px;border-radius:12px;background:#fff7ed;border:1px solid #fed7aa;color:#9a3412;font-size:13px;line-height:1.8;">
请勿将验证码转发给他人。万物可爱工作人员不会向您索要邮箱验证码。
</div>
<p style="margin:22px 0 0;color:#64748b;font-size:13px;line-height:1.8;">如果不是您本人操作,可以直接忽略此邮件。</p>
</td>
</tr>
<tr>
<td style="padding:18px 28px;background:#f8fafc;border-top:1px solid #edf2f7;color:#94a3b8;font-size:12px;line-height:1.7;text-align:center;">
此邮件由系统自动发送,请勿直接回复。<br>OmniAI · 万物可爱
</td>
</tr>
</table>
</td>
</tr>
</table>
</body>
</html>`,
};
}
async function sendEmailCode(email, code, purpose) { async function sendEmailCode(email, code, purpose) {
const provider = String(process.env.EMAIL_PROVIDER || "mock").trim().toLowerCase(); const provider = String(process.env.EMAIL_PROVIDER || "mock").trim().toLowerCase();
if (provider === "smtp") { if (provider === "smtp") {
const nodemailer = require("nodemailer"); const nodemailer = require("nodemailer");
const smtpOptions = buildSmtpTransportOptions("SYSTEM"); const transporter = nodemailer.createTransport({
const transporter = nodemailer.createTransport(smtpOptions); host: process.env.SMTP_HOST,
const fromAddress = process.env.SYSTEM_SMTP_FROM || process.env.SMTP_FROM || smtpOptions.auth.user; port: Number(process.env.SMTP_PORT) || 587,
const fromName = process.env.SYSTEM_SMTP_FROM_NAME || process.env.SMTP_FROM_NAME || "万物可爱"; secure: process.env.SMTP_SECURE === "1",
auth: {
user: process.env.SMTP_USER,
pass: process.env.SMTP_PASS,
},
});
const content = buildEmailCodeContent(code, purpose); const purposeText = purpose === "register" ? "注册" : purpose === "login" ? "登录" : "重置密码";
await transporter.sendMail({ await transporter.sendMail({
from: formatEmailAddress(fromAddress, fromName), from: process.env.SMTP_FROM || process.env.SMTP_USER,
to: email, to: email,
subject: content.subject, subject: "[OmniAI] \u90ae\u7bb1\u9a8c\u8bc1\u7801",
text: content.text, text: "\u60a8\u7684\u9a8c\u8bc1\u7801\u662f\uff1a" + code + "\n\u7528\u9014\uff1a" + purposeText + "\n\u6709\u6548\u671f\uff1a" + String(process.env.EMAIL_CODE_TTL_MINUTES || 10) + " \u5206\u949f\n\u5982\u679c\u4e0d\u662f\u60a8\u672c\u4eba\u64cd\u4f5c\uff0c\u8bf7\u5ffd\u7565\u6b64\u90ae\u4ef6\u3002",
html: content.html, html: "<div style=\"font-family:sans-serif;max-width:480px;margin:0 auto;padding:24px\"><h2 style=\"color:#333\">OmniAI \u90ae\u7bb1\u9a8c\u8bc1</h2><p style=\"font-size:16px;color:#555\">\u60a8\u7684\u9a8c\u8bc1\u7801\u662f\uff1a</p><p style=\"font-size:32px;font-weight:bold;letter-spacing:6px;color:#1677ff;margin:16px 0\">" + code + "</p><p style=\"color:#888\">\u7528\u9014\uff1a" + purposeText + "</p><p style=\"color:#888\">\u6709\u6548\u671f\uff1a" + String(process.env.EMAIL_CODE_TTL_MINUTES || 10) + " \u5206\u949f</p><hr style=\"border:none;border-top:1px solid #eee;margin:24px 0\"><p style=\"color:#aaa;font-size:13px\">\u5982\u679c\u4e0d\u662f\u60a8\u672c\u4eba\u64cd\u4f5c\uff0c\u8bf7\u5ffd\u7565\u6b64\u90ae\u4ef6\u3002</p></div>",
}); });
return { provider: "smtp" }; return { provider: "smtp" };
} }
-2
View File
@@ -21,7 +21,6 @@ const { registerBetaApplicationRoutes } = require('./betaApplications')
const { registerDraftRoutes } = require('./drafts'); const { registerDraftRoutes } = require('./drafts');
const { registerFileExtractRoutes } = require('./fileExtract'); const { registerFileExtractRoutes } = require('./fileExtract');
const mountClientErrorRoutes = require('./clientErrors') const mountClientErrorRoutes = require('./clientErrors')
const bugFeedbackRouter = require("./bugFeedback")
const router = express.Router() const router = express.Router()
@@ -53,7 +52,6 @@ registerNotificationRoutes(router)
registerBetaApplicationRoutes(router) registerBetaApplicationRoutes(router)
registerDraftRoutes(router) registerDraftRoutes(router)
registerFileExtractRoutes(router) registerFileExtractRoutes(router)
router.use(bugFeedbackRouter)
registerHealthRoutes(router) registerHealthRoutes(router)
module.exports = router module.exports = router
+1 -6
View File
@@ -1,16 +1,11 @@
const { keyManager, listModelPrices, pool } = require("./context"); const { keyManager, listModelPrices, pool } = require("./context");
const { getEnterpriseVideoPricingConfig } = require("../enterpriseVideoBilling");
function registerPriceRoutes(router) { function registerPriceRoutes(router) {
// ── Public ─────────────────────────────────────────────────────────── // ── Public ───────────────────────────────────────────────────────────
router.get("/prices", async (_req, res) => { router.get("/prices", async (_req, res) => {
const prices = await listModelPrices({ enabledOnly: true }); const prices = await listModelPrices({ enabledOnly: true });
res.json({ res.json(prices);
prices,
modelPrices: prices,
enterpriseVideoPricing: getEnterpriseVideoPricingConfig(),
});
}); });
} }
+3 -3
View File
@@ -137,7 +137,7 @@ function registerUserRoutes(router) {
WHEN billing_refunded = 1 THEN 0 WHEN billing_refunded = 1 THEN 0
WHEN cost_cents > 0 THEN cost_cents WHEN cost_cents > 0 THEN cost_cents
WHEN status = 'completed' AND type = 'image' THEN 2000 WHEN status = 'completed' AND type = 'image' THEN 2000
WHEN status = 'completed' AND type = 'video' THEN 50000 WHEN status = 'completed' AND type = 'video' THEN 500
ELSE 0 ELSE 0
END END
), 0) AS used_cents ), 0) AS used_cents
@@ -172,7 +172,7 @@ function registerUserRoutes(router) {
else if (model.includes("wan2.7-i2v") || model.includes("wanxiang")) rate = res === "720P" ? 0.6 : 1; else if (model.includes("wan2.7-i2v") || model.includes("wanxiang")) rate = res === "720P" ? 0.6 : 1;
else if (model.includes("animate-mix") || model.includes("s2v")) rate = res === "720P" ? 0.6 : 1; else if (model.includes("animate-mix") || model.includes("s2v")) rate = res === "720P" ? 0.6 : 1;
else if (model.includes("kling")) rate = res === "720P" ? 0.6 : 0.8; else if (model.includes("kling")) rate = res === "720P" ? 0.6 : 0.8;
estimatedCents = Math.ceil(rate * dur * 10000); estimatedCents = Math.ceil(rate * dur * 100);
} }
} }
} catch { } catch {
@@ -210,7 +210,7 @@ function registerUserRoutes(router) {
WHEN billing_refunded = 1 THEN 0 WHEN billing_refunded = 1 THEN 0
WHEN cost_cents > 0 THEN cost_cents WHEN cost_cents > 0 THEN cost_cents
WHEN status = 'completed' AND type = 'image' THEN 2000 WHEN status = 'completed' AND type = 'image' THEN 2000
WHEN status = 'completed' AND type = 'video' THEN 50000 WHEN status = 'completed' AND type = 'video' THEN 500
ELSE 0 ELSE 0
END END
), 0) AS used_cents ), 0) AS used_cents