Skip to content

Commit 0430d86

Browse files
committed
add multiple columns interface
1 parent f98a45f commit 0430d86

File tree

5 files changed

+169
-1
lines changed

5 files changed

+169
-1
lines changed

src/main/scala/com/johnsnowlabs/nlp/HasInputAnnotationCols.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ trait HasInputAnnotationCols extends Params {
3434
new StringArrayParam(this, "inputCols", "the input annotation columns")
3535

3636
/** Overrides required annotators column if different than default */
37-
final def setInputCols(value: Array[String]): this.type = {
37+
def setInputCols(value: Array[String]): this.type = {
3838
require(
3939
value.length == inputAnnotatorTypes.length,
4040
s"setInputCols in ${this.uid} expecting ${inputAnnotatorTypes.length} columns. " +
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright 2017-2021 John Snow Labs
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.johnsnowlabs.nlp
18+
19+
import com.johnsnowlabs.nlp.AnnotatorType.CHUNK
20+
import org.apache.spark.ml.param.{Params, StringArrayParam}
21+
import org.apache.spark.sql.types.StructType
22+
23+
trait HasMultipleInputAnnotationCols extends HasInputAnnotationCols {
24+
25+
val inputAnnotatorType: String
26+
27+
lazy override val inputAnnotatorTypes: Array[String] = getInputCols.map(_ =>inputAnnotatorType)
28+
29+
override def setInputCols(value: Array[String]): this.type = {
30+
set(inputCols, value)
31+
}
32+
33+
34+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package com.johnsnowlabs.nlp.annotators.multipleannotations
2+
3+
import com.johnsnowlabs.nlp.{AnnotatorApproach, HasMultipleInputAnnotationCols}
4+
import com.johnsnowlabs.nlp.AnnotatorType.{CHUNK, DOCUMENT}
5+
import org.apache.spark.ml.PipelineModel
6+
import org.apache.spark.ml.util.Identifiable
7+
import org.apache.spark.sql.Dataset
8+
9+
10+
class MultiColumnApproach(override val uid: String) extends AnnotatorApproach[MultiColumnsModel] with HasMultipleInputAnnotationCols{
11+
12+
def this() = this(Identifiable.randomUID("multiplecolums"))
13+
override val description: String = "Example multiple columns"
14+
15+
/**
16+
* Input annotator types: CHUNK
17+
*
18+
* @group anno
19+
*/
20+
override val outputAnnotatorType: AnnotatorType = DOCUMENT
21+
/**
22+
* Output annotator types: CHUNK, CHUNK
23+
*
24+
* @group anno
25+
*/
26+
override val inputAnnotatorType: AnnotatorType = DOCUMENT
27+
28+
29+
/** whether to merge overlapping matched chunks. Defaults to true
30+
*
31+
* @group param
32+
* */
33+
34+
override def train(dataset: Dataset[_], recursivePipeline: Option[PipelineModel]): MultiColumnsModel = {
35+
36+
new MultiColumnsModel().setInputCols($(inputCols)).setOutputCol($(outputCol))
37+
}
38+
39+
40+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package com.johnsnowlabs.nlp.annotators.multipleannotations
2+
3+
import com.johnsnowlabs.nlp.AnnotatorType.{CHUNK, DOCUMENT}
4+
import com.johnsnowlabs.nlp._
5+
import org.apache.spark.ml.util.Identifiable
6+
7+
8+
class MultiColumnsModel(override val uid: String) extends AnnotatorModel[MultiColumnsModel]
9+
with HasMultipleInputAnnotationCols
10+
with HasSimpleAnnotate[MultiColumnsModel]{
11+
12+
def this() = this(Identifiable.randomUID("MERGE"))
13+
14+
15+
/**
16+
* Input annotator types: CHUNK
17+
*
18+
* @group anno
19+
*/
20+
override val outputAnnotatorType: AnnotatorType = DOCUMENT
21+
22+
23+
/**
24+
* Multiple columns
25+
*
26+
* @group anno
27+
*/
28+
29+
override val inputAnnotatorType: String = DOCUMENT
30+
31+
/**
32+
* Merges columns of chunk Annotations while considering false positives and replacements.
33+
* @param annotations a Sequence of chunks to merge
34+
* @return a Sequence of Merged CHUNK Annotations
35+
*/
36+
override def annotate(annotations: Seq[Annotation]): Seq[Annotation] = {
37+
annotations
38+
}
39+
40+
41+
}
42+
43+
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package com.johnsnowlabs.nlp.annotators.multipleannotations
2+
3+
import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector
4+
import com.johnsnowlabs.nlp.{ContentProvider, DocumentAssembler, LightPipeline, RecursivePipeline, SparkAccessor}
5+
import com.johnsnowlabs.nlp.annotators.{TextMatcher, Tokenizer}
6+
import com.johnsnowlabs.nlp.util.io.ReadAs
7+
import com.johnsnowlabs.tags.FastTest
8+
import org.apache.spark.ml.Pipeline
9+
import org.scalatest.flatspec.AnyFlatSpec
10+
11+
class MultiannotationsSpec extends AnyFlatSpec {
12+
import SparkAccessor.spark.implicits._
13+
14+
"An multiple anootator chunks" should "transform data " taggedAs FastTest in {
15+
val data = SparkAccessor.spark.sparkContext.parallelize(Seq("Example text")).toDS().toDF("text")
16+
17+
val documentAssembler = new DocumentAssembler()
18+
.setInputCol("text")
19+
.setOutputCol("document")
20+
21+
val documentAssembler2 = new DocumentAssembler()
22+
.setInputCol("text")
23+
.setOutputCol("document2")
24+
25+
val documentAssembler3 = new DocumentAssembler()
26+
.setInputCol("text")
27+
.setOutputCol("document3")
28+
29+
val multipleColumns = new MultiColumnApproach().setInputCols("document","document2","document3").setOutputCol("merge")
30+
31+
val pipeline = new Pipeline()
32+
.setStages(Array(
33+
documentAssembler,
34+
documentAssembler2,
35+
documentAssembler3,
36+
multipleColumns
37+
))
38+
39+
val pipelineModel = pipeline.fit(data)
40+
41+
pipelineModel.transform(data).show(truncate = false)
42+
43+
val result = new LightPipeline(pipelineModel).annotate("My document")
44+
45+
println(result)
46+
47+
}
48+
49+
50+
51+
}

0 commit comments

Comments
 (0)