Skip to content

Commit 19ec359

Browse files
committed
feat: 添加客户端事件循环测试用例和改进连接注册逻辑
1 parent 2b6b1a6 commit 19ec359

File tree

2 files changed

+162
-38
lines changed

2 files changed

+162
-38
lines changed

client_event_loop.go

Lines changed: 75 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,82 +2,113 @@ package pulse
22

33
import (
44
"context"
5+
"fmt"
56
"net"
6-
"runtime"
77
"sync"
88
"sync/atomic"
99

1010
"github.com/antlabs/pulse/core"
1111
)
1212

1313
type ClientEventLoop struct {
14-
pollers []core.PollingApi
15-
conns []*core.SafeConns[Conn]
16-
callback Callback
17-
options Options
18-
next uint32 // 用于轮询分配
19-
ctx context.Context
14+
*MultiEventLoop
15+
next uint32 // 轮询计数器
16+
conns *core.SafeConns[Conn] // 每个事件循环的连接管理器
17+
callback Callback // 回调函数
18+
ctx context.Context // 上下文
2019
}
2120

2221
func NewClientEventLoop(ctx context.Context, opts ...func(*Options)) *ClientEventLoop {
23-
var options Options
24-
for _, opt := range opts {
25-
opt(&options)
26-
}
27-
n := runtime.NumCPU()
28-
pollers := make([]core.PollingApi, n)
29-
conns := make([]*core.SafeConns[Conn], n)
30-
for i := 0; i < n; i++ {
31-
pollers[i], _ = core.Create(core.TriggerTypeEdge)
32-
conns[i] = &core.SafeConns[Conn]{}
33-
conns[i].Init(core.GetMaxFd())
22+
multiLoop, err := NewMultiEventLoop(ctx, opts...)
23+
if err != nil {
24+
panic(err)
3425
}
26+
27+
// 初始化连接管理器
28+
conns := core.SafeConns[Conn]{}
29+
conns.Init(core.GetMaxFd())
30+
3531
return &ClientEventLoop{
36-
pollers: pollers,
37-
conns: conns,
38-
callback: options.callback,
39-
options: options,
40-
ctx: ctx,
32+
MultiEventLoop: multiLoop,
33+
conns: &conns,
34+
callback: multiLoop.options.callback,
35+
ctx: ctx,
4136
}
4237
}
4338

4439
func (loop *ClientEventLoop) RegisterConn(conn net.Conn) error {
40+
// 1. 获取文件描述符
4541
fd, err := core.GetFdFromConn(conn)
4642
if err != nil {
47-
return err
43+
return fmt.Errorf("failed to get fd from connection: %w", err)
4844
}
45+
46+
// 2. 关闭原始连接(因为我们要使用文件描述符)
4947
if err := conn.Close(); err != nil {
50-
return err
48+
return fmt.Errorf("failed to close original connection: %w", err)
5149
}
52-
idx := atomic.AddUint32(&loop.next, 1) % uint32(len(loop.pollers))
53-
c := newConn(fd, loop.conns[idx], nil, TaskTypeInEventLoop, loop.pollers[idx], 4096, false)
54-
loop.conns[idx].Add(fd, c)
55-
loop.callback.OnOpen(c)
56-
return loop.pollers[idx].AddRead(fd)
50+
51+
// 3. 选择事件循环(轮询分配)
52+
eventLoopIndex := loop.selectEventLoop()
53+
eventLoop := loop.MultiEventLoop.eventLoops[eventLoopIndex]
54+
55+
// 4. 创建新连接
56+
connInstance := loop.createConn(fd, loop.conns, eventLoop)
57+
58+
// 5. 添加到连接管理器
59+
loop.conns.Add(fd, connInstance)
60+
61+
// 6. 调用回调函数
62+
if loop.callback != nil {
63+
loop.callback.OnOpen(connInstance)
64+
}
65+
66+
// 7. 添加到事件循环
67+
return eventLoop.AddRead(fd)
68+
}
69+
70+
// selectEventLoop 选择事件循环(轮询分配)
71+
func (loop *ClientEventLoop) selectEventLoop() int {
72+
return int(atomic.AddUint32(&loop.next, 1) % uint32(len(loop.MultiEventLoop.eventLoops)))
73+
}
74+
75+
// createConn 创建连接实例
76+
func (loop *ClientEventLoop) createConn(fd int, safeConns *core.SafeConns[Conn], eventLoop core.PollingApi) *Conn {
77+
return newConn(
78+
fd,
79+
safeConns,
80+
nil, // 客户端模式不需要任务池
81+
TaskTypeInEventLoop,
82+
eventLoop,
83+
loop.MultiEventLoop.options.eventLoopReadBufferSize,
84+
loop.MultiEventLoop.options.flowBackPressureRemoveRead,
85+
)
5786
}
5887

5988
func (loop *ClientEventLoop) Serve() {
60-
n := len(loop.pollers)
89+
n := len(loop.MultiEventLoop.eventLoops)
6190
var wg sync.WaitGroup
6291
wg.Add(n)
6392
defer wg.Wait()
6493

6594
for i := 0; i < n; i++ {
6695
go func(idx int) {
6796
defer wg.Done()
68-
buf := make([]byte, loop.options.eventLoopReadBufferSize)
97+
buf := make([]byte, loop.MultiEventLoop.options.eventLoopReadBufferSize)
6998
for {
7099
select {
71100
case <-loop.ctx.Done():
72101
return
73102
default:
74103
}
75-
_, err := loop.pollers[idx].Poll(0, func(fd int, state core.State, pollErr error) {
76-
c := loop.conns[idx].GetUnsafe(fd)
104+
_, err := loop.MultiEventLoop.eventLoops[idx].Poll(0, func(fd int, state core.State, pollErr error) {
105+
c := loop.conns.GetUnsafe(fd)
77106
if pollErr != nil {
78107
if c != nil {
79108
c.Close()
80-
loop.callback.OnClose(c, pollErr)
109+
if loop.callback != nil {
110+
loop.callback.OnClose(c, pollErr)
111+
}
81112
}
82113
return
83114
}
@@ -90,15 +121,21 @@ func (loop *ClientEventLoop) Serve() {
90121
c.mu.Unlock()
91122
if err != nil {
92123
c.Close()
93-
loop.callback.OnClose(c, err)
124+
if loop.callback != nil {
125+
loop.callback.OnClose(c, err)
126+
}
94127
return
95128
}
96129
if n == 0 {
97130
c.Close()
98-
loop.callback.OnClose(c, nil)
131+
if loop.callback != nil {
132+
loop.callback.OnClose(c, nil)
133+
}
99134
return
100135
}
101-
loop.callback.OnData(c, buf[:n])
136+
if loop.callback != nil {
137+
loop.callback.OnData(c, buf[:n])
138+
}
102139
}
103140
})
104141
if err != nil {

client_event_loop_test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package pulse
2+
3+
import (
4+
"context"
5+
"net"
6+
"testing"
7+
)
8+
9+
func TestClientEventLoop_RegisterConn(t *testing.T) {
10+
// 创建客户端事件循环
11+
ctx, cancel := context.WithCancel(context.Background())
12+
defer cancel()
13+
14+
loop := NewClientEventLoop(ctx, WithCallback(&testCallback{}))
15+
16+
// 创建测试连接
17+
conn, err := net.Dial("tcp", "127.0.0.1:8080")
18+
if err != nil {
19+
// 如果连接失败,这是预期的,因为我们没有启动服务器
20+
// 我们只是想测试 RegisterConn 函数本身
21+
t.Logf("Expected connection failure: %v", err)
22+
return
23+
}
24+
25+
// 测试注册连接
26+
err = loop.RegisterConn(conn)
27+
if err != nil {
28+
t.Errorf("RegisterConn failed: %v", err)
29+
}
30+
}
31+
32+
type testCallback struct{}
33+
34+
func (tc *testCallback) OnOpen(c *Conn) {
35+
// 测试回调
36+
}
37+
38+
func (tc *testCallback) OnData(c *Conn, data []byte) {
39+
// 测试回调
40+
}
41+
42+
func (tc *testCallback) OnClose(c *Conn, err error) {
43+
// 测试回调
44+
}
45+
46+
func TestClientEventLoop_NewClientEventLoop(t *testing.T) {
47+
ctx, cancel := context.WithCancel(context.Background())
48+
defer cancel()
49+
50+
loop := NewClientEventLoop(ctx, WithCallback(&testCallback{}))
51+
52+
// 验证基本字段是否正确初始化
53+
if loop.MultiEventLoop == nil {
54+
t.Error("MultiEventLoop should not be nil")
55+
}
56+
57+
if loop.callback == nil {
58+
t.Error("callback should not be nil")
59+
}
60+
61+
if loop.ctx == nil {
62+
t.Error("ctx should not be nil")
63+
}
64+
65+
}
66+
67+
func TestClientEventLoop_SelectEventLoop(t *testing.T) {
68+
ctx, cancel := context.WithCancel(context.Background())
69+
defer cancel()
70+
71+
loop := NewClientEventLoop(ctx, WithCallback(&testCallback{}))
72+
73+
// 测试事件循环选择
74+
index1 := loop.selectEventLoop()
75+
index2 := loop.selectEventLoop()
76+
77+
if index1 < 0 || index1 >= len(loop.MultiEventLoop.eventLoops) {
78+
t.Errorf("Invalid event loop index: %d", index1)
79+
}
80+
81+
if index2 < 0 || index2 >= len(loop.MultiEventLoop.eventLoops) {
82+
t.Errorf("Invalid event loop index: %d", index2)
83+
}
84+
85+
// 验证轮询分配(虽然不能保证每次都不同,但应该合理分布)
86+
t.Logf("Selected event loops: %d, %d", index1, index2)
87+
}

0 commit comments

Comments
 (0)