Skip to content

Commit f30e0ca

Browse files
authored
feat(adk): add plan task tool (#736)
Change-Id: I61cd3709e78b5e1ef1fe169572913e7e35946d56
1 parent a74b18b commit f30e0ca

File tree

12 files changed

+2577
-0
lines changed

12 files changed

+2577
-0
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Copyright 2025 CloudWeGo Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package plantask
18+
19+
import (
20+
"context"
21+
"errors"
22+
"path/filepath"
23+
"strings"
24+
"sync"
25+
)
26+
27+
type inMemoryBackend struct {
28+
files map[string]string
29+
mu sync.RWMutex
30+
}
31+
32+
func newInMemoryBackend() *inMemoryBackend {
33+
return &inMemoryBackend{
34+
files: make(map[string]string),
35+
}
36+
}
37+
38+
func (b *inMemoryBackend) LsInfo(ctx context.Context, req *LsInfoRequest) ([]FileInfo, error) {
39+
b.mu.RLock()
40+
defer b.mu.RUnlock()
41+
42+
reqPath := strings.TrimSuffix(req.Path, "/")
43+
var result []FileInfo
44+
for path := range b.files {
45+
dir := filepath.Dir(path)
46+
if dir == reqPath {
47+
result = append(result, FileInfo{Path: path})
48+
}
49+
}
50+
return result, nil
51+
}
52+
53+
func (b *inMemoryBackend) Read(ctx context.Context, req *ReadRequest) (string, error) {
54+
b.mu.RLock()
55+
defer b.mu.RUnlock()
56+
57+
content, ok := b.files[req.FilePath]
58+
if !ok {
59+
return "", errors.New("file not found")
60+
}
61+
return content, nil
62+
}
63+
64+
func (b *inMemoryBackend) Write(ctx context.Context, req *WriteRequest) error {
65+
b.mu.Lock()
66+
defer b.mu.Unlock()
67+
68+
b.files[req.FilePath] = req.Content
69+
return nil
70+
}
71+
72+
func (b *inMemoryBackend) Delete(ctx context.Context, req *DeleteRequest) error {
73+
b.mu.Lock()
74+
defer b.mu.Unlock()
75+
76+
delete(b.files, req.FilePath)
77+
return nil
78+
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Copyright 2025 CloudWeGo Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package plantask
18+
19+
import (
20+
"context"
21+
"fmt"
22+
"sync"
23+
24+
"github.com/cloudwego/eino/adk"
25+
)
26+
27+
// Config is the configuration for the tool search middleware.
28+
type Config struct {
29+
Backend Backend
30+
BaseDir string
31+
}
32+
33+
// New creates a new plantask middleware that provides task management tools for agents.
34+
// It adds TaskCreate, TaskGet, TaskUpdate, and TaskList tools to the agent's tool set,
35+
// allowing agents to create and manage structured task lists during coding sessions.
36+
func New(ctx context.Context, config *Config) (adk.ChatModelAgentMiddleware, error) {
37+
if config == nil {
38+
return nil, fmt.Errorf("config is required")
39+
}
40+
if config.Backend == nil {
41+
return nil, fmt.Errorf("backend is required")
42+
}
43+
if config.BaseDir == "" {
44+
return nil, fmt.Errorf("baseDir is required")
45+
}
46+
47+
return &middleware{backend: config.Backend, baseDir: config.BaseDir}, nil
48+
}
49+
50+
type middleware struct {
51+
adk.BaseChatModelAgentMiddleware
52+
backend Backend
53+
baseDir string
54+
}
55+
56+
func (m *middleware) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgentContext) (context.Context, *adk.ChatModelAgentContext, error) {
57+
if runCtx == nil {
58+
return ctx, runCtx, nil
59+
}
60+
61+
nRunCtx := *runCtx
62+
lock := sync.Mutex{}
63+
nRunCtx.Tools = append(nRunCtx.Tools,
64+
newTaskCreateTool(m.backend, m.baseDir, &lock),
65+
newTaskGetTool(m.backend, m.baseDir, &lock),
66+
newTaskUpdateTool(m.backend, m.baseDir, &lock),
67+
newTaskListTool(m.backend, m.baseDir, &lock),
68+
)
69+
70+
return ctx, &nRunCtx, nil
71+
}
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Copyright 2025 CloudWeGo Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package plantask
18+
19+
import (
20+
"context"
21+
"sync"
22+
"testing"
23+
24+
"github.com/stretchr/testify/assert"
25+
26+
"github.com/cloudwego/eino/adk"
27+
"github.com/cloudwego/eino/components/tool"
28+
)
29+
30+
func TestNew(t *testing.T) {
31+
ctx := context.Background()
32+
33+
_, err := New(ctx, nil)
34+
assert.Error(t, err)
35+
assert.Contains(t, err.Error(), "config is required")
36+
37+
_, err = New(ctx, &Config{})
38+
assert.Error(t, err)
39+
assert.Contains(t, err.Error(), "backend is required")
40+
41+
_, err = New(ctx, &Config{Backend: newInMemoryBackend()})
42+
assert.Error(t, err)
43+
assert.Contains(t, err.Error(), "baseDir is required")
44+
45+
m, err := New(ctx, &Config{Backend: newInMemoryBackend(), BaseDir: "/tmp/tasks"})
46+
assert.NoError(t, err)
47+
assert.NotNil(t, m)
48+
}
49+
50+
func TestMiddlewareBeforeAgent(t *testing.T) {
51+
ctx := context.Background()
52+
backend := newInMemoryBackend()
53+
baseDir := "/tmp/tasks"
54+
55+
m, err := New(ctx, &Config{Backend: backend, BaseDir: baseDir})
56+
assert.NoError(t, err)
57+
58+
mw := m.(*middleware)
59+
60+
ctx, runCtx, err := mw.BeforeAgent(ctx, nil)
61+
assert.NoError(t, err)
62+
assert.Nil(t, runCtx)
63+
64+
runCtx = &adk.ChatModelAgentContext{
65+
Tools: []tool.BaseTool{},
66+
}
67+
ctx, newRunCtx, err := mw.BeforeAgent(ctx, runCtx)
68+
assert.NoError(t, err)
69+
assert.NotNil(t, newRunCtx)
70+
assert.Len(t, newRunCtx.Tools, 4)
71+
72+
toolNames := make([]string, 0, 4)
73+
for _, t := range newRunCtx.Tools {
74+
info, _ := t.Info(ctx)
75+
toolNames = append(toolNames, info.Name)
76+
}
77+
assert.Contains(t, toolNames, "TaskCreate")
78+
assert.Contains(t, toolNames, "TaskGet")
79+
assert.Contains(t, toolNames, "TaskUpdate")
80+
assert.Contains(t, toolNames, "TaskList")
81+
}
82+
83+
func TestIntegration(t *testing.T) {
84+
ctx := context.Background()
85+
backend := newInMemoryBackend()
86+
baseDir := "/tmp/tasks"
87+
lock := &sync.Mutex{}
88+
89+
createTool := newTaskCreateTool(backend, baseDir, lock)
90+
getTool := newTaskGetTool(backend, baseDir, lock)
91+
updateTool := newTaskUpdateTool(backend, baseDir, lock)
92+
listTool := newTaskListTool(backend, baseDir, lock)
93+
94+
result, err := createTool.InvokableRun(ctx, `{"subject": "Task 1", "description": "First task"}`)
95+
assert.NoError(t, err)
96+
assert.Contains(t, result, "Task #1")
97+
98+
result, err = createTool.InvokableRun(ctx, `{"subject": "Task 2", "description": "Second task"}`)
99+
assert.NoError(t, err)
100+
assert.Contains(t, result, "Task #2")
101+
102+
_, err = updateTool.InvokableRun(ctx, `{"taskId": "2", "addBlockedBy": ["1"]}`)
103+
assert.NoError(t, err)
104+
105+
result, err = listTool.InvokableRun(ctx, `{}`)
106+
assert.NoError(t, err)
107+
assert.Contains(t, result, "#1 [pending] Task 1")
108+
assert.Contains(t, result, "#2 [pending] Task 2")
109+
assert.Contains(t, result, "[blocked by #1]")
110+
111+
_, err = updateTool.InvokableRun(ctx, `{"taskId": "1", "status": "in_progress"}`)
112+
assert.NoError(t, err)
113+
114+
result, err = getTool.InvokableRun(ctx, `{"taskId": "1"}`)
115+
assert.NoError(t, err)
116+
assert.Contains(t, result, "Status: in_progress")
117+
118+
_, err = updateTool.InvokableRun(ctx, `{"taskId": "1", "status": "completed"}`)
119+
assert.NoError(t, err)
120+
121+
result, err = listTool.InvokableRun(ctx, `{}`)
122+
assert.NoError(t, err)
123+
assert.Contains(t, result, "#1 [completed] Task 1")
124+
}

adk/middlewares/plantask/task.go

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Copyright 2025 CloudWeGo Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package plantask
18+
19+
import (
20+
"context"
21+
"regexp"
22+
23+
"github.com/cloudwego/eino/adk/middlewares/filesystem"
24+
)
25+
26+
var validTaskIDRegex = regexp.MustCompile(`^\d+$`)
27+
28+
const highWatermarkFileName = ".highwatermark"
29+
30+
type task struct {
31+
ID string `json:"id"`
32+
Subject string `json:"subject"`
33+
Description string `json:"description"`
34+
Status string `json:"status"`
35+
Blocks []string `json:"blocks"`
36+
BlockedBy []string `json:"blockedBy"`
37+
ActiveForm string `json:"activeForm,omitempty"`
38+
Owner string `json:"owner,omitempty"`
39+
Metadata map[string]any `json:"metadata,omitempty"`
40+
}
41+
42+
type taskOut struct {
43+
Result string `json:"result"`
44+
}
45+
46+
const (
47+
taskStatusPending = "pending"
48+
taskStatusInProgress = "in_progress"
49+
taskStatusCompleted = "completed"
50+
taskStatusDeleted = "deleted"
51+
)
52+
53+
type FileInfo = filesystem.FileInfo
54+
type LsInfoRequest = filesystem.LsInfoRequest
55+
type ReadRequest = filesystem.ReadRequest
56+
type WriteRequest = filesystem.WriteRequest
57+
58+
type DeleteRequest struct {
59+
FilePath string
60+
}
61+
62+
// Backend defines the storage interface for task persistence.
63+
// Implementations can use local filesystem, remote storage, or any other storage backend.
64+
type Backend interface {
65+
// LsInfo lists file information in the specified directory.
66+
LsInfo(ctx context.Context, req *LsInfoRequest) ([]FileInfo, error)
67+
// Read reads the content of a file.
68+
Read(ctx context.Context, req *ReadRequest) (string, error)
69+
// Write writes content to a file, creating it if it doesn't exist.
70+
Write(ctx context.Context, req *WriteRequest) error
71+
// Delete removes a file from storage.
72+
Delete(ctx context.Context, req *DeleteRequest) error
73+
}
74+
75+
func isValidTaskID(taskID string) bool {
76+
return validTaskIDRegex.MatchString(taskID)
77+
}
78+
79+
func appendUnique(slice []string, items ...string) []string {
80+
seen := make(map[string]struct{}, len(slice))
81+
for _, s := range slice {
82+
seen[s] = struct{}{}
83+
}
84+
for _, item := range items {
85+
if _, exists := seen[item]; !exists {
86+
slice = append(slice, item)
87+
seen[item] = struct{}{}
88+
}
89+
}
90+
return slice
91+
}
92+
93+
func hasCyclicDependency(taskMap map[string]*task, blockerID, blockedID string) bool {
94+
if blockerID == blockedID {
95+
return true
96+
}
97+
98+
visited := make(map[string]bool)
99+
return canReach(taskMap, blockedID, blockerID, visited)
100+
}
101+
102+
func canReach(taskMap map[string]*task, fromID, toID string, visited map[string]bool) bool {
103+
if fromID == toID {
104+
return true
105+
}
106+
if visited[fromID] {
107+
return false
108+
}
109+
visited[fromID] = true
110+
111+
fromTask, exists := taskMap[fromID]
112+
if !exists {
113+
return false
114+
}
115+
116+
for _, blockedID := range fromTask.Blocks {
117+
if canReach(taskMap, blockedID, toID, visited) {
118+
return true
119+
}
120+
}
121+
122+
return false
123+
}

0 commit comments

Comments
 (0)