Skip to content

Commit 3113915

Browse files
feat: add dag branch (#75)
1 parent 5e5dc1f commit 3113915

File tree

6 files changed

+350
-34
lines changed

6 files changed

+350
-34
lines changed

compose/dag.go

Lines changed: 70 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,46 +22,47 @@ import (
2222
)
2323

2424
func dagChannelBuilder(dependencies []string) channel {
25+
waitList := make(map[string]bool, len(dependencies))
26+
for _, dep := range dependencies {
27+
waitList[dep] = false
28+
}
2529
return &dagChannel{
2630
values: make(map[string]any),
27-
waitList: dependencies,
31+
waitList: waitList,
2832
}
2933
}
3034

35+
type waitPred struct {
36+
key string
37+
skipped bool
38+
}
39+
3140
type dagChannel struct {
3241
values map[string]any
33-
waitList []string
42+
waitList map[string]bool
3443
value any
44+
skipped bool
3545
}
3646

3747
func (ch *dagChannel) update(ctx context.Context, ins map[string]any) error {
48+
if ch.skipped {
49+
return nil
50+
}
51+
3852
for k, v := range ins {
3953
if _, ok := ch.values[k]; ok {
4054
return fmt.Errorf("dag channel update, calculate node repeatedly: %s", k)
4155
}
4256
ch.values[k] = v
4357
}
4458

45-
for i := range ch.waitList {
46-
if _, ok := ch.values[ch.waitList[i]]; !ok {
47-
return nil
48-
}
49-
}
50-
51-
if len(ch.waitList) == 1 {
52-
ch.value = ch.values[ch.waitList[0]]
53-
return nil
54-
}
55-
v, err := mergeValues(mapToList(ch.values))
56-
if err != nil {
57-
return fmt.Errorf("dag channel merge value fail: %w", err)
58-
}
59-
ch.value = v
60-
61-
return nil
59+
return ch.tryUpdateValue()
6260
}
6361

6462
func (ch *dagChannel) get(ctx context.Context) (any, error) {
63+
if ch.skipped {
64+
return nil, fmt.Errorf("dag channel has been skipped")
65+
}
6566
if ch.value == nil {
6667
return nil, fmt.Errorf("dag channel not ready, value is nil")
6768
}
@@ -71,5 +72,55 @@ func (ch *dagChannel) get(ctx context.Context) (any, error) {
7172
}
7273

7374
func (ch *dagChannel) ready(ctx context.Context) bool {
75+
if ch.skipped {
76+
return false
77+
}
7478
return ch.value != nil
7579
}
80+
81+
func (ch *dagChannel) reportSkip(keys []string) (bool, error) {
82+
for _, k := range keys {
83+
if _, ok := ch.waitList[k]; ok {
84+
ch.waitList[k] = true
85+
}
86+
}
87+
88+
allSkipped := true
89+
for _, skipped := range ch.waitList {
90+
if !skipped {
91+
allSkipped = false
92+
break
93+
}
94+
}
95+
ch.skipped = allSkipped
96+
97+
var err error
98+
if !allSkipped {
99+
err = ch.tryUpdateValue()
100+
}
101+
102+
return allSkipped, err
103+
}
104+
105+
func (ch *dagChannel) tryUpdateValue() error {
106+
var validList []string
107+
for key, skipped := range ch.waitList {
108+
if _, ok := ch.values[key]; !ok && !skipped {
109+
return nil
110+
} else if !skipped {
111+
validList = append(validList, key)
112+
}
113+
}
114+
115+
if len(validList) == 1 {
116+
ch.value = ch.values[validList[0]]
117+
return nil
118+
}
119+
v, err := mergeValues(mapToList(ch.values))
120+
if err != nil {
121+
return err
122+
}
123+
ch.value = v
124+
return nil
125+
126+
}

compose/graph.go

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,7 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa
723723
if isWorkflow(g.cmp) {
724724
eager = true
725725
}
726-
if !eager && opt != nil && opt.getStateEnabled {
726+
if !isWorkflow(g.cmp) && opt != nil && opt.getStateEnabled {
727727
return nil, fmt.Errorf("shouldn't set WithGetStateEnable outside of the Workflow")
728728
}
729729
forbidGetState := true
@@ -745,11 +745,6 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa
745745
}
746746
}
747747

748-
// dag doesn't support branch
749-
if runType == runTypeDAG && len(g.branches) > 0 {
750-
return nil, fmt.Errorf("dag doesn't support branch for now")
751-
}
752-
753748
for key := range g.fieldMappingRecords {
754749
// not allowed to map multiple fields to the same field
755750
toMap := make(map[string]bool)
@@ -806,6 +801,17 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa
806801

807802
}
808803
}
804+
for start, branches := range g.branches {
805+
for _, branch := range branches {
806+
for end := range branch.endNodes {
807+
if _, ok := invertedEdges[end]; !ok {
808+
invertedEdges[end] = []string{start}
809+
} else {
810+
invertedEdges[end] = append(invertedEdges[end], start)
811+
}
812+
}
813+
}
814+
}
809815

