Skip to content

Commit 2bd6c08

Browse files
author
Colin McDonnell
committed
Clean up discriminatedUnion
1 parent 87a3caa commit 2bd6c08

File tree

4 files changed

+134
-350
lines changed

4 files changed

+134
-350
lines changed

deno/lib/__tests__/discriminatedUnions.test.ts

Lines changed: 5 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,6 @@ test("valid", () => {
1515
});
1616

1717
test("valid - discriminator value of various primitive types", () => {
18-
interface lazy {
19-
type: "lazy typed";
20-
val: 17;
21-
}
22-
2318
const schema = z.discriminatedUnion("type", [
2419
z.object({ type: z.literal("1"), val: z.literal(1) }),
2520
z.object({ type: z.literal(1), val: z.literal(2) }),
@@ -30,43 +25,9 @@ test("valid - discriminator value of various primitive types", () => {
3025
z.object({ type: z.literal(null), val: z.literal(7) }),
3126
z.object({ type: z.literal("undefined"), val: z.literal(8) }),
3227
z.object({ type: z.literal(undefined), val: z.literal(9) }),
33-
z
34-
.object({ type: z.literal("transform"), val: z.literal(10) })
35-
.transform((val) => ({
36-
val: val.val,
37-
})),
38-
z
39-
.object({ type: z.literal("refine"), val: z.literal(11) })
40-
.refine(() => true),
41-
z
42-
.object({ type: z.literal("superRefine"), val: z.literal(12) })
43-
.superRefine(() => {}),
44-
z.lazy(() => z.object({ type: z.literal("lazy"), val: z.literal(13) })),
45-
z.lazy(() =>
46-
z
47-
.object({ type: z.literal("chained 1"), val: z.literal(14) })
48-
.transform((val) => ({
49-
val: val.val,
50-
}))
51-
),
52-
z
53-
.lazy(() =>
54-
z.object({ type: z.literal("chained 2"), val: z.literal(15) })
55-
)
56-
.transform((val) => ({
57-
val: val.val,
58-
})),
59-
z
60-
.lazy(() =>
61-
z.object({ type: z.literal("chained 3"), val: z.literal(16) })
62-
)
63-
.transform((val) => ({
64-
val: val.val,
65-
}))
66-
.refine(() => true),
67-
z.lazy(() =>
68-
z.object({ type: z.literal("lazy typed"), val: z.literal(17) })
69-
) as z.ZodType<lazy>,
28+
z.object({ type: z.literal("transform"), val: z.literal(10) }),
29+
z.object({ type: z.literal("refine"), val: z.literal(11) }),
30+
z.object({ type: z.literal("superRefine"), val: z.literal(12) }),
7031
]);
7132

7233
expect(schema.parse({ type: "1", val: 1 })).toEqual({ type: "1", val: 1 });
@@ -99,34 +60,6 @@ test("valid - discriminator value of various primitive types", () => {
9960
type: undefined,
10061
val: 9,
10162
});
102-
expect(schema.parse({ type: "transform", val: 10 })).toEqual({
103-
val: 10,
104-
});
105-
expect(schema.parse({ type: "refine", val: 11 })).toEqual({
106-
type: "refine",
107-
val: 11,
108-
});
109-
expect(schema.parse({ type: "superRefine", val: 12 })).toEqual({
110-
type: "superRefine",
111-
val: 12,
112-
});
113-
expect(schema.parse({ type: "lazy", val: 13 })).toEqual({
114-
type: "lazy",
115-
val: 13,
116-
});
117-
expect(schema.parse({ type: "chained 1", val: 14 })).toEqual({
118-
val: 14,
119-
});
120-
expect(schema.parse({ type: "chained 2", val: 15 })).toEqual({
121-
val: 15,
122-
});
123-
expect(schema.parse({ type: "chained 3", val: 16 })).toEqual({
124-
val: 16,
125-
});
126-
expect(schema.parse({ type: "lazy typed", val: 17 })).toEqual({
127-
type: "lazy typed",
128-
val: 17,
129-
});
13063
});
13164

13265
test("invalid - null", () => {
@@ -196,9 +129,7 @@ test("wrong schema - missing discriminator", () => {
196129
]);
197130
throw new Error();
198131
} catch (e: any) {
199-
expect(e.message).toEqual(
200-
"The discriminator value could not be extracted from all the provided schemas"
201-
);
132+
expect(e.message).toContain("could not be extracted");
202133
}
203134
});
204135

@@ -210,9 +141,7 @@ test("wrong schema - duplicate discriminator values", () => {
210141
]);
211142
throw new Error();
212143
} catch (e: any) {
213-
expect(e.message).toEqual(
214-
"Some of the discriminator values are not unique"
215-
);
144+
expect(e.message).toContain("has duplicate value");
216145
}
217146
});
218147

