|
| 1 | +/* |
| 2 | +* Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | +* contributor license agreements. See the NOTICE file distributed with |
| 4 | +* this work for additional information regarding copyright ownership. |
| 5 | +* The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | +* (the "License"); you may not use this file except in compliance with |
| 7 | +* the License. You may obtain a copy of the License at |
| 8 | +* |
| 9 | +* http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +* |
| 11 | +* Unless required by applicable law or agreed to in writing, software |
| 12 | +* distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +* See the License for the specific language governing permissions and |
| 15 | +* limitations under the License. |
| 16 | +*/ |
| 17 | + |
| 18 | +package org.apache.spark.sql |
| 19 | + |
| 20 | +import org.apache.spark.sql.catalyst.analysis |
| 21 | +import org.apache.spark.sql.catalyst.analysis.Star |
| 22 | + |
| 23 | +import scala.language.implicitConversions |
| 24 | + |
| 25 | +import org.apache.spark.sql.catalyst.expressions._ |
| 26 | +import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr} |
| 27 | +import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} |
| 28 | +import org.apache.spark.sql.types._ |
| 29 | + |
| 30 | + |
| 31 | +object Literal { |
| 32 | + def apply(literal: Boolean): Column = new Column(LiteralExpr(literal)) |
| 33 | + def apply(literal: Byte): Column = new Column(LiteralExpr(literal)) |
| 34 | + def apply(literal: Short): Column = new Column(LiteralExpr(literal)) |
| 35 | + def apply(literal: Int): Column = new Column(LiteralExpr(literal)) |
| 36 | + def apply(literal: Long): Column = new Column(LiteralExpr(literal)) |
| 37 | + def apply(literal: Float): Column = new Column(LiteralExpr(literal)) |
| 38 | + def apply(literal: Double): Column = new Column(LiteralExpr(literal)) |
| 39 | + def apply(literal: String): Column = new Column(LiteralExpr(literal)) |
| 40 | + def apply(literal: BigDecimal): Column = new Column(LiteralExpr(literal)) |
| 41 | + def apply(literal: java.math.BigDecimal): Column = new Column(LiteralExpr(literal)) |
| 42 | + def apply(literal: java.sql.Timestamp): Column = new Column(LiteralExpr(literal)) |
| 43 | + def apply(literal: java.sql.Date): Column = new Column(LiteralExpr(literal)) |
| 44 | + def apply(literal: Array[Byte]): Column = new Column(LiteralExpr(literal)) |
| 45 | + def apply(literal: Null): Column = new Column(LiteralExpr(null)) |
| 46 | +} |
| 47 | + |
| 48 | + |
| 49 | +object Column { |
| 50 | + def unapply(col: Column): Option[Expression] = Some(col.expr) |
| 51 | +} |
| 52 | + |
| 53 | + |
| 54 | +class Column( |
| 55 | + sqlContext: Option[SQLContext], |
| 56 | + plan: Option[LogicalPlan], |
| 57 | + val expr: Expression) |
| 58 | + extends DataFrame(sqlContext, plan) with ExpressionApi[Column] { |
| 59 | + |
| 60 | + def this(expr: Expression) = this(None, None, expr) |
| 61 | + |
| 62 | + def this(name: String) = this(name match { |
| 63 | + case "*" => Star(None) |
| 64 | + case _ if name.endsWith(".*") => Star(Some(name.substring(0, name.length - 2))) |
| 65 | + case _ => analysis.UnresolvedAttribute(name) |
| 66 | + }) |
| 67 | + |
| 68 | + private[this] implicit def toColumn(expr: Expression): Column = { |
| 69 | + val projectedPlan = plan.map { p => |
| 70 | + Project(Seq(expr match { |
| 71 | + case named: NamedExpression => named |
| 72 | + case unnamed: Expression => Alias(unnamed, "col")() |
| 73 | + }), p) |
| 74 | + } |
| 75 | + new Column(sqlContext, projectedPlan, expr) |
| 76 | + } |
| 77 | + |
| 78 | + override def unary_- : Column = UnaryMinus(expr) |
| 79 | + |
| 80 | + override def ||(other: Column): Column = Or(expr, other.expr) |
| 81 | + |
| 82 | + override def unary_~ : Column = BitwiseNot(expr) |
| 83 | + |
| 84 | + override def !==(other: Column): Column = Not(EqualTo(expr, other.expr)) |
| 85 | + |
| 86 | + override def >(other: Column): Column = GreaterThan(expr, other.expr) |
| 87 | + |
| 88 | + override def unary_! : Column = Not(expr) |
| 89 | + |
| 90 | + override def &(other: Column): Column = BitwiseAnd(expr, other.expr) |
| 91 | + |
| 92 | + override def /(other: Column): Column = Divide(expr, other.expr) |
| 93 | + |
| 94 | + override def &&(other: Column): Column = And(expr, other.expr) |
| 95 | + |
| 96 | + override def |(other: Column): Column = BitwiseOr(expr, other.expr) |
| 97 | + |
| 98 | + override def ^(other: Column): Column = BitwiseXor(expr, other.expr) |
| 99 | + |
| 100 | + override def <=>(other: Column): Column = EqualNullSafe(expr, other.expr) |
| 101 | + |
| 102 | + override def ===(other: Column): Column = EqualTo(expr, other.expr) |
| 103 | + |
| 104 | + override def equalTo(other: Column): Column = this === other |
| 105 | + |
| 106 | + override def +(other: Column): Column = Add(expr, other.expr) |
| 107 | + |
| 108 | + override def rlike(other: Column): Column = RLike(expr, other.expr) |
| 109 | + |
| 110 | + override def %(other: Column): Column = Remainder(expr, other.expr) |
| 111 | + |
| 112 | + override def in(list: Column*): Column = In(expr, list.map(_.expr)) |
| 113 | + |
| 114 | + override def getItem(ordinal: Int): Column = GetItem(expr, LiteralExpr(ordinal)) |
| 115 | + |
| 116 | + override def getItem(ordinal: Column): Column = GetItem(expr, ordinal.expr) |
| 117 | + |
| 118 | + override def <=(other: Column): Column = LessThanOrEqual(expr, other.expr) |
| 119 | + |
| 120 | + override def like(other: Column): Column = Like(expr, other.expr) |
| 121 | + |
| 122 | + override def getField(fieldName: String): Column = GetField(expr, fieldName) |
| 123 | + |
| 124 | + override def isNotNull: Column = IsNotNull(expr) |
| 125 | + |
| 126 | + override def substr(startPos: Column, len: Column): Column = |
| 127 | + Substring(expr, startPos.expr, len.expr) |
| 128 | + |
| 129 | + override def <(other: Column): Column = LessThan(expr, other.expr) |
| 130 | + |
| 131 | + override def isNull: Column = IsNull(expr) |
| 132 | + |
| 133 | + override def contains(other: Column): Column = Contains(expr, other.expr) |
| 134 | + |
| 135 | + override def -(other: Column): Column = Subtract(expr, other.expr) |
| 136 | + |
| 137 | + override def desc: Column = SortOrder(expr, Descending) |
| 138 | + |
| 139 | + override def >=(other: Column): Column = GreaterThanOrEqual(expr, other.expr) |
| 140 | + |
| 141 | + override def asc: Column = SortOrder(expr, Ascending) |
| 142 | + |
| 143 | + override def endsWith(other: Column): Column = EndsWith(expr, other.expr) |
| 144 | + |
| 145 | + override def *(other: Column): Column = Multiply(expr, other.expr) |
| 146 | + |
| 147 | + override def startsWith(other: Column): Column = StartsWith(expr, other.expr) |
| 148 | + |
| 149 | + override def as(alias: String): Column = Alias(expr, alias)() |
| 150 | + |
| 151 | + override def cast(to: DataType): Column = Cast(expr, to) |
| 152 | +} |
| 153 | + |
| 154 | + |
| 155 | +class ColumnName(name: String) extends Column(name) { |
| 156 | + |
| 157 | + /** Creates a new AttributeReference of type boolean */ |
| 158 | + def boolean: StructField = StructField(name, BooleanType) |
| 159 | + |
| 160 | + /** Creates a new AttributeReference of type byte */ |
| 161 | + def byte: StructField = StructField(name, ByteType) |
| 162 | + |
| 163 | + /** Creates a new AttributeReference of type short */ |
| 164 | + def short: StructField = StructField(name, ShortType) |
| 165 | + |
| 166 | + /** Creates a new AttributeReference of type int */ |
| 167 | + def int: StructField = StructField(name, IntegerType) |
| 168 | + |
| 169 | + /** Creates a new AttributeReference of type long */ |
| 170 | + def long: StructField = StructField(name, LongType) |
| 171 | + |
| 172 | + /** Creates a new AttributeReference of type float */ |
| 173 | + def float: StructField = StructField(name, FloatType) |
| 174 | + |
| 175 | + /** Creates a new AttributeReference of type double */ |
| 176 | + def double: StructField = StructField(name, DoubleType) |
| 177 | + |
| 178 | + /** Creates a new AttributeReference of type string */ |
| 179 | + def string: StructField = StructField(name, StringType) |
| 180 | + |
| 181 | + /** Creates a new AttributeReference of type date */ |
| 182 | + def date: StructField = StructField(name, DateType) |
| 183 | + |
| 184 | + /** Creates a new AttributeReference of type decimal */ |
| 185 | + def decimal: StructField = StructField(name, DecimalType.Unlimited) |
| 186 | + |
| 187 | + /** Creates a new AttributeReference of type decimal */ |
| 188 | + def decimal(precision: Int, scale: Int): StructField = |
| 189 | + StructField(name, DecimalType(precision, scale)) |
| 190 | + |
| 191 | + /** Creates a new AttributeReference of type timestamp */ |
| 192 | + def timestamp: StructField = StructField(name, TimestampType) |
| 193 | + |
| 194 | + /** Creates a new AttributeReference of type binary */ |
| 195 | + def binary: StructField = StructField(name, BinaryType) |
| 196 | + |
| 197 | + /** Creates a new AttributeReference of type array */ |
| 198 | + def array(dataType: DataType): StructField = StructField(name, ArrayType(dataType)) |
| 199 | + |
| 200 | + /** Creates a new AttributeReference of type map */ |
| 201 | + def map(keyType: DataType, valueType: DataType): StructField = |
| 202 | + map(MapType(keyType, valueType)) |
| 203 | + |
| 204 | + def map(mapType: MapType): StructField = StructField(name, mapType) |
| 205 | + |
| 206 | + /** Creates a new AttributeReference of type struct */ |
| 207 | + def struct(fields: StructField*): StructField = struct(StructType(fields)) |
| 208 | + |
| 209 | + def struct(structType: StructType): StructField = StructField(name, structType) |
| 210 | +} |
0 commit comments