810816
inputChannels := &chanCall{
811817
writeTo: g.edges[START],
@@ -833,6 +839,12 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa
833839
edgeHandlerManager: &edgeHandlerManager{h: g.handlerOnEdges},
834840
}
835841

842+
successors := make(map[string][]string)
843+
for ch := range r.chanSubscribeTo {
844+
successors[ch] = getSuccessors(r.chanSubscribeTo[ch])
845+
}
846+
r.successors = successors
847+
836848
if g.stateGenerator != nil {
837849
r.runCtx = func(ctx context.Context) context.Context {
838850
return context.WithValue(ctx, stateKey{}, &internalState{
@@ -868,6 +880,17 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa
868880
return r.toComposableRunnable(), nil
869881
}
870882

883+
func getSuccessors(c *chanCall) []string {
884+
ret := make([]string, len(c.writeTo))
885+
copy(ret, c.writeTo)
886+
for _, branch := range c.writeToBranches {
887+
for node := range branch.endNodes {
888+
ret = append(ret, node)
889+
}
890+
}
891+
return ret
892+
}
893+
871894
type subGraphCompileCallback struct {
872895
closure func(ctx context.Context, info *GraphInfo)
873896
}
@@ -1043,6 +1066,14 @@ func validateDAG(chanSubscribeTo map[string]*chanCall, invertedEdges map[string]
10431066
}
10441067
m[subNode]--
10451068
}
1069+
for _, subBranch := range chanSubscribeTo[node].writeToBranches {
1070+
for subNode := range subBranch.endNodes {
1071+
if subNode == END {
1072+
continue
1073+
}
1074+
m[subNode]--
1075+
}
1076+
}
10461077
m[node] = -1
10471078
}
10481079
}

compose/graph_manager.go

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ type channel interface {
3030
update(context.Context, map[string]any) error
3131
get(context.Context) (any, error)
3232
ready(context.Context) bool
33+
reportSkip([]string) (bool, error)
3334
}
3435

