Files
omniai-server/src/keyManager.js
T

426 lines
11 KiB
JavaScript

const crypto = require("node:crypto");
const { pool, withTransaction } = require("./db");
const STALE_LEASE_MINUTES = 30;
const DEFAULT_WAIT_TIMEOUT_MS = 25_000;
const MAX_WAIT_TIMEOUT_MS = 5 * 60 * 1000;
const providerWaitQueues = new Map();
const providerDrainPromises = new Map();
function normalizeWaitTimeoutMs(value) {
const numeric = Number(value);
if (!Number.isFinite(numeric) || numeric <= 0) {
return 0;
}
return Math.min(Math.trunc(numeric), MAX_WAIT_TIMEOUT_MS);
}
function getProviderWaitQueue(provider) {
const existingQueue = providerWaitQueues.get(provider);
if (existingQueue) {
return existingQueue;
}
const nextQueue = [];
providerWaitQueues.set(provider, nextQueue);
return nextQueue;
}
function compactProviderWaitQueue(provider) {
const queue = providerWaitQueues.get(provider);
if (!queue || queue.length === 0) {
providerWaitQueues.delete(provider);
return [];
}
const activeWaiters = queue.filter((waiter) => !waiter.settled);
if (activeWaiters.length === 0) {
providerWaitQueues.delete(provider);
return [];
}
providerWaitQueues.set(provider, activeWaiters);
return activeWaiters;
}
function removeWaiter(waiter) {
const queue = providerWaitQueues.get(waiter.provider);
if (!queue) {
return;
}
const index = queue.indexOf(waiter);
if (index >= 0) {
queue.splice(index, 1);
}
compactProviderWaitQueue(waiter.provider);
}
function settleWaiter(waiter, outcome) {
if (waiter.settled) {
return;
}
waiter.settled = true;
clearTimeout(waiter.timer);
if (waiter.signal && waiter.abortHandler) {
waiter.signal.removeEventListener("abort", waiter.abortHandler);
}
removeWaiter(waiter);
if (outcome instanceof Error) {
waiter.reject(outcome);
return;
}
waiter.resolve(outcome);
}
function normalizeUserContext(user) {
if (user && typeof user === "object") {
return {
userId: user.id,
enterpriseId: user.enterpriseId ?? null,
accountType: user.accountType ?? "personal",
};
}
return {
userId: user,
enterpriseId: null,
accountType: "personal",
};
}
async function tryAcquireKey(provider, user, preauthResult) {
const { userId, enterpriseId } = normalizeUserContext(user);
const leaseToken = crypto.randomUUID();
const estimatedCostCents = preauthResult?.estimatedCostCents || 0;
return withTransaction(async (client) => {
const { rows } = await client.query(
`
WITH candidate AS (
SELECT id, api_key, label
FROM api_keys
WHERE provider = $1 AND enabled = 1 AND active_count < max_concurrency
ORDER BY active_count ASC, id ASC
FOR UPDATE SKIP LOCKED
LIMIT 1
),
updated AS (
UPDATE api_keys
SET active_count = active_count + 1,
total_used = total_used + 1
WHERE id = (SELECT id FROM candidate)
RETURNING id, api_key, label
)
SELECT id, api_key, label FROM updated
`,
[provider],
);
const key = rows[0];
if (!key) {
return null;
}
await client.query(
`
INSERT INTO key_leases (key_id, user_id, lease_token, estimated_cost_cents, enterprise_id)
VALUES ($1, $2, $3, $4, $5)
`,
[key.id, userId, leaseToken, estimatedCostCents, enterpriseId],
);
await client.query(
`
INSERT INTO usage_logs (user_id, enterprise_id, provider, key_id, action)
VALUES ($1, $2, $3, $4, $5)
`,
[userId, enterpriseId, provider, key.id, "acquire"],
);
return {
leaseToken,
apiKey: key.api_key === "pool-slot" ? "" : key.api_key,
provider,
keyLabel: key.label || `Key #${key.id}`,
};
});
}
async function drainProviderWaitQueue(provider) {
const existingDrain = providerDrainPromises.get(provider);
if (existingDrain) {
return existingDrain;
}
const drainPromise = (async () => {
while (true) {
const queue = compactProviderWaitQueue(provider);
const nextWaiter = queue[0];
if (!nextWaiter) {
break;
}
let acquiredLease = null;
try {
acquiredLease = await tryAcquireKey(provider, nextWaiter.user, nextWaiter.preauthResult);
} catch (error) {
settleWaiter(nextWaiter, error instanceof Error ? error : new Error(String(error)));
continue;
}
if (!acquiredLease) {
break;
}
if (nextWaiter.settled || nextWaiter.signal?.aborted) {
await releaseLeaseInternal(acquiredLease.leaseToken, nextWaiter.user, { skipDrain: true });
continue;
}
settleWaiter(nextWaiter, acquiredLease);
}
})().finally(() => {
providerDrainPromises.delete(provider);
});
providerDrainPromises.set(provider, drainPromise);
await drainPromise;
}
function enqueueAcquireWaiter(provider, user, preauthResult, options = {}) {
const waitTimeoutMs = normalizeWaitTimeoutMs(options.waitTimeoutMs ?? DEFAULT_WAIT_TIMEOUT_MS);
if (waitTimeoutMs <= 0) {
return Promise.resolve(null);
}
return new Promise((resolve, reject) => {
if (options.signal?.aborted) {
reject(new Error("Cancelled"));
return;
}
const waiter = {
provider,
user,
preauthResult,
resolve,
reject,
settled: false,
signal: options.signal,
abortHandler: null,
timer: null,
};
waiter.abortHandler = () => {
settleWaiter(waiter, new Error("Cancelled"));
};
if (waiter.signal) {
waiter.signal.addEventListener("abort", waiter.abortHandler, { once: true });
}
waiter.timer = setTimeout(() => {
settleWaiter(waiter, null);
}, waitTimeoutMs);
getProviderWaitQueue(provider).push(waiter);
});
}
async function cleanStaleLeases() {
const cutoff = new Date(Date.now() - STALE_LEASE_MINUTES * 60 * 1000).toISOString();
const { rows: stale } = await pool.query(
`
SELECT l.id, l.key_id, k.provider
FROM key_leases l
JOIN api_keys k ON k.id = l.key_id
WHERE l.released_at IS NULL AND l.leased_at < $1
`,
[cutoff],
);
if (stale.length === 0) return 0;
await withTransaction(async (client) => {
for (const lease of stale) {
await client.query("UPDATE key_leases SET released_at = NOW() WHERE id = $1", [lease.id]);
await client.query(
"UPDATE api_keys SET active_count = GREATEST(0, active_count - 1) WHERE id = $1",
[lease.key_id],
);
}
});
const affectedProviders = Array.from(
new Set(stale.map((lease) => lease.provider).filter(Boolean)),
);
await Promise.all(affectedProviders.map((provider) => drainProviderWaitQueue(provider)));
return stale.length;
}
async function acquireKey(provider, user, preauthResult, options = {}) {
await cleanStaleLeases();
const immediateLease = await tryAcquireKey(provider, user, preauthResult);
if (immediateLease) {
return immediateLease;
}
return enqueueAcquireWaiter(provider, user, preauthResult, options);
}
async function releaseLeaseInternal(leaseToken, user, options = {}) {
const { userId, enterpriseId } = normalizeUserContext(user);
const releaseResult = await withTransaction(async (client) => {
const { rows } = await client.query(
`
WITH candidate AS (
SELECT l.id, l.key_id, l.user_id, l.enterprise_id, k.provider
FROM key_leases l
JOIN api_keys k ON k.id = l.key_id
WHERE l.lease_token = $1 AND l.released_at IS NULL
FOR UPDATE SKIP LOCKED
LIMIT 1
),
released AS (
UPDATE key_leases
SET released_at = NOW()
WHERE id = (SELECT id FROM candidate)
RETURNING id, key_id
)
SELECT r.id, r.key_id, c.user_id AS lease_user_id, c.enterprise_id AS lease_enterprise_id, c.provider
FROM released r
JOIN candidate c ON c.key_id = r.key_id
`,
[leaseToken],
);
const lease = rows[0];
if (!lease) {
const { rows: existingRows } = await client.query(
`
SELECT l.key_id, l.released_at, k.provider
FROM key_leases l
JOIN api_keys k ON k.id = l.key_id
WHERE l.lease_token = $1
LIMIT 1
`,
[leaseToken],
);
const existing = existingRows[0];
if (!existing) {
return { released: false, notFound: true, alreadyReleased: false, provider: null };
}
return {
released: false,
notFound: false,
alreadyReleased: Boolean(existing.released_at),
provider: existing.provider || null,
};
}
await client.query(
"UPDATE api_keys SET active_count = GREATEST(0, active_count - 1) WHERE id = $1",
[lease.key_id],
);
await client.query(
`
INSERT INTO usage_logs (user_id, enterprise_id, provider, key_id, action)
VALUES ($1, $2, (SELECT provider FROM api_keys WHERE id = $3), $4, $5)
`,
[
userId || lease.lease_user_id,
enterpriseId || lease.lease_enterprise_id,
lease.key_id,
lease.key_id,
"release",
],
);
return {
released: true,
notFound: false,
alreadyReleased: false,
provider: lease.provider || null,
};
});
if (releaseResult.released && releaseResult.provider) {
if (options.skipDrain) {
return releaseResult;
}
await drainProviderWaitQueue(releaseResult.provider);
}
return releaseResult;
}
async function releaseKey(leaseToken, user) {
return releaseLeaseInternal(leaseToken, user);
}
function getQueuedWaiterCount(provider) {
return compactProviderWaitQueue(provider).length;
}
async function getKeyStatus(provider) {
const { rows: keys } = await pool.query(
`
SELECT id, provider, label, max_concurrency, active_count, total_used, enabled
FROM api_keys
WHERE provider = $1
ORDER BY id
`,
[provider],
);
const totalCapacity = keys.reduce((sum, key) => sum + (key.enabled ? key.max_concurrency : 0), 0);
const totalActive = keys.reduce((sum, key) => sum + key.active_count, 0);
return {
provider,
keys: keys.map((key) => ({
id: key.id,
label: key.label || `Key #${key.id}`,
active: key.active_count,
capacity: key.max_concurrency,
totalUsed: key.total_used,
enabled: !!key.enabled,
})),
totalActive,
totalCapacity,
queuedCount: getQueuedWaiterCount(provider),
available: totalCapacity - totalActive,
};
}
async function getAllStatus() {
const { rows: providers } = await pool.query(
"SELECT DISTINCT provider FROM api_keys WHERE enabled = 1",
);
const results = [];
for (const providerRow of providers) {
results.push(await getKeyStatus(providerRow.provider));
}
return results;
}
module.exports = {
acquireKey,
releaseKey,
getKeyStatus,
getAllStatus,
cleanStaleLeases,
};