Skip to content

Commit 29ad2cb

Browse files
authored
Expose DataStreamWriter.ForeachBatch API (#549)
1 parent 1cd9cca commit 29ad2cb

File tree

33 files changed

+1731
-156
lines changed

33 files changed

+1731
-156
lines changed

src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/DataStreamWriterTests.cs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Collections.Generic;
77
using System.IO;
88
using System.Linq;
9+
using System.Threading;
910
using Microsoft.Spark.E2ETest.Utils;
1011
using Microsoft.Spark.Sql;
1112
using Microsoft.Spark.Sql.Streaming;
@@ -67,6 +68,69 @@ public void TestSignaturesV2_3_X()
6768
Assert.IsType<DataStreamWriter>(dsw.Trigger(Trigger.Once()));
6869
}
6970

71+
[SkipIfSparkVersionIsLessThan(Versions.V2_4_0)]
72+
public void TestForeachBatch()
73+
{
74+
// Temporary folder to put our test stream input.
75+
using var srcTempDirectory = new TemporaryDirectory();
76+
// Temporary folder to write ForeachBatch output.
77+
using var dstTempDirectory = new TemporaryDirectory();
78+
79+
Func<Column, Column> outerUdf = Udf<int, int>(i => i + 100);
80+
81+
// id column: [0, 1, ..., 9]
82+
WriteCsv(0, 10, Path.Combine(srcTempDirectory.Path, "input1.csv"));
83+
84+
DataStreamWriter dsw = _spark
85+
.ReadStream()
86+
.Schema("id INT")
87+
.Csv(srcTempDirectory.Path)
88+
.WriteStream()
89+
.ForeachBatch((df, id) =>
90+
{
91+
Func<Column, Column> innerUdf = Udf<int, int>(i => i + 200);
92+
df.Select(outerUdf(innerUdf(Col("id"))))
93+
.Write()
94+
.Csv(Path.Combine(dstTempDirectory.Path, id.ToString()));
95+
});
96+
97+
StreamingQuery sq = dsw.Start();
98+
99+
// Process until all available data in the source has been processed and committed
100+
// to the ForeachBatch sink.
101+
sq.ProcessAllAvailable();
102+
103+
// Add new file to the source path. The spark stream will read any new files
104+
// added to the source path.
105+
// id column: [10, 11, ..., 19]
106+
WriteCsv(10, 10, Path.Combine(srcTempDirectory.Path, "input2.csv"));
107+
108+
// Process until all available data in the source has been processed and committed
109+
// to the ForeachBatch sink.
110+
sq.ProcessAllAvailable();
111+
sq.Stop();
112+
113+
// Verify folders in the destination path.
114+
string[] csvPaths =
115+
Directory.GetDirectories(dstTempDirectory.Path).OrderBy(s => s).ToArray();
116+
var expectedPaths = new string[]
117+
{
118+
Path.Combine(dstTempDirectory.Path, "0"),
119+
Path.Combine(dstTempDirectory.Path, "1"),
120+
};
121+
Assert.True(expectedPaths.SequenceEqual(csvPaths));
122+
123+
// Read the generated csv paths and verify contents.
124+
DataFrame df = _spark
125+
.Read()
126+
.Schema("id INT")
127+
.Csv(csvPaths[0], csvPaths[1])
128+
.Sort("id");
129+
130+
IEnumerable<int> actualIds = df.Collect().Select(r => r.GetAs<int>("id"));
131+
Assert.True(Enumerable.Range(300, 20).SequenceEqual(actualIds));
132+
}
133+
70134
[SkipIfSparkVersionIsLessThan(Versions.V2_4_0)]
71135
public void TestForeach()
72136
{
@@ -200,6 +264,15 @@ private void TestAndValidateForeach(
200264
foreachWriterOutputDF.Collect().Select(r => r.Values));
201265
}
202266

