Skip to content

Commit 2281caa

Browse files
roblabatColin McDonnell
andauthored
#1171 support for refine, superRefine, transform and lazy in discriminatedUnion (#1290)
* #1171 * fix tests * add superRefine in tests * add support for lazy * fix typings * fixe typings for asserted lazy * fix * clean console.log from debug * Clean up discriminatedUnion * Fix deno test Co-authored-by: Colin McDonnell <[email protected]>
1 parent 22ac512 commit 2281caa

File tree

4 files changed

+154
-116
lines changed

4 files changed

+154
-116
lines changed

deno/lib/__tests__/discriminatedUnions.test.ts

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ test("valid - discriminator value of various primitive types", () => {
2525
z.object({ type: z.literal(null), val: z.literal(7) }),
2626
z.object({ type: z.literal("undefined"), val: z.literal(8) }),
2727
z.object({ type: z.literal(undefined), val: z.literal(9) }),
28+
z.object({ type: z.literal("transform"), val: z.literal(10) }),
29+
z.object({ type: z.literal("refine"), val: z.literal(11) }),
30+
z.object({ type: z.literal("superRefine"), val: z.literal(12) }),
2831
]);
2932

3033
expect(schema.parse({ type: "1", val: 1 })).toEqual({ type: "1", val: 1 });
@@ -126,9 +129,7 @@ test("wrong schema - missing discriminator", () => {
126129
]);
127130
throw new Error();
128131
} catch (e: any) {
129-
expect(e.message).toEqual(
130-
"The discriminator value could not be extracted from all the provided schemas"
131-
);
132+
expect(e.message.includes("could not be extracted")).toBe(true);
132133
}
133134
});
134135

@@ -140,9 +141,7 @@ test("wrong schema - duplicate discriminator values", () => {
140141
]);
141142
throw new Error();
142143
} catch (e: any) {
143-
expect(e.message).toEqual(
144-
"Some of the discriminator values are not unique"
145-
);
144+
expect(e.message.includes("has duplicate value")).toEqual(true);
146145
}
147146
});
148147

deno/lib/types.ts

Lines changed: 72 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2297,33 +2297,46 @@ export class ZodUnion<T extends ZodUnionOptions> extends ZodType<
22972297
/////////////////////////////////////////////////////
22982298
/////////////////////////////////////////////////////
22992299

2300-
export type ZodDiscriminatedUnionOption<
2301-
Discriminator extends string,
2302-
DiscriminatorValue extends Primitive
2303-
> = ZodObject<
2304-
{ [key in Discriminator]: ZodLiteral<DiscriminatorValue> } & ZodRawShape,
2305-
any,
2306-
any
2307-
>;
2300+
const getDiscriminator = <T extends ZodTypeAny>(
2301+
type: T
2302+
): Primitive[] | null => {
2303+
if (type instanceof ZodLazy) {
2304+
return getDiscriminator(type.schema);
2305+
} else if (type instanceof ZodEffects) {
2306+
return getDiscriminator(type.innerType());
2307+
} else if (type instanceof ZodLiteral) {
2308+
return [type.value];
2309+
} else if (type instanceof ZodEnum) {
2310+
return type.options;
2311+
} else if (type instanceof ZodUndefined) {
2312+
return [undefined];
2313+
} else if (type instanceof ZodNull) {
2314+
return [null];
2315+
} else {
2316+
return null;
2317+
}
2318+
};
2319+
2320+
export type ZodDiscriminatedUnionOption<Discriminator extends string> =
2321+
ZodObject<{ [key in Discriminator]: ZodTypeAny } & ZodRawShape, any, any>;
23082322

23092323
export interface ZodDiscriminatedUnionDef<
23102324
Discriminator extends string,
2311-
DiscriminatorValue extends Primitive,
2312-
Option extends ZodDiscriminatedUnionOption<Discriminator, DiscriminatorValue>
2325+
Options extends ZodDiscriminatedUnionOption<any>[]
23132326
> extends ZodTypeDef {
23142327
discriminator: Discriminator;
2315-
options: Map<DiscriminatorValue, Option>;
2328+
options: Options;
2329+
optionsMap: Map<Primitive, ZodDiscriminatedUnionOption<any>>;
23162330
typeName: ZodFirstPartyTypeKind.ZodDiscriminatedUnion;
23172331
}
23182332

