diff --git a/src/tools/collections.ts b/src/tools/collections.ts index e732d73..f94ef12 100644 --- a/src/tools/collections.ts +++ b/src/tools/collections.ts @@ -93,7 +93,7 @@ export const collectionsTools: BaseTool[] = [ countQuery += ` AND "teamId" = $1`; } const countResult = await pool.query(countQuery, countParams); - const totalCount = parseInt(countResult.rows[0].count); + const totalCount = parseInt(countResult.rows[0].count, 10); return { content: [ @@ -614,7 +614,7 @@ export const collectionsTools: BaseTool[] = [ // Get total count const countQuery = 'SELECT COUNT(*) FROM documents WHERE "collectionId" = $1 AND "deletedAt" IS NULL'; const countResult = await pool.query(countQuery, [collectionId]); - const totalCount = parseInt(countResult.rows[0].count); + const totalCount = parseInt(countResult.rows[0].count, 10); return { content: [ @@ -842,7 +842,7 @@ export const collectionsTools: BaseTool[] = [ // Get total count const countQuery = 'SELECT COUNT(*) FROM collection_users WHERE "collectionId" = $1'; const countResult = await pool.query(countQuery, [collectionId]); - const totalCount = parseInt(countResult.rows[0].count); + const totalCount = parseInt(countResult.rows[0].count, 10); return { content: [ @@ -1069,7 +1069,7 @@ export const collectionsTools: BaseTool[] = [ // Get total count const countQuery = 'SELECT COUNT(*) FROM collection_groups WHERE "collectionId" = $1'; const countResult = await pool.query(countQuery, [collectionId]); - const totalCount = parseInt(countResult.rows[0].count); + const totalCount = parseInt(countResult.rows[0].count, 10); return { content: [ diff --git a/src/tools/groups.ts b/src/tools/groups.ts index be2e725..bd7b4b6 100644 --- a/src/tools/groups.ts +++ b/src/tools/groups.ts @@ -80,7 +80,7 @@ const listGroups: BaseTool = { { data: { groups: result.rows, - total: result.rows.length > 0 ? parseInt(result.rows[0].total) : 0, + total: result.rows.length > 0 ? parseInt(result.rows[0].total, 10) : 0, limit, offset, }, diff --git a/src/tools/revisions.ts b/src/tools/revisions.ts index 1a9b421..906d80c 100644 --- a/src/tools/revisions.ts +++ b/src/tools/revisions.ts @@ -87,7 +87,7 @@ const listRevisions: BaseTool = { pagination: { limit, offset, - total: parseInt(countQuery.rows[0].total), + total: parseInt(countQuery.rows[0].total, 10), }, }, null, diff --git a/src/tools/users.ts b/src/tools/users.ts index 37793ef..5c8777e 100644 --- a/src/tools/users.ts +++ b/src/tools/users.ts @@ -109,7 +109,7 @@ const listUsers: BaseTool = { { data: { users: result.rows, - total: result.rows.length > 0 ? parseInt(result.rows[0].total) : 0, + total: result.rows.length > 0 ? parseInt(result.rows[0].total, 10) : 0, limit, offset, }, diff --git a/src/utils/transaction.ts b/src/utils/transaction.ts index 8b6095d..6374bbb 100644 --- a/src/utils/transaction.ts +++ b/src/utils/transaction.ts @@ -235,6 +235,30 @@ export async function withReadOnlyTransaction( } } +/** + * Validate and sanitize savepoint name to prevent SQL injection + * Only allows alphanumeric characters and underscores + */ +function sanitizeSavepointName(name: string): string { + // Remove any non-alphanumeric characters except underscore + const sanitized = name.replace(/[^a-zA-Z0-9_]/g, ''); + + if (sanitized.length === 0) { + throw new Error('Savepoint name must contain at least one alphanumeric character'); + } + + if (sanitized.length > 63) { + throw new Error('Savepoint name must be 63 characters or less'); + } + + // Ensure it doesn't start with a number + if (/^[0-9]/.test(sanitized)) { + return `sp_${sanitized}`; + } + + return sanitized; +} + /** * Savepoint helper for nested transaction-like behavior */ @@ -245,7 +269,7 @@ export class Savepoint { constructor(client: PoolClient, name: string) { this.client = client; - this.name = name; + this.name = sanitizeSavepointName(name); } /**