Initial commit: OmniAI backend server
This commit is contained in:
@@ -0,0 +1,793 @@
|
||||
const express = require("express");
|
||||
const bcrypt = require("bcryptjs");
|
||||
const {
|
||||
requireAuth,
|
||||
requireAdmin,
|
||||
requireEnterpriseAdmin,
|
||||
requireManagementAccess,
|
||||
login,
|
||||
generateToken,
|
||||
startUserSession,
|
||||
getUserContextById,
|
||||
isSystemAdmin,
|
||||
generateUniqueEnterpriseCode,
|
||||
} = require("../auth");
|
||||
const keyManager = require("../keyManager");
|
||||
const {
|
||||
calculateCost,
|
||||
calculateCostMills,
|
||||
listModelPrices,
|
||||
normalizeModelPriceRow,
|
||||
getAverageCostCents,
|
||||
loadPriceCache,
|
||||
} = require("../pricing");
|
||||
const {
|
||||
deductForApiCall,
|
||||
deductImageGenerationCredits,
|
||||
creditBalance,
|
||||
creditUserBalance,
|
||||
activatePackage,
|
||||
distributeCredits,
|
||||
getEnterpriseFinancials,
|
||||
getUserEnterpriseId,
|
||||
getEnterpriseName,
|
||||
preauthorizeCall,
|
||||
} = require("../billing");
|
||||
const wechatPay = require("../paymentWechat");
|
||||
const alipay = require("../paymentAlipay");
|
||||
const crypto = require("node:crypto");
|
||||
const { pool, withTransaction } = require("../db");
|
||||
const {
|
||||
computeNextRevision,
|
||||
normalizeRevisionValue,
|
||||
shouldRejectStaleRevision,
|
||||
} = require("../projectRevisionLogic");
|
||||
const { loadBetaInviteCodes } = require("../betaInviteCodes");
|
||||
|
||||
const USERNAME_PATTERN = /^[a-zA-Z0-9_\u4e00-\u9fa5]+$/;
|
||||
const PRICE_CATEGORIES = new Set(["text", "image", "video"]);
|
||||
const PRICE_TYPES = new Set(["token", "flat"]);
|
||||
const PHONE_PATTERN = /^\+?[0-9]{6,20}$/;
|
||||
const EMAIL_PATTERN = /^[^\s@]+@[^\s@]+\.[^\s@]+$/;
|
||||
const SMS_PURPOSES = new Set(["register", "login"]);
|
||||
const SMS_CODE_TTL_MINUTES = Math.max(1, Number(process.env.SMS_CODE_TTL_MINUTES) || 10);
|
||||
const SMS_CODE_COOLDOWN_SECONDS = Math.max(10, Number(process.env.SMS_CODE_COOLDOWN_SECONDS) || 60);
|
||||
const SMS_CODE_MAX_ATTEMPTS = Math.max(1, Number(process.env.SMS_CODE_MAX_ATTEMPTS) || 5);
|
||||
|
||||
function validateUsername(username) {
|
||||
if (!username) return "缺少用户名";
|
||||
if (username.length < 2 || username.length > 30) return "用户名长度必须在 2 到 30 之间";
|
||||
if (!USERNAME_PATTERN.test(username)) return "用户名只能包含字母、数字、下划线或中文";
|
||||
return null;
|
||||
}
|
||||
|
||||
function validatePassword(password) {
|
||||
if (!password) return "缺少密码";
|
||||
if (password.length < 6) return "密码至少 6 位";
|
||||
return null;
|
||||
}
|
||||
|
||||
function normalizePhone(phone) {
|
||||
return String(phone || "")
|
||||
.trim()
|
||||
.replace(/[\s-]/g, "");
|
||||
}
|
||||
|
||||
function validatePhone(phone) {
|
||||
const normalized = normalizePhone(phone);
|
||||
if (!normalized) return "缺少手机号";
|
||||
if (!PHONE_PATTERN.test(normalized)) return "手机号格式不正确";
|
||||
return null;
|
||||
}
|
||||
|
||||
function normalizeEmail(email) {
|
||||
return String(email || "").trim().toLowerCase();
|
||||
}
|
||||
|
||||
function validateEmail(email) {
|
||||
const normalized = normalizeEmail(email);
|
||||
if (!normalized) return "缺少邮箱";
|
||||
if (normalized.length > 200 || !EMAIL_PATTERN.test(normalized)) return "邮箱格式不正确";
|
||||
return null;
|
||||
}
|
||||
|
||||
function hashSmsCode(phone, code) {
|
||||
const secret = process.env.SMS_CODE_SECRET || process.env.JWT_SECRET || "omniai-dev-sms-secret";
|
||||
return crypto.createHash("sha256").update(`${phone}:${code}:${secret}`).digest("hex");
|
||||
}
|
||||
|
||||
function generateSmsCode() {
|
||||
return String(Math.floor(100000 + Math.random() * 900000));
|
||||
}
|
||||
|
||||
async function sendSmsCode(phone, code, purpose) {
|
||||
const provider = String(process.env.SMS_PROVIDER || "mock")
|
||||
.trim()
|
||||
.toLowerCase();
|
||||
if (provider === "http") {
|
||||
const endpoint = process.env.SMS_HTTP_ENDPOINT;
|
||||
if (!endpoint) throw new Error("SMS_HTTP_ENDPOINT 未配置");
|
||||
|
||||
const response = await fetch(endpoint, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
...(process.env.SMS_HTTP_TOKEN
|
||||
? { Authorization: `Bearer ${process.env.SMS_HTTP_TOKEN}` }
|
||||
: {}),
|
||||
},
|
||||
body: JSON.stringify({ phone, code, purpose }),
|
||||
});
|
||||
if (!response.ok) {
|
||||
throw new Error(`短信平台返回 HTTP ${response.status}`);
|
||||
}
|
||||
return { provider };
|
||||
}
|
||||
|
||||
console.log(`[sms:${purpose}] ${phone} verification sent (mock provider)`);
|
||||
return {
|
||||
provider: "mock",
|
||||
devCode: process.env.SMS_DEV_RETURN_CODE === "1" ? code : undefined,
|
||||
};
|
||||
}
|
||||
|
||||
async function createLoginResultForUserId(userId, req) {
|
||||
const user = await getUserContextById(userId);
|
||||
if (!user?.enabled) return null;
|
||||
const userAgent = req?.headers?.["user-agent"] || null;
|
||||
const sessionId = await startUserSession(user.id, userAgent);
|
||||
const userWithSession = {
|
||||
...user,
|
||||
sessionId,
|
||||
sessionStartedAt: new Date().toISOString(),
|
||||
};
|
||||
return {
|
||||
token: generateToken(userWithSession, sessionId),
|
||||
user: userWithSession,
|
||||
};
|
||||
}
|
||||
|
||||
function sanitizeUsernameSeed(seed, fallback) {
|
||||
const normalized = String(seed || "")
|
||||
.trim()
|
||||
.replace(/[^\w\u4e00-\u9fa5]/g, "_")
|
||||
.replace(/_+/g, "_")
|
||||
.replace(/^_+|_+$/g, "");
|
||||
const safe = normalized || fallback;
|
||||
return safe.length > 24 ? safe.slice(0, 24) : safe;
|
||||
}
|
||||
|
||||
async function generateUniqueUsername(seed, fallback) {
|
||||
const base = sanitizeUsernameSeed(seed, fallback);
|
||||
for (let attempt = 0; attempt < 10; attempt++) {
|
||||
const suffix = crypto.randomBytes(3).toString("hex");
|
||||
const username = `${base}_${suffix}`.slice(0, 30);
|
||||
const { rows } = await pool.query("SELECT 1 FROM users WHERE username = $1", [username]);
|
||||
if (rows.length === 0) return username;
|
||||
}
|
||||
return `${fallback}_${Date.now().toString(36)}`.slice(0, 30);
|
||||
}
|
||||
|
||||
async function consumeSmsCode(phone, code, purpose) {
|
||||
const { rows } = await pool.query(
|
||||
`
|
||||
SELECT id, code_hash, attempts
|
||||
FROM sms_verification_codes
|
||||
WHERE phone = $1
|
||||
AND purpose = $2
|
||||
AND consumed_at IS NULL
|
||||
AND expires_at > NOW()
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
`,
|
||||
[phone, purpose],
|
||||
);
|
||||
|
||||
const row = rows[0];
|
||||
if (!row) return false;
|
||||
|
||||
if (Number(row.attempts || 0) >= SMS_CODE_MAX_ATTEMPTS) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const expectedHash = hashSmsCode(phone, String(code || "").trim());
|
||||
if (row.code_hash !== expectedHash) {
|
||||
await pool.query("UPDATE sms_verification_codes SET attempts = attempts + 1 WHERE id = $1", [
|
||||
row.id,
|
||||
]);
|
||||
return false;
|
||||
}
|
||||
|
||||
await pool.query("UPDATE sms_verification_codes SET consumed_at = NOW() WHERE id = $1", [row.id]);
|
||||
return true;
|
||||
}
|
||||
|
||||
function getWechatLoginConfig() {
|
||||
const appId = process.env.WECHAT_LOGIN_APP_ID || process.env.WECHAT_APP_ID || "";
|
||||
const appSecret = process.env.WECHAT_LOGIN_APP_SECRET || process.env.WECHAT_APP_SECRET || "";
|
||||
const redirectUri = process.env.WECHAT_LOGIN_REDIRECT_URI || "";
|
||||
return { appId, appSecret, redirectUri };
|
||||
}
|
||||
|
||||
async function fetchWechatJson(url) {
|
||||
const response = await fetch(url);
|
||||
const payload = await response.json();
|
||||
if (!response.ok || payload.errcode) {
|
||||
throw new Error(payload.errmsg || `微信接口返回 HTTP ${response.status}`);
|
||||
}
|
||||
return payload;
|
||||
}
|
||||
|
||||
async function exchangeWechatCode(code) {
|
||||
const { appId, appSecret } = getWechatLoginConfig();
|
||||
if (!appId || !appSecret) {
|
||||
throw new Error("微信开放平台 AppID/AppSecret 未配置");
|
||||
}
|
||||
|
||||
const tokenUrl = new URL("https://api.weixin.qq.com/sns/oauth2/access_token");
|
||||
tokenUrl.searchParams.set("appid", appId);
|
||||
tokenUrl.searchParams.set("secret", appSecret);
|
||||
tokenUrl.searchParams.set("code", code);
|
||||
tokenUrl.searchParams.set("grant_type", "authorization_code");
|
||||
|
||||
const tokenPayload = await fetchWechatJson(tokenUrl.toString());
|
||||
const accessToken = tokenPayload.access_token;
|
||||
const openid = tokenPayload.openid;
|
||||
if (!accessToken || !openid) {
|
||||
throw new Error("微信登录未返回 openid");
|
||||
}
|
||||
|
||||
let profile = {};
|
||||
try {
|
||||
const userInfoUrl = new URL("https://api.weixin.qq.com/sns/userinfo");
|
||||
userInfoUrl.searchParams.set("access_token", accessToken);
|
||||
userInfoUrl.searchParams.set("openid", openid);
|
||||
userInfoUrl.searchParams.set("lang", "zh_CN");
|
||||
profile = await fetchWechatJson(userInfoUrl.toString());
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
"[auth/wechat] userinfo failed",
|
||||
error instanceof Error ? error.message : String(error),
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
openid,
|
||||
unionid: profile.unionid || tokenPayload.unionid || null,
|
||||
nickname: profile.nickname || null,
|
||||
};
|
||||
}
|
||||
|
||||
async function findOrCreateWechatUser(wechatUser) {
|
||||
const { rows: existingRows } = await pool.query(
|
||||
"SELECT id, enabled FROM users WHERE wechat_openid = $1 LIMIT 1",
|
||||
[wechatUser.openid],
|
||||
);
|
||||
if (existingRows.length > 0) {
|
||||
if (!existingRows[0].enabled) {
|
||||
const error = new Error("账号已禁用");
|
||||
error.status = 403;
|
||||
throw error;
|
||||
}
|
||||
return existingRows[0].id;
|
||||
}
|
||||
|
||||
if (loadBetaInviteCodes().size > 0) {
|
||||
const error = new Error("内测阶段请先使用内测码注册账号后再使用微信登录");
|
||||
error.status = 403;
|
||||
throw error;
|
||||
}
|
||||
|
||||
const username = await generateUniqueUsername(
|
||||
wechatUser.nickname || `wx${wechatUser.openid.slice(-6)}`,
|
||||
"wechat",
|
||||
);
|
||||
const randomPasswordHash = await bcrypt.hash(crypto.randomBytes(32).toString("hex"), 10);
|
||||
const { rows } = await pool.query(
|
||||
`
|
||||
INSERT INTO users (username, password_hash, wechat_openid, wechat_unionid, auth_provider, role, max_concurrency, enterprise_id, is_enterprise_admin, balance_cents)
|
||||
VALUES ($1, $2, $3, $4, 'wechat', 'user', 30, null, 0, 0)
|
||||
RETURNING id
|
||||
`,
|
||||
[username, randomPasswordHash, wechatUser.openid, wechatUser.unionid],
|
||||
);
|
||||
return rows[0].id;
|
||||
}
|
||||
|
||||
function validateEnterpriseName(name) {
|
||||
if (!name) return "缺少企业名称";
|
||||
if (name.trim().length < 2 || name.trim().length > 80) return "企业名称长度必须在 2 到 80 之间";
|
||||
return null;
|
||||
}
|
||||
|
||||
function parseNumericValue(value, fieldLabel, { allowNull = true } = {}) {
|
||||
if (value === undefined) return { ok: true, value: undefined };
|
||||
if (value === null || value === "") {
|
||||
return allowNull ? { ok: true, value: null } : { ok: false, error: `${fieldLabel}不能为空` };
|
||||
}
|
||||
const numeric = Number(value);
|
||||
if (!Number.isFinite(numeric) || numeric < 0)
|
||||
return { ok: false, error: `${fieldLabel}必须是非负数字` };
|
||||
return { ok: true, value: numeric };
|
||||
}
|
||||
|
||||
async function ensureEnterpriseExists(enterpriseId) {
|
||||
if (enterpriseId == null) return null;
|
||||
const { rows } = await pool.query("SELECT id, name FROM enterprises WHERE id = $1", [
|
||||
enterpriseId,
|
||||
]);
|
||||
return rows[0] || null;
|
||||
}
|
||||
|
||||
function formatUserRow(row) {
|
||||
return {
|
||||
id: Number(row.id),
|
||||
username: row.username,
|
||||
role: row.role,
|
||||
avatarUrl: row.avatar_url || null,
|
||||
maxConcurrency: Number(row.max_concurrency || 0),
|
||||
enabled: !!row.enabled,
|
||||
enterpriseId: row.enterprise_id == null ? null : Number(row.enterprise_id),
|
||||
enterpriseName: row.enterprise_name || null,
|
||||
isEnterpriseAdmin: !!row.is_enterprise_admin,
|
||||
balanceCents: row.balance_cents != null ? Number(row.balance_cents) : 0,
|
||||
billingMode: row.billing_mode || "credits",
|
||||
betaExpiresAt: row.beta_expires_at || null,
|
||||
createdAt: row.created_at,
|
||||
};
|
||||
}
|
||||
|
||||
function normalizeOssRegion(region) {
|
||||
const trimmed = String(region || "").trim();
|
||||
return trimmed.startsWith("oss-") ? trimmed.slice(4) : trimmed;
|
||||
}
|
||||
|
||||
function buildOssPublicUrl(ossKey) {
|
||||
const publicBaseUrl = String(process.env.OSS_PUBLIC_BASE_URL || "")
|
||||
.trim()
|
||||
.replace(/\/+$/, "");
|
||||
if (publicBaseUrl) {
|
||||
return `${publicBaseUrl}/${ossKey}`;
|
||||
}
|
||||
|
||||
const bucket = String(process.env.OSS_BUCKET || "").trim();
|
||||
const region = normalizeOssRegion(process.env.OSS_REGION || "");
|
||||
if (!bucket || !region) {
|
||||
throw new Error("OSS bucket or region is not configured");
|
||||
}
|
||||
|
||||
return `https://${bucket}.oss-${region}.aliyuncs.com/${ossKey}`;
|
||||
}
|
||||
|
||||
function normalizeAvatarOssKey(value, userId) {
|
||||
if (value === undefined) return { value: undefined };
|
||||
if (value === null) return { value: null };
|
||||
|
||||
const safeUserId = String(userId).replace(/[^a-zA-Z0-9_-]/g, "");
|
||||
const ossKey = String(value || "")
|
||||
.trim()
|
||||
.replace(/^\/+/, "");
|
||||
if (!ossKey) return { value: null };
|
||||
|
||||
const expectedPrefix = `users/${safeUserId}/profile/avatar/`;
|
||||
const allowedPattern = new RegExp(
|
||||
`^users/${safeUserId}/profile/avatar/avatar\\.(jpg|jpeg|png|webp)$`,
|
||||
"i",
|
||||
);
|
||||
if (!ossKey.startsWith(expectedPrefix) || !allowedPattern.test(ossKey)) {
|
||||
return { error: "Invalid avatar OSS key" };
|
||||
}
|
||||
|
||||
return { value: ossKey };
|
||||
}
|
||||
|
||||
function normalizeProfileMediaUrl(value) {
|
||||
if (value === undefined) return { value: undefined };
|
||||
if (value === null || value === "") return { value: null };
|
||||
|
||||
const url = String(value || "").trim();
|
||||
if (!url) return { value: null };
|
||||
if (url.length > 2000) return { error: "资料图片地址过长" };
|
||||
if (url.startsWith("data:")) return { error: "资料图片请先上传到 OSS" };
|
||||
|
||||
try {
|
||||
const parsed = new URL(url);
|
||||
if (parsed.protocol !== "https:" && parsed.protocol !== "http:") {
|
||||
return { error: "资料图片地址格式不正确" };
|
||||
}
|
||||
} catch {
|
||||
return { error: "资料图片地址格式不正确" };
|
||||
}
|
||||
|
||||
return { value: url };
|
||||
}
|
||||
|
||||
function normalizeProjectOssKey(value, userId, projectId) {
|
||||
const safeUserId = String(userId).replace(/[^a-zA-Z0-9_-]/g, "");
|
||||
const safeProjectId = String(projectId || "")
|
||||
.trim()
|
||||
.replace(/[^a-zA-Z0-9_-]/g, "");
|
||||
const ossKey = String(value || "")
|
||||
.trim()
|
||||
.replace(/^\/+/, "");
|
||||
|
||||
if (!safeUserId || !safeProjectId || safeProjectId !== String(projectId || "").trim()) {
|
||||
return { error: "Invalid project OSS key scope" };
|
||||
}
|
||||
|
||||
const expectedKey = `users/${safeUserId}/projects/${safeProjectId}/current/project.json`;
|
||||
if (ossKey !== expectedKey) {
|
||||
return { error: "Invalid project OSS key scope" };
|
||||
}
|
||||
|
||||
return { value: ossKey };
|
||||
}
|
||||
|
||||
function getManagementEnterpriseId(user) {
|
||||
if (!user || isSystemAdmin(user)) return null;
|
||||
return user.enterpriseId || null;
|
||||
}
|
||||
|
||||
function appendEnterpriseScope(whereClauses, params, user, expression, paramIdx) {
|
||||
const enterpriseId = getManagementEnterpriseId(user);
|
||||
if (enterpriseId != null) {
|
||||
whereClauses.push(`${expression} = $${paramIdx}`);
|
||||
params.push(enterpriseId);
|
||||
return paramIdx + 1;
|
||||
}
|
||||
return paramIdx;
|
||||
}
|
||||
|
||||
function readModelPricePayload(body, existing = null) {
|
||||
const modelKey = String(body.modelKey ?? existing?.modelKey ?? "").trim();
|
||||
const displayName = String(body.displayName ?? existing?.displayName ?? "").trim();
|
||||
const category = String(body.category ?? existing?.category ?? "text").trim();
|
||||
const pricingType = String(body.pricingType ?? existing?.pricingType ?? "token").trim();
|
||||
const currency = String(body.currency ?? existing?.currency ?? "CNY").trim() || "CNY";
|
||||
const enabled = body.enabled === undefined ? (existing?.enabled ?? true) : !!body.enabled;
|
||||
|
||||
if (!modelKey) return { error: "缺少模型标识" };
|
||||
if (!displayName) return { error: "缺少显示名称" };
|
||||
if (!PRICE_CATEGORIES.has(category)) return { error: "模型分类无效" };
|
||||
if (!PRICE_TYPES.has(pricingType)) return { error: "计费类型无效" };
|
||||
|
||||
const inputPriceMills = parseNumericValue(body.inputPriceMills, "输入价格(厘)");
|
||||
if (!inputPriceMills.ok) return { error: inputPriceMills.error };
|
||||
const outputPriceMills = parseNumericValue(body.outputPriceMills, "输出价格(厘)");
|
||||
if (!outputPriceMills.ok) return { error: outputPriceMills.error };
|
||||
const flatPriceMills = parseNumericValue(body.flatPriceMills, "固定价格(厘)");
|
||||
if (!flatPriceMills.ok) return { error: flatPriceMills.error };
|
||||
|
||||
const merged = {
|
||||
modelKey,
|
||||
displayName,
|
||||
category,
|
||||
pricingType,
|
||||
currency,
|
||||
enabled,
|
||||
inputPriceMills:
|
||||
inputPriceMills.value !== undefined
|
||||
? inputPriceMills.value
|
||||
: (existing?.inputPriceMills ?? null),
|
||||
outputPriceMills:
|
||||
outputPriceMills.value !== undefined
|
||||
? outputPriceMills.value
|
||||
: (existing?.outputPriceMills ?? null),
|
||||
flatPriceMills:
|
||||
flatPriceMills.value !== undefined
|
||||
? flatPriceMills.value
|
||||
: (existing?.flatPriceMills ?? null),
|
||||
};
|
||||
|
||||
if (pricingType === "token") {
|
||||
if (merged.inputPriceMills == null || merged.outputPriceMills == null)
|
||||
return { error: "按 Token 计费时必须提供输入和输出价格(厘)" };
|
||||
merged.flatPriceMills = null;
|
||||
} else {
|
||||
if (merged.flatPriceMills == null) return { error: "固定计费时必须提供固定价格(厘)" };
|
||||
merged.inputPriceMills = null;
|
||||
merged.outputPriceMills = null;
|
||||
}
|
||||
|
||||
return { value: merged };
|
||||
}
|
||||
|
||||
async function getModelPriceById(id) {
|
||||
const { rows } = await pool.query("SELECT * FROM model_prices WHERE id = $1", [id]);
|
||||
return normalizeModelPriceRow(rows[0]);
|
||||
}
|
||||
|
||||
function getPeriodStart(period) {
|
||||
switch (period) {
|
||||
case "7d":
|
||||
return "NOW() - INTERVAL '7 days'";
|
||||
case "30d":
|
||||
return "NOW() - INTERVAL '30 days'";
|
||||
case "all":
|
||||
return null;
|
||||
default:
|
||||
return "NOW() - INTERVAL '7 days'";
|
||||
}
|
||||
}
|
||||
|
||||
// Fills a SQL day-aggregation result into a continuous 7-day series ending
|
||||
// today, padding missing days with zeros so the trend chart has no gaps.
|
||||
function buildDailyTrend(rows, days = 7) {
|
||||
const byDay = new Map();
|
||||
for (const row of rows || []) {
|
||||
byDay.set(String(row.day), {
|
||||
usedCents: Number(row.used_cents || 0),
|
||||
taskCount: Number(row.task_count || 0),
|
||||
});
|
||||
}
|
||||
const series = [];
|
||||
const today = new Date();
|
||||
for (let i = days - 1; i >= 0; i -= 1) {
|
||||
const d = new Date(today);
|
||||
d.setDate(today.getDate() - i);
|
||||
const key = d.toISOString().slice(0, 10);
|
||||
const hit = byDay.get(key) || { usedCents: 0, taskCount: 0 };
|
||||
series.push({ date: key, usedCents: hit.usedCents, taskCount: hit.taskCount });
|
||||
}
|
||||
return series;
|
||||
}
|
||||
|
||||
function clampPositiveInteger(value, fallback, max) {
|
||||
const numeric = Number(value);
|
||||
if (!Number.isFinite(numeric) || numeric <= 0) return fallback;
|
||||
return Math.min(Math.trunc(numeric), max);
|
||||
}
|
||||
|
||||
function clampNonNegativeInteger(value, fallback, max) {
|
||||
const numeric = Number(value);
|
||||
if (!Number.isFinite(numeric) || numeric < 0) return fallback;
|
||||
return Math.min(Math.trunc(numeric), max);
|
||||
}
|
||||
|
||||
function generateOrderNo() {
|
||||
const timestamp = Date.now().toString(36).toUpperCase();
|
||||
const random = crypto.randomBytes(4).toString("hex").toUpperCase();
|
||||
return `ORD${timestamp}${random}`;
|
||||
}
|
||||
|
||||
const GENERATION_TASK_STATUSES = new Set([
|
||||
"pending",
|
||||
"running",
|
||||
"completed",
|
||||
"failed",
|
||||
"cancelled",
|
||||
]);
|
||||
const GENERATION_TASK_TYPES = new Set(["image", "video"]);
|
||||
|
||||
function clampTaskProgress(value) {
|
||||
const numeric = Number(value);
|
||||
if (!Number.isFinite(numeric)) return 0;
|
||||
return Math.max(0, Math.min(100, Math.trunc(numeric)));
|
||||
}
|
||||
|
||||
function serializeTaskParams(value) {
|
||||
if (!value || typeof value !== "object") return "{}";
|
||||
return JSON.stringify(value);
|
||||
}
|
||||
|
||||
function parseTaskParams(value) {
|
||||
if (!value || typeof value !== "string") return {};
|
||||
try {
|
||||
return JSON.parse(value);
|
||||
} catch {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
function formatGenerationTaskRow(row) {
|
||||
return {
|
||||
id: Number(row.id),
|
||||
projectId: row.project_id,
|
||||
clientQueueId: row.client_queue_id,
|
||||
type: row.type,
|
||||
status: row.status,
|
||||
providerTaskId: row.provider_task_id || null,
|
||||
params: parseTaskParams(row.params_json),
|
||||
resultUrl: row.result_url || null,
|
||||
progress: Number(row.progress || 0),
|
||||
error: row.error || null,
|
||||
dedupeKey: row.dedupe_key || null,
|
||||
sourceDeviceId: row.source_device_id || null,
|
||||
createdAt: row.created_at,
|
||||
updatedAt: row.updated_at,
|
||||
completedAt: row.completed_at || null,
|
||||
};
|
||||
}
|
||||
|
||||
function normalizeGenerationTaskPayload(body) {
|
||||
const clientQueueId = String(body.clientQueueId || body.client_queue_id || "")
|
||||
.trim()
|
||||
.slice(0, 128);
|
||||
const type = String(body.type || "").trim();
|
||||
const status = String(body.status || "pending").trim();
|
||||
|
||||
if (!clientQueueId) return { error: "Missing clientQueueId" };
|
||||
if (!GENERATION_TASK_TYPES.has(type)) return { error: "Invalid task type" };
|
||||
if (!GENERATION_TASK_STATUSES.has(status)) return { error: "Invalid task status" };
|
||||
|
||||
return {
|
||||
value: {
|
||||
clientQueueId,
|
||||
type,
|
||||
status,
|
||||
providerTaskId: body.providerTaskId || body.provider_task_id || null,
|
||||
paramsJson: serializeTaskParams(body.params || body.paramsJson || body.params_json),
|
||||
resultUrl: body.resultUrl || body.result_url || null,
|
||||
progress: clampTaskProgress(body.progress),
|
||||
error: body.error || null,
|
||||
dedupeKey: body.dedupeKey || body.dedupe_key || null,
|
||||
sourceDeviceId: body.sourceDeviceId || body.source_device_id || null,
|
||||
createdAt: body.createdAt || body.created_at || null,
|
||||
completedAt: body.completedAt || body.completed_at || null,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
async function requireOwnedProject(client, userId, projectId) {
|
||||
const { rows } = await client.query("SELECT id FROM projects WHERE id = $1 AND user_id = $2", [
|
||||
projectId,
|
||||
userId,
|
||||
]);
|
||||
return rows.length > 0;
|
||||
}
|
||||
|
||||
async function upsertGenerationTask(client, userId, projectId, payload) {
|
||||
const {
|
||||
rows: [row],
|
||||
} = await client.query(
|
||||
`
|
||||
INSERT INTO generation_tasks (
|
||||
user_id,
|
||||
project_id,
|
||||
client_queue_id,
|
||||
type,
|
||||
status,
|
||||
provider_task_id,
|
||||
params_json,
|
||||
result_url,
|
||||
progress,
|
||||
error,
|
||||
dedupe_key,
|
||||
source_device_id,
|
||||
created_at,
|
||||
updated_at,
|
||||
completed_at
|
||||
)
|
||||
VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12,
|
||||
COALESCE($13::timestamptz, NOW()),
|
||||
NOW(),
|
||||
$14::timestamptz
|
||||
)
|
||||
ON CONFLICT (project_id, client_queue_id) WHERE project_id IS NOT NULL DO UPDATE SET
|
||||
type = EXCLUDED.type,
|
||||
status = EXCLUDED.status,
|
||||
provider_task_id = EXCLUDED.provider_task_id,
|
||||
params_json = EXCLUDED.params_json,
|
||||
result_url = EXCLUDED.result_url,
|
||||
progress = EXCLUDED.progress,
|
||||
error = EXCLUDED.error,
|
||||
dedupe_key = EXCLUDED.dedupe_key,
|
||||
source_device_id = EXCLUDED.source_device_id,
|
||||
updated_at = NOW(),
|
||||
completed_at = EXCLUDED.completed_at
|
||||
RETURNING *
|
||||
`,
|
||||
[
|
||||
userId,
|
||||
projectId,
|
||||
payload.clientQueueId,
|
||||
payload.type,
|
||||
payload.status,
|
||||
payload.providerTaskId,
|
||||
payload.paramsJson,
|
||||
payload.resultUrl,
|
||||
payload.progress,
|
||||
payload.error,
|
||||
payload.dedupeKey,
|
||||
payload.sourceDeviceId,
|
||||
payload.createdAt,
|
||||
payload.completedAt,
|
||||
],
|
||||
);
|
||||
|
||||
return row;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
express,
|
||||
bcrypt,
|
||||
requireAuth,
|
||||
requireAdmin,
|
||||
requireEnterpriseAdmin,
|
||||
requireManagementAccess,
|
||||
login,
|
||||
generateToken,
|
||||
startUserSession,
|
||||
getUserContextById,
|
||||
isSystemAdmin,
|
||||
generateUniqueEnterpriseCode,
|
||||
keyManager,
|
||||
calculateCost,
|
||||
calculateCostMills,
|
||||
listModelPrices,
|
||||
normalizeModelPriceRow,
|
||||
getAverageCostCents,
|
||||
loadPriceCache,
|
||||
deductForApiCall,
|
||||
deductImageGenerationCredits,
|
||||
creditBalance,
|
||||
creditUserBalance,
|
||||
activatePackage,
|
||||
distributeCredits,
|
||||
getEnterpriseFinancials,
|
||||
getUserEnterpriseId,
|
||||
getEnterpriseName,
|
||||
preauthorizeCall,
|
||||
wechatPay,
|
||||
alipay,
|
||||
crypto,
|
||||
pool,
|
||||
withTransaction,
|
||||
computeNextRevision,
|
||||
normalizeRevisionValue,
|
||||
shouldRejectStaleRevision,
|
||||
USERNAME_PATTERN,
|
||||
PRICE_CATEGORIES,
|
||||
PRICE_TYPES,
|
||||
PHONE_PATTERN,
|
||||
EMAIL_PATTERN,
|
||||
SMS_PURPOSES,
|
||||
SMS_CODE_TTL_MINUTES,
|
||||
SMS_CODE_COOLDOWN_SECONDS,
|
||||
SMS_CODE_MAX_ATTEMPTS,
|
||||
validateUsername,
|
||||
validatePassword,
|
||||
normalizePhone,
|
||||
validatePhone,
|
||||
normalizeEmail,
|
||||
validateEmail,
|
||||
hashSmsCode,
|
||||
generateSmsCode,
|
||||
sendSmsCode,
|
||||
createLoginResultForUserId,
|
||||
sanitizeUsernameSeed,
|
||||
generateUniqueUsername,
|
||||
consumeSmsCode,
|
||||
getWechatLoginConfig,
|
||||
fetchWechatJson,
|
||||
exchangeWechatCode,
|
||||
findOrCreateWechatUser,
|
||||
validateEnterpriseName,
|
||||
parseNumericValue,
|
||||
ensureEnterpriseExists,
|
||||
formatUserRow,
|
||||
normalizeOssRegion,
|
||||
buildOssPublicUrl,
|
||||
normalizeAvatarOssKey,
|
||||
normalizeProfileMediaUrl,
|
||||
normalizeProjectOssKey,
|
||||
getManagementEnterpriseId,
|
||||
appendEnterpriseScope,
|
||||
readModelPricePayload,
|
||||
getModelPriceById,
|
||||
getPeriodStart,
|
||||
buildDailyTrend,
|
||||
clampPositiveInteger,
|
||||
clampNonNegativeInteger,
|
||||
generateOrderNo,
|
||||
GENERATION_TASK_STATUSES,
|
||||
GENERATION_TASK_TYPES,
|
||||
clampTaskProgress,
|
||||
serializeTaskParams,
|
||||
parseTaskParams,
|
||||
formatGenerationTaskRow,
|
||||
normalizeGenerationTaskPayload,
|
||||
requireOwnedProject,
|
||||
upsertGenerationTask,
|
||||
};
|
||||
Reference in New Issue
Block a user