deno/lib/types.ts

Lines changed: 62 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import {
1818
ParsePath,
1919
ParseReturnType,
2020
ParseStatus,
21-
SyncParseReturnType
21+
SyncParseReturnType,
2222
} from "./helpers/parseUtil.ts";
2323
import { partialUtil } from "./helpers/partialUtil.ts";
2424
import { Primitive } from "./helpers/typeAliases.ts";
@@ -30,7 +30,7 @@ import {
3030
ZodError,
3131
ZodErrorMap,
3232
ZodIssue,
33-
ZodIssueCode
33+
ZodIssueCode,
3434
} from "./ZodError.ts";
3535

3636
///////////////////////////////////////
@@ -2137,80 +2137,46 @@ export class ZodUnion<T extends ZodUnionOptions> extends ZodType<
21372137
/////////////////////////////////////////////////////
21382138
/////////////////////////////////////////////////////
21392139

2140-
type ZodSourceType<T extends ZodTypeAny> = T extends ZodLazy<infer U>
2141-
? ZodSourceType<U>
2142-
: T extends ZodEffects<infer U>
2143-
? ZodSourceType<U>
2144-
: T;
2145-
2146-
type Prev = [never, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
2147-
2148-
type ZodOriginType<T extends ZodTypeAny, D extends Prev[number] = 9> = [
2149-
D
2150-
] extends [never]
2151-
? never
2152-
:
2153-
| T
2154-
| ZodEffects<ZodOriginType<T, Prev[D]>>
2155-
| ZodLazy<ZodOriginType<T, Prev[D]>>;
2156-
2157-
function getSourceType<T extends ZodTypeAny>(type: T): ZodSourceType<T> {
2158-
if (type._def.typeName === ZodFirstPartyTypeKind.ZodLazy) {
2159-
return getSourceType(
2160-
(type as unknown as ZodLazy<ZodTypeAny>).schema
2161-
) as ZodSourceType<T>;
2162-
} else if (type._def.typeName === ZodFirstPartyTypeKind.ZodEffects) {
2163-
return getSourceType(
2164-
(type as unknown as ZodEffects<ZodTypeAny>).sourceType()
2165-
) as ZodSourceType<T>;
2140+
const getDiscriminator = <T extends ZodTypeAny>(
2141+
type: T
2142+
): Primitive[] | null => {
2143+
if (type instanceof ZodLazy) {
2144+
return getDiscriminator(type.schema);
2145+
} else if (type instanceof ZodEffects) {
2146+
return getDiscriminator(type.innerType());
2147+
} else if (type instanceof ZodLiteral) {
2148+
return [type.value];
2149+
} else if (type instanceof ZodEnum) {
2150+
return type.options;
2151+
} else if (type instanceof ZodUndefined) {
2152+
return [undefined];
2153+
} else if (type instanceof ZodNull) {
2154+
return [null];
21662155
} else {
2167-
return type as ZodSourceType<T>;
2156+
return null;
21682157
}
2169-
}
2170-
2171-
type ZodDiscriminatedUnionOptionBase<
2172-
Discriminator extends string,
2173-
DiscriminatorValue extends Primitive
2174-
> = ZodObject<
2175-
{ [key in Discriminator]: ZodLiteral<DiscriminatorValue> } & ZodRawShape,
2176-
any,
2177-
any
2178-
>;
2179-
2180-
type ZodDiscriminatedUnionType<Discriminator extends string> = Record<
2181-
string,
2182-
any
2183-
> & {
2184-
[key in Discriminator]: Primitive;
21852158
};
21862159

2187-
export type ZodDiscriminatedUnionOption<
2188-
Discriminator extends string,
2189-
DiscriminatorValue extends Primitive
2190-
> =
2191-
| ZodOriginType<
2192-
ZodDiscriminatedUnionOptionBase<Discriminator, DiscriminatorValue>
2193-
>
2194-
| ZodType<any, any, ZodDiscriminatedUnionType<Discriminator>>;
2160+
export type ZodDiscriminatedUnionOption<Discriminator extends string> =
2161+
ZodObject<{ [key in Discriminator]: ZodTypeAny } & ZodRawShape, any, any>;
21952162

21962163
export interface ZodDiscriminatedUnionDef<
21972164
Discriminator extends string,
2198-
DiscriminatorValue extends Primitive,
2199-
Option extends ZodDiscriminatedUnionOption<Discriminator, DiscriminatorValue>
2165+
Options extends ZodDiscriminatedUnionOption<any>[]
22002166
> extends ZodTypeDef {
22012167
discriminator: Discriminator;
2202-
options: Map<DiscriminatorValue, Option>;
2168+
options: Options;
2169+
optionsMap: Map<Primitive, ZodDiscriminatedUnionOption<any>>;
22032170
typeName: ZodFirstPartyTypeKind.ZodDiscriminatedUnion;
22042171
}
22052172

22062173
export class ZodDiscriminatedUnion<
22072174
Discriminator extends string,
2208-
DiscriminatorValue extends Primitive,
2209-
Option extends ZodDiscriminatedUnionOption<Discriminator, DiscriminatorValue>
2175+
Options extends ZodDiscriminatedUnionOption<Discriminator>[]
22102176
> extends ZodType<
2211-
output<Option>,
2212-
ZodDiscriminatedUnionDef<Discriminator, DiscriminatorValue, Option>,
2213-
input<Option>
2177+
output<Options[number]>,
2178+
ZodDiscriminatedUnionDef<Discriminator, Options>,
2179+
input<Options[number]>
22142180
> {
22152181
_parse(input: ParseInput): ParseReturnType<this["_output"]> {
22162182
const { ctx } = this._processInputParams(input);
@@ -2225,13 +2191,13 @@ export class ZodDiscriminatedUnion<
22252191
}
22262192

22272193
const discriminator = this.discriminator;
2228-
const discriminatorValue: DiscriminatorValue = ctx.data[discriminator];
2229-
const option = this.options.get(discriminatorValue);
2194+
const discriminatorValue: string = ctx.data[discriminator];
2195+
const option = this.optionsMap.get(discriminatorValue);
22302196

22312197
if (!option) {
22322198
addIssueToContext(ctx, {
22332199
code: ZodIssueCode.invalid_union_discriminator,
2234-
options: this.validDiscriminatorValues,
2200+
options: Array.from(this.optionsMap.keys()),
22352201
path: [discriminator],
22362202
});
22372203
return INVALID;
@@ -2242,28 +2208,28 @@ export class ZodDiscriminatedUnion<
22422208
data: ctx.data,
22432209
path: ctx.path,
22442210
parent: ctx,
2245-
});
2211+
}) as any;
22462212
} else {
22472213
return option._parseSync({
22482214
data: ctx.data,
22492215
path: ctx.path,
22502216
parent: ctx,
2251-
});
2217+
}) as any;
22522218
}
22532219
}
22542220