267+
private void WriteCsv(int start, int count, string path)
268+
{
269+
using var streamWriter = new StreamWriter(path);
270+
foreach (int i in Enumerable.Range(start, count))
271+
{
272+
streamWriter.WriteLine(i);
273+
}
274+
}
275+
203276
[Serializable]
204277
private class TestForeachWriter : IForeachWriter
205278
{
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Concurrent;
7+
using System.Collections.Generic;
8+
using System.IO;
9+
using System.Linq;
10+
using System.Net;
11+
using System.Threading;
12+
using System.Threading.Tasks;
13+
using Microsoft.Spark.Interop.Ipc;
14+
using Microsoft.Spark.Network;
15+
using Moq;
16+
using Xunit;
17+
18+
namespace Microsoft.Spark.UnitTest
19+
{
20+
[Collection("Spark Unit Tests")]
21+
public class CallbackTests
22+
{
23+
private readonly Mock<IJvmBridge> _mockJvm;
24+
25+
public CallbackTests(SparkFixture fixture)
26+
{
27+
_mockJvm = fixture.MockJvm;
28+
}
29+
30+
[Fact]
31+
public async Task TestCallbackIds()
32+
{
33+
int numToRegister = 100;
34+
var callbackServer = new CallbackServer(_mockJvm.Object, false);
35+
var callbackHandler = new TestCallbackHandler();
36+
37+
var ids = new ConcurrentBag<int>();
38+
var tasks = new List<Task>();
39+
for (int i = 0; i < numToRegister; ++i)
40+
{
41+
tasks.Add(
42+
Task.Run(() => ids.Add(callbackServer.RegisterCallback(callbackHandler))));
43+
}
44+
45+
await Task.WhenAll(tasks);
46+
47+
IOrderedEnumerable<int> actualIds = ids.OrderBy(i => i);
48+
IEnumerable<int> expectedIds = Enumerable.Range(1, numToRegister);
49+
Assert.True(expectedIds.SequenceEqual(actualIds));
50+
}
51+
52+
[Fact]
53+
public void TestCallbackServer()
54+
{
55+
var callbackServer = new CallbackServer(_mockJvm.Object, false);
56+
var callbackHandler = new TestCallbackHandler();
57+
58+
callbackHandler.Id = callbackServer.RegisterCallback(callbackHandler);
59+
Assert.Equal(1, callbackHandler.Id);
60+
61+
using ISocketWrapper callbackSocket = SocketFactory.CreateSocket();
62+
callbackServer.Run(callbackSocket);
63+
64+
int connectionNumber = 10;
65+
for (int i = 0; i < connectionNumber; ++i)
66+
{
67+
var ipEndpoint = (IPEndPoint)callbackSocket.LocalEndPoint;
68+
ISocketWrapper clientSocket = SocketFactory.CreateSocket();
69+
clientSocket.Connect(ipEndpoint.Address, ipEndpoint.Port);
70+
71+
WriteAndReadTestData(clientSocket, callbackHandler, i, new CancellationToken());
72+
}
73+
74+
Assert.Equal(connectionNumber, callbackServer.CurrentNumConnections);
75+
76+
IOrderedEnumerable<int> actualValues = callbackHandler.Inputs.OrderBy(i => i);
77+
IEnumerable<int> expectedValues = Enumerable
78+
.Range(0, connectionNumber)
79+
.Select(i => callbackHandler.Apply(i))
80+
.OrderBy(i => i);
81+
Assert.True(expectedValues.SequenceEqual(actualValues));
82+
}
83+
84+
[Fact]
85+
public void TestCallbackHandlers()
86+
{
87+
var tokenSource = new CancellationTokenSource();
88+
var callbackHandlersDict = new ConcurrentDictionary<int, ICallbackHandler>();
89+
int inputToHandler = 1;
90+
{
91+
// Test CallbackConnection using a ICallbackHandler that runs
92+
// normally without error.
93+
var callbackHandler = new TestCallbackHandler
94+
{
95+
Id = 1
96+
};
97+
callbackHandlersDict[callbackHandler.Id] = callbackHandler;
98+
TestCallbackConnection(
99+
callbackHandlersDict,
100+
callbackHandler,
101+
inputToHandler,
102+
tokenSource.Token);
103+
Assert.Single(callbackHandler.Inputs);
104+
Assert.Equal(
105+
callbackHandler.Apply(inputToHandler),
106+
callbackHandler.Inputs.First());
107+
}
108+
{
109+
// Test CallbackConnection using a ICallbackHandler that
110+
// throws an exception.
111+
var callbackHandler = new ThrowsExceptionHandler
112+
{
113+
Id = 2
114+
};
115+
callbackHandlersDict[callbackHandler.Id] = callbackHandler;
116+
TestCallbackConnection(
117+
callbackHandlersDict,
118+
callbackHandler,
119+
inputToHandler,
120+
tokenSource.Token);
121+
Assert.Empty(callbackHandler.Inputs);
122+
}
123+
{
124+
// Test CallbackConnection when cancellation has been requested for the token.
125+
tokenSource.Cancel();
126+
var callbackHandler = new TestCallbackHandler
127+
{
128+
Id = 3
129+
};
130+
callbackHandlersDict[callbackHandler.Id] = callbackHandler;
131+
TestCallbackConnection(
132+
callbackHandlersDict,
133+
callbackHandler,
134+
inputToHandler,
135+
tokenSource.Token);
136+
Assert.Empty(callbackHandler.Inputs);
137+
}
138+
}
139+
140+
private void TestCallbackConnection(
141+
ConcurrentDictionary<int, ICallbackHandler> callbackHandlersDict,
142+
ITestCallbackHandler callbackHandler,
143+
int inputToHandler,
144+
CancellationToken token)
145+
{
146+
using ISocketWrapper serverListener = SocketFactory.CreateSocket();
147+
serverListener.Listen();
148+
149+
var ipEndpoint = (IPEndPoint)serverListener.LocalEndPoint;
150+
ISocketWrapper clientSocket = SocketFactory.CreateSocket();
151+
clientSocket.Connect(ipEndpoint.Address, ipEndpoint.Port);
152+
153+
var callbackConnection = new CallbackConnection(0, clientSocket, callbackHandlersDict);
154+
Task.Run(() => callbackConnection.Run(token));
155+
156+
using ISocketWrapper serverSocket = serverListener.Accept();
157+
WriteAndReadTestData(serverSocket, callbackHandler, inputToHandler, token);
158+
}
159+
160+
private void WriteAndReadTestData(
161+
ISocketWrapper socket,
162+
ITestCallbackHandler callbackHandler,
163+
int inputToHandler,
164+
CancellationToken token)
165+
{
166+
Stream inputStream = socket.InputStream;
167+
Stream outputStream = socket.OutputStream;
168+
169+
SerDe.Write(outputStream, (int)CallbackFlags.CALLBACK);
170+
SerDe.Write(outputStream, callbackHandler.Id);
171+
SerDe.Write(outputStream, sizeof(int));
172+
SerDe.Write(outputStream, inputToHandler);
173+
SerDe.Write(outputStream, (int)CallbackFlags.END_OF_STREAM);
174+
outputStream.Flush();
175+
176+
if (token.IsCancellationRequested)
177+
{
178+
Assert.Throws<IOException>(() => SerDe.ReadInt32(inputStream));
179+
}
180+
else
181+
{
182+
int callbackFlag = SerDe.ReadInt32(inputStream);
183+
if (callbackFlag == (int)CallbackFlags.DOTNET_EXCEPTION_THROWN)
184+
{
185+
string exceptionMessage = SerDe.ReadString(inputStream);
186+
Assert.False(string.IsNullOrEmpty(exceptionMessage));
187+
Assert.Contains(callbackHandler.ExceptionMessage, exceptionMessage);
188+
}
189+
else
190+
{
191+
Assert.Equal((int)CallbackFlags.END_OF_STREAM, callbackFlag);
192+
}
193+
}
194+
}
195+
196+
private class TestCallbackHandler : ICallbackHandler, ITestCallbackHandler
197+
{
198+
public void Run(Stream inputStream) => Inputs.Add(Apply(SerDe.ReadInt32(inputStream)));
199+
200+
public ConcurrentBag<int> Inputs { get; } = new ConcurrentBag<int>();
201+
202+
public int Id { get; set; }
203+
204+
public bool Throws { get; } = false;
205+
206+
public string ExceptionMessage => throw new NotImplementedException();
207+
208+
public int Apply(int i) => 10 * i;
209+
}
210+
211+
private class ThrowsExceptionHandler : ICallbackHandler, ITestCallbackHandler
212+
{
213+
public void Run(Stream inputStream) => throw new Exception(ExceptionMessage);
214+
215+
public ConcurrentBag<int> Inputs { get; } = new ConcurrentBag<int>();
216+
217+
public int Id { get; set; }
218+
219+
public bool Throws { get; } = true;
220+
221+
public string ExceptionMessage { get; } = "Dotnet Callback Handler Exception Message";
222+
223+
public int Apply(int i) => throw new NotImplementedException();
224+
}
225+
226+
private interface ITestCallbackHandler
227+
{
228+
ConcurrentBag<int> Inputs { get; }
229+
230+
int Id { get; set; }
231+
232+
bool Throws { get; }
233+
234+
string ExceptionMessage { get; }
235+
236+
int Apply(int i);
237+
}
238+
}
239+
}

src/csharp/Microsoft.Spark.UnitTest/SparkFixture.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6-
using System.IO;
76
using Microsoft.Spark.Interop;
87
using Microsoft.Spark.Interop.Ipc;
98
using Moq;
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.IO;
7+
using Xunit.Abstractions;
8+
9+
namespace Microsoft.Spark.UnitTest.TestUtils
10+
{
11+
// Tests can subclass this to get Console output to display when using
12+
// xUnit testing framework.
13+
// Workaround found at https://github.com/microsoft/vstest/issues/799
14+
public class XunitConsoleOutHelper : IDisposable
15+
{
16+
private readonly ITestOutputHelper _output;
17+
private readonly TextWriter _originalOut;
18+
private readonly TextWriter _textWriter;
19+
20+
public XunitConsoleOutHelper(ITestOutputHelper output)
21+
{
22+
_output = output;
23+
_originalOut = Console.Out;
24+
_textWriter = new StringWriter();
25+
Console.SetOut(_textWriter);
26+
}
27+
28+
public void Dispose()
29+
{
30+
_output.WriteLine(_textWriter.ToString());
31+
Console.SetOut(_originalOut);
32+
}
33+
}
34+
}

src/csharp/Microsoft.Spark.Worker.UnitTest/CommandExecutorTests.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System;
65
using System.Collections;
76
using System.Collections.Generic;
87
using System.IO;

src/csharp/Microsoft.Spark.Worker.UnitTest/DaemonWorkerTests.cs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,10 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6-
using System.Collections;
76
using System.Collections.Generic;
8-
using System.IO;
97
using System.Net;
108
using System.Threading.Tasks;
11-
using Microsoft.Spark.Interop.Ipc;
129
using Microsoft.Spark.Network;
13-
using Razorvine.Pickle;
1410
using Xunit;
1511

1612
namespace Microsoft.Spark.Worker.UnitTest

0 commit comments

Comments
 (0)