23192333
export class ZodDiscriminatedUnion<
23202334
Discriminator extends string,
2321-
DiscriminatorValue extends Primitive,
2322-
Option extends ZodDiscriminatedUnionOption<Discriminator, DiscriminatorValue>
2335+
Options extends ZodDiscriminatedUnionOption<Discriminator>[]
23232336
> extends ZodType<
2324-
Option["_output"],
2325-
ZodDiscriminatedUnionDef<Discriminator, DiscriminatorValue, Option>,
2326-
Option["_input"]
2337+
output<Options[number]>,
2338+
ZodDiscriminatedUnionDef<Discriminator, Options>,
2339+
input<Options[number]>
23272340
> {
23282341
_parse(input: ParseInput): ParseReturnType<this["_output"]> {
23292342
const { ctx } = this._processInputParams(input);
@@ -2338,13 +2351,13 @@ export class ZodDiscriminatedUnion<
23382351
}
23392352

23402353
const discriminator = this.discriminator;
2341-
const discriminatorValue: DiscriminatorValue = ctx.data[discriminator];
2342-
const option = this.options.get(discriminatorValue);
2354+
const discriminatorValue: string = ctx.data[discriminator];
2355+
const option = this.optionsMap.get(discriminatorValue);
23432356

23442357
if (!option) {
23452358
addIssueToContext(ctx, {
23462359
code: ZodIssueCode.invalid_union_discriminator,
2347-
options: this.validDiscriminatorValues,
2360+
options: Array.from(this.optionsMap.keys()),
23482361
path: [discriminator],
23492362
});
23502363
return INVALID;
@@ -2355,28 +2368,28 @@ export class ZodDiscriminatedUnion<
23552368
data: ctx.data,
23562369
path: ctx.path,
23572370
parent: ctx,
2358-
});
2371+
}) as any;
23592372
} else {
23602373
return option._parseSync({
23612374
data: ctx.data,
23622375
path: ctx.path,
23632376
parent: ctx,
2364-
});
2377+
}) as any;
23652378
}
23662379
}
23672380

23682381
get discriminator() {
23692382
return this._def.discriminator;
23702383
}
23712384

2372-
get validDiscriminatorValues() {
2373-
return Array.from(this.options.keys());
2374-
}
2375-
23762385
get options() {
23772386
return this._def.options;
23782387
}
23792388