22552221
get discriminator() {
22562222
return this._def.discriminator;
22572223
}
22582224

2259-
get validDiscriminatorValues() {
2260-
return Array.from(this.options.keys());
2261-
}
2262-
22632225
get options() {
22642226
return this._def.options;
22652227
}
22662228

2229+
get optionsMap() {
2230+
return this._def.optionsMap;
2231+
}
2232+
22672233
/**
22682234
* The constructor of the discriminated union schema. Its behaviour is very similar to that of the normal z.union() constructor.
22692235
* However, it only allows a union of objects, all of which need to share a discriminator property. This property must
@@ -2274,48 +2240,45 @@ export class ZodDiscriminatedUnion<
22742240
*/
22752241
static create<
22762242
Discriminator extends string,
2277-
DiscriminatorValue extends Primitive,
22782243
Types extends [
2279-
ZodDiscriminatedUnionOption<Discriminator, DiscriminatorValue>,
2280-
ZodDiscriminatedUnionOption<Discriminator, DiscriminatorValue>,
2281-
...ZodDiscriminatedUnionOption<Discriminator, DiscriminatorValue>[]
2244+
ZodDiscriminatedUnionOption<Discriminator>,
2245+
...ZodDiscriminatedUnionOption<Discriminator>[]
22822246
]
22832247
>(
22842248
discriminator: Discriminator,
2285-
types: Types,
2249+
options: Types,
22862250
params?: RawCreateParams
2287-
): ZodDiscriminatedUnion<Discriminator, DiscriminatorValue, Types[number]> {
2251+
): ZodDiscriminatedUnion<Discriminator, Types> {
22882252
// Get all the valid discriminator values
2289-
const options: Map<DiscriminatorValue, Types[number]> = new Map();
2290-
2291-
try {
2292-
types.forEach((type) => {
2293-
const discriminatorValue = getSourceType(
2294-
type as ZodOriginType<
2295-
ZodDiscriminatedUnionOptionBase<Discriminator, DiscriminatorValue>
2296-
>
2297-
).shape[discriminator].value;
2298-
options.set(discriminatorValue, type);
2299-
});
2300-
} catch (e) {
2301-
throw new Error(
2302-
"The discriminator value could not be extracted from all the provided schemas"
2303-
);
2304-
}
2305-
2306-
// Assert that all the discriminator values are unique
2307-
if (options.size !== types.length) {
2308-
throw new Error("Some of the discriminator values are not unique");
2253+
const optionsMap: Map<Primitive, Types[number]> = new Map();
2254+
2255+
// try {
2256+
for (const type of options) {
2257+
const discriminatorValues = getDiscriminator(type.shape[discriminator]);
2258+
if (!discriminatorValues) {
2259+
throw new Error(
2260+
`A discriminator value for key \`${discriminator}\`could not be extracted from all schema options`
2261+
);
2262+
}
2263+
for (const value of discriminatorValues) {
2264+
if (optionsMap.has(value)) {
2265+
throw new Error(
2266+
`Discriminator property ${discriminator} has duplicate value ${value}`
2267+
);
2268+
}
2269+
optionsMap.set(value, type);
2270+
}
23092271
}
23102272

23112273
return new ZodDiscriminatedUnion<
23122274
Discriminator,
2313-
DiscriminatorValue,
2314-
Types[number]
2275+
// DiscriminatorValue,
2276+
Types
23152277
>({
23162278
typeName: ZodFirstPartyTypeKind.ZodDiscriminatedUnion,
23172279
discriminator,
23182280
options,
2281+
optionsMap,
23192282
...processCreateParams(params),
23202283
});
23212284
}
@@ -3911,7 +3874,7 @@ export type ZodFirstPartySchemaTypes =
39113874
| ZodArray<any, any>
39123875
| ZodObject<any, any, any, any, any>
39133876
| ZodUnion<any>
3914-
| ZodDiscriminatedUnion<any, any, any>
3877+
| ZodDiscriminatedUnion<any, any>
39153878
| ZodIntersection<any, any>
39163879
| ZodTuple<any, any>
39173880
| ZodRecord<any, any>

0 commit comments

Comments
 (0)