diff --git a/README.md b/README.md index 7e439982..0c1aa31a 100644 --- a/README.md +++ b/README.md @@ -137,7 +137,7 @@ a local testing environment, as shown in the following steps: use the local testing settings specified in `tests/config.json`, instead of the CI settings 3. Run the tests manually by using the command\ - `deno test --unstable -A` + `deno test -A` ## Deno compatibility diff --git a/client.ts b/client.ts index 96d01780..91e491ff 100644 --- a/client.ts +++ b/client.ts @@ -67,9 +67,7 @@ export abstract class QueryClient { #assertOpenConnection() { if (this.#terminated) { - throw new Error( - "Connection to the database has been terminated", - ); + throw new Error("Connection to the database has been terminated"); } } @@ -243,9 +241,7 @@ export abstract class QueryClient { async #executeQuery( _query: Query, ): Promise>; - async #executeQuery( - query: Query, - ): Promise { + async #executeQuery(query: Query): Promise { return await this.#connection.query(query); } @@ -397,9 +393,7 @@ export abstract class QueryClient { query: TemplateStringsArray, ...args: unknown[] ): Promise>; - async queryObject< - T = Record, - >( + async queryObject>( query_template_or_config: | string | QueryObjectOptions diff --git a/client/error.ts b/client/error.ts index 70d3786c..60d0f917 100644 --- a/client/error.ts +++ b/client/error.ts @@ -25,14 +25,8 @@ export class PostgresError extends Error { } export class TransactionError extends Error { - constructor( - transaction_name: string, - cause: PostgresError, - ) { - super( - `The transaction "${transaction_name}" has been aborted`, - { cause }, - ); + constructor(transaction_name: string, cause: PostgresError) { + super(`The transaction "${transaction_name}" has been aborted`, { cause }); this.name = "TransactionError"; } } diff --git a/connection/connection.ts b/connection/connection.ts index 1764a25b..9ce6f65a 100644 --- a/connection/connection.ts +++ b/connection/connection.ts @@ -86,9 +86,7 @@ function assertSuccessfulAuthentication(auth_message: Message) { throw new PostgresError(parseNoticeMessage(auth_message)); } - if ( - auth_message.type !== INCOMING_AUTHENTICATION_MESSAGES.AUTHENTICATION - ) { + if (auth_message.type !== INCOMING_AUTHENTICATION_MESSAGES.AUTHENTICATION) { throw new Error(`Unexpected auth response: ${auth_message.type}.`); } @@ -118,10 +116,7 @@ export class Connection { #onDisconnection: () => Promise; #packetWriter = new PacketWriter(); #pid?: number; - #queryLock: DeferredStack = new DeferredStack( - 1, - [undefined], - ); + #queryLock: DeferredStack = new DeferredStack(1, [undefined]); // TODO // Find out what the secret key is for #secretKey?: number; @@ -180,10 +175,7 @@ export class Connection { async #serverAcceptsTLS(): Promise { const writer = this.#packetWriter; writer.clear(); - writer - .addInt32(8) - .addInt32(80877103) - .join(); + writer.addInt32(8).addInt32(80877103).join(); await this.#bufWriter.write(writer.flush()); await this.#bufWriter.flush(); @@ -216,16 +208,20 @@ export class Connection { // TODO: recognize other parameters writer.addCString("user").addCString(this.#connection_params.user); writer.addCString("database").addCString(this.#connection_params.database); - writer.addCString("application_name").addCString( - this.#connection_params.applicationName, - ); + writer + .addCString("application_name") + .addCString(this.#connection_params.applicationName); const connection_options = Object.entries(this.#connection_params.options); if (connection_options.length > 0) { // The database expects options in the --key=value - writer.addCString("options").addCString( - connection_options.map(([key, value]) => `--${key}=${value}`).join(" "), - ); + writer + .addCString("options") + .addCString( + connection_options + .map(([key, value]) => `--${key}=${value}`) + .join(" "), + ); } // terminator after all parameters were writter @@ -236,10 +232,7 @@ export class Connection { writer.clear(); - const finalBuffer = writer - .addInt32(bodyLength) - .add(bodyBuffer) - .join(); + const finalBuffer = writer.addInt32(bodyLength).add(bodyBuffer).join(); await this.#bufWriter.write(finalBuffer); await this.#bufWriter.flush(); @@ -248,7 +241,7 @@ export class Connection { } async #openConnection(options: ConnectOptions) { - // @ts-ignore This will throw in runtime if the options passed to it are socket related and deno is running + // @ts-expect-error This will throw in runtime if the options passed to it are socket related and deno is running // on stable this.#conn = await Deno.connect(options); this.#bufWriter = new BufWriter(this.#conn); @@ -257,9 +250,7 @@ export class Connection { async #openSocketConnection(path: string, port: number) { if (Deno.build.os === "windows") { - throw new Error( - "Socket connection is only available on UNIX systems", - ); + throw new Error("Socket connection is only available on UNIX systems"); } const socket = await Deno.stat(path); @@ -296,10 +287,7 @@ export class Connection { this.connected = false; this.#packetWriter = new PacketWriter(); this.#pid = undefined; - this.#queryLock = new DeferredStack( - 1, - [undefined], - ); + this.#queryLock = new DeferredStack(1, [undefined]); this.#secretKey = undefined; this.#tls = undefined; this.#transport = undefined; @@ -319,14 +307,10 @@ export class Connection { this.#closeConnection(); const { - hostname, host_type, + hostname, port, - tls: { - enabled: tls_enabled, - enforce: tls_enforced, - caCertificates, - }, + tls: { caCertificates, enabled: tls_enabled, enforce: tls_enforced }, } = this.#connection_params; if (host_type === "socket") { @@ -341,12 +325,11 @@ export class Connection { if (tls_enabled) { // If TLS is disabled, we don't even try to connect. - const accepts_tls = await this.#serverAcceptsTLS() - .catch((e) => { - // Make sure to close the connection if the TLS validation throws - this.#closeConnection(); - throw e; - }); + const accepts_tls = await this.#serverAcceptsTLS().catch((e) => { + // Make sure to close the connection if the TLS validation throws + this.#closeConnection(); + throw e; + }); // https://www.postgresql.org/docs/14/protocol-flow.html#id-1.10.5.7.11 if (accepts_tls) { @@ -657,24 +640,18 @@ export class Connection { `Unexpected message in SASL finalization: ${maybe_sasl_continue.type}`, ); } - const sasl_final = utf8.decode( - maybe_sasl_final.reader.readAllBytes(), - ); + const sasl_final = utf8.decode(maybe_sasl_final.reader.readAllBytes()); await client.receiveResponse(sasl_final); // Return authentication result return this.#readMessage(); } - async #simpleQuery( - query: Query, - ): Promise; + async #simpleQuery(query: Query): Promise; async #simpleQuery( query: Query, ): Promise; - async #simpleQuery( - query: Query, - ): Promise { + async #simpleQuery(query: Query): Promise { this.#packetWriter.clear(); const buffer = this.#packetWriter.addCString(query.text).flush(0x51); @@ -757,9 +734,7 @@ export class Connection { await this.#bufWriter.write(buffer); } - async #appendArgumentsToMessage( - query: Query, - ) { + async #appendArgumentsToMessage(query: Query) { this.#packetWriter.clear(); const hasBinaryArgs = query.args.some((arg) => arg instanceof Uint8Array); @@ -830,10 +805,7 @@ export class Connection { // TODO // Rename process function to a more meaningful name and move out of class - async #processErrorUnsafe( - msg: Message, - recoverable = true, - ) { + async #processErrorUnsafe(msg: Message, recoverable = true) { const error = new PostgresError(parseNoticeMessage(msg)); if (recoverable) { let maybe_ready_message = await this.#readMessage(); @@ -930,15 +902,9 @@ export class Connection { return result; } - async query( - query: Query, - ): Promise; - async query( - query: Query, - ): Promise; - async query( - query: Query, - ): Promise { + async query(query: Query): Promise; + async query(query: Query): Promise; + async query(query: Query): Promise { if (!this.connected) { await this.startup(true); } diff --git a/connection/connection_params.ts b/connection/connection_params.ts index 38c46711..e03e052d 100644 --- a/connection/connection_params.ts +++ b/connection/connection_params.ts @@ -59,12 +59,7 @@ export interface ConnectionOptions { } /** https://www.postgresql.org/docs/14/libpq-ssl.html#LIBPQ-SSL-PROTECTION */ -type TLSModes = - | "disable" - | "prefer" - | "require" - | "verify-ca" - | "verify-full"; +type TLSModes = "disable" | "prefer" | "require" | "verify-ca" | "verify-full"; // TODO // Refactor enabled and enforce into one single option for 1.0 @@ -121,11 +116,7 @@ export interface ClientConfiguration { } function formatMissingParams(missingParams: string[]) { - return `Missing connection parameters: ${ - missingParams.join( - ", ", - ) - }`; + return `Missing connection parameters: ${missingParams.join(", ")}`; } /** @@ -201,24 +192,25 @@ function parseOptionsArgument(options: string): Record { } else if (/^--\w/.test(args[x])) { transformed_args.push(args[x].slice(2)); } else { - throw new Error( - `Value "${args[x]}" is not a valid options argument`, - ); + throw new Error(`Value "${args[x]}" is not a valid options argument`); } } - return transformed_args.reduce((options, x) => { - if (!/.+=.+/.test(x)) { - throw new Error(`Value "${x}" is not a valid options argument`); - } + return transformed_args.reduce( + (options, x) => { + if (!/.+=.+/.test(x)) { + throw new Error(`Value "${x}" is not a valid options argument`); + } - const key = x.slice(0, x.indexOf("=")); - const value = x.slice(x.indexOf("=") + 1); + const key = x.slice(0, x.indexOf("=")); + const value = x.slice(x.indexOf("=") + 1); - options[key] = value; + options[key] = value; - return options; - }, {} as Record); + return options; + }, + {} as Record, + ); } function parseOptionsFromUri(connection_string: string): ClientOptions { @@ -237,14 +229,11 @@ function parseOptionsFromUri(connection_string: string): ClientOptions { // Treat as sslmode=require sslmode: uri.params.ssl === "true" ? "require" - : uri.params.sslmode as TLSModes, + : (uri.params.sslmode as TLSModes), user: uri.user || uri.params.user, }; } catch (e) { - throw new ConnectionParamsError( - `Could not parse the connection string`, - e, - ); + throw new ConnectionParamsError("Could not parse the connection string", e); } if (!["postgres", "postgresql"].includes(postgres_uri.driver)) { @@ -255,7 +244,7 @@ function parseOptionsFromUri(connection_string: string): ClientOptions { // No host by default means socket connection const host_type = postgres_uri.host - ? (isAbsolute(postgres_uri.host) ? "socket" : "tcp") + ? isAbsolute(postgres_uri.host) ? "socket" : "tcp" : "socket"; const options = postgres_uri.options @@ -302,7 +291,10 @@ function parseOptionsFromUri(connection_string: string): ClientOptions { } const DEFAULT_OPTIONS: - & Omit + & Omit< + ClientConfiguration, + "database" | "user" | "hostname" + > & { host: string; socket: string } = { applicationName: "deno_postgres", connection: { @@ -360,18 +352,13 @@ export function createParams( if (parsed_host.protocol === "file:") { host = fromFileUrl(parsed_host); } else { - throw new Error( - "The provided host is not a file path", - ); + throw new Error("The provided host is not a file path"); } } else { host = socket; } } catch (e) { - throw new ConnectionParamsError( - `Could not parse host "${socket}"`, - e, - ); + throw new ConnectionParamsError(`Could not parse host "${socket}"`, e); } } else { host = provided_host ?? DEFAULT_OPTIONS.host; @@ -414,7 +401,7 @@ export function createParams( if (host_type === "socket" && params?.tls) { throw new ConnectionParamsError( - `No TLS options are allowed when host type is set to "socket"`, + 'No TLS options are allowed when host type is set to "socket"', ); } const tls_enabled = !!(params?.tls?.enabled ?? DEFAULT_OPTIONS.tls.enabled); @@ -429,7 +416,8 @@ export function createParams( // TODO // Perhaps username should be taken from the PC user as a default? const connection_options = { - applicationName: params.applicationName ?? pgEnv.applicationName ?? + applicationName: params.applicationName ?? + pgEnv.applicationName ?? DEFAULT_OPTIONS.applicationName, connection: { attempts: params?.connection?.attempts ?? diff --git a/connection/message.ts b/connection/message.ts index edf40866..af032b07 100644 --- a/connection/message.ts +++ b/connection/message.ts @@ -34,9 +34,10 @@ export interface Notice { routine?: string; } -export function parseBackendKeyMessage( - message: Message, -): { pid: number; secret_key: number } { +export function parseBackendKeyMessage(message: Message): { + pid: number; + secret_key: number; +} { return { pid: message.reader.readInt32(), secret_key: message.reader.readInt32(), diff --git a/connection/message_code.ts b/connection/message_code.ts index 966a02ae..ede4ed09 100644 --- a/connection/message_code.ts +++ b/connection/message_code.ts @@ -33,13 +33,13 @@ export const INCOMING_TLS_MESSAGES = { export const INCOMING_QUERY_MESSAGES = { BIND_COMPLETE: "2", - PARSE_COMPLETE: "1", COMMAND_COMPLETE: "C", DATA_ROW: "D", EMPTY_QUERY: "I", - NO_DATA: "n", NOTICE_WARNING: "N", + NO_DATA: "n", PARAMETER_STATUS: "S", + PARSE_COMPLETE: "1", READY: "Z", ROW_DESCRIPTION: "T", } as const; diff --git a/connection/scram.ts b/connection/scram.ts index b197035c..96abeb80 100644 --- a/connection/scram.ts +++ b/connection/scram.ts @@ -128,9 +128,7 @@ async function deriveKeySignatures( /** Escapes "=" and "," in a string. */ function escape(str: string): string { - return str - .replace(/=/g, "=3D") - .replace(/,/g, "=2C"); + return str.replace(/=/g, "=3D").replace(/,/g, "=2C"); } function generateRandomNonce(size: number): string { @@ -228,6 +226,8 @@ export class Client { throw new Error(Reason.BadSalt); } + if (!salt) throw new Error(Reason.BadSalt); + const iterCount = parseInt(attrs.i) | 0; if (iterCount <= 0) { throw new Error(Reason.BadIterationCount); diff --git a/docker-compose.yml b/docker-compose.yml index 93c0f17a..be919039 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -63,7 +63,7 @@ services: build: . # Name the image to be reused in no_check_tests image: postgres/tests - command: sh -c "/wait && deno test --unstable -A --parallel --check" + command: sh -c "/wait && deno test -A --parallel --check" depends_on: - postgres_clear - postgres_md5 @@ -74,7 +74,7 @@ services: no_check_tests: image: postgres/tests - command: sh -c "/wait && deno test --unstable -A --parallel --no-check" + command: sh -c "/wait && deno test -A --parallel --no-check" depends_on: - tests environment: diff --git a/pool.ts b/pool.ts index 3488e799..934e4ff7 100644 --- a/pool.ts +++ b/pool.ts @@ -183,21 +183,18 @@ export class Pool { */ async #initialize() { const initialized = this.#lazy ? 0 : this.#size; - const clients = Array.from( - { length: this.#size }, - async (_e, index) => { - const client: PoolClient = new PoolClient( - this.#connection_params, - () => this.#available_connections!.push(client), - ); - - if (index < initialized) { - await client.connect(); - } - - return client; - }, - ); + const clients = Array.from({ length: this.#size }, async (_e, index) => { + const client: PoolClient = new PoolClient( + this.#connection_params, + () => this.#available_connections!.push(client), + ); + + if (index < initialized) { + await client.connect(); + } + + return client; + }); this.#available_connections = new DeferredAccessStack( await Promise.all(clients), @@ -206,7 +203,8 @@ export class Pool { ); this.#ended = false; - } /** + } + /** * This will return the number of initialized clients in the pool */ diff --git a/query/array_parser.ts b/query/array_parser.ts index 1db591d0..9fd043bd 100644 --- a/query/array_parser.ts +++ b/query/array_parser.ts @@ -34,13 +34,13 @@ class ArrayParser { const character = this.source[this.position++]; if (character === "\\") { return { - value: this.source[this.position++], escaped: true, + value: this.source[this.position++], }; } return { - value: character, escaped: false, + value: character, }; } diff --git a/query/decode.ts b/query/decode.ts index 8d61d34f..b09940d6 100644 --- a/query/decode.ts +++ b/query/decode.ts @@ -1,4 +1,5 @@ import { Oid } from "./oid.ts"; +import { bold, yellow } from "../deps.ts"; import { decodeBigint, decodeBigintArray, @@ -14,6 +15,8 @@ import { decodeDateArray, decodeDatetime, decodeDatetimeArray, + decodeFloat, + decodeFloatArray, decodeInt, decodeIntArray, decodeJson, @@ -58,138 +61,150 @@ function decodeBinary() { throw new Error("Not implemented!"); } -// deno-lint-ignore no-explicit-any -function decodeText(value: Uint8Array, typeOid: number): any { +function decodeText(value: Uint8Array, typeOid: number) { const strValue = decoder.decode(value); - switch (typeOid) { - case Oid.bpchar: - case Oid.char: - case Oid.cidr: - case Oid.float4: - case Oid.float8: - case Oid.inet: - case Oid.macaddr: - case Oid.name: - case Oid.numeric: - case Oid.oid: - case Oid.regclass: - case Oid.regconfig: - case Oid.regdictionary: - case Oid.regnamespace: - case Oid.regoper: - case Oid.regoperator: - case Oid.regproc: - case Oid.regprocedure: - case Oid.regrole: - case Oid.regtype: - case Oid.text: - case Oid.time: - case Oid.timetz: - case Oid.uuid: - case Oid.varchar: - case Oid.void: - return strValue; - case Oid.bpchar_array: - case Oid.char_array: - case Oid.cidr_array: - case Oid.float4_array: - case Oid.float8_array: - case Oid.inet_array: - case Oid.macaddr_array: - case Oid.name_array: - case Oid.numeric_array: - case Oid.oid_array: - case Oid.regclass_array: - case Oid.regconfig_array: - case Oid.regdictionary_array: - case Oid.regnamespace_array: - case Oid.regoper_array: - case Oid.regoperator_array: - case Oid.regproc_array: - case Oid.regprocedure_array: - case Oid.regrole_array: - case Oid.regtype_array: - case Oid.text_array: - case Oid.time_array: - case Oid.timetz_array: - case Oid.uuid_array: - case Oid.varchar_array: - return decodeStringArray(strValue); - case Oid.int2: - case Oid.int4: - case Oid.xid: - return decodeInt(strValue); - case Oid.int2_array: - case Oid.int4_array: - case Oid.xid_array: - return decodeIntArray(strValue); - case Oid.bool: - return decodeBoolean(strValue); - case Oid.bool_array: - return decodeBooleanArray(strValue); - case Oid.box: - return decodeBox(strValue); - case Oid.box_array: - return decodeBoxArray(strValue); - case Oid.circle: - return decodeCircle(strValue); - case Oid.circle_array: - return decodeCircleArray(strValue); - case Oid.bytea: - return decodeBytea(strValue); - case Oid.byte_array: - return decodeByteaArray(strValue); - case Oid.date: - return decodeDate(strValue); - case Oid.date_array: - return decodeDateArray(strValue); - case Oid.int8: - return decodeBigint(strValue); - case Oid.int8_array: - return decodeBigintArray(strValue); - case Oid.json: - case Oid.jsonb: - return decodeJson(strValue); - case Oid.json_array: - case Oid.jsonb_array: - return decodeJsonArray(strValue); - case Oid.line: - return decodeLine(strValue); - case Oid.line_array: - return decodeLineArray(strValue); - case Oid.lseg: - return decodeLineSegment(strValue); - case Oid.lseg_array: - return decodeLineSegmentArray(strValue); - case Oid.path: - return decodePath(strValue); - case Oid.path_array: - return decodePathArray(strValue); - case Oid.point: - return decodePoint(strValue); - case Oid.point_array: - return decodePointArray(strValue); - case Oid.polygon: - return decodePolygon(strValue); - case Oid.polygon_array: - return decodePolygonArray(strValue); - case Oid.tid: - return decodeTid(strValue); - case Oid.tid_array: - return decodeTidArray(strValue); - case Oid.timestamp: - case Oid.timestamptz: - return decodeDatetime(strValue); - case Oid.timestamp_array: - case Oid.timestamptz_array: - return decodeDatetimeArray(strValue); - default: - // A separate category for not handled values - // They might or might not be represented correctly as strings, - // returning them to the user as raw strings allows them to parse - // them as they see fit - return strValue; + try { + switch (typeOid) { + case Oid.bpchar: + case Oid.char: + case Oid.cidr: + case Oid.float8: + case Oid.inet: + case Oid.macaddr: + case Oid.name: + case Oid.numeric: + case Oid.oid: + case Oid.regclass: + case Oid.regconfig: + case Oid.regdictionary: + case Oid.regnamespace: + case Oid.regoper: + case Oid.regoperator: + case Oid.regproc: + case Oid.regprocedure: + case Oid.regrole: + case Oid.regtype: + case Oid.text: + case Oid.time: + case Oid.timetz: + case Oid.uuid: + case Oid.varchar: + case Oid.void: + return strValue; + case Oid.bpchar_array: + case Oid.char_array: + case Oid.cidr_array: + case Oid.float8_array: + case Oid.inet_array: + case Oid.macaddr_array: + case Oid.name_array: + case Oid.numeric_array: + case Oid.oid_array: + case Oid.regclass_array: + case Oid.regconfig_array: + case Oid.regdictionary_array: + case Oid.regnamespace_array: + case Oid.regoper_array: + case Oid.regoperator_array: + case Oid.regproc_array: + case Oid.regprocedure_array: + case Oid.regrole_array: + case Oid.regtype_array: + case Oid.text_array: + case Oid.time_array: + case Oid.timetz_array: + case Oid.uuid_array: + case Oid.varchar_array: + return decodeStringArray(strValue); + case Oid.float4: + return decodeFloat(strValue); + case Oid.float4_array: + return decodeFloatArray(strValue); + case Oid.int2: + case Oid.int4: + case Oid.xid: + return decodeInt(strValue); + case Oid.int2_array: + case Oid.int4_array: + case Oid.xid_array: + return decodeIntArray(strValue); + case Oid.bool: + return decodeBoolean(strValue); + case Oid.bool_array: + return decodeBooleanArray(strValue); + case Oid.box: + return decodeBox(strValue); + case Oid.box_array: + return decodeBoxArray(strValue); + case Oid.circle: + return decodeCircle(strValue); + case Oid.circle_array: + return decodeCircleArray(strValue); + case Oid.bytea: + return decodeBytea(strValue); + case Oid.byte_array: + return decodeByteaArray(strValue); + case Oid.date: + return decodeDate(strValue); + case Oid.date_array: + return decodeDateArray(strValue); + case Oid.int8: + return decodeBigint(strValue); + case Oid.int8_array: + return decodeBigintArray(strValue); + case Oid.json: + case Oid.jsonb: + return decodeJson(strValue); + case Oid.json_array: + case Oid.jsonb_array: + return decodeJsonArray(strValue); + case Oid.line: + return decodeLine(strValue); + case Oid.line_array: + return decodeLineArray(strValue); + case Oid.lseg: + return decodeLineSegment(strValue); + case Oid.lseg_array: + return decodeLineSegmentArray(strValue); + case Oid.path: + return decodePath(strValue); + case Oid.path_array: + return decodePathArray(strValue); + case Oid.point: + return decodePoint(strValue); + case Oid.point_array: + return decodePointArray(strValue); + case Oid.polygon: + return decodePolygon(strValue); + case Oid.polygon_array: + return decodePolygonArray(strValue); + case Oid.tid: + return decodeTid(strValue); + case Oid.tid_array: + return decodeTidArray(strValue); + case Oid.timestamp: + case Oid.timestamptz: + return decodeDatetime(strValue); + case Oid.timestamp_array: + case Oid.timestamptz_array: + return decodeDatetimeArray(strValue); + default: + // A separate category for not handled values + // They might or might not be represented correctly as strings, + // returning them to the user as raw strings allows them to parse + // them as they see fit + return strValue; + } + } catch (_e) { + console.error( + bold(yellow(`Error decoding type Oid ${typeOid} value`)) + + _e.message + + "\n" + + bold("Defaulting to null."), + ); + // If an error occurred during decoding, return null + return null; } } diff --git a/query/decoders.ts b/query/decoders.ts index c5435836..6553d66b 100644 --- a/query/decoders.ts +++ b/query/decoders.ts @@ -28,24 +28,44 @@ export function decodeBigint(value: string): bigint { } export function decodeBigintArray(value: string) { - return parseArray(value, (x) => BigInt(x)); + return parseArray(value, decodeBigint); } export function decodeBoolean(value: string): boolean { - return value[0] === "t"; + const v = value.toLowerCase(); + return ( + v === "t" || + v === "true" || + v === "y" || + v === "yes" || + v === "on" || + v === "1" + ); } export function decodeBooleanArray(value: string) { - return parseArray(value, (x) => x[0] === "t"); + return parseArray(value, decodeBoolean); } export function decodeBox(value: string): Box { - const [a, b] = value.match(/\(.*?\)/g) || []; + const points = value.match(/\(.*?\)/g) || []; - return { - a: decodePoint(a || ""), - b: decodePoint(b), - }; + if (points.length !== 2) { + throw new Error( + `Invalid Box: "${value}". Box must have only 2 point, ${points.length} given.`, + ); + } + + const [a, b] = points; + + try { + return { + a: decodePoint(a), + b: decodePoint(b), + }; + } catch (e) { + throw new Error(`Invalid Box: "${value}" : ${e.message}`); + } } export function decodeBoxArray(value: string) { @@ -60,7 +80,7 @@ export function decodeBytea(byteaStr: string): Uint8Array { } } -export function decodeByteaArray(value: string): unknown[] { +export function decodeByteaArray(value: string) { return parseArray(value, decodeBytea); } @@ -104,14 +124,24 @@ function decodeByteaHex(byteaStr: string): Uint8Array { } export function decodeCircle(value: string): Circle { - const [point, radius] = value.substring(1, value.length - 1).split( - /,(?![^(]*\))/, - ) as [string, Float8]; + const [point, radius] = value + .substring(1, value.length - 1) + .split(/,(?![^(]*\))/) as [string, Float8]; - return { - point: decodePoint(point), - radius: radius, - }; + if (Number.isNaN(parseFloat(radius))) { + throw new Error( + `Invalid Circle: "${value}". Circle radius "${radius}" must be a valid number.`, + ); + } + + try { + return { + point: decodePoint(point), + radius: radius, + }; + } catch (e) { + throw new Error(`Invalid Circle: "${value}" : ${e.message}`); + } } export function decodeCircleArray(value: string) { @@ -186,12 +216,18 @@ export function decodeInt(value: string): number { return parseInt(value, 10); } -// deno-lint-ignore no-explicit-any -export function decodeIntArray(value: string): any { - if (!value) return null; +export function decodeIntArray(value: string) { return parseArray(value, decodeInt); } +export function decodeFloat(value: string): number { + return parseFloat(value); +} + +export function decodeFloatArray(value: string) { + return parseArray(value, decodeFloat); +} + export function decodeJson(value: string): unknown { return JSON.parse(value); } @@ -201,12 +237,28 @@ export function decodeJsonArray(value: string): unknown[] { } export function decodeLine(value: string): Line { - const [a, b, c] = value.substring(1, value.length - 1).split(",") as [ + const equationConsts = value.substring(1, value.length - 1).split(",") as [ Float8, Float8, Float8, ]; + if (equationConsts.length !== 3) { + throw new Error( + `Invalid Line: "${value}". Line in linear equation format must have 3 constants, ${equationConsts.length} given.`, + ); + } + + equationConsts.forEach((c) => { + if (Number.isNaN(parseFloat(c))) { + throw new Error( + `Invalid Line: "${value}". Line constant "${c}" must be a valid number.`, + ); + } + }); + + const [a, b, c] = equationConsts; + return { a: a, b: b, @@ -219,14 +271,24 @@ export function decodeLineArray(value: string) { } export function decodeLineSegment(value: string): LineSegment { - const [a, b] = value - .substring(1, value.length - 1) - .match(/\(.*?\)/g) || []; + const points = value.substring(1, value.length - 1).match(/\(.*?\)/g) || []; - return { - a: decodePoint(a || ""), - b: decodePoint(b), - }; + if (points.length !== 2) { + throw new Error( + `Invalid Line Segment: "${value}". Line segments must have only 2 point, ${points.length} given.`, + ); + } + + const [a, b] = points; + + try { + return { + a: decodePoint(a), + b: decodePoint(b), + }; + } catch (e) { + throw new Error(`Invalid Line Segment: "${value}" : ${e.message}`); + } } export function decodeLineSegmentArray(value: string) { @@ -238,7 +300,13 @@ export function decodePath(value: string): Path { // since encapsulated commas are separators for the point coordinates const points = value.substring(1, value.length - 1).split(/,(?![^(]*\))/); - return points.map(decodePoint); + return points.map((point) => { + try { + return decodePoint(point); + } catch (e) { + throw new Error(`Invalid Path: "${value}" : ${e.message}`); + } + }); } export function decodePathArray(value: string) { @@ -246,14 +314,23 @@ export function decodePathArray(value: string) { } export function decodePoint(value: string): Point { - const [x, y] = value.substring(1, value.length - 1).split(",") as [ - Float8, - Float8, - ]; + const coordinates = value + .substring(1, value.length - 1) + .split(",") as Float8[]; + + if (coordinates.length !== 2) { + throw new Error( + `Invalid Point: "${value}". Points must have only 2 coordinates, ${coordinates.length} given.`, + ); + } + + const [x, y] = coordinates; if (Number.isNaN(parseFloat(x)) || Number.isNaN(parseFloat(y))) { throw new Error( - `Invalid point value: "${Number.isNaN(parseFloat(x)) ? x : y}"`, + `Invalid Point: "${value}". Coordinate "${ + Number.isNaN(parseFloat(x)) ? x : y + }" must be a valid number.`, ); } @@ -268,7 +345,11 @@ export function decodePointArray(value: string) { } export function decodePolygon(value: string): Polygon { - return decodePath(value); + try { + return decodePath(value); + } catch (e) { + throw new Error(`Invalid Polygon: "${value}" : ${e.message}`); + } } export function decodePolygonArray(value: string) { diff --git a/query/query.ts b/query/query.ts index e58aa85a..42862e68 100644 --- a/query/query.ts +++ b/query/query.ts @@ -43,7 +43,10 @@ export enum ResultType { } export class RowDescription { - constructor(public columnCount: number, public columns: Column[]) {} + constructor( + public columnCount: number, + public columns: Column[], + ) {} } /** @@ -95,9 +98,7 @@ function normalizeObjectQueryArgs( args: Record, ): Record { const normalized_args = Object.fromEntries( - Object.entries(args).map(( - [key, value], - ) => [key.toLowerCase(), value]), + Object.entries(args).map(([key, value]) => [key.toLowerCase(), value]), ); if (Object.keys(normalized_args).length !== Object.keys(args).length) { @@ -197,8 +198,9 @@ export class QueryResult { } } -export class QueryArrayResult = Array> - extends QueryResult { +export class QueryArrayResult< + T extends Array = Array, +> extends QueryResult { public rows: T[] = []; insertRow(row_data: Uint8Array[]) { @@ -234,19 +236,14 @@ function findDuplicatesInArray(array: string[]): string[] { } function snakecaseToCamelcase(input: string) { - return input - .split("_") - .reduce( - (res, word, i) => { - if (i !== 0) { - word = word[0].toUpperCase() + word.slice(1); - } + return input.split("_").reduce((res, word, i) => { + if (i !== 0) { + word = word[0].toUpperCase() + word.slice(1); + } - res += word; - return res; - }, - "", - ); + res += word; + return res; + }, ""); } export class QueryObjectResult< @@ -283,8 +280,8 @@ export class QueryObjectResult< snakecaseToCamelcase(column.name) ); } else { - column_names = this.rowDescription.columns.map((column) => - column.name + column_names = this.rowDescription.columns.map( + (column) => column.name, ); } @@ -293,7 +290,9 @@ export class QueryObjectResult< if (duplicates.length) { throw new Error( `Field names ${ - duplicates.map((str) => `"${str}"`).join(", ") + duplicates + .map((str) => `"${str}"`) + .join(", ") } are duplicated in the result of the query`, ); } @@ -360,15 +359,8 @@ export class Query { this.text = config_or_text; this.args = args.map(encodeArgument); } else { - let { - args = [], - camelcase, - encoder = encodeArgument, - fields, - // deno-lint-ignore no-unused-vars - name, - text, - } = config_or_text; + const { camelcase, encoder = encodeArgument, fields } = config_or_text; + let { args = [], text } = config_or_text; // Check that the fields passed are valid and can be used to map // the result of the query diff --git a/query/transaction.ts b/query/transaction.ts index a5088cfd..3f8dfe92 100644 --- a/query/transaction.ts +++ b/query/transaction.ts @@ -156,7 +156,7 @@ export class Transaction { #assertTransactionOpen() { if (this.#client.session.current_transaction !== this.name) { throw new Error( - `This transaction has not been started yet, make sure to use the "begin" method to do so`, + 'This transaction has not been started yet, make sure to use the "begin" method to do so', ); } } @@ -183,9 +183,7 @@ export class Transaction { async begin() { if (this.#client.session.current_transaction !== null) { if (this.#client.session.current_transaction === this.name) { - throw new Error( - "This transaction is already open", - ); + throw new Error("This transaction is already open"); } throw new Error( @@ -338,9 +336,9 @@ export class Transaction { async getSnapshot(): Promise { this.#assertTransactionOpen(); - const { rows } = await this.queryObject< - { snapshot: string } - >`SELECT PG_EXPORT_SNAPSHOT() AS SNAPSHOT;`; + const { rows } = await this.queryObject<{ + snapshot: string; + }>`SELECT PG_EXPORT_SNAPSHOT() AS SNAPSHOT;`; return rows[0].snapshot; } @@ -419,7 +417,7 @@ export class Transaction { } try { - return await this.#executeQuery(query) as QueryArrayResult; + return (await this.#executeQuery(query)) as QueryArrayResult; } catch (e) { if (e instanceof PostgresError) { await this.commit(); @@ -504,9 +502,7 @@ export class Transaction { query: TemplateStringsArray, ...args: unknown[] ): Promise>; - async queryObject< - T = Record, - >( + async queryObject>( query_template_or_config: | string | QueryObjectOptions @@ -536,7 +532,7 @@ export class Transaction { } try { - return await this.#executeQuery(query) as QueryObjectResult; + return (await this.#executeQuery(query)) as QueryObjectResult; } catch (e) { if (e instanceof PostgresError) { await this.commit(); @@ -614,9 +610,13 @@ export class Transaction { async rollback(options?: { savepoint?: string | Savepoint }): Promise; async rollback(options?: { chain?: boolean }): Promise; async rollback( - savepoint_or_options?: string | Savepoint | { - savepoint?: string | Savepoint; - } | { chain?: boolean }, + savepoint_or_options?: + | string + | Savepoint + | { + savepoint?: string | Savepoint; + } + | { chain?: boolean }, ): Promise { this.#assertTransactionOpen(); @@ -627,8 +627,9 @@ export class Transaction { ) { savepoint_option = savepoint_or_options; } else { - savepoint_option = - (savepoint_or_options as { savepoint?: string | Savepoint })?.savepoint; + savepoint_option = ( + savepoint_or_options as { savepoint?: string | Savepoint } + )?.savepoint; } let savepoint_name: string | undefined; @@ -652,8 +653,8 @@ export class Transaction { // If a savepoint is provided, rollback to that savepoint, continue the transaction if (typeof savepoint_option !== "undefined") { - const ts_savepoint = this.#savepoints.find(({ name }) => - name === savepoint_name + const ts_savepoint = this.#savepoints.find( + ({ name }) => name === savepoint_name, ); if (!ts_savepoint) { throw new Error( diff --git a/tests/README.md b/tests/README.md index 4cd45602..c17f1a58 100644 --- a/tests/README.md +++ b/tests/README.md @@ -9,8 +9,13 @@ need to modify the configuration. From within the project directory, run: -``` +```sh +# run on host deno test --allow-read --allow-net --allow-env + +# run in docker container +docker-compose build --no-cache +docker-compose run tests ``` ## Docker Configuration diff --git a/tests/data_types_test.ts b/tests/data_types_test.ts index d2741f3c..6c9fab8f 100644 --- a/tests/data_types_test.ts +++ b/tests/data_types_test.ts @@ -4,7 +4,7 @@ import { generateSimpleClientTest } from "./helpers.ts"; import type { Box, Circle, - Float4, + // Float4, Float8, Line, LineSegment, @@ -856,22 +856,22 @@ Deno.test( Deno.test( "float4", testClient(async (client) => { - const result = await client.queryArray<[Float4, Float4]>( + const result = await client.queryArray<[number, number]>( "SELECT '1'::FLOAT4, '17.89'::FLOAT4", ); - assertEquals(result.rows[0], ["1", "17.89"]); + assertEquals(result.rows[0], [1, 17.89]); }), ); Deno.test( "float4 array", testClient(async (client) => { - const result = await client.queryArray<[[Float4, Float4]]>( + const result = await client.queryArray<[[number, number]]>( "SELECT ARRAY['12.25'::FLOAT4, '4789']", ); - assertEquals(result.rows[0][0], ["12.25", "4789"]); + assertEquals(result.rows[0][0], [12.25, 4789]); }), ); diff --git a/tests/decode_test.ts b/tests/decode_test.ts new file mode 100644 index 00000000..000cbab4 --- /dev/null +++ b/tests/decode_test.ts @@ -0,0 +1,250 @@ +import { + decodeBigint, + decodeBigintArray, + decodeBoolean, + decodeBooleanArray, + decodeBox, + decodeCircle, + decodeDate, + decodeDatetime, + decodeFloat, + decodeInt, + decodeJson, + decodeLine, + decodeLineSegment, + decodePath, + decodePoint, + decodeTid, +} from "../query/decoders.ts"; +import { assertEquals, assertThrows } from "./test_deps.ts"; + +Deno.test("decodeBigint", function () { + assertEquals(decodeBigint("18014398509481984"), 18014398509481984n); +}); + +Deno.test("decodeBigintArray", function () { + assertEquals( + decodeBigintArray( + "{17365398509481972,9007199254740992,-10414398509481984}", + ), + [17365398509481972n, 9007199254740992n, -10414398509481984n], + ); +}); + +Deno.test("decodeBoolean", function () { + assertEquals(decodeBoolean("True"), true); + assertEquals(decodeBoolean("yEs"), true); + assertEquals(decodeBoolean("T"), true); + assertEquals(decodeBoolean("t"), true); + assertEquals(decodeBoolean("YeS"), true); + assertEquals(decodeBoolean("On"), true); + assertEquals(decodeBoolean("1"), true); + assertEquals(decodeBoolean("no"), false); + assertEquals(decodeBoolean("off"), false); + assertEquals(decodeBoolean("0"), false); + assertEquals(decodeBoolean("F"), false); + assertEquals(decodeBoolean("false"), false); + assertEquals(decodeBoolean("n"), false); + assertEquals(decodeBoolean(""), false); +}); + +Deno.test("decodeBooleanArray", function () { + assertEquals(decodeBooleanArray("{True,0,T}"), [true, false, true]); + assertEquals(decodeBooleanArray("{no,Y,1}"), [false, true, true]); +}); + +Deno.test("decodeBox", function () { + assertEquals(decodeBox("(12.4,2),(33,4.33)"), { + a: { x: "12.4", y: "2" }, + b: { x: "33", y: "4.33" }, + }); + let testValue = "(12.4,2)"; + assertThrows( + () => decodeBox(testValue), + Error, + `Invalid Box: "${testValue}". Box must have only 2 point, 1 given.`, + ); + testValue = "(12.4,2),(123,123,123),(9303,33)"; + assertThrows( + () => decodeBox(testValue), + Error, + `Invalid Box: "${testValue}". Box must have only 2 point, 3 given.`, + ); + testValue = "(0,0),(123,123,123)"; + assertThrows( + () => decodeBox(testValue), + Error, + `Invalid Box: "${testValue}" : Invalid Point: "(123,123,123)". Points must have only 2 coordinates, 3 given.`, + ); + testValue = "(0,0),(100,r100)"; + assertThrows( + () => decodeBox(testValue), + Error, + `Invalid Box: "${testValue}" : Invalid Point: "(100,r100)". Coordinate "r100" must be a valid number.`, + ); +}); + +Deno.test("decodeCircle", function () { + assertEquals(decodeCircle("<(12.4,2),3.5>"), { + point: { x: "12.4", y: "2" }, + radius: "3.5", + }); + let testValue = "<(c21 23,2),3.5>"; + assertThrows( + () => decodeCircle(testValue), + Error, + `Invalid Circle: "${testValue}" : Invalid Point: "(c21 23,2)". Coordinate "c21 23" must be a valid number.`, + ); + testValue = "<(33,2),mn23 3.5>"; + assertThrows( + () => decodeCircle(testValue), + Error, + `Invalid Circle: "${testValue}". Circle radius "mn23 3.5" must be a valid number.`, + ); +}); + +Deno.test("decodeDate", function () { + assertEquals(decodeDate("2021-08-01"), new Date("2021-08-01 00:00:00-00")); +}); + +Deno.test("decodeDatetime", function () { + assertEquals( + decodeDatetime("2021-08-01"), + new Date("2021-08-01 00:00:00-00"), + ); + assertEquals( + decodeDatetime("1997-12-17 07:37:16-08"), + new Date("1997-12-17 07:37:16-08"), + ); +}); + +Deno.test("decodeFloat", function () { + assertEquals(decodeFloat("3.14"), 3.14); + assertEquals(decodeFloat("q743 44 23i4"), NaN); +}); + +Deno.test("decodeInt", function () { + assertEquals(decodeInt("42"), 42); + assertEquals(decodeInt("q743 44 23i4"), NaN); +}); + +Deno.test("decodeJson", function () { + assertEquals( + decodeJson( + '{"key_1": "MY VALUE", "key_2": null, "key_3": 10, "key_4": {"subkey_1": true, "subkey_2": ["1",2]}}', + ), + { + key_1: "MY VALUE", + key_2: null, + key_3: 10, + key_4: { subkey_1: true, subkey_2: ["1", 2] }, + }, + ); + assertThrows(() => decodeJson("{ 'eqw' ; ddd}")); +}); + +Deno.test("decodeLine", function () { + assertEquals(decodeLine("{100,50,0}"), { a: "100", b: "50", c: "0" }); + let testValue = "{100,50,0,100}"; + assertThrows( + () => decodeLine("{100,50,0,100}"), + Error, + `Invalid Line: "${testValue}". Line in linear equation format must have 3 constants, 4 given.`, + ); + testValue = "{100,d3km,0}"; + assertThrows( + () => decodeLine(testValue), + Error, + `Invalid Line: "${testValue}". Line constant "d3km" must be a valid number.`, + ); +}); + +Deno.test("decodeLineSegment", function () { + assertEquals(decodeLineSegment("((100,50),(350,350))"), { + a: { x: "100", y: "50" }, + b: { x: "350", y: "350" }, + }); + let testValue = "((100,50),(r344,350))"; + assertThrows( + () => decodeLineSegment(testValue), + Error, + `Invalid Line Segment: "${testValue}" : Invalid Point: "(r344,350)". Coordinate "r344" must be a valid number.`, + ); + testValue = "((100),(r344,350))"; + assertThrows( + () => decodeLineSegment(testValue), + Error, + `Invalid Line Segment: "${testValue}" : Invalid Point: "(100)". Points must have only 2 coordinates, 1 given.`, + ); + testValue = "((100,50))"; + assertThrows( + () => decodeLineSegment(testValue), + Error, + `Invalid Line Segment: "${testValue}". Line segments must have only 2 point, 1 given.`, + ); + testValue = "((100,50),(350,350),(100,100))"; + assertThrows( + () => decodeLineSegment(testValue), + Error, + `Invalid Line Segment: "${testValue}". Line segments must have only 2 point, 3 given.`, + ); +}); + +Deno.test("decodePath", function () { + assertEquals(decodePath("[(100,50),(350,350)]"), [ + { x: "100", y: "50" }, + { x: "350", y: "350" }, + ]); + assertEquals(decodePath("[(1,10),(2,20),(3,30)]"), [ + { x: "1", y: "10" }, + { x: "2", y: "20" }, + { x: "3", y: "30" }, + ]); + let testValue = "((100,50),(350,kjf334))"; + assertThrows( + () => decodePath(testValue), + Error, + `Invalid Path: "${testValue}" : Invalid Point: "(350,kjf334)". Coordinate "kjf334" must be a valid number.`, + ); + testValue = "((100,50,9949))"; + assertThrows( + () => decodePath(testValue), + Error, + `Invalid Path: "${testValue}" : Invalid Point: "(100,50,9949)". Points must have only 2 coordinates, 3 given.`, + ); +}); + +Deno.test("decodePoint", function () { + assertEquals(decodePoint("(10.555,50.8)"), { x: "10.555", y: "50.8" }); + let testValue = "(1000)"; + assertThrows( + () => decodePoint(testValue), + Error, + `Invalid Point: "${testValue}". Points must have only 2 coordinates, 1 given.`, + ); + testValue = "(100.100,50,350)"; + assertThrows( + () => decodePoint(testValue), + Error, + `Invalid Point: "${testValue}". Points must have only 2 coordinates, 3 given.`, + ); + testValue = "(1,r344)"; + assertThrows( + () => decodePoint(testValue), + Error, + `Invalid Point: "${testValue}". Coordinate "r344" must be a valid number.`, + ); + testValue = "(cd 213ee,100)"; + assertThrows( + () => decodePoint(testValue), + Error, + `Invalid Point: "${testValue}". Coordinate "cd 213ee" must be a valid number.`, + ); +}); + +Deno.test("decodeTid", function () { + assertEquals(decodeTid("(19714398509481984,29383838509481984)"), [ + 19714398509481984n, + 29383838509481984n, + ]); +}); diff --git a/utils/deferred.ts b/utils/deferred.ts index e6378c50..7ff60702 100644 --- a/utils/deferred.ts +++ b/utils/deferred.ts @@ -7,11 +7,7 @@ export class DeferredStack { #queue: Array>; #size: number; - constructor( - max?: number, - ls?: Iterable, - creator?: () => Promise, - ) { + constructor(max?: number, ls?: Iterable, creator?: () => Promise) { this.#elements = ls ? [...ls] : []; this.#creator = creator; this.#max_size = max || 10; @@ -100,9 +96,7 @@ export class DeferredAccessStack { this.#elements.map((e) => this.#checkElementInitialization(e)), ); - return initialized - .filter((initialized) => initialized === true) - .length; + return initialized.filter((initialized) => initialized === true).length; } async pop(): Promise { @@ -117,7 +111,7 @@ export class DeferredAccessStack { element = await d; } - if (!await this.#checkElementInitialization(element)) { + if (!(await this.#checkElementInitialization(element))) { await this.#initializeElement(element); } return element; diff --git a/utils/utils.ts b/utils/utils.ts index 3add6096..ae7ccee8 100644 --- a/utils/utils.ts +++ b/utils/utils.ts @@ -43,6 +43,20 @@ export interface Uri { user: string; } +type ConnectionInfo = { + driver?: string; + user?: string; + password?: string; + full_host?: string; + path?: string; + params?: string; +}; + +type ParsedHost = { + host?: string; + port?: string; +}; + /** * This function parses valid connection strings according to https://www.postgresql.org/docs/14/libpq-connect.html#LIBPQ-CONNSTRING * @@ -53,6 +67,7 @@ export function parseConnectionUri(uri: string): Uri { /(?\w+):\/{2}((?[^\/?#\s:]+?)?(:(?[^\/?#\s]+)?)?@)?(?[^\/?#\s]+)?(\/(?[^?#\s]*))?(\?(?[^#\s]+))?.*/, ); if (!parsed_uri) throw new Error("Could not parse the provided URL"); + let { driver = "", full_host = "", @@ -60,26 +75,17 @@ export function parseConnectionUri(uri: string): Uri { password = "", path = "", user = "", - }: { - driver?: string; - user?: string; - password?: string; - full_host?: string; - path?: string; - params?: string; - } = parsed_uri.groups ?? {}; + }: ConnectionInfo = parsed_uri.groups ?? {}; const parsed_host = full_host.match( /(?(\[.+\])|(.*?))(:(?[\w]*))?$/, ); if (!parsed_host) throw new Error(`Could not parse "${full_host}" host`); + let { host = "", port = "", - }: { - host?: string; - port?: string; - } = parsed_host.groups ?? {}; + }: ParsedHost = parsed_host.groups ?? {}; try { if (host) { @@ -87,9 +93,7 @@ export function parseConnectionUri(uri: string): Uri { } } catch (_e) { console.error( - bold( - yellow("Failed to decode URL host") + "\nDefaulting to raw host", - ), + bold(yellow("Failed to decode URL host") + "\nDefaulting to raw host"), ); }