diff --git a/observable_operator.go b/observable_operator.go index 89bc6798..eaff262a 100644 --- a/observable_operator.go +++ b/observable_operator.go @@ -531,12 +531,22 @@ func (o *ObservableImpl) BufferWithTimeOrCount(timespan Duration, count int, opt checkBuffer := func() { mutex.Lock() - if len(buffer) != 0 { - if !Of(buffer).SendContext(ctx, next) { - mutex.Unlock() - return + switch { + case len(buffer) == 0: + case len(buffer) > count: + for { + if !Of(buffer[:count]).SendContext(ctx, next) { + break + } + buffer = buffer[count:] + if len(buffer) < count { + break + } + } + default: + if Of(buffer).SendContext(ctx, next) { + buffer = make([]interface{}, 0) } - buffer = make([]interface{}, 0) } mutex.Unlock() } @@ -1111,44 +1121,40 @@ func (op *firstOrDefaultOperator) gatherNext(_ context.Context, _ Item, _ chan<- // FlatMap transforms the items emitted by an Observable into Observables, then flatten the emissions from those into a single Observable. func (o *ObservableImpl) FlatMap(apply ItemToObservable, opts ...Option) Observable { - f := func(ctx context.Context, next chan Item, option Option, opts ...Option) { - defer close(next) - observe := o.Observe(opts...) - for { - select { - case <-ctx.Done(): - return - case item, ok := <-observe: - if !ok { - return - } - observe2 := apply(item).Observe(opts...) - loop2: - for { - select { - case <-ctx.Done(): - return - case item, ok := <-observe2: - if !ok { - break loop2 - } - if item.Error() { - item.SendContext(ctx, next) - if option.getErrorStrategy() == StopOnError { - return - } - } else { - if !item.SendContext(ctx, next) { - return - } - } - } - } - } + return observable(o, func() operator { + return &flatMapOperator{apply: apply} + }, false, true, opts...) +} + +type flatMapOperator struct { + apply ItemToObservable +} + +func (op *flatMapOperator) next(ctx context.Context, item Item, dst chan<- Item, operatorOptions operatorOptions) { + observe2 := op.apply(item).Observe() + for item := range observe2 { + if item.E != nil { + Error(item.E).SendContext(ctx, dst) + operatorOptions.stop() + return } + Of(item.V).SendContext(ctx, dst) } +} - return customObservableOperator(f, opts...) +func (op *flatMapOperator) err(ctx context.Context, item Item, dst chan<- Item, operatorOptions operatorOptions) { + op.next(ctx, item, dst, operatorOptions) +} + +func (op *flatMapOperator) end(_ context.Context, _ chan<- Item) { +} + +func (op *flatMapOperator) gatherNext(ctx context.Context, item Item, dst chan<- Item, _ operatorOptions) { + switch item.V.(type) { + case *flatMapOperator: + return + } + item.SendContext(ctx, dst) } // ForEach subscribes to the Observable and receives notifications for each element. diff --git a/observable_operator_test.go b/observable_operator_test.go index 182a7821..71c09148 100644 --- a/observable_operator_test.go +++ b/observable_operator_test.go @@ -308,19 +308,14 @@ func Test_Observable_BufferWithTimeOrCount(t *testing.T) { defer cancel() ch := make(chan Item, 1) obs := FromChannel(ch) - obs = obs.BufferWithTimeOrCount(WithDuration(30*time.Millisecond), 100) + obs = obs.BufferWithTimeOrCount(WithDuration(time.Second), 2) go func() { - for i := 0; i < 10; i++ { + for i := 0; i < 5; i++ { ch <- Of(i) } close(ch) }() - Assert(ctx, t, obs, CustomPredicate(func(items []interface{}) error { - if len(items) == 0 { - return errors.New("items should not be nil") - } - return nil - })) + Assert(ctx, t, obs, HasItems([]interface{}{0, 1}, []interface{}{2, 3}, []interface{}{4})) } func Test_Observable_Contain(t *testing.T) {