2389+
get optionsMap() {
2390+
return this._def.optionsMap;
2391+
}
2392+
23802393
/**
23812394
* The constructor of the discriminated union schema. Its behaviour is very similar to that of the normal z.union() constructor.
23822395
* However, it only allows a union of objects, all of which need to share a discriminator property. This property must
@@ -2387,44 +2400,45 @@ export class ZodDiscriminatedUnion<
23872400
*/
23882401
static create<
23892402
Discriminator extends string,
2390-
DiscriminatorValue extends Primitive,
23912403
Types extends [
2392-
ZodDiscriminatedUnionOption<Discriminator, DiscriminatorValue>,
2393-
ZodDiscriminatedUnionOption<Discriminator, DiscriminatorValue>,
2394-
...ZodDiscriminatedUnionOption<Discriminator, DiscriminatorValue>[]
2404+
ZodDiscriminatedUnionOption<Discriminator>,
2405+
...ZodDiscriminatedUnionOption<Discriminator>[]
23952406
]
23962407
>(
23972408
discriminator: Discriminator,
2398-
types: Types,
2409+
options: Types,
23992410
params?: RawCreateParams
2400-
): ZodDiscriminatedUnion<Discriminator, DiscriminatorValue, Types[number]> {
2411+
): ZodDiscriminatedUnion<Discriminator, Types> {
24012412
// Get all the valid discriminator values
2402-
const options: Map<DiscriminatorValue, Types[number]> = new Map();
2403-
2404-
try {
2405-
types.forEach((type) => {
2406-
const discriminatorValue = type.shape[discriminator].value;
2407-
options.set(discriminatorValue, type);
2408-
});
2409-
} catch (e) {
2410-
throw new Error(
2411-
"The discriminator value could not be extracted from all the provided schemas"
2412-
);
2413-
}
2414-
2415-
// Assert that all the discriminator values are unique
2416-
if (options.size !== types.length) {
2417-
throw new Error("Some of the discriminator values are not unique");
2413+
const optionsMap: Map<Primitive, Types[number]> = new Map();
2414+
2415+
// try {
2416+
for (const type of options) {
2417+
const discriminatorValues = getDiscriminator(type.shape[discriminator]);
2418+
if (!discriminatorValues) {
2419+
throw new Error(
2420+
`A discriminator value for key \`${discriminator}\`could not be extracted from all schema options`
2421+
);
2422+
}
2423+
for (const value of discriminatorValues) {
2424+
if (optionsMap.has(value)) {
2425+
throw new Error(
2426+
`Discriminator property ${discriminator} has duplicate value ${value}`
2427+
);
2428+
}
2429+
optionsMap.set(value, type);
2430+
}
24182431
}
24192432

24202433
return new ZodDiscriminatedUnion<
24212434
Discriminator,
2422-
DiscriminatorValue,
2423-
Types[number]
2435+
// DiscriminatorValue,
2436+
Types
24242437
>({
24252438
typeName: ZodFirstPartyTypeKind.ZodDiscriminatedUnion,
24262439
discriminator,
24272440
options,
2441+
optionsMap,
24282442
...processCreateParams(params),
24292443
});
24302444
}
@@ -3570,13 +3584,19 @@ export interface ZodEffectsDef<T extends ZodTypeAny = ZodTypeAny>
35703584

35713585
export class ZodEffects<
35723586
T extends ZodTypeAny,
3573-
Output = T["_output"],
3574-
Input = T["_input"]
3587+
Output = output<T>,
3588+
Input = input<T>
35753589
> extends ZodType<Output, ZodEffectsDef<T>, Input> {
35763590
innerType() {
35773591
return this._def.schema;
35783592
}
35793593

3594+
sourceType(): T {
3595+
return this._def.schema._def.typeName === ZodFirstPartyTypeKind.ZodEffects
3596+
? (this._def.schema as unknown as ZodEffects<T>).sourceType()
3597+
: (this._def.schema as T);
3598+
}
3599+
35803600
_parse(input: ParseInput): ParseReturnType<this["_output"]> {
35813601
const { status, ctx } = this._processInputParams(input);
35823602

@@ -4161,7 +4181,7 @@ export type ZodFirstPartySchemaTypes =
41614181
| ZodArray<any, any>
41624182
| ZodObject<any, any, any, any, any>
41634183
| ZodUnion<any>
4164-
| ZodDiscriminatedUnion<any, any, any>
4184+
| ZodDiscriminatedUnion<any, any>
41654185
| ZodIntersection<any, any>
41664186
| ZodTuple<any, any>
41674187
| ZodRecord<any, any>

src/__tests__/discriminatedUnions.test.ts

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ test("valid - discriminator value of various primitive types", () => {
2424
z.object({ type: z.literal(null), val: z.literal(7) }),
2525
z.object({ type: z.literal("undefined"), val: z.literal(8) }),
2626
z.object({ type: z.literal(undefined), val: z.literal(9) }),
27+
z.object({ type: z.literal("transform"), val: z.literal(10) }),
28+
z.object({ type: z.literal("refine"), val: z.literal(11) }),
29+
z.object({ type: z.literal("superRefine"), val: z.literal(12) }),
2730
]);
2831

2932
expect(schema.parse({ type: "1", val: 1 })).toEqual({ type: "1", val: 1 });
@@ -125,9 +128,7 @@ test("wrong schema - missing discriminator", () => {
125128
]);
126129
throw new Error();
127130
} catch (e: any) {
128-
expect(e.message).toEqual(
129-
"The discriminator value could not be extracted from all the provided schemas"
130-
);
131+
expect(e.message.includes("could not be extracted")).toBe(true);
131132
}
132133
});
133134

@@ -139,9 +140,7 @@ test("wrong schema - duplicate discriminator values", () => {
139140
]);
140141
throw new Error();
141142
} catch (e: any) {
142-
expect(e.message).toEqual(
143-
"Some of the discriminator values are not unique"
144-
);
143+
expect(e.message.includes("has duplicate value")).toEqual(true);
145144
}
146145
});
147146

0 commit comments

Comments
 (0)