3536
type edgeHandlerManager struct {
@@ -108,8 +109,9 @@ func (p *preBranchHandlerManager) handle(nodeKey string, idx int, value any, isS
108109
}
109110

110111
type channelManager struct {
111-
isStream bool
112-
channels map[string]channel
112+
isStream bool
113+
successors map[string][]string
114+
channels map[string]channel
113115

114116
edgeHandlerManager *edgeHandlerManager
115117
preNodeHandlerManager *preNodeHandlerManager
@@ -163,6 +165,37 @@ func (c *channelManager) updateAndGet(ctx context.Context, values map[string]map
163165
return c.getFromReadyChannels(ctx, isStream)
164166
}
165167

168+
func (c *channelManager) reportBranch(from string, skippedNodes []string) error {
169+
var nKeys []string
170+
for _, node := range skippedNodes {
171+
skipped, err := c.channels[node].reportSkip([]string{from})
172+
if err != nil {
173+
return err
174+
}
175+
if skipped {
176+
nKeys = append(nKeys, node)
177+
}
178+
}
179+
180+
for i := 0; i < len(nKeys); i++ {
181+
key := nKeys[i]
182+
if _, ok := c.successors[key]; !ok {
183+
return fmt.Errorf("unknown node: %s", key)
184+
}
185+
for _, successor := range c.successors[key] {
186+
skipped, err := c.channels[successor].reportSkip([]string{key})
187+
if err != nil {
188+
return err
189+
}
190+
if skipped {
191+
nKeys = append(nKeys, successor)
192+
}
193+
// todo: detect if end node has been skipped?
194+
}
195+
}
196+
return nil
197+
}
198+
166199
type task struct {
167200
ctx context.Context
168201
nodeKey string

compose/graph_run.go

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ type chanBuilder func(d []string) channel
6464
type runner struct {
6565
chanSubscribeTo map[string]*chanCall
6666
invertedEdges map[string][]string
67+
successors map[string][]string
6768
inputChannels *chanCall
6869

6970
chanBuilder chanBuilder // could be nil
@@ -176,7 +177,7 @@ func (r *runner) run(ctx context.Context, isStream bool, input any, opts ...Opti
176177
}
177178

178179
// 1. Calculate active edges and resolve their values.
179-
writeChannelValues, err := r.resolveCompletedTasks(ctx, completedTasks, isStream)
180+
writeChannelValues, err := r.resolveCompletedTasks(ctx, completedTasks, isStream, cm)
180181
if err != nil {
181182
return nil, err
182183
}
@@ -233,13 +234,13 @@ func (r *runner) createTasks(ctx context.Context, nodeMap map[string]any, optMap
233234
return nextTasks, nil
234235
}
235236

236-
func (r *runner) resolveCompletedTasks(ctx context.Context, completedTasks []*task, isStream bool) (map[string]map[string]any, error) {
237+
func (r *runner) resolveCompletedTasks(ctx context.Context, completedTasks []*task, isStream bool, cm *channelManager) (map[string]map[string]any, error) {
237238
writeChannelValues := make(map[string]map[string]any)
238239
for _, t := range completedTasks {
239240
// update channel & new_next_tasks
240241
vs := copyItem(t.output, len(t.call.writeTo)+len(t.call.writeToBranches)*2)
241242
nextNodeKeys, err := r.calculateNext(ctx, t.nodeKey, t.call,
242-
vs[len(t.call.writeTo)+len(t.call.writeToBranches):], isStream)
243+
vs[len(t.call.writeTo)+len(t.call.writeToBranches):], isStream, cm)
243244
if err != nil {
244245
return nil, fmt.Errorf("calculate next step fail, node: %s, error: %w", t.nodeKey, err)
245246
}
@@ -253,7 +254,7 @@ func (r *runner) resolveCompletedTasks(ctx context.Context, completedTasks []*ta
253254
return writeChannelValues, nil
254255
}
255256

256-
func (r *runner) calculateNext(ctx context.Context, curNodeKey string, startChan *chanCall, input []any, isStream bool) ([]string, error) {
257+
func (r *runner) calculateNext(ctx context.Context, curNodeKey string, startChan *chanCall, input []any, isStream bool, cm *channelManager) ([]string, error) {
257258
if len(input) < len(startChan.writeToBranches) {
258259
// unreachable
259260
return nil, errors.New("calculate next input length is shorter than branches")
@@ -266,6 +267,7 @@ func (r *runner) calculateNext(ctx context.Context, curNodeKey string, startChan
266267
ret := make([]string, 0, len(startChan.writeTo))
267268
ret = append(ret, startChan.writeTo...)
268269

270+
skippedNodes := make(map[string]struct{})
269271
for i, branch := range startChan.writeToBranches {
270272
// check branch input type if needed
271273
var err error
@@ -305,8 +307,33 @@ func (r *runner) calculateNext(ctx context.Context, curNodeKey string, startChan
305307
return nil, errors.New("invoke branch result isn't string")
306308
}
307309
}
310+
311+
for node := range branch.endNodes {
312+
if node != w {
313+
skippedNodes[node] = struct{}{}
314+
}
315+
}
316+
308317
ret = append(ret, w)
309318
}
319+
320+
// When a node has multiple branches,
321+
// there may be a situation where a succeeding node is selected by some branches and discarded by the other branches,
322+
// in which case the succeeding node should not be skipped.
323+
var skippedNodeList []string
324+
for _, selected := range ret {
325+
if _, ok := skippedNodes[selected]; ok {
326+
delete(skippedNodes, selected)
327+
}
328+
}
329+
for skipped := range skippedNodes {
330+
skippedNodeList = append(skippedNodeList, skipped)
331+
}
332+
333+
err := cm.reportBranch(curNodeKey, skippedNodeList)
334+
if err != nil {
335+
return nil, err
336+
}
310337
return ret, nil
311338
}
312339

@@ -337,8 +364,9 @@ func (r *runner) initChannelManager(isStream bool) *channelManager {
337364
chs[END] = builder(r.invertedEdges[END])
338365

339366
return &channelManager{
340-
isStream: isStream,
341-
channels: chs,
367+
isStream: isStream,
368+
channels: chs,
369+
successors: r.successors,
342370

343371
edgeHandlerManager: r.edgeHandlerManager,
344372
preNodeHandlerManager: r.preNodeHandlerManager,

0 commit comments

Comments
 (0)