diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/StopWordsRemoverTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/StopWordsRemoverTests.cs new file mode 100644 index 000000000..4bf614a44 --- /dev/null +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/StopWordsRemoverTests.cs @@ -0,0 +1,78 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.IO; +using Microsoft.Spark.E2ETest.Utils; +using Microsoft.Spark.ML.Feature; +using Microsoft.Spark.Sql; +using Microsoft.Spark.Sql.Types; +using Microsoft.Spark.UnitTest.TestUtils; +using Xunit; + +namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature +{ + [Collection("Spark E2E Tests")] + public class StopWordsRemoverTests : FeatureBaseTests + { + private readonly SparkSession _spark; + + public StopWordsRemoverTests(SparkFixture fixture) : base(fixture) + { + _spark = fixture.Spark; + } + + /// + /// Test signatures for APIs up to Spark 2.3.*. + /// + [Fact] + public void TestSignaturesV2_3_X() + { + string expectedUid = "theUidWithOutLocale"; + string expectedInputCol = "input_col"; + string expectedOutputCol = "output_col"; + bool expectedCaseSensitive = false; + var expectedStopWords = new string[] { "test1", "test2" }; + + DataFrame input = _spark.Sql("SELECT split('Hi I heard about Spark', ' ') as input_col"); + + StopWordsRemover stopWordsRemover = new StopWordsRemover(expectedUid) + .SetInputCol(expectedInputCol) + .SetOutputCol(expectedOutputCol) + .SetCaseSensitive(expectedCaseSensitive) + .SetStopWords(expectedStopWords); + + Assert.Equal(expectedUid, stopWordsRemover.Uid()); + Assert.Equal(expectedInputCol, stopWordsRemover.GetInputCol()); + Assert.Equal(expectedOutputCol, stopWordsRemover.GetOutputCol()); + Assert.Equal(expectedCaseSensitive, stopWordsRemover.GetCaseSensitive()); + Assert.Equal(expectedStopWords, stopWordsRemover.GetStopWords()); + Assert.NotEmpty(StopWordsRemover.LoadDefaultStopWords("english")); + + using (var tempDirectory = new TemporaryDirectory()) + { + string savePath = Path.Join(tempDirectory.Path, "StopWordsRemover"); + stopWordsRemover.Save(savePath); + + StopWordsRemover loadedStopWordsRemover = StopWordsRemover.Load(savePath); + Assert.Equal(stopWordsRemover.Uid(), loadedStopWordsRemover.Uid()); + } + + Assert.IsType(stopWordsRemover.TransformSchema(input.Schema())); + Assert.IsType(stopWordsRemover.Transform(input)); + + TestFeatureBase(stopWordsRemover, "inputCol", "input_col"); + } + + /// + /// Test signatures for APIs introduced in Spark 2.4.*. + /// + [SkipIfSparkVersionIsLessThan(Versions.V2_4_0)] + public void TestSignaturesV2_4_X() + { + string expectedLocale = "en_GB"; + StopWordsRemover stopWordsRemover = new StopWordsRemover().SetLocale(expectedLocale); + Assert.Equal(expectedLocale, stopWordsRemover.GetLocale()); + } + } +} diff --git a/src/csharp/Microsoft.Spark/ML/Feature/StopWordsRemover.cs b/src/csharp/Microsoft.Spark/ML/Feature/StopWordsRemover.cs new file mode 100644 index 000000000..19458ea2e --- /dev/null +++ b/src/csharp/Microsoft.Spark/ML/Feature/StopWordsRemover.cs @@ -0,0 +1,176 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using Microsoft.Spark.Interop; +using Microsoft.Spark.Interop.Ipc; +using Microsoft.Spark.Sql; +using Microsoft.Spark.Sql.Types; + +namespace Microsoft.Spark.ML.Feature +{ + /// + /// A feature transformer that filters out stop words from input. + /// + public class StopWordsRemover : FeatureBase, IJvmObjectReferenceProvider + { + private static readonly string s_stopWordsRemoverClassName = + "org.apache.spark.ml.feature.StopWordsRemover"; + + /// + /// Create a without any parameters. + /// + public StopWordsRemover() : base(s_stopWordsRemoverClassName) + { + } + + /// + /// Create a with a UID that is used to give the + /// a unique ID. + /// + /// An immutable unique ID for the object and its derivatives. + public StopWordsRemover(string uid) : base(s_stopWordsRemoverClassName, uid) + { + } + + internal StopWordsRemover(JvmObjectReference jvmObject) : base(jvmObject) + { + } + + JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject; + + /// + /// Sets the column that the should read from. + /// + /// The name of the column to use as the source + /// New object + public StopWordsRemover SetInputCol(string value) => + WrapAsStopWordsRemover(_jvmObject.Invoke("setInputCol", value)); + + /// + /// The will create a new column in the DataFrame, this is the + /// name of the new column. + /// + /// The name of the column to use as the target + /// New object + public StopWordsRemover SetOutputCol(string value) => + WrapAsStopWordsRemover(_jvmObject.Invoke("setOutputCol", value)); + + /// + /// Executes the and transforms the DataFrame to include the new + /// column. + /// + /// The DataFrame to transform + /// + /// New object with the source transformed + /// + public DataFrame Transform(DataFrame source) => + new DataFrame((JvmObjectReference)_jvmObject.Invoke("transform", source)); + + /// + /// Gets the column that the should read from. + /// + /// Input column name + public string GetInputCol() => (string)_jvmObject.Invoke("getInputCol"); + + /// + /// The will create a new column in the DataFrame, this is the + /// name of the new column. + /// + /// The output column name + public string GetOutputCol() => (string)_jvmObject.Invoke("getOutputCol"); + + /// + /// Sets locale for transform. + /// Refer java.util.locale.getavailablelocales() for all available locales. + /// + /// Locale to be used for transform + /// New object + [Since(Versions.V2_4_0)] + public StopWordsRemover SetLocale(string value) => + WrapAsStopWordsRemover(_jvmObject.Invoke("setLocale", value)); + + /// + /// Gets locale for transform + /// + /// The locale + [Since(Versions.V2_4_0)] + public string GetLocale() => (string)_jvmObject.Invoke("getLocale"); + + /// + /// Sets case sensitivity. + /// + /// true if case sensitive, false otherwise + /// New object + public StopWordsRemover SetCaseSensitive(bool value) => + WrapAsStopWordsRemover(_jvmObject.Invoke("setCaseSensitive", value)); + + /// + /// Gets case sensitivity. + /// + /// true if case sensitive, false otherwise + public bool GetCaseSensitive() => (bool)_jvmObject.Invoke("getCaseSensitive"); + + /// + /// Sets custom stop words. + /// + /// Custom stop words + /// New object + public StopWordsRemover SetStopWords(IEnumerable values) => + WrapAsStopWordsRemover(_jvmObject.Invoke("setStopWords", values)); + + /// + /// Gets the custom stop words. + /// + /// Custom stop words + public IEnumerable GetStopWords() => + (IEnumerable)_jvmObject.Invoke("getStopWords"); + + /// + /// Check transform validity and derive the output schema from the input schema. + /// + /// This checks for validity of interactions between parameters during Transform and + /// raises an exception if any parameter value is invalid. + /// + /// Typical implementation should first conduct verification on schema change and parameter + /// validity, including complex parameter interaction checks. + /// + /// + /// The of the which will be transformed. + /// + /// + /// The of the output schema that would have been derived from the + /// input schema, if Transform had been called. + /// + public StructType TransformSchema(StructType value) => + new StructType( + (JvmObjectReference)_jvmObject.Invoke( + "transformSchema", + DataType.FromJson(_jvmObject.Jvm, value.Json))); + + /// + /// Load default stop words of given language for + /// transform. + /// Supported languages: danish, dutch, english, finnish, french, german, + /// hungarian, italian, norwegian, portuguese, russian, spanish, swedish, turkish. + /// + /// Language + /// Default stop words for the given language + public static string[] LoadDefaultStopWords(string language) => + (string[])SparkEnvironment.JvmBridge.CallStaticJavaMethod( + s_stopWordsRemoverClassName, "loadDefaultStopWords", language); + + /// + /// Loads the that was previously saved using Save. + /// + /// The path the previous was saved to + /// New object, loaded from path + public static StopWordsRemover Load(string path) => + WrapAsStopWordsRemover( + SparkEnvironment.JvmBridge.CallStaticJavaMethod(s_stopWordsRemoverClassName, "load", path)); + + private static StopWordsRemover WrapAsStopWordsRemover(object obj) => + new StopWordsRemover((JvmObjectReference)obj); + } +}