out)
}
@VisibleForTesting
- static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion {
+ static class EncryptedMessage extends AbstractFileRegion {
private final SaslEncryptionBackend backend;
private final boolean isByteBuf;
@@ -166,7 +166,7 @@ static class EncryptedMessage extends AbstractReferenceCounted implements FileRe
* This makes assumptions about how netty treats FileRegion instances, because there's no way
* to know beforehand what will be the size of the encrypted message. Namely, it assumes
* that netty will try to transfer data from this message while
- * transfered() < count()
. So these two methods return, technically, wrong data,
+ * transferred() < count()
. So these two methods return, technically, wrong data,
* but netty doesn't know better.
*/
@Override
@@ -183,10 +183,45 @@ public long position() {
* Returns an approximation of the amount of data transferred. See {@link #count()}.
*/
@Override
- public long transfered() {
+ public long transferred() {
return transferred;
}
+ @Override
+ public EncryptedMessage touch(Object o) {
+ super.touch(o);
+ if (buf != null) {
+ buf.touch(o);
+ }
+ if (region != null) {
+ region.touch(o);
+ }
+ return this;
+ }
+
+ @Override
+ public EncryptedMessage retain(int increment) {
+ super.retain(increment);
+ if (buf != null) {
+ buf.retain(increment);
+ }
+ if (region != null) {
+ region.retain(increment);
+ }
+ return this;
+ }
+
+ @Override
+ public boolean release(int decrement) {
+ if (region != null) {
+ region.release(decrement);
+ }
+ if (buf != null) {
+ buf.release(decrement);
+ }
+ return super.release(decrement);
+ }
+
/**
* Transfers data from the original message to the channel, encrypting it in the process.
*
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/AbstractFileRegion.java b/common/network-common/src/main/java/org/apache/spark/network/util/AbstractFileRegion.java
new file mode 100644
index 0000000000000..8651297d97ec2
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/AbstractFileRegion.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import io.netty.channel.FileRegion;
+import io.netty.util.AbstractReferenceCounted;
+
+public abstract class AbstractFileRegion extends AbstractReferenceCounted implements FileRegion {
+
+ @Override
+ @SuppressWarnings("deprecation")
+ public final long transfered() {
+ return transferred();
+ }
+
+ @Override
+ public AbstractFileRegion retain() {
+ super.retain();
+ return this;
+ }
+
+ @Override
+ public AbstractFileRegion retain(int increment) {
+ super.retain(increment);
+ return this;
+ }
+
+ @Override
+ public AbstractFileRegion touch() {
+ super.touch();
+ return this;
+ }
+
+ @Override
+ public AbstractFileRegion touch(Object o) {
+ return this;
+ }
+}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
index 50d9651ccbbb2..8e73ab077a5c1 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
@@ -29,7 +29,7 @@
/**
* A customized frame decoder that allows intercepting raw data.
*
- * This behaves like Netty's frame decoder (with harcoded parameters that match this library's
+ * This behaves like Netty's frame decoder (with hard coded parameters that match this library's
* needs), except it allows an interceptor to be installed to read data directly before it's
* framed.
*
diff --git a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java
index bb1c40c4b0e06..bc94f7ca63a96 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java
@@ -56,7 +56,7 @@ private void testServerToClient(Message msg) {
NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE);
while (!serverChannel.outboundMessages().isEmpty()) {
- clientChannel.writeInbound(serverChannel.readOutbound());
+ clientChannel.writeOneInbound(serverChannel.readOutbound());
}
assertEquals(1, clientChannel.inboundMessages().size());
@@ -72,7 +72,7 @@ private void testClientToServer(Message msg) {
NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE);
while (!clientChannel.outboundMessages().isEmpty()) {
- serverChannel.writeInbound(clientChannel.readOutbound());
+ serverChannel.writeOneInbound(clientChannel.readOutbound());
}
assertEquals(1, serverChannel.inboundMessages().size());
diff --git a/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java
index b341c5681e00c..ecb66fcf2ff76 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java
@@ -23,8 +23,7 @@
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
-import io.netty.channel.FileRegion;
-import io.netty.util.AbstractReferenceCounted;
+import org.apache.spark.network.util.AbstractFileRegion;
import org.junit.Test;
import org.mockito.Mockito;
@@ -108,7 +107,7 @@ private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exc
return Unpooled.wrappedBuffer(channel.getData());
}
- private static class TestFileRegion extends AbstractReferenceCounted implements FileRegion {
+ private static class TestFileRegion extends AbstractFileRegion {
private final int writeCount;
private final int writesPerCall;
@@ -130,7 +129,7 @@ public long position() {
}
@Override
- public long transfered() {
+ public long transferred() {
return 8 * written;
}
diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml
index 9968480ab7658..05335df61a664 100644
--- a/common/network-shuffle/pom.xml
+++ b/common/network-shuffle/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.3.0-SNAPSHOT
+ 2.4.0-SNAPSHOT
../../pom.xml
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
index 3f2f20b4149f1..9cac7d00cc6b6 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
@@ -18,11 +18,11 @@
package org.apache.spark.network.shuffle;
import java.io.File;
+import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;
-import java.nio.file.Files;
import java.util.Arrays;
import org.slf4j.Logger;
@@ -165,7 +165,7 @@ private class DownloadCallback implements StreamCallback {
DownloadCallback(int chunkIndex) throws IOException {
this.targetFile = tempFileManager.createTempFile();
- this.channel = Channels.newChannel(Files.newOutputStream(targetFile.toPath()));
+ this.channel = Channels.newChannel(new FileOutputStream(targetFile));
this.chunkIndex = chunkIndex;
}
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java
index 23438a08fa094..6d201b8fe8d7d 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java
@@ -127,7 +127,7 @@ public void jsonSerializationOfExecutorRegistration() throws IOException {
mapper.readValue(shuffleJson, ExecutorShuffleInfo.class);
assertEquals(parsedShuffleInfo, shuffleInfo);
- // Intentionally keep these hard-coded strings in here, to check backwards-compatability.
+ // Intentionally keep these hard-coded strings in here, to check backwards-compatibility.
// its not legacy yet, but keeping this here in case anybody changes it
String legacyAppIdJson = "{\"appId\":\"foo\", \"execId\":\"bar\"}";
assertEquals(appId, mapper.readValue(legacyAppIdJson, AppExecId.class));
diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml
index ec2db6e5bb88c..564e6583c909e 100644
--- a/common/network-yarn/pom.xml
+++ b/common/network-yarn/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.3.0-SNAPSHOT
+ 2.4.0-SNAPSHOT
../../pom.xml
diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml
index 2d59c71cc3757..2f04abe8c7e88 100644
--- a/common/sketch/pom.xml
+++ b/common/sketch/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.3.0-SNAPSHOT
+ 2.4.0-SNAPSHOT
../../pom.xml
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
index c0b425e729595..37803c7a3b104 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
@@ -34,7 +34,7 @@
*
{@link String}
*
* The false positive probability ({@code FPP}) of a Bloom filter is defined as the probability that
- * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that hasu
+ * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that has
* not actually been put in the {@code BloomFilter}.
*
* The implementation is largely based on the {@code BloomFilter} class from Guava.
diff --git a/common/tags/pom.xml b/common/tags/pom.xml
index f7e586ee777e1..ba127408e1c59 100644
--- a/common/tags/pom.xml
+++ b/common/tags/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.3.0-SNAPSHOT
+ 2.4.0-SNAPSHOT
../../pom.xml
diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml
index a3772a2620088..1527854730394 100644
--- a/common/unsafe/pom.xml
+++ b/common/unsafe/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.3.0-SNAPSHOT
+ 2.4.0-SNAPSHOT
../../pom.xml
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java
index f121b1cd745b8..a6b1f7a16d605 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java
@@ -66,7 +66,7 @@ public static boolean arrayEquals(
i += 1;
}
}
- // for architectures that suport unaligned accesses, chew it up 8 bytes at a time
+ // for architectures that support unaligned accesses, chew it up 8 bytes at a time
if (unaligned || (((leftOffset + i) % 8 == 0) && ((rightOffset + i) % 8 == 0))) {
while (i <= length - 8) {
if (Platform.getLong(leftBase, leftOffset + i) !=
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
index cc9cc429643ad..a9603c1aba051 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
@@ -31,8 +31,7 @@
public class HeapMemoryAllocator implements MemoryAllocator {
@GuardedBy("this")
- private final Map>> bufferPoolsBySize =
- new HashMap<>();
+ private final Map>> bufferPoolsBySize = new HashMap<>();
private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024;
@@ -49,13 +48,14 @@ private boolean shouldPool(long size) {
public MemoryBlock allocate(long size) throws OutOfMemoryError {
if (shouldPool(size)) {
synchronized (this) {
- final LinkedList> pool = bufferPoolsBySize.get(size);
+ final LinkedList> pool = bufferPoolsBySize.get(size);
if (pool != null) {
while (!pool.isEmpty()) {
- final WeakReference blockReference = pool.pop();
- final MemoryBlock memory = blockReference.get();
- if (memory != null) {
- assert (memory.size() == size);
+ final WeakReference arrayReference = pool.pop();
+ final long[] array = arrayReference.get();
+ if (array != null) {
+ assert (array.length * 8L >= size);
+ MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size);
if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) {
memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE);
}
@@ -76,18 +76,36 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError {
@Override
public void free(MemoryBlock memory) {
+ assert (memory.obj != null) :
+ "baseObject was null; are you trying to use the on-heap allocator to free off-heap memory?";
+ assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) :
+ "page has already been freed";
+ assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER)
+ || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) :
+ "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator " +
+ "free()";
+
final long size = memory.size();
if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) {
memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE);
}
+
+ // Mark the page as freed (so we can detect double-frees).
+ memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER;
+
+ // As an additional layer of defense against use-after-free bugs, we mutate the
+ // MemoryBlock to null out its reference to the long[] array.
+ long[] array = (long[]) memory.obj;
+ memory.setObjAndOffset(null, 0);
+
if (shouldPool(size)) {
synchronized (this) {
- LinkedList> pool = bufferPoolsBySize.get(size);
+ LinkedList> pool = bufferPoolsBySize.get(size);
if (pool == null) {
pool = new LinkedList<>();
bufferPoolsBySize.put(size, pool);
}
- pool.add(new WeakReference<>(memory));
+ pool.add(new WeakReference<>(array));
}
} else {
// Do nothing
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
index cd1d378bc1470..c333857358d30 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
@@ -26,6 +26,25 @@
*/
public class MemoryBlock extends MemoryLocation {
+ /** Special `pageNumber` value for pages which were not allocated by TaskMemoryManagers */
+ public static final int NO_PAGE_NUMBER = -1;
+
+ /**
+ * Special `pageNumber` value for marking pages that have been freed in the TaskMemoryManager.
+ * We set `pageNumber` to this value in TaskMemoryManager.freePage() so that MemoryAllocator
+ * can detect if pages which were allocated by TaskMemoryManager have been freed in the TMM
+ * before being passed to MemoryAllocator.free() (it is an error to allocate a page in
+ * TaskMemoryManager and then directly free it in a MemoryAllocator without going through
+ * the TMM freePage() call).
+ */
+ public static final int FREED_IN_TMM_PAGE_NUMBER = -2;
+
+ /**
+ * Special `pageNumber` value for pages that have been freed by the MemoryAllocator. This allows
+ * us to detect double-frees.
+ */
+ public static final int FREED_IN_ALLOCATOR_PAGE_NUMBER = -3;
+
private final long length;
/**
@@ -33,7 +52,7 @@ public class MemoryBlock extends MemoryLocation {
* TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager,
* which lives in a different package.
*/
- public int pageNumber = -1;
+ public int pageNumber = NO_PAGE_NUMBER;
public MemoryBlock(@Nullable Object obj, long offset, long length) {
super(obj, offset);
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
index 55bcdf1ed7b06..4368fb615ba1e 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
@@ -38,9 +38,20 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError {
public void free(MemoryBlock memory) {
assert (memory.obj == null) :
"baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?";
+ assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) :
+ "page has already been freed";
+ assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER)
+ || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) :
+ "TMM-allocated pages must be freed via TMM.freePage(), not directly in allocator free()";
+
if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) {
memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE);
}
Platform.freeMemory(memory.offset);
+ // As an additional layer of defense against use-after-free bugs, we mutate the
+ // MemoryBlock to reset its pointer.
+ memory.offset = 0;
+ // Mark the page as freed (so we can detect double-frees).
+ memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER;
}
}
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java
index 7ced13d357237..c03caf0076f61 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java
@@ -74,4 +74,29 @@ public static byte[] subStringSQL(byte[] bytes, int pos, int len) {
}
return Arrays.copyOfRange(bytes, start, end);
}
+
+ public static byte[] concat(byte[]... inputs) {
+ // Compute the total length of the result
+ int totalLength = 0;
+ for (int i = 0; i < inputs.length; i++) {
+ if (inputs[i] != null) {
+ totalLength += inputs[i].length;
+ } else {
+ return null;
+ }
+ }
+
+ // Allocate a new byte array, and copy the inputs one by one into it
+ final byte[] result = new byte[totalLength];
+ int offset = 0;
+ for (int i = 0; i < inputs.length; i++) {
+ int len = inputs[i].length;
+ Platform.copyMemory(
+ inputs[i], Platform.BYTE_ARRAY_OFFSET,
+ result, Platform.BYTE_ARRAY_OFFSET + offset,
+ len);
+ offset += len;
+ }
+ return result;
+ }
}
diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
index 4b141339ec816..62854837b05ed 100644
--- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
+++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
@@ -62,6 +62,52 @@ public void overlappingCopyMemory() {
}
}
+ @Test
+ public void onHeapMemoryAllocatorPoolingReUsesLongArrays() {
+ MemoryBlock block1 = MemoryAllocator.HEAP.allocate(1024 * 1024);
+ Object baseObject1 = block1.getBaseObject();
+ MemoryAllocator.HEAP.free(block1);
+ MemoryBlock block2 = MemoryAllocator.HEAP.allocate(1024 * 1024);
+ Object baseObject2 = block2.getBaseObject();
+ Assert.assertSame(baseObject1, baseObject2);
+ MemoryAllocator.HEAP.free(block2);
+ }
+
+ @Test
+ public void freeingOnHeapMemoryBlockResetsBaseObjectAndOffset() {
+ MemoryBlock block = MemoryAllocator.HEAP.allocate(1024);
+ Assert.assertNotNull(block.getBaseObject());
+ MemoryAllocator.HEAP.free(block);
+ Assert.assertNull(block.getBaseObject());
+ Assert.assertEquals(0, block.getBaseOffset());
+ Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber);
+ }
+
+ @Test
+ public void freeingOffHeapMemoryBlockResetsOffset() {
+ MemoryBlock block = MemoryAllocator.UNSAFE.allocate(1024);
+ Assert.assertNull(block.getBaseObject());
+ Assert.assertNotEquals(0, block.getBaseOffset());
+ MemoryAllocator.UNSAFE.free(block);
+ Assert.assertNull(block.getBaseObject());
+ Assert.assertEquals(0, block.getBaseOffset());
+ Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber);
+ }
+
+ @Test(expected = AssertionError.class)
+ public void onHeapMemoryAllocatorThrowsAssertionErrorOnDoubleFree() {
+ MemoryBlock block = MemoryAllocator.HEAP.allocate(1024);
+ MemoryAllocator.HEAP.free(block);
+ MemoryAllocator.HEAP.free(block);
+ }
+
+ @Test(expected = AssertionError.class)
+ public void offHeapMemoryAllocatorThrowsAssertionErrorOnDoubleFree() {
+ MemoryBlock block = MemoryAllocator.UNSAFE.allocate(1024);
+ MemoryAllocator.UNSAFE.free(block);
+ MemoryAllocator.UNSAFE.free(block);
+ }
+
@Test
public void memoryDebugFillEnabledInTest() {
Assert.assertTrue(MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED);
@@ -71,9 +117,11 @@ public void memoryDebugFillEnabledInTest() {
MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE);
MemoryBlock onheap1 = MemoryAllocator.HEAP.allocate(1024 * 1024);
+ Object onheap1BaseObject = onheap1.getBaseObject();
+ long onheap1BaseOffset = onheap1.getBaseOffset();
MemoryAllocator.HEAP.free(onheap1);
Assert.assertEquals(
- Platform.getByte(onheap1.getBaseObject(), onheap1.getBaseOffset()),
+ Platform.getByte(onheap1BaseObject, onheap1BaseOffset),
MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE);
MemoryBlock onheap2 = MemoryAllocator.HEAP.allocate(1024 * 1024);
Assert.assertEquals(
diff --git a/core/pom.xml b/core/pom.xml
index fc394d2f0feaa..c7925ea71f10b 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.3.0-SNAPSHOT
+ 2.4.0-SNAPSHOT
../pom.xml
@@ -355,6 +355,14 @@
spark-tags_${scala.binary.version}
+
+ org.apache.spark
+ spark-launcher_${scala.binary.version}
+ ${project.version}
+ tests
+ test
+
+
1.11.76
@@ -158,8 +158,8 @@
1.9.3
1.2
- 4.5.3
- 4.4.6
+ 4.5.4
+ 4.4.8
3.1
3.4.1
@@ -184,7 +184,7 @@
3.2.10
1.1.1
2.5.1
- 3.0.0
+ 3.0.8
2.25.1
2.9.9
3.5.2
@@ -200,7 +200,7 @@
2.8
1.8
1.0.0
- 0.4.0
+ 0.8.0
${java.home}
@@ -411,6 +411,7 @@
org.bouncycastle
bcprov-jdk15on
+
${bouncycastle.version}
jline
@@ -2240,6 +2242,14 @@
com.fasterxml.jackson.core
jackson-databind
+
+ io.netty
+ netty-buffer
+
+
+ io.netty
+ netty-common
+
io.netty
netty-handler
@@ -2555,6 +2565,9 @@
org.apache.maven.plugins
maven-assembly-plugin
3.1.0
+
+ posix
+
org.apache.maven.plugins
diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala
index 2ef0e7b40d940..adde213e361f0 100644
--- a/project/MimaBuild.scala
+++ b/project/MimaBuild.scala
@@ -88,7 +88,7 @@ object MimaBuild {
def mimaSettings(sparkHome: File, projectRef: ProjectRef) = {
val organization = "org.apache.spark"
- val previousSparkVersion = "2.0.0"
+ val previousSparkVersion = "2.2.0"
val project = projectRef.project
val fullId = "spark-" + project + "_2.11"
mimaDefaultSettings ++
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 965d1e44ac894..7d75ef8ddb7a8 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -34,8 +34,33 @@ import com.typesafe.tools.mima.core.ProblemFilters._
*/
object MimaExcludes {
+ // Exclude rules for 2.4.x
+ lazy val v24excludes = v23excludes ++ Seq(
+ // Converted from case object to case class
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#RetrieveSparkAppConfig.productElement"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#RetrieveSparkAppConfig.productArity"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#RetrieveSparkAppConfig.canEqual"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#RetrieveSparkAppConfig.productIterator"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#RetrieveSparkAppConfig.productPrefix"),
+ ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#RetrieveSparkAppConfig.toString")
+ )
+
// Exclude rules for 2.3.x
lazy val v23excludes = v22excludes ++ Seq(
+ // [SPARK-22897] Expose stageAttemptId in TaskContext
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.stageAttemptNumber"),
+
+ // SPARK-22789: Map-only continuous processing execution
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$8"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$6"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$9"),
+
+ // SPARK-22372: Make cluster submission use SparkApplication.
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getSecretKeyFromUserCredentials"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.isYarnMode"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getCurrentUserCredentials"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.addSecretKeyToUserCredentials"),
+
// SPARK-18085: Better History Server scalability for many / large applications
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ExecutorSummary.executorLogs"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.getSparkUI"),
@@ -45,6 +70,8 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.StorageStatusListener"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorStageSummary.this"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.JobData.this"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkStatusTracker.this"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.jobs.JobProgressListener"),
// [SPARK-20495][SQL] Add StorageLevel to cacheTable API
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable"),
@@ -82,7 +109,40 @@ object MimaExcludes {
// [SPARK-21087] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala
ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelWriter"),
- ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter")
+ ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter"),
+
+ // [SPARK-21728][CORE] Allow SparkSubmit to use Logging
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.downloadFileList"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.downloadFile"),
+
+ // [SPARK-21714][CORE][YARN] Avoiding re-uploading remote resources in yarn client mode
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.prepareSubmitEnvironment"),
+
+ // [SPARK-22324][SQL][PYTHON] Upgrade Arrow to 0.8.0
+ ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.network.util.AbstractFileRegion.transfered"),
+
+ // [SPARK-20643][CORE] Add listener implementation to collect app state
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$5"),
+
+ // [SPARK-20648][CORE] Port JobsTab and StageTab to the new UI backend
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$12"),
+
+ // [SPARK-21462][SS] Added batchId to StreamingQueryProgress.json
+ // [SPARK-21409][SS] Expose state store memory usage in SQL metrics and progress updates
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StateOperatorProgress.this"),
+
+ // [SPARK-22278][SS] Expose current event time watermark and current processing time in GroupState
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.GroupState.getCurrentWatermarkMs"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.GroupState.getCurrentProcessingTimeMs"),
+
+ // [SPARK-20542][ML][SQL] Add an API to Bucketizer that can bin multiple columns
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasOutputCols.org$apache$spark$ml$param$shared$HasOutputCols$_setter_$outputCols_="),
+
+ // [SPARK-18619][ML] Make QuantileDiscretizer/Bucketizer/StringIndexer/RFormula inherit from HasHandleInvalid
+ ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.Bucketizer.getHandleInvalid"),
+ ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.StringIndexer.getHandleInvalid"),
+ ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.QuantileDiscretizer.getHandleInvalid"),
+ ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.getHandleInvalid")
)
// Exclude rules for 2.2.x
@@ -1060,10 +1120,16 @@ object MimaExcludes {
// [SPARK-21680][ML][MLLIB]optimzie Vector coompress
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.toSparseWithSize"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Vector.toSparseWithSize")
+ ) ++ Seq(
+ // [SPARK-3181][ML]Implement huber loss for LinearRegression.
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasLoss.org$apache$spark$ml$param$shared$HasLoss$_setter_$loss_="),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasLoss.getLoss"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasLoss.loss")
)
}
def excludes(version: String) = version match {
+ case v if v.startsWith("2.4") => v24excludes
case v if v.startsWith("2.3") => v23excludes
case v if v.startsWith("2.2") => v22excludes
case v if v.startsWith("2.1") => v21excludes
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index cf000e5751e2a..a0ae82639034e 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -54,16 +54,15 @@ object BuildCommons {
"tags", "sketch", "kvstore"
).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects
- val optionallyEnabledProjects@Seq(mesos, yarn,
+ val optionallyEnabledProjects@Seq(kubernetes, mesos, yarn,
streamingFlumeSink, streamingFlume,
streamingKafka, sparkGangliaLgpl, streamingKinesisAsl,
- dockerIntegrationTests, hadoopCloud,
- kubernetes, _*) =
- Seq("mesos", "yarn",
+ dockerIntegrationTests, hadoopCloud, _*) =
+ Seq("kubernetes", "mesos", "yarn",
"streaming-flume-sink", "streaming-flume",
"streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl",
"docker-integration-tests", "hadoop-cloud",
- "kubernetes", "kubernetes-integration-tests",
+ "kubernetes-integration-tests",
"kubernetes-integration-tests-spark-jobs", "kubernetes-integration-tests-spark-jobs-helpers",
"kubernetes-docker-minimal-bundle"
).map(ProjectRef(buildLocation, _))
@@ -247,14 +246,14 @@ object SparkBuild extends PomBuild {
javacOptions in Compile ++= Seq(
"-encoding", "UTF-8",
- "-source", javacJVMVersion.value,
- "-Xlint:unchecked"
+ "-source", javacJVMVersion.value
),
- // This -target option cannot be set in the Compile configuration scope since `javadoc` doesn't
- // play nicely with it; see https://github.com/sbt/sbt/issues/355#issuecomment-3817629 for
- // additional discussion and explanation.
+ // This -target and Xlint:unchecked options cannot be set in the Compile configuration scope since
+ // `javadoc` doesn't play nicely with them; see https://github.com/sbt/sbt/issues/355#issuecomment-3817629
+ // for additional discussion and explanation.
javacOptions in (Compile, compile) ++= Seq(
- "-target", javacJVMVersion.value
+ "-target", javacJVMVersion.value,
+ "-Xlint:unchecked"
),
scalacOptions in Compile ++= Seq(
@@ -262,6 +261,21 @@ object SparkBuild extends PomBuild {
"-sourcepath", (baseDirectory in ThisBuild).value.getAbsolutePath // Required for relative source links in scaladoc
),
+ // Remove certain packages from Scaladoc
+ scalacOptions in (Compile, doc) := Seq(
+ "-groups",
+ "-skip-packages", Seq(
+ "org.apache.spark.api.python",
+ "org.apache.spark.network",
+ "org.apache.spark.deploy",
+ "org.apache.spark.util.collection"
+ ).mkString(":"),
+ "-doc-title", "Spark " + version.value.replaceAll("-SNAPSHOT", "") + " ScalaDoc"
+ ) ++ {
+ // Do not attempt to scaladoc javadoc comments under 2.12 since it can't handle inner classes
+ if (scalaBinaryVersion.value == "2.12") Seq("-no-java-comments") else Seq.empty
+ },
+
// Implements -Xfatal-warnings, ignoring deprecation warnings.
// Code snippet taken from https://issues.scala-lang.org/browse/SI-8410.
compile in Compile := {
@@ -688,9 +702,9 @@ object Unidoc {
publish := {},
unidocProjectFilter in(ScalaUnidoc, unidoc) :=
- inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010, sqlKafka010),
+ inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, yarn, tags, streamingKafka010, sqlKafka010),
unidocProjectFilter in(JavaUnidoc, unidoc) :=
- inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010, sqlKafka010),
+ inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, yarn, tags, streamingKafka010, sqlKafka010),
unidocAllClasspaths in (ScalaUnidoc, unidoc) := {
ignoreClasspaths((unidocAllClasspaths in (ScalaUnidoc, unidoc)).value)
@@ -845,18 +859,7 @@ object TestSettings {
}
Seq.empty[File]
}).value,
- concurrentRestrictions in Global += Tags.limit(Tags.Test, 1),
- // Remove certain packages from Scaladoc
- scalacOptions in (Compile, doc) := Seq(
- "-groups",
- "-skip-packages", Seq(
- "org.apache.spark.api.python",
- "org.apache.spark.network",
- "org.apache.spark.deploy",
- "org.apache.spark.util.collection"
- ).mkString(":"),
- "-doc-title", "Spark " + version.value.replaceAll("-SNAPSHOT", "") + " ScalaDoc"
- )
+ concurrentRestrictions in Global += Tags.limit(Tags.Test, 1)
)
}
diff --git a/python/README.md b/python/README.md
index 84ec88141cb00..3f17fdb98a081 100644
--- a/python/README.md
+++ b/python/README.md
@@ -29,4 +29,4 @@ The Python packaging for Spark is not intended to replace all of the other use c
## Python Requirements
-At its core PySpark depends on Py4J (currently version 0.10.6), but additional sub-packages have their own requirements (including numpy and pandas).
+At its core PySpark depends on Py4J (currently version 0.10.6), but some additional sub-packages have their own extra requirements for some features (including numpy, pandas, and pyarrow).
diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py
index a6767cee9bf28..d4470b5bf2900 100644
--- a/python/pyspark/ml/base.py
+++ b/python/pyspark/ml/base.py
@@ -18,13 +18,52 @@
from abc import ABCMeta, abstractmethod
import copy
+import threading
from pyspark import since
-from pyspark.ml.param import Params
from pyspark.ml.param.shared import *
from pyspark.ml.common import inherit_doc
from pyspark.sql.functions import udf
-from pyspark.sql.types import StructField, StructType, DoubleType
+from pyspark.sql.types import StructField, StructType
+
+
+class _FitMultipleIterator(object):
+ """
+ Used by default implementation of Estimator.fitMultiple to produce models in a thread safe
+ iterator. This class handles the simple case of fitMultiple where each param map should be
+ fit independently.
+
+ :param fitSingleModel: Function: (int => Model) which fits an estimator to a dataset.
+ `fitSingleModel` may be called up to `numModels` times, with a unique index each time.
+ Each call to `fitSingleModel` with an index should return the Model associated with
+ that index.
+ :param numModel: Number of models this iterator should produce.
+
+ See Estimator.fitMultiple for more info.
+ """
+ def __init__(self, fitSingleModel, numModels):
+ """
+
+ """
+ self.fitSingleModel = fitSingleModel
+ self.numModel = numModels
+ self.counter = 0
+ self.lock = threading.Lock()
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ with self.lock:
+ index = self.counter
+ if index >= self.numModel:
+ raise StopIteration("No models remaining.")
+ self.counter += 1
+ return index, self.fitSingleModel(index)
+
+ def next(self):
+ """For python2 compatibility."""
+ return self.__next__()
@inherit_doc
@@ -47,6 +86,27 @@ def _fit(self, dataset):
"""
raise NotImplementedError()
+ @since("2.3.0")
+ def fitMultiple(self, dataset, paramMaps):
+ """
+ Fits a model to the input dataset for each param map in `paramMaps`.
+
+ :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`.
+ :param paramMaps: A Sequence of param maps.
+ :return: A thread safe iterable which contains one model for each param map. Each
+ call to `next(modelIterator)` will return `(index, model)` where model was fit
+ using `paramMaps[index]`. `index` values may not be sequential.
+
+ .. note:: DeveloperApi
+ .. note:: Experimental
+ """
+ estimator = self.copy()
+
+ def fitSingleModel(index):
+ return estimator.fit(dataset, paramMaps[index])
+
+ return _FitMultipleIterator(fitSingleModel, len(paramMaps))
+
@since("1.3.0")
def fit(self, dataset, params=None):
"""
@@ -61,7 +121,10 @@ def fit(self, dataset, params=None):
if params is None:
params = dict()
if isinstance(params, (list, tuple)):
- return [self.fit(dataset, paramMap) for paramMap in params]
+ models = [None] * len(params)
+ for index, model in self.fitMultiple(dataset, params):
+ models[index] = model
+ return models
elif isinstance(params, dict):
if params:
return self.copy(params)._fit(dataset)
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 608f2a5715497..eb79b193103e2 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -45,6 +45,7 @@
'NGram',
'Normalizer',
'OneHotEncoder',
+ 'OneHotEncoderEstimator', 'OneHotEncoderModel',
'PCA', 'PCAModel',
'PolynomialExpansion',
'QuantileDiscretizer',
@@ -57,6 +58,7 @@
'Tokenizer',
'VectorAssembler',
'VectorIndexer', 'VectorIndexerModel',
+ 'VectorSizeHint',
'VectorSlicer',
'Word2Vec', 'Word2VecModel']
@@ -713,9 +715,9 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures,
* Numeric columns:
For numeric features, the hash value of the column name is used to map the
- feature value to its index in the feature vector. Numeric features are never
- treated as categorical, even when they are integers. You must explicitly
- convert numeric columns containing categorical features to strings first.
+ feature value to its index in the feature vector. By default, numeric features
+ are not treated as categorical (even when they are integers). To treat them
+ as categorical, specify the relevant columns in `categoricalCols`.
* String columns:
For categorical features, the hash value of the string "column_name=value"
@@ -740,6 +742,8 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures,
>>> hasher = FeatureHasher(inputCols=cols, outputCol="features")
>>> hasher.transform(df).head().features
SparseVector(262144, {51871: 1.0, 63643: 1.0, 174475: 2.0, 253195: 1.0})
+ >>> hasher.setCategoricalCols(["real"]).transform(df).head().features
+ SparseVector(262144, {51871: 1.0, 63643: 1.0, 171257: 1.0, 253195: 1.0})
>>> hasherPath = temp_path + "/hasher"
>>> hasher.save(hasherPath)
>>> loadedHasher = FeatureHasher.load(hasherPath)
@@ -751,10 +755,14 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures,
.. versionadded:: 2.3.0
"""
+ categoricalCols = Param(Params._dummy(), "categoricalCols",
+ "numeric columns to treat as categorical",
+ typeConverter=TypeConverters.toListString)
+
@keyword_only
- def __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None):
+ def __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None):
"""
- __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None)
+ __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None)
"""
super(FeatureHasher, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.FeatureHasher", self.uid)
@@ -764,14 +772,28 @@ def __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None):
@keyword_only
@since("2.3.0")
- def setParams(self, numFeatures=1 << 18, inputCols=None, outputCol=None):
+ def setParams(self, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None):
"""
- setParams(self, numFeatures=1 << 18, inputCols=None, outputCol=None)
+ setParams(self, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None)
Sets params for this FeatureHasher.
"""
kwargs = self._input_kwargs
return self._set(**kwargs)
+ @since("2.3.0")
+ def setCategoricalCols(self, value):
+ """
+ Sets the value of :py:attr:`categoricalCols`.
+ """
+ return self._set(categoricalCols=value)
+
+ @since("2.3.0")
+ def getCategoricalCols(self):
+ """
+ Gets the value of binary or its default value.
+ """
+ return self.getOrDefault(self.categoricalCols)
+
@inherit_doc
class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, JavaMLReadable,
@@ -1556,6 +1578,9 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
.. note:: This is different from scikit-learn's OneHotEncoder,
which keeps all categories. The output vectors are sparse.
+ .. note:: Deprecated in 2.3.0. :py:class:`OneHotEncoderEstimator` will be renamed to
+ :py:class:`OneHotEncoder` and this :py:class:`OneHotEncoder` will be removed in 3.0.0.
+
.. seealso::
:py:class:`StringIndexer` for converting categorical values into
@@ -1620,6 +1645,118 @@ def getDropLast(self):
return self.getOrDefault(self.dropLast)
+@inherit_doc
+class OneHotEncoderEstimator(JavaEstimator, HasInputCols, HasOutputCols, HasHandleInvalid,
+ JavaMLReadable, JavaMLWritable):
+ """
+ A one-hot encoder that maps a column of category indices to a column of binary vectors, with
+ at most a single one-value per row that indicates the input category index.
+ For example with 5 categories, an input value of 2.0 would map to an output vector of
+ `[0.0, 0.0, 1.0, 0.0]`.
+ The last category is not included by default (configurable via `dropLast`),
+ because it makes the vector entries sum up to one, and hence linearly dependent.
+ So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`.
+
+ Note: This is different from scikit-learn's OneHotEncoder, which keeps all categories.
+ The output vectors are sparse.
+
+ When `handleInvalid` is configured to 'keep', an extra "category" indicating invalid values is
+ added as last category. So when `dropLast` is true, invalid values are encoded as all-zeros
+ vector.
+
+ Note: When encoding multi-column by using `inputCols` and `outputCols` params, input/output
+ cols come in pairs, specified by the order in the arrays, and each pair is treated
+ independently.
+
+ See `StringIndexer` for converting categorical values into category indices
+
+ >>> from pyspark.ml.linalg import Vectors
+ >>> df = spark.createDataFrame([(0.0,), (1.0,), (2.0,)], ["input"])
+ >>> ohe = OneHotEncoderEstimator(inputCols=["input"], outputCols=["output"])
+ >>> model = ohe.fit(df)
+ >>> model.transform(df).head().output
+ SparseVector(2, {0: 1.0})
+ >>> ohePath = temp_path + "/oheEstimator"
+ >>> ohe.save(ohePath)
+ >>> loadedOHE = OneHotEncoderEstimator.load(ohePath)
+ >>> loadedOHE.getInputCols() == ohe.getInputCols()
+ True
+ >>> modelPath = temp_path + "/ohe-model"
+ >>> model.save(modelPath)
+ >>> loadedModel = OneHotEncoderModel.load(modelPath)
+ >>> loadedModel.categorySizes == model.categorySizes
+ True
+
+ .. versionadded:: 2.3.0
+ """
+
+ handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data during " +
+ "transform(). Options are 'keep' (invalid data presented as an extra " +
+ "categorical feature) or error (throw an error). Note that this Param " +
+ "is only used during transform; during fitting, invalid data will " +
+ "result in an error.",
+ typeConverter=TypeConverters.toString)
+
+ dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category",
+ typeConverter=TypeConverters.toBoolean)
+
+ @keyword_only
+ def __init__(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True):
+ """
+ __init__(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True)
+ """
+ super(OneHotEncoderEstimator, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.feature.OneHotEncoderEstimator", self.uid)
+ self._setDefault(handleInvalid="error", dropLast=True)
+ kwargs = self._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ @since("2.3.0")
+ def setParams(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True):
+ """
+ setParams(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True)
+ Sets params for this OneHotEncoderEstimator.
+ """
+ kwargs = self._input_kwargs
+ return self._set(**kwargs)
+
+ @since("2.3.0")
+ def setDropLast(self, value):
+ """
+ Sets the value of :py:attr:`dropLast`.
+ """
+ return self._set(dropLast=value)
+
+ @since("2.3.0")
+ def getDropLast(self):
+ """
+ Gets the value of dropLast or its default value.
+ """
+ return self.getOrDefault(self.dropLast)
+
+ def _create_model(self, java_model):
+ return OneHotEncoderModel(java_model)
+
+
+class OneHotEncoderModel(JavaModel, JavaMLReadable, JavaMLWritable):
+ """
+ Model fitted by :py:class:`OneHotEncoderEstimator`.
+
+ .. versionadded:: 2.3.0
+ """
+
+ @property
+ @since("2.3.0")
+ def categorySizes(self):
+ """
+ Original number of categories for each feature being encoded.
+ The array contains one value for each input column, in order.
+ """
+ return self._call_java("categorySizes")
+
+
@inherit_doc
class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
JavaMLWritable):
@@ -3466,6 +3603,84 @@ def selectedFeatures(self):
return self._call_java("selectedFeatures")
+@inherit_doc
+class VectorSizeHint(JavaTransformer, HasInputCol, HasHandleInvalid, JavaMLReadable,
+ JavaMLWritable):
+ """
+ .. note:: Experimental
+
+ A feature transformer that adds size information to the metadata of a vector column.
+ VectorAssembler needs size information for its input columns and cannot be used on streaming
+ dataframes without this metadata.
+
+ .. note:: VectorSizeHint modifies `inputCol` to include size metadata and does not have an
+ outputCol.
+
+ >>> from pyspark.ml.linalg import Vectors
+ >>> from pyspark.ml import Pipeline, PipelineModel
+ >>> data = [(Vectors.dense([1., 2., 3.]), 4.)]
+ >>> df = spark.createDataFrame(data, ["vector", "float"])
+ >>>
+ >>> sizeHint = VectorSizeHint(inputCol="vector", size=3, handleInvalid="skip")
+ >>> vecAssembler = VectorAssembler(inputCols=["vector", "float"], outputCol="assembled")
+ >>> pipeline = Pipeline(stages=[sizeHint, vecAssembler])
+ >>>
+ >>> pipelineModel = pipeline.fit(df)
+ >>> pipelineModel.transform(df).head().assembled
+ DenseVector([1.0, 2.0, 3.0, 4.0])
+ >>> vectorSizeHintPath = temp_path + "/vector-size-hint-pipeline"
+ >>> pipelineModel.save(vectorSizeHintPath)
+ >>> loadedPipeline = PipelineModel.load(vectorSizeHintPath)
+ >>> loaded = loadedPipeline.transform(df).head().assembled
+ >>> expected = pipelineModel.transform(df).head().assembled
+ >>> loaded == expected
+ True
+
+ .. versionadded:: 2.3.0
+ """
+
+ size = Param(Params._dummy(), "size", "Size of vectors in column.",
+ typeConverter=TypeConverters.toInt)
+
+ handleInvalid = Param(Params._dummy(), "handleInvalid",
+ "How to handle invalid vectors in inputCol. Invalid vectors include "
+ "nulls and vectors with the wrong size. The options are `skip` (filter "
+ "out rows with invalid vectors), `error` (throw an error) and "
+ "`optimistic` (do not check the vector size, and keep all rows). "
+ "`error` by default.",
+ TypeConverters.toString)
+
+ @keyword_only
+ def __init__(self, inputCol=None, size=None, handleInvalid="error"):
+ """
+ __init__(self, inputCol=None, size=None, handleInvalid="error")
+ """
+ super(VectorSizeHint, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorSizeHint", self.uid)
+ self._setDefault(handleInvalid="error")
+ self.setParams(**self._input_kwargs)
+
+ @keyword_only
+ @since("2.3.0")
+ def setParams(self, inputCol=None, size=None, handleInvalid="error"):
+ """
+ setParams(self, inputCol=None, size=None, handleInvalid="error")
+ Sets params for this VectorSizeHint.
+ """
+ kwargs = self._input_kwargs
+ return self._set(**kwargs)
+
+ @since("2.3.0")
+ def getSize(self):
+ """ Gets size param, the size of vectors in `inputCol`."""
+ self.getOrDefault(self.size)
+
+ @since("2.3.0")
+ def setSize(self, value):
+ """ Sets size param, the size of vectors in `inputCol`."""
+ self._set(size=value)
+
+
if __name__ == "__main__":
import doctest
import tempfile
diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py
index 7d14f05295572..c9b840276f675 100644
--- a/python/pyspark/ml/image.py
+++ b/python/pyspark/ml/image.py
@@ -108,12 +108,23 @@ def toNDArray(self, image):
"""
Converts an image to an array with metadata.
- :param image: The image to be converted.
+ :param `Row` image: A row that contains the image to be converted. It should
+ have the attributes specified in `ImageSchema.imageSchema`.
:return: a `numpy.ndarray` that is an image.
.. versionadded:: 2.3.0
"""
+ if not isinstance(image, Row):
+ raise TypeError(
+ "image argument should be pyspark.sql.types.Row; however, "
+ "it got [%s]." % type(image))
+
+ if any(not hasattr(image, f) for f in self.imageFields):
+ raise ValueError(
+ "image argument should have attributes specified in "
+ "ImageSchema.imageSchema [%s]." % ", ".join(self.imageFields))
+
height = image.height
width = image.width
nChannels = image.nChannels
@@ -127,15 +138,20 @@ def toImage(self, array, origin=""):
"""
Converts an array with metadata to a two-dimensional image.
- :param array array: The array to convert to image.
+ :param `numpy.ndarray` array: The array to convert to image.
:param str origin: Path to the image, optional.
:return: a :class:`Row` that is a two dimensional image.
.. versionadded:: 2.3.0
"""
+ if not isinstance(array, np.ndarray):
+ raise TypeError(
+ "array argument should be numpy.ndarray; however, it got [%s]." % type(array))
+
if array.ndim != 3:
raise ValueError("Invalid array shape")
+
height, width, nChannels = array.shape
ocvTypes = ImageSchema.ocvTypes
if nChannels == 1:
@@ -146,7 +162,12 @@ def toImage(self, array, origin=""):
mode = ocvTypes["CV_8UC4"]
else:
raise ValueError("Invalid number of channels")
- data = bytearray(array.astype(dtype=np.uint8).ravel())
+
+ # Running `bytearray(numpy.array([1]))` fails in specific Python versions
+ # with a specific Numpy version, for example in Python 3.6.0 and NumPy 1.13.3.
+ # Here, it avoids it by converting it to bytes.
+ data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes())
+
# Creating new Row with _create_row(), because Row(name = value, ... )
# orders fields by name, which conflicts with expected schema order
# when the new DataFrame is created by UDF
@@ -180,9 +201,8 @@ def readImages(self, path, recursive=False, numPartitions=-1,
.. versionadded:: 2.3.0
"""
- ctx = SparkContext._active_spark_context
- spark = SparkSession(ctx)
- image_schema = ctx._jvm.org.apache.spark.ml.image.ImageSchema
+ spark = SparkSession.builder.getOrCreate()
+ image_schema = spark._jvm.org.apache.spark.ml.image.ImageSchema
jsession = spark._jsparkSession
jresult = image_schema.readImages(path, jsession, recursive, numPartitions,
dropImageFailures, float(sampleRatio), seed)
@@ -192,7 +212,7 @@ def readImages(self, path, recursive=False, numPartitions=-1,
ImageSchema = _ImageSchema()
-# Monkey patch to disallow instantization of this class.
+# Monkey patch to disallow instantiation of this class.
def _disallow_instance(_):
raise RuntimeError("Creating instance of _ImageSchema class is disallowed.")
_ImageSchema.__init__ = _disallow_instance
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index 130d1a0bae7f0..db951d81de1e7 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -119,10 +119,12 @@ def get$Name(self):
("inputCol", "input column name.", None, "TypeConverters.toString"),
("inputCols", "input column names.", None, "TypeConverters.toListString"),
("outputCol", "output column name.", "self.uid + '__output'", "TypeConverters.toString"),
+ ("outputCols", "output column names.", None, "TypeConverters.toListString"),
("numFeatures", "number of features.", None, "TypeConverters.toInt"),
("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " +
- "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None,
- "TypeConverters.toInt"),
+ "E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: " +
+ "this setting will be ignored if the checkpoint directory is not set in the SparkContext.",
+ None, "TypeConverters.toInt"),
("seed", "random seed.", "hash(type(self).__name__)", "TypeConverters.toInt"),
("tol", "the convergence tolerance for iterative algorithms (>= 0).", None,
"TypeConverters.toFloat"),
@@ -154,7 +156,8 @@ def get$Name(self):
("aggregationDepth", "suggested depth for treeAggregate (>= 2).", "2",
"TypeConverters.toInt"),
("parallelism", "the number of threads to use when running parallel algorithms (>= 1).",
- "1", "TypeConverters.toInt")]
+ "1", "TypeConverters.toInt"),
+ ("loss", "the loss function to be optimized.", None, "TypeConverters.toString")]
code = []
for name, doc, defaultValueStr, typeConverter in shared:
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index 4041d9c43b236..474c38764e5a1 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -256,6 +256,29 @@ def getOutputCol(self):
return self.getOrDefault(self.outputCol)
+class HasOutputCols(Params):
+ """
+ Mixin for param outputCols: output column names.
+ """
+
+ outputCols = Param(Params._dummy(), "outputCols", "output column names.", typeConverter=TypeConverters.toListString)
+
+ def __init__(self):
+ super(HasOutputCols, self).__init__()
+
+ def setOutputCols(self, value):
+ """
+ Sets the value of :py:attr:`outputCols`.
+ """
+ return self._set(outputCols=value)
+
+ def getOutputCols(self):
+ """
+ Gets the value of outputCols or its default value.
+ """
+ return self.getOrDefault(self.outputCols)
+
+
class HasNumFeatures(Params):
"""
Mixin for param numFeatures: number of features.
@@ -281,10 +304,10 @@ def getNumFeatures(self):
class HasCheckpointInterval(Params):
"""
- Mixin for param checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.
+ Mixin for param checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.
"""
- checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", typeConverter=TypeConverters.toInt)
+ checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.", typeConverter=TypeConverters.toInt)
def __init__(self):
super(HasCheckpointInterval, self).__init__()
@@ -632,6 +655,29 @@ def getParallelism(self):
return self.getOrDefault(self.parallelism)
+class HasLoss(Params):
+ """
+ Mixin for param loss: the loss function to be optimized.
+ """
+
+ loss = Param(Params._dummy(), "loss", "the loss function to be optimized.", typeConverter=TypeConverters.toString)
+
+ def __init__(self):
+ super(HasLoss, self).__init__()
+
+ def setLoss(self, value):
+ """
+ Sets the value of :py:attr:`loss`.
+ """
+ return self._set(loss=value)
+
+ def getLoss(self):
+ """
+ Gets the value of loss or its default value.
+ """
+ return self.getOrDefault(self.loss)
+
+
class DecisionTreeParams(Params):
"""
Mixin for Decision Tree parameters.
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 9d5b768091cf4..f0812bd1d4a39 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -39,23 +39,26 @@
@inherit_doc
class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept,
- HasStandardization, HasSolver, HasWeightCol, HasAggregationDepth,
+ HasStandardization, HasSolver, HasWeightCol, HasAggregationDepth, HasLoss,
JavaMLWritable, JavaMLReadable):
"""
Linear regression.
- The learning objective is to minimize the squared error, with regularization.
- The specific squared error loss function used is: L = 1/2n ||A coefficients - y||^2^
+ The learning objective is to minimize the specified loss function, with regularization.
+ This supports two kinds of loss:
- This supports multiple types of regularization:
-
- * none (a.k.a. ordinary least squares)
+ * squaredError (a.k.a squared loss)
+ * huber (a hybrid of squared error for relatively small errors and absolute error for \
+ relatively large ones, and we estimate the scale parameter from training data)
- * L2 (ridge regression)
+ This supports multiple types of regularization:
- * L1 (Lasso)
+ * none (a.k.a. ordinary least squares)
+ * L2 (ridge regression)
+ * L1 (Lasso)
+ * L2 + L1 (elastic net)
- * L2 + L1 (elastic net)
+ Note: Fitting with huber loss only supports none and L2 regularization.
>>> from pyspark.ml.linalg import Vectors
>>> df = spark.createDataFrame([
@@ -98,19 +101,28 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
"options: auto, normal, l-bfgs.", typeConverter=TypeConverters.toString)
+ loss = Param(Params._dummy(), "loss", "The loss function to be optimized. Supported " +
+ "options: squaredError, huber.", typeConverter=TypeConverters.toString)
+
+ epsilon = Param(Params._dummy(), "epsilon", "The shape parameter to control the amount of " +
+ "robustness. Must be > 1.0. Only valid when loss is huber",
+ typeConverter=TypeConverters.toFloat)
+
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
- standardization=True, solver="auto", weightCol=None, aggregationDepth=2):
+ standardization=True, solver="auto", weightCol=None, aggregationDepth=2,
+ loss="squaredError", epsilon=1.35):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
- standardization=True, solver="auto", weightCol=None, aggregationDepth=2)
+ standardization=True, solver="auto", weightCol=None, aggregationDepth=2, \
+ loss="squaredError", epsilon=1.35)
"""
super(LinearRegression, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.regression.LinearRegression", self.uid)
- self._setDefault(maxIter=100, regParam=0.0, tol=1e-6)
+ self._setDefault(maxIter=100, regParam=0.0, tol=1e-6, loss="squaredError", epsilon=1.35)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -118,11 +130,13 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
@since("1.4.0")
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
- standardization=True, solver="auto", weightCol=None, aggregationDepth=2):
+ standardization=True, solver="auto", weightCol=None, aggregationDepth=2,
+ loss="squaredError", epsilon=1.35):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
- standardization=True, solver="auto", weightCol=None, aggregationDepth=2)
+ standardization=True, solver="auto", weightCol=None, aggregationDepth=2, \
+ loss="squaredError", epsilon=1.35)
Sets params for linear regression.
"""
kwargs = self._input_kwargs
@@ -131,6 +145,20 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
def _create_model(self, java_model):
return LinearRegressionModel(java_model)
+ @since("2.3.0")
+ def setEpsilon(self, value):
+ """
+ Sets the value of :py:attr:`epsilon`.
+ """
+ return self._set(epsilon=value)
+
+ @since("2.3.0")
+ def getEpsilon(self):
+ """
+ Gets the value of epsilon or its default value.
+ """
+ return self.getOrDefault(self.epsilon)
+
class LinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable):
"""
@@ -155,6 +183,14 @@ def intercept(self):
"""
return self._call_java("intercept")
+ @property
+ @since("2.3.0")
+ def scale(self):
+ """
+ The value by which \|y - X'w\| is scaled down when loss is "huber", otherwise 1.0.
+ """
+ return self._call_java("scale")
+
@property
@since("2.0.0")
def summary(self):
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 2258d61c95333..1af2b91da900d 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -44,6 +44,7 @@
import numpy as np
from numpy import abs, all, arange, array, array_equal, inf, ones, tile, zeros
import inspect
+import py4j
from pyspark import keyword_only, SparkContext
from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer, UnaryTransformer
@@ -67,11 +68,11 @@
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaParams, JavaWrapper
from pyspark.serializers import PickleSerializer
-from pyspark.sql import DataFrame, Row, SparkSession
+from pyspark.sql import DataFrame, Row, SparkSession, HiveContext
from pyspark.sql.functions import rand
from pyspark.sql.types import DoubleType, IntegerType
from pyspark.storagelevel import *
-from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
+from pyspark.tests import QuietTest, ReusedPySparkTestCase as PySparkTestCase
ser = PickleSerializer()
@@ -1725,6 +1726,27 @@ def test_offset(self):
self.assertTrue(np.isclose(model.intercept, -1.561613, atol=1E-4))
+class LinearRegressionTest(SparkSessionTestCase):
+
+ def test_linear_regression_with_huber_loss(self):
+
+ data_path = "data/mllib/sample_linear_regression_data.txt"
+ df = self.spark.read.format("libsvm").load(data_path)
+
+ lir = LinearRegression(loss="huber", epsilon=2.0)
+ model = lir.fit(df)
+
+ expectedCoefficients = [0.136, 0.7648, -0.7761, 2.4236, 0.537,
+ 1.2612, -0.333, -0.5694, -0.6311, 0.6053]
+ expectedIntercept = 0.1607
+ expectedScale = 9.758
+
+ self.assertTrue(
+ np.allclose(model.coefficients.toArray(), expectedCoefficients, atol=1E-3))
+ self.assertTrue(np.isclose(model.intercept, expectedIntercept, atol=1E-3))
+ self.assertTrue(np.isclose(model.scale, expectedScale, atol=1E-3))
+
+
class LogisticRegressionTest(SparkSessionTestCase):
def test_binomial_logistic_regression_with_bound(self):
@@ -1836,6 +1858,56 @@ def test_read_images(self):
self.assertEqual(ImageSchema.imageFields, expected)
self.assertEqual(ImageSchema.undefinedImageType, "Undefined")
+ with QuietTest(self.sc):
+ self.assertRaisesRegexp(
+ TypeError,
+ "image argument should be pyspark.sql.types.Row; however",
+ lambda: ImageSchema.toNDArray("a"))
+
+ with QuietTest(self.sc):
+ self.assertRaisesRegexp(
+ ValueError,
+ "image argument should have attributes specified in",
+ lambda: ImageSchema.toNDArray(Row(a=1)))
+
+ with QuietTest(self.sc):
+ self.assertRaisesRegexp(
+ TypeError,
+ "array argument should be numpy.ndarray; however, it got",
+ lambda: ImageSchema.toImage("a"))
+
+
+class ImageReaderTest2(PySparkTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ super(ImageReaderTest2, cls).setUpClass()
+ # Note that here we enable Hive's support.
+ cls.spark = None
+ try:
+ cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
+ except py4j.protocol.Py4JError:
+ cls.tearDownClass()
+ raise unittest.SkipTest("Hive is not available")
+ except TypeError:
+ cls.tearDownClass()
+ raise unittest.SkipTest("Hive is not available")
+ cls.spark = HiveContext._createForTesting(cls.sc)
+
+ @classmethod
+ def tearDownClass(cls):
+ super(ImageReaderTest2, cls).tearDownClass()
+ if cls.spark is not None:
+ cls.spark.sparkSession.stop()
+ cls.spark = None
+
+ def test_read_images_multiple_times(self):
+ # This test case is to check if `ImageSchema.readImages` tries to
+ # initiate Hive client multiple times. See SPARK-22651.
+ data_path = 'data/mllib/images/kittens'
+ ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True)
+ ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True)
+
class ALSTest(SparkSessionTestCase):
@@ -2308,6 +2380,21 @@ def test_unary_transformer_transform(self):
self.assertEqual(res.input + shiftVal, res.output)
+class EstimatorTest(unittest.TestCase):
+
+ def testDefaultFitMultiple(self):
+ N = 4
+ data = MockDataset()
+ estimator = MockEstimator()
+ params = [{estimator.fake: i} for i in range(N)]
+ modelIter = estimator.fitMultiple(data, params)
+ indexList = []
+ for index, model in modelIter:
+ self.assertEqual(model.getFake(), index)
+ indexList.append(index)
+ self.assertEqual(sorted(indexList), list(range(N)))
+
+
if __name__ == "__main__":
from pyspark.ml.tests import *
if xmlrunner:
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 47351133524e7..6c0cad6cbaaa1 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -31,6 +31,28 @@
'TrainValidationSplitModel']
+def _parallelFitTasks(est, train, eva, validation, epm):
+ """
+ Creates a list of callables which can be called from different threads to fit and evaluate
+ an estimator in parallel. Each callable returns an `(index, metric)` pair.
+
+ :param est: Estimator, the estimator to be fit.
+ :param train: DataFrame, training data set, used for fitting.
+ :param eva: Evaluator, used to compute `metric`
+ :param validation: DataFrame, validation data set, used for evaluation.
+ :param epm: Sequence of ParamMap, params maps to be used during fitting & evaluation.
+ :return: (int, float), an index into `epm` and the associated metric value.
+ """
+ modelIter = est.fitMultiple(train, epm)
+
+ def singleTask():
+ index, model = next(modelIter)
+ metric = eva.evaluate(model.transform(validation, epm[index]))
+ return index, metric
+
+ return [singleTask] * len(epm)
+
+
class ParamGridBuilder(object):
r"""
Builder for a param grid used in grid search-based model selection.
@@ -266,15 +288,9 @@ def _fit(self, dataset):
validation = df.filter(condition).cache()
train = df.filter(~condition).cache()
- def singleTrain(paramMap):
- model = est.fit(train, paramMap)
- # TODO: duplicate evaluator to take extra params from input
- metric = eva.evaluate(model.transform(validation, paramMap))
- return metric
-
- currentFoldMetrics = pool.map(singleTrain, epm)
- for j in range(numModels):
- metrics[j] += (currentFoldMetrics[j] / nFolds)
+ tasks = _parallelFitTasks(est, train, eva, validation, epm)
+ for j, metric in pool.imap_unordered(lambda f: f(), tasks):
+ metrics[j] += (metric / nFolds)
validation.unpersist()
train.unpersist()
@@ -523,13 +539,11 @@ def _fit(self, dataset):
validation = df.filter(condition).cache()
train = df.filter(~condition).cache()
- def singleTrain(paramMap):
- model = est.fit(train, paramMap)
- metric = eva.evaluate(model.transform(validation, paramMap))
- return metric
-
+ tasks = _parallelFitTasks(est, train, eva, validation, epm)
pool = ThreadPool(processes=min(self.getParallelism(), numModels))
- metrics = pool.map(singleTrain, epm)
+ metrics = [None] * numModels
+ for j, metric in pool.imap_unordered(lambda f: f(), tasks):
+ metrics[j] = metric
train.unpersist()
validation.unpersist()
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 37e7cf3fa662e..88d6a191babca 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -223,27 +223,14 @@ def _create_batch(series, timezone):
series = [series]
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)
- # If a nullable integer series has been promoted to floating point with NaNs, need to cast
- # NOTE: this is not necessary with Arrow >= 0.7
- def cast_series(s, t):
- if type(t) == pa.TimestampType:
- # NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680
- return _check_series_convert_timestamps_internal(s.fillna(0), timezone)\
- .values.astype('datetime64[us]', copy=False)
- # NOTE: can not compare None with pyarrow.DataType(), fixed with Arrow >= 0.7.1
- elif t is not None and t == pa.date32():
- # TODO: this converts the series to Python objects, possibly avoid with Arrow >= 0.8
- return s.dt.date
- elif t is None or s.dtype == t.to_pandas_dtype():
- return s
- else:
- return s.fillna(0).astype(t.to_pandas_dtype(), copy=False)
-
- # Some object types don't support masks in Arrow, see ARROW-1721
def create_array(s, t):
- casted = cast_series(s, t)
- mask = None if casted.dtype == 'object' else s.isnull()
- return pa.Array.from_pandas(casted, mask=mask, type=t)
+ mask = s.isnull()
+ # Ensure timestamp series are in expected form for Spark internal representation
+ if t is not None and pa.types.is_timestamp(t):
+ s = _check_series_convert_timestamps_internal(s.fillna(0), timezone)
+ # TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2
+ return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False)
+ return pa.Array.from_pandas(s, mask=mask, type=t)
arrs = [create_array(s, t) for s, t in series]
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py
index 659bc65701a0c..156603128d063 100644
--- a/python/pyspark/sql/catalog.py
+++ b/python/pyspark/sql/catalog.py
@@ -227,15 +227,15 @@ def dropGlobalTempView(self, viewName):
@ignore_unicode_prefix
@since(2.0)
def registerFunction(self, name, f, returnType=StringType()):
- """Registers a python function (including lambda function) as a UDF
- so it can be used in SQL statements.
+ """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
+ as a UDF. The registered UDF can be used in SQL statement.
In addition to a name and the function itself, the return type can be optionally specified.
When the return type is not given it default to a string and conversion will automatically
be done. For any other return type, the produced object must match the specified type.
:param name: name of the UDF
- :param f: python function
+ :param f: a Python function, or a wrapped/native UserDefinedFunction
:param returnType: a :class:`pyspark.sql.types.DataType` object
:return: a wrapped :class:`UserDefinedFunction`
@@ -255,9 +255,26 @@ def registerFunction(self, name, f, returnType=StringType()):
>>> _ = spark.udf.register("stringLengthInt", len, IntegerType())
>>> spark.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
+
+ >>> import random
+ >>> from pyspark.sql.functions import udf
+ >>> from pyspark.sql.types import IntegerType, StringType
+ >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
+ >>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType())
+ >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP
+ [Row(random_udf()=u'82')]
+ >>> spark.range(1).select(newRandom_udf()).collect() # doctest: +SKIP
+ [Row(random_udf()=u'62')]
"""
- udf = UserDefinedFunction(f, returnType=returnType, name=name,
- evalType=PythonEvalType.SQL_BATCHED_UDF)
+
+ # This is to check whether the input function is a wrapped/native UserDefinedFunction
+ if hasattr(f, 'asNondeterministic'):
+ udf = UserDefinedFunction(f.func, returnType=returnType, name=name,
+ evalType=PythonEvalType.SQL_BATCHED_UDF,
+ deterministic=f.deterministic)
+ else:
+ udf = UserDefinedFunction(f, returnType=returnType, name=name,
+ evalType=PythonEvalType.SQL_BATCHED_UDF)
self._jsparkSession.udf().registerPython(name, udf._judf)
return udf._wrapped()
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index b1e723cdecef3..b8d86cc098e94 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -175,15 +175,15 @@ def range(self, start, end=None, step=1, numPartitions=None):
@ignore_unicode_prefix
@since(1.2)
def registerFunction(self, name, f, returnType=StringType()):
- """Registers a python function (including lambda function) as a UDF
- so it can be used in SQL statements.
+ """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
+ as a UDF. The registered UDF can be used in SQL statement.
In addition to a name and the function itself, the return type can be optionally specified.
When the return type is not given it default to a string and conversion will automatically
be done. For any other return type, the produced object must match the specified type.
:param name: name of the UDF
- :param f: python function
+ :param f: a Python function, or a wrapped/native UserDefinedFunction
:param returnType: a :class:`pyspark.sql.types.DataType` object
:return: a wrapped :class:`UserDefinedFunction`
@@ -203,6 +203,16 @@ def registerFunction(self, name, f, returnType=StringType()):
>>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
+
+ >>> import random
+ >>> from pyspark.sql.functions import udf
+ >>> from pyspark.sql.types import IntegerType, StringType
+ >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
+ >>> newRandom_udf = sqlContext.registerFunction("random_udf", random_udf, StringType())
+ >>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP
+ [Row(random_udf()=u'82')]
+ >>> sqlContext.range(1).select(newRandom_udf()).collect() # doctest: +SKIP
+ [Row(random_udf()=u'62')]
"""
return self.sparkSession.catalog.registerFunction(name, f, returnType)
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 9864dc98c1f33..95eca76fa9888 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -368,6 +368,20 @@ def checkpoint(self, eager=True):
jdf = self._jdf.checkpoint(eager)
return DataFrame(jdf, self.sql_ctx)
+ @since(2.3)
+ def localCheckpoint(self, eager=True):
+ """Returns a locally checkpointed version of this Dataset. Checkpointing can be used to
+ truncate the logical plan of this DataFrame, which is especially useful in iterative
+ algorithms where the plan may grow exponentially. Local checkpoints are stored in the
+ executors using the caching subsystem and therefore they are not reliable.
+
+ :param eager: Whether to checkpoint this DataFrame immediately
+
+ .. note:: Experimental
+ """
+ jdf = self._jdf.localCheckpoint(eager)
+ return DataFrame(jdf, self.sql_ctx)
+
@since(2.1)
def withWatermark(self, eventTime, delayThreshold):
"""Defines an event time watermark for this :class:`DataFrame`. A watermark tracks a point
@@ -1892,7 +1906,9 @@ def toPandas(self):
if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true":
try:
from pyspark.sql.types import _check_dataframe_localize_timestamps
+ from pyspark.sql.utils import require_minimum_pyarrow_version
import pyarrow
+ require_minimum_pyarrow_version()
tables = self._collectAsArrow()
if tables:
table = pyarrow.concat_tables(tables)
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 4e0faddb1c0df..f7b3f29764040 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1099,7 +1099,7 @@ def trunc(date, format):
"""
Returns date truncated to the unit specified by the format.
- :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm'
+ :param format: 'year', 'yyyy', 'yy' or 'month', 'mon', 'mm'
>>> df = spark.createDataFrame([('1997-02-28',)], ['d'])
>>> df.select(trunc(df.d, 'year').alias('year')).collect()
@@ -1111,6 +1111,24 @@ def trunc(date, format):
return Column(sc._jvm.functions.trunc(_to_java_column(date), format))
+@since(2.3)
+def date_trunc(format, timestamp):
+ """
+ Returns timestamp truncated to the unit specified by the format.
+
+ :param format: 'year', 'yyyy', 'yy', 'month', 'mon', 'mm',
+ 'day', 'dd', 'hour', 'minute', 'second', 'week', 'quarter'
+
+ >>> df = spark.createDataFrame([('1997-02-28 05:02:11',)], ['t'])
+ >>> df.select(date_trunc('year', df.t).alias('year')).collect()
+ [Row(year=datetime.datetime(1997, 1, 1, 0, 0))]
+ >>> df.select(date_trunc('mon', df.t).alias('month')).collect()
+ [Row(month=datetime.datetime(1997, 2, 1, 0, 0))]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.date_trunc(format, _to_java_column(timestamp)))
+
+
@since(1.5)
def next_day(date, dayOfWeek):
"""
@@ -1356,7 +1374,8 @@ def hash(*cols):
@ignore_unicode_prefix
def concat(*cols):
"""
- Concatenates multiple input string columns together into a single string column.
+ Concatenates multiple input columns together into a single column.
+ If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.
>>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
>>> df.select(concat(df.s, df.d).alias('s')).collect()
@@ -1830,14 +1849,14 @@ def explode_outer(col):
+---+----------+----+-----+
>>> df.select("id", "a_map", explode_outer("an_array")).show()
- +---+-------------+----+
- | id| a_map| col|
- +---+-------------+----+
- | 1|Map(x -> 1.0)| foo|
- | 1|Map(x -> 1.0)| bar|
- | 2| Map()|null|
- | 3| null|null|
- +---+-------------+----+
+ +---+----------+----+
+ | id| a_map| col|
+ +---+----------+----+
+ | 1|[x -> 1.0]| foo|
+ | 1|[x -> 1.0]| bar|
+ | 2| []|null|
+ | 3| null|null|
+ +---+----------+----+
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.explode_outer(_to_java_column(col))
@@ -1862,14 +1881,14 @@ def posexplode_outer(col):
| 3| null|null|null| null|
+---+----------+----+----+-----+
>>> df.select("id", "a_map", posexplode_outer("an_array")).show()
- +---+-------------+----+----+
- | id| a_map| pos| col|
- +---+-------------+----+----+
- | 1|Map(x -> 1.0)| 0| foo|
- | 1|Map(x -> 1.0)| 1| bar|
- | 2| Map()|null|null|
- | 3| null|null|null|
- +---+-------------+----+----+
+ +---+----------+----+----+
+ | id| a_map| pos| col|
+ +---+----------+----+----+
+ | 1|[x -> 1.0]| 0| foo|
+ | 1|[x -> 1.0]| 1| bar|
+ | 2| []|null|null|
+ | 3| null|null|null|
+ +---+----------+----+----+
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.posexplode_outer(_to_java_column(col))
@@ -2075,9 +2094,14 @@ class PandasUDFType(object):
def udf(f=None, returnType=StringType()):
"""Creates a user defined function (UDF).
- .. note:: The user-defined functions must be deterministic. Due to optimization,
- duplicate invocations may be eliminated or the function may even be invoked more times than
- it is present in the query.
+ .. note:: The user-defined functions are considered deterministic by default. Due to
+ optimization, duplicate invocations may be eliminated or the function may even be invoked
+ more times than it is present in the query. If your function is not deterministic, call
+ `asNondeterministic` on the user defined function. E.g.:
+
+ >>> from pyspark.sql.types import IntegerType
+ >>> import random
+ >>> random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic()
.. note:: The user-defined functions do not support conditional expressions or short curcuiting
in boolean expressions and it ends up with being executed all internally. If the functions
@@ -2141,16 +2165,17 @@ def pandas_udf(f=None, returnType=None, functionType=None):
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> from pyspark.sql.types import IntegerType, StringType
- >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType())
- >>> @pandas_udf(StringType())
+ >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) # doctest: +SKIP
+ >>> @pandas_udf(StringType()) # doctest: +SKIP
... def to_upper(s):
... return s.str.upper()
...
- >>> @pandas_udf("integer", PandasUDFType.SCALAR)
+ >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP
... def add_one(x):
... return x + 1
...
- >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
+ >>> df = spark.createDataFrame([(1, "John Doe", 21)],
+ ... ("id", "name", "age")) # doctest: +SKIP
>>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\
... .show() # doctest: +SKIP
+----------+--------------+------------+
@@ -2159,6 +2184,11 @@ def pandas_udf(f=None, returnType=None, functionType=None):
| 8| JOHN DOE| 22|
+----------+--------------+------------+
+ .. note:: The length of `pandas.Series` within a scalar UDF is not that of the whole input
+ column, but is the length of an internal batch used for each call to the function.
+ Therefore, this can be used, for example, to ensure the length of each returned
+ `pandas.Series`, and can not be used as the column length.
+
2. GROUP_MAP
A group map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame`
@@ -2171,8 +2201,8 @@ def pandas_udf(f=None, returnType=None, functionType=None):
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
- ... ("id", "v"))
- >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP)
+ ... ("id", "v")) # doctest: +SKIP
+ >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) # doctest: +SKIP
... def normalize(pdf):
... v = pdf.v
... return pdf.assign(v=(v - v.mean()) / v.std())
@@ -2189,7 +2219,17 @@ def pandas_udf(f=None, returnType=None, functionType=None):
.. seealso:: :meth:`pyspark.sql.GroupedData.apply`
- .. note:: The user-defined function must be deterministic.
+ .. note:: The user-defined functions are considered deterministic by default. Due to
+ optimization, duplicate invocations may be eliminated or the function may even be invoked
+ more times than it is present in the query. If your function is not deterministic, call
+ `asNondeterministic` on the user defined function. E.g.:
+
+ >>> @pandas_udf('double', PandasUDFType.SCALAR) # doctest: +SKIP
+ ... def random(v):
+ ... import numpy as np
+ ... import pandas as pd
+ ... return pd.Series(np.random.randn(len(v))
+ >>> random = random.asNondeterministic() # doctest: +SKIP
.. note:: The user-defined functions do not support conditional expressions or short curcuiting
in boolean expressions and it ends up with being executed all internally. If the functions
diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index 4d47dd6a3e878..09fae46adf014 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -218,7 +218,7 @@ def apply(self, udf):
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v"))
- >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP)
+ >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) # doctest: +SKIP
... def normalize(pdf):
... v = pdf.v
... return pdf.assign(v=(v - v.mean()) / v.std())
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 1ad974e9aa4c7..49af1bcee5ef8 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -304,7 +304,7 @@ def parquet(self, *paths):
@ignore_unicode_prefix
@since(1.6)
- def text(self, paths):
+ def text(self, paths, wholetext=False):
"""
Loads text files and returns a :class:`DataFrame` whose schema starts with a
string column named "value", and followed by partitioned columns if there
@@ -313,11 +313,16 @@ def text(self, paths):
Each line in the text file is a new row in the resulting DataFrame.
:param paths: string, or list of strings, for input path(s).
+ :param wholetext: if true, read each file from input path(s) as a single row.
>>> df = spark.read.text('python/test_support/sql/text-test.txt')
>>> df.collect()
[Row(value=u'hello'), Row(value=u'this')]
+ >>> df = spark.read.text('python/test_support/sql/text-test.txt', wholetext=True)
+ >>> df.collect()
+ [Row(value=u'hello\\nthis')]
"""
+ self._set_opts(wholetext=wholetext)
if isinstance(paths, basestring):
paths = [paths]
return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(paths)))
@@ -328,7 +333,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
- columnNameOfCorruptRecord=None, multiLine=None):
+ columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None):
"""Loads a CSV file and returns the result as a :class:`DataFrame`.
This function will go through the input once to determine the input schema if
@@ -339,17 +344,17 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
or RDD of Strings storing CSV rows.
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
- :param sep: sets the single character as a separator for each field and value.
+ :param sep: sets a single character as a separator for each field and value.
If None is set, it uses the default value, ``,``.
:param encoding: decodes the CSV files by the given encoding type. If None is set,
it uses the default value, ``UTF-8``.
- :param quote: sets the single character used for escaping quoted values where the
+ :param quote: sets a single character used for escaping quoted values where the
separator can be part of the value. If None is set, it uses the default
value, ``"``. If you would like to turn off quotations, you need to set an
empty string.
- :param escape: sets the single character used for escaping quotes inside an already
+ :param escape: sets a single character used for escaping quotes inside an already
quoted value. If None is set, it uses the default value, ``\``.
- :param comment: sets the single character used for skipping lines beginning with this
+ :param comment: sets a single character used for skipping lines beginning with this
character. By default (None), it is disabled.
:param header: uses the first line as names of columns. If None is set, it uses the
default value, ``false``.
@@ -405,6 +410,10 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
``spark.sql.columnNameOfCorruptRecord``.
:param multiLine: parse records, which may span multiple lines. If None is
set, it uses the default value, ``false``.
+ :param charToEscapeQuoteEscaping: sets a single character used for escaping the escape for
+ the quote character. If None is set, the default value is
+ escape character when escape and quote characters are
+ different, ``\0`` otherwise.
>>> df = spark.read.csv('python/test_support/sql/ages.csv')
>>> df.dtypes
@@ -422,7 +431,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns,
maxCharsPerColumn=maxCharsPerColumn,
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
- columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine)
+ columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
+ charToEscapeQuoteEscaping=charToEscapeQuoteEscaping)
if isinstance(path, basestring):
path = [path]
if type(path) == list:
@@ -809,7 +819,8 @@ def text(self, path, compression=None):
@since(2.0)
def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None,
header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None,
- timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None):
+ timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None,
+ charToEscapeQuoteEscaping=None):
"""Saves the content of the :class:`DataFrame` in CSV format at the specified path.
:param path: the path in any Hadoop supported file system
@@ -824,12 +835,12 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
:param compression: compression codec to use when saving to file. This can be one of the
known case-insensitive shorten names (none, bzip2, gzip, lz4,
snappy and deflate).
- :param sep: sets the single character as a separator for each field and value. If None is
+ :param sep: sets a single character as a separator for each field and value. If None is
set, it uses the default value, ``,``.
- :param quote: sets the single character used for escaping quoted values where the
+ :param quote: sets a single character used for escaping quoted values where the
separator can be part of the value. If None is set, it uses the default
value, ``"``. If an empty string is set, it uses ``u0000`` (null character).
- :param escape: sets the single character used for escaping quotes inside an already
+ :param escape: sets a single character used for escaping quotes inside an already
quoted value. If None is set, it uses the default value, ``\``
:param escapeQuotes: a flag indicating whether values containing quotes should always
be enclosed in quotes. If None is set, it uses the default value
@@ -855,6 +866,10 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
:param ignoreTrailingWhiteSpace: a flag indicating whether or not trailing whitespaces from
values being written should be skipped. If None is set, it
uses the default value, ``true``.
+ :param charToEscapeQuoteEscaping: sets a single character used for escaping the escape for
+ the quote character. If None is set, the default value is
+ escape character when escape and quote characters are
+ different, ``\0`` otherwise..
>>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data'))
"""
@@ -863,7 +878,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
nullValue=nullValue, escapeQuotes=escapeQuotes, quoteAll=quoteAll,
dateFormat=dateFormat, timestampFormat=timestampFormat,
ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
- ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace)
+ ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace,
+ charToEscapeQuoteEscaping=charToEscapeQuoteEscaping)
self._jwrite.csv(path)
@since(1.5)
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index e2435e09af23d..604021c1f45cc 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -325,11 +325,12 @@ def range(self, start, end=None, step=1, numPartitions=None):
return DataFrame(jdf, self._wrapped)
- def _inferSchemaFromList(self, data):
+ def _inferSchemaFromList(self, data, names=None):
"""
Infer schema from list of Row or tuple.
:param data: list of Row or tuple
+ :param names: list of column names
:return: :class:`pyspark.sql.types.StructType`
"""
if not data:
@@ -338,12 +339,12 @@ def _inferSchemaFromList(self, data):
if type(first) is dict:
warnings.warn("inferring schema from dict is deprecated,"
"please use pyspark.sql.Row instead")
- schema = reduce(_merge_type, map(_infer_schema, data))
+ schema = reduce(_merge_type, (_infer_schema(row, names) for row in data))
if _has_nulltype(schema):
raise ValueError("Some of types cannot be determined after inferring")
return schema
- def _inferSchema(self, rdd, samplingRatio=None):
+ def _inferSchema(self, rdd, samplingRatio=None, names=None):
"""
Infer schema from an RDD of Row or tuple.
@@ -360,10 +361,10 @@ def _inferSchema(self, rdd, samplingRatio=None):
"Use pyspark.sql.Row instead")
if samplingRatio is None:
- schema = _infer_schema(first)
+ schema = _infer_schema(first, names=names)
if _has_nulltype(schema):
for row in rdd.take(100)[1:]:
- schema = _merge_type(schema, _infer_schema(row))
+ schema = _merge_type(schema, _infer_schema(row, names=names))
if not _has_nulltype(schema):
break
else:
@@ -372,7 +373,7 @@ def _inferSchema(self, rdd, samplingRatio=None):
else:
if samplingRatio < 0.99:
rdd = rdd.sample(False, float(samplingRatio))
- schema = rdd.map(_infer_schema).reduce(_merge_type)
+ schema = rdd.map(lambda row: _infer_schema(row, names)).reduce(_merge_type)
return schema
def _createFromRDD(self, rdd, schema, samplingRatio):
@@ -380,7 +381,7 @@ def _createFromRDD(self, rdd, schema, samplingRatio):
Create an RDD for DataFrame from an existing RDD, returns the RDD and schema.
"""
if schema is None or isinstance(schema, (list, tuple)):
- struct = self._inferSchema(rdd, samplingRatio)
+ struct = self._inferSchema(rdd, samplingRatio, names=schema)
converter = _create_converter(struct)
rdd = rdd.map(converter)
if isinstance(schema, (list, tuple)):
@@ -406,7 +407,7 @@ def _createFromLocal(self, data, schema):
data = list(data)
if schema is None or isinstance(schema, (list, tuple)):
- struct = self._inferSchemaFromList(data)
+ struct = self._inferSchemaFromList(data, names=schema)
converter = _create_converter(struct)
data = map(converter, data)
if isinstance(schema, (list, tuple)):
@@ -458,21 +459,23 @@ def _convert_from_pandas(self, pdf, schema, timezone):
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if isinstance(field.dataType, TimestampType):
s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone)
- if not copied and s is not pdf[field.name]:
- # Copy once if the series is modified to prevent the original Pandas
- # DataFrame from being updated
- pdf = pdf.copy()
- copied = True
- pdf[field.name] = s
+ if s is not pdf[field.name]:
+ if not copied:
+ # Copy once if the series is modified to prevent the original
+ # Pandas DataFrame from being updated
+ pdf = pdf.copy()
+ copied = True
+ pdf[field.name] = s
else:
for column, series in pdf.iteritems():
- s = _check_series_convert_timestamps_tz_local(pdf[column], timezone)
- if not copied and s is not pdf[column]:
- # Copy once if the series is modified to prevent the original Pandas
- # DataFrame from being updated
- pdf = pdf.copy()
- copied = True
- pdf[column] = s
+ s = _check_series_convert_timestamps_tz_local(series, timezone)
+ if s is not series:
+ if not copied:
+ # Copy once if the series is modified to prevent the original
+ # Pandas DataFrame from being updated
+ pdf = pdf.copy()
+ copied = True
+ pdf[column] = s
# Convert pandas.DataFrame to list of numpy records
np_records = pdf.to_records(index=False)
@@ -493,12 +496,14 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
data types will be used to coerce the data in Pandas to Arrow conversion.
"""
from pyspark.serializers import ArrowSerializer, _create_batch
- from pyspark.sql.types import from_arrow_schema, to_arrow_type, \
- _old_pandas_exception_message, TimestampType
- try:
- from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
- except ImportError as e:
- raise ImportError(_old_pandas_exception_message(e))
+ from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType
+ from pyspark.sql.utils import require_minimum_pandas_version, \
+ require_minimum_pyarrow_version
+
+ require_minimum_pandas_version()
+ require_minimum_pyarrow_version()
+
+ from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
# Determine arrow types to coerce data when creating batches
if isinstance(schema, StructType):
@@ -643,7 +648,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
# If no schema supplied by user then get the names of columns only
if schema is None:
- schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in data.columns]
+ schema = [str(x) if not isinstance(x, basestring) else
+ (x.encode('utf-8') if not isinstance(x, str) else x)
+ for x in data.columns]
if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \
and len(data) > 0:
diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index 0cf702143c773..24ae3776a217b 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -490,6 +490,23 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
else:
raise TypeError("path can be only a single string")
+ @since(2.3)
+ def orc(self, path):
+ """Loads a ORC file stream, returning the result as a :class:`DataFrame`.
+
+ .. note:: Evolving.
+
+ >>> orc_sdf = spark.readStream.schema(sdf_schema).orc(tempfile.mkdtemp())
+ >>> orc_sdf.isStreaming
+ True
+ >>> orc_sdf.schema == sdf_schema
+ True
+ """
+ if isinstance(path, basestring):
+ return self._df(self._jreader.orc(path))
+ else:
+ raise TypeError("path can be only a single string")
+
@since(2.0)
def parquet(self, path):
"""Loads a Parquet file stream, returning the result as a :class:`DataFrame`.
@@ -543,7 +560,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
- columnNameOfCorruptRecord=None, multiLine=None):
+ columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None):
"""Loads a CSV file stream and returns the result as a :class:`DataFrame`.
This function will go through the input once to determine the input schema if
@@ -555,17 +572,17 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
:param path: string, or list of strings, for input path(s).
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
- :param sep: sets the single character as a separator for each field and value.
+ :param sep: sets a single character as a separator for each field and value.
If None is set, it uses the default value, ``,``.
:param encoding: decodes the CSV files by the given encoding type. If None is set,
it uses the default value, ``UTF-8``.
- :param quote: sets the single character used for escaping quoted values where the
+ :param quote: sets a single character used for escaping quoted values where the
separator can be part of the value. If None is set, it uses the default
value, ``"``. If you would like to turn off quotations, you need to set an
empty string.
- :param escape: sets the single character used for escaping quotes inside an already
+ :param escape: sets a single character used for escaping quotes inside an already
quoted value. If None is set, it uses the default value, ``\``.
- :param comment: sets the single character used for skipping lines beginning with this
+ :param comment: sets a single character used for skipping lines beginning with this
character. By default (None), it is disabled.
:param header: uses the first line as names of columns. If None is set, it uses the
default value, ``false``.
@@ -621,6 +638,10 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
``spark.sql.columnNameOfCorruptRecord``.
:param multiLine: parse one record, which may span multiple lines. If None is
set, it uses the default value, ``false``.
+ :param charToEscapeQuoteEscaping: sets a single character used for escaping the escape for
+ the quote character. If None is set, the default value is
+ escape character when escape and quote characters are
+ different, ``\0`` otherwise..
>>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema)
>>> csv_sdf.isStreaming
@@ -636,7 +657,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns,
maxCharsPerColumn=maxCharsPerColumn,
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
- columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine)
+ columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
+ charToEscapeQuoteEscaping=charToEscapeQuoteEscaping)
if isinstance(path, basestring):
return self._df(self._jreader.csv(path))
else:
@@ -771,6 +793,10 @@ def trigger(self, processingTime=None, once=None):
.. note:: Evolving.
:param processingTime: a processing time interval as a string, e.g. '5 seconds', '1 minute'.
+ Set a trigger that runs a query periodically based on the processing
+ time. Only one trigger can be set.
+ :param once: if set to True, set a trigger that processes only one batch of data in a
+ streaming query then terminates the query. Only one trigger can be set.
>>> # trigger the query for execution every 5 seconds
>>> writer = sdf.writeStream.trigger(processingTime='5 seconds')
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index b4d32d8de8a22..80a94a91a87b3 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -53,7 +53,8 @@
try:
import pandas
try:
- import pandas.api
+ from pyspark.sql.utils import require_minimum_pandas_version
+ require_minimum_pandas_version()
_have_pandas = True
except:
_have_old_pandas = True
@@ -67,6 +68,7 @@
from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier
from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings
from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings
+from pyspark.sql.types import _merge_type
from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests
from pyspark.sql.functions import UserDefinedFunction, sha2, lit
from pyspark.sql.window import Window
@@ -377,6 +379,55 @@ def test_udf2(self):
[res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
self.assertEqual(4, res[0])
+ def test_udf3(self):
+ twoargs = self.spark.catalog.registerFunction(
+ "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y), IntegerType())
+ self.assertEqual(twoargs.deterministic, True)
+ [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
+ self.assertEqual(row[0], 5)
+
+ def test_nondeterministic_udf(self):
+ # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
+ from pyspark.sql.functions import udf
+ import random
+ udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
+ self.assertEqual(udf_random_col.deterministic, False)
+ df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
+ udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
+ [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
+ self.assertEqual(row[0] + 10, row[1])
+
+ def test_nondeterministic_udf2(self):
+ import random
+ from pyspark.sql.functions import udf
+ random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic()
+ self.assertEqual(random_udf.deterministic, False)
+ random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf, StringType())
+ self.assertEqual(random_udf1.deterministic, False)
+ [row] = self.spark.sql("SELECT randInt()").collect()
+ self.assertEqual(row[0], "6")
+ [row] = self.spark.range(1).select(random_udf1()).collect()
+ self.assertEqual(row[0], "6")
+ [row] = self.spark.range(1).select(random_udf()).collect()
+ self.assertEqual(row[0], 6)
+ # render_doc() reproduces the help() exception without printing output
+ pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType()))
+ pydoc.render_doc(random_udf)
+ pydoc.render_doc(random_udf1)
+ pydoc.render_doc(udf(lambda x: x).asNondeterministic)
+
+ def test_nondeterministic_udf_in_aggregate(self):
+ from pyspark.sql.functions import udf, sum
+ import random
+ udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic()
+ df = self.spark.range(10)
+
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(AnalysisException, "nondeterministic"):
+ df.groupby('id').agg(sum(udf_random_col())).collect()
+ with self.assertRaisesRegexp(AnalysisException, "nondeterministic"):
+ df.agg(sum(udf_random_col())).collect()
+
def test_chained_udf(self):
self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType())
[row] = self.spark.sql("SELECT double(1)").collect()
@@ -557,7 +608,6 @@ def test_read_multiple_orc_file(self):
def test_udf_with_input_file_name(self):
from pyspark.sql.functions import udf, input_file_name
- from pyspark.sql.types import StringType
sourceFile = udf(lambda path: path, StringType())
filePath = "python/test_support/sql/people1.json"
row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
@@ -565,7 +615,6 @@ def test_udf_with_input_file_name(self):
def test_udf_with_input_file_name_for_hadooprdd(self):
from pyspark.sql.functions import udf, input_file_name
- from pyspark.sql.types import StringType
def filename(path):
return path
@@ -625,7 +674,6 @@ def test_udf_with_string_return_type(self):
def test_udf_shouldnt_accept_noncallable_object(self):
from pyspark.sql.functions import UserDefinedFunction
- from pyspark.sql.types import StringType
non_callable = None
self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())
@@ -851,6 +899,15 @@ def test_infer_schema(self):
result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
self.assertEqual(1, result.head()[0])
+ def test_infer_schema_not_enough_names(self):
+ df = self.spark.createDataFrame([["a", "b"]], ["col1"])
+ self.assertEqual(df.columns, ['col1', '_2'])
+
+ def test_infer_schema_fails(self):
+ with self.assertRaisesRegexp(TypeError, 'field a'):
+ self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]),
+ schema=["a", "b"], samplingRatio=0.99)
+
def test_infer_nested_schema(self):
NestedRow = Row("f1", "f2")
nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}),
@@ -871,6 +928,10 @@ def test_infer_nested_schema(self):
df = self.spark.createDataFrame(rdd)
self.assertEqual(Row(field1=1, field2=u'row1'), df.first())
+ def test_create_dataframe_from_dict_respects_schema(self):
+ df = self.spark.createDataFrame([{'a': 1}], ["b"])
+ self.assertEqual(df.columns, ['b'])
+
def test_create_dataframe_from_objects(self):
data = [MyObject(1, "1"), MyObject(2, "2")]
df = self.spark.createDataFrame(data)
@@ -1289,7 +1350,6 @@ def test_between_function(self):
df.filter(df.a.between(df.b, df.c)).collect())
def test_struct_type(self):
- from pyspark.sql.types import StructType, StringType, StructField
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
struct2 = StructType([StructField("f1", StringType(), True),
StructField("f2", StringType(), True, None)])
@@ -1358,7 +1418,6 @@ def test_parse_datatype_string(self):
_parse_datatype_string("a INT, c DOUBLE"))
def test_metadata_null(self):
- from pyspark.sql.types import StructType, StringType, StructField
schema = StructType([StructField("f1", StringType(), True, None),
StructField("f2", StringType(), True, {'a': None})])
rdd = self.sc.parallelize([["a", "b"], ["c", "d"]])
@@ -1727,6 +1786,92 @@ def test_infer_long_type(self):
self.assertEqual(_infer_type(2**61), LongType())
self.assertEqual(_infer_type(2**71), LongType())
+ def test_merge_type(self):
+ self.assertEqual(_merge_type(LongType(), NullType()), LongType())
+ self.assertEqual(_merge_type(NullType(), LongType()), LongType())
+
+ self.assertEqual(_merge_type(LongType(), LongType()), LongType())
+
+ self.assertEqual(_merge_type(
+ ArrayType(LongType()),
+ ArrayType(LongType())
+ ), ArrayType(LongType()))
+ with self.assertRaisesRegexp(TypeError, 'element in array'):
+ _merge_type(ArrayType(LongType()), ArrayType(DoubleType()))
+
+ self.assertEqual(_merge_type(
+ MapType(StringType(), LongType()),
+ MapType(StringType(), LongType())
+ ), MapType(StringType(), LongType()))
+ with self.assertRaisesRegexp(TypeError, 'key of map'):
+ _merge_type(
+ MapType(StringType(), LongType()),
+ MapType(DoubleType(), LongType()))
+ with self.assertRaisesRegexp(TypeError, 'value of map'):
+ _merge_type(
+ MapType(StringType(), LongType()),
+ MapType(StringType(), DoubleType()))
+
+ self.assertEqual(_merge_type(
+ StructType([StructField("f1", LongType()), StructField("f2", StringType())]),
+ StructType([StructField("f1", LongType()), StructField("f2", StringType())])
+ ), StructType([StructField("f1", LongType()), StructField("f2", StringType())]))
+ with self.assertRaisesRegexp(TypeError, 'field f1'):
+ _merge_type(
+ StructType([StructField("f1", LongType()), StructField("f2", StringType())]),
+ StructType([StructField("f1", DoubleType()), StructField("f2", StringType())]))
+
+ self.assertEqual(_merge_type(
+ StructType([StructField("f1", StructType([StructField("f2", LongType())]))]),
+ StructType([StructField("f1", StructType([StructField("f2", LongType())]))])
+ ), StructType([StructField("f1", StructType([StructField("f2", LongType())]))]))
+ with self.assertRaisesRegexp(TypeError, 'field f2 in field f1'):
+ _merge_type(
+ StructType([StructField("f1", StructType([StructField("f2", LongType())]))]),
+ StructType([StructField("f1", StructType([StructField("f2", StringType())]))]))
+
+ self.assertEqual(_merge_type(
+ StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]),
+ StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())])
+ ), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]))
+ with self.assertRaisesRegexp(TypeError, 'element in array field f1'):
+ _merge_type(
+ StructType([
+ StructField("f1", ArrayType(LongType())),
+ StructField("f2", StringType())]),
+ StructType([
+ StructField("f1", ArrayType(DoubleType())),
+ StructField("f2", StringType())]))
+
+ self.assertEqual(_merge_type(
+ StructType([
+ StructField("f1", MapType(StringType(), LongType())),
+ StructField("f2", StringType())]),
+ StructType([
+ StructField("f1", MapType(StringType(), LongType())),
+ StructField("f2", StringType())])
+ ), StructType([
+ StructField("f1", MapType(StringType(), LongType())),
+ StructField("f2", StringType())]))
+ with self.assertRaisesRegexp(TypeError, 'value of map field f1'):
+ _merge_type(
+ StructType([
+ StructField("f1", MapType(StringType(), LongType())),
+ StructField("f2", StringType())]),
+ StructType([
+ StructField("f1", MapType(StringType(), DoubleType())),
+ StructField("f2", StringType())]))
+
+ self.assertEqual(_merge_type(
+ StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]),
+ StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))])
+ ), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]))
+ with self.assertRaisesRegexp(TypeError, 'key of map element in array field f1'):
+ _merge_type(
+ StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]),
+ StructType([StructField("f1", ArrayType(MapType(DoubleType(), LongType())))])
+ )
+
def test_filter_with_datetime(self):
time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000)
date = time.date()
@@ -2600,7 +2745,7 @@ def test_to_pandas(self):
@unittest.skipIf(not _have_old_pandas, "Old Pandas not installed")
def test_to_pandas_old(self):
with QuietTest(self.sc):
- with self.assertRaisesRegexp(ImportError, 'Pandas \(.*\) must be installed'):
+ with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
self._to_pandas()
@unittest.skipIf(not _have_pandas, "Pandas not installed")
@@ -2643,7 +2788,7 @@ def test_create_dataframe_from_old_pandas(self):
pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
"d": [pd.Timestamp.now().date()]})
with QuietTest(self.sc):
- with self.assertRaisesRegexp(ImportError, 'Pandas \(.*\) must be installed'):
+ with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
self.spark.createDataFrame(pdf)
@@ -3141,6 +3286,7 @@ class ArrowTests(ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
from datetime import datetime
+ from decimal import Decimal
ReusedSQLTestCase.setUpClass()
# Synchronize default timezone between Python and Java
@@ -3157,11 +3303,15 @@ def setUpClass(cls):
StructField("3_long_t", LongType(), True),
StructField("4_float_t", FloatType(), True),
StructField("5_double_t", DoubleType(), True),
- StructField("6_date_t", DateType(), True),
- StructField("7_timestamp_t", TimestampType(), True)])
- cls.data = [(u"a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
- (u"b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
- (u"c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
+ StructField("6_decimal_t", DecimalType(38, 18), True),
+ StructField("7_date_t", DateType(), True),
+ StructField("8_timestamp_t", TimestampType(), True)])
+ cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"),
+ datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
+ (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"),
+ datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
+ (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"),
+ datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
@classmethod
def tearDownClass(cls):
@@ -3189,10 +3339,11 @@ def create_pandas_data_frame(self):
return pd.DataFrame(data=data_dict)
def test_unsupported_datatype(self):
- schema = StructType([StructField("decimal", DecimalType(), True)])
+ schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
with QuietTest(self.sc):
- self.assertRaises(Exception, lambda: df.toPandas())
+ with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
+ df.toPandas()
def test_null_conversion(self):
df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] +
@@ -3292,7 +3443,7 @@ def test_createDataFrame_respect_session_timezone(self):
self.assertNotEqual(result_ny, result_la)
# Correct result_la by adjusting 3 hours difference between Los Angeles and New York
- result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '7_timestamp_t' else v
+ result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '8_timestamp_t' else v
for k, v in row.asDict().items()})
for row in result_la]
self.assertEqual(result_ny, result_la_corrected)
@@ -3316,11 +3467,11 @@ def test_createDataFrame_with_incorrect_schema(self):
def test_createDataFrame_with_names(self):
pdf = self.create_pandas_data_frame()
# Test that schema as a list of column names gets applied
- df = self.spark.createDataFrame(pdf, schema=list('abcdefg'))
- self.assertEquals(df.schema.fieldNames(), list('abcdefg'))
+ df = self.spark.createDataFrame(pdf, schema=list('abcdefgh'))
+ self.assertEquals(df.schema.fieldNames(), list('abcdefgh'))
# Test that schema as tuple of column names gets applied
- df = self.spark.createDataFrame(pdf, schema=tuple('abcdefg'))
- self.assertEquals(df.schema.fieldNames(), list('abcdefg'))
+ df = self.spark.createDataFrame(pdf, schema=tuple('abcdefgh'))
+ self.assertEquals(df.schema.fieldNames(), list('abcdefgh'))
def test_createDataFrame_column_name_encoding(self):
import pandas as pd
@@ -3339,10 +3490,11 @@ def test_createDataFrame_with_single_data_type(self):
self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int")
def test_createDataFrame_does_not_modify_input(self):
+ import pandas as pd
# Some series get converted for Spark to consume, this makes sure input is unchanged
pdf = self.create_pandas_data_frame()
# Use a nanosecond value to make sure it is not truncated
- pdf.ix[0, '7_timestamp_t'] = 1
+ pdf.ix[0, '8_timestamp_t'] = pd.Timestamp(1)
# Integers with nulls will get NaNs filled with 0 and will be casted
pdf.ix[1, '2_int_t'] = None
pdf_copy = pdf.copy(deep=True)
@@ -3355,7 +3507,42 @@ def test_schema_conversion_roundtrip(self):
schema_rt = from_arrow_schema(arrow_schema)
self.assertEquals(self.schema, schema_rt)
+ def test_createDataFrame_with_array_type(self):
+ import pandas as pd
+ pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [[u"x", u"y"], [u"y", u"z"]]})
+ df, df_arrow = self._createDataFrame_toggle(pdf)
+ result = df.collect()
+ result_arrow = df_arrow.collect()
+ expected = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)]
+ for r in range(len(expected)):
+ for e in range(len(expected[r])):
+ self.assertTrue(expected[r][e] == result_arrow[r][e] and
+ result[r][e] == result_arrow[r][e])
+
+ def test_toPandas_with_array_type(self):
+ expected = [([1, 2], [u"x", u"y"]), ([3, 4], [u"y", u"z"])]
+ array_schema = StructType([StructField("a", ArrayType(IntegerType())),
+ StructField("b", ArrayType(StringType()))])
+ df = self.spark.createDataFrame(expected, schema=array_schema)
+ pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
+ result = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)]
+ result_arrow = [tuple(list(e) for e in rec) for rec in pdf_arrow.to_records(index=False)]
+ for r in range(len(expected)):
+ for e in range(len(expected[r])):
+ self.assertTrue(expected[r][e] == result_arrow[r][e] and
+ result[r][e] == result_arrow[r][e])
+
+ def test_createDataFrame_with_int_col_names(self):
+ import numpy as np
+ import pandas as pd
+ pdf = pd.DataFrame(np.random.rand(4, 2))
+ df, df_arrow = self._createDataFrame_toggle(pdf)
+ pdf_col_names = [str(c) for c in pdf.columns]
+ self.assertEqual(pdf_col_names, df.columns)
+ self.assertEqual(pdf_col_names, df_arrow.columns)
+
+@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
class PandasUDFTests(ReusedSQLTestCase):
def test_pandas_udf_basic(self):
from pyspark.rdd import PythonEvalType
@@ -3503,6 +3690,18 @@ def tearDownClass(cls):
time.tzset()
ReusedSQLTestCase.tearDownClass()
+ @property
+ def random_udf(self):
+ from pyspark.sql.functions import pandas_udf
+
+ @pandas_udf('double')
+ def random_udf(v):
+ import pandas as pd
+ import numpy as np
+ return pd.Series(np.random.random(len(v)))
+ random_udf = random_udf.asNondeterministic()
+ return random_udf
+
def test_vectorized_udf_basic(self):
from pyspark.sql.functions import pandas_udf, col
df = self.spark.range(10).select(
@@ -3511,6 +3710,7 @@ def test_vectorized_udf_basic(self):
col('id').alias('long'),
col('id').cast('float').alias('float'),
col('id').cast('double').alias('double'),
+ col('id').cast('decimal').alias('decimal'),
col('id').cast('boolean').alias('bool'))
f = lambda x: x
str_f = pandas_udf(f, StringType())
@@ -3518,10 +3718,12 @@ def test_vectorized_udf_basic(self):
long_f = pandas_udf(f, LongType())
float_f = pandas_udf(f, FloatType())
double_f = pandas_udf(f, DoubleType())
+ decimal_f = pandas_udf(f, DecimalType())
bool_f = pandas_udf(f, BooleanType())
res = df.select(str_f(col('str')), int_f(col('int')),
long_f(col('long')), float_f(col('float')),
- double_f(col('double')), bool_f(col('bool')))
+ double_f(col('double')), decimal_f('decimal'),
+ bool_f(col('bool')))
self.assertEquals(df.collect(), res.collect())
def test_vectorized_udf_null_boolean(self):
@@ -3587,6 +3789,16 @@ def test_vectorized_udf_null_double(self):
res = df.select(double_f(col('double')))
self.assertEquals(df.collect(), res.collect())
+ def test_vectorized_udf_null_decimal(self):
+ from decimal import Decimal
+ from pyspark.sql.functions import pandas_udf, col
+ data = [(Decimal(3.0),), (Decimal(5.0),), (Decimal(-1.0),), (None,)]
+ schema = StructType().add("decimal", DecimalType(38, 18))
+ df = self.spark.createDataFrame(data, schema)
+ decimal_f = pandas_udf(lambda x: x, DecimalType(38, 18))
+ res = df.select(decimal_f(col('decimal')))
+ self.assertEquals(df.collect(), res.collect())
+
def test_vectorized_udf_null_string(self):
from pyspark.sql.functions import pandas_udf, col
data = [("foo",), (None,), ("bar",), ("bar",)]
@@ -3604,6 +3816,7 @@ def test_vectorized_udf_datatype_string(self):
col('id').alias('long'),
col('id').cast('float').alias('float'),
col('id').cast('double').alias('double'),
+ col('id').cast('decimal').alias('decimal'),
col('id').cast('boolean').alias('bool'))
f = lambda x: x
str_f = pandas_udf(f, 'string')
@@ -3611,12 +3824,32 @@ def test_vectorized_udf_datatype_string(self):
long_f = pandas_udf(f, 'long')
float_f = pandas_udf(f, 'float')
double_f = pandas_udf(f, 'double')
+ decimal_f = pandas_udf(f, 'decimal(38, 18)')
bool_f = pandas_udf(f, 'boolean')
res = df.select(str_f(col('str')), int_f(col('int')),
long_f(col('long')), float_f(col('float')),
- double_f(col('double')), bool_f(col('bool')))
+ double_f(col('double')), decimal_f('decimal'),
+ bool_f(col('bool')))
self.assertEquals(df.collect(), res.collect())
+ def test_vectorized_udf_array_type(self):
+ from pyspark.sql.functions import pandas_udf, col
+ data = [([1, 2],), ([3, 4],)]
+ array_schema = StructType([StructField("array", ArrayType(IntegerType()))])
+ df = self.spark.createDataFrame(data, schema=array_schema)
+ array_f = pandas_udf(lambda x: x, ArrayType(IntegerType()))
+ result = df.select(array_f(col('array')))
+ self.assertEquals(df.collect(), result.collect())
+
+ def test_vectorized_udf_null_array(self):
+ from pyspark.sql.functions import pandas_udf, col
+ data = [([1, 2],), (None,), (None,), ([3, 4],), (None,)]
+ array_schema = StructType([StructField("array", ArrayType(IntegerType()))])
+ df = self.spark.createDataFrame(data, schema=array_schema)
+ array_f = pandas_udf(lambda x: x, ArrayType(IntegerType()))
+ result = df.select(array_f(col('array')))
+ self.assertEquals(df.collect(), result.collect())
+
def test_vectorized_udf_complex(self):
from pyspark.sql.functions import pandas_udf, col, expr
df = self.spark.range(10).select(
@@ -3671,9 +3904,9 @@ def test_vectorized_udf_chained(self):
def test_vectorized_udf_wrong_return_type(self):
from pyspark.sql.functions import pandas_udf, col
df = self.spark.range(10)
- f = pandas_udf(lambda x: x * 1.0, StringType())
+ f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
with QuietTest(self.sc):
- with self.assertRaisesRegexp(Exception, 'Invalid.*type'):
+ with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'):
df.select(f(col('id'))).collect()
def test_vectorized_udf_return_scalar(self):
@@ -3710,12 +3943,12 @@ def test_vectorized_udf_varargs(self):
def test_vectorized_udf_unsupported_types(self):
from pyspark.sql.functions import pandas_udf, col
- schema = StructType([StructField("dt", DecimalType(), True)])
+ schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
- f = pandas_udf(lambda x: x, DecimalType())
+ f = pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
- df.select(f(col('dt'))).collect()
+ df.select(f(col('map'))).collect()
def test_vectorized_udf_null_date(self):
from pyspark.sql.functions import pandas_udf, col
@@ -3791,6 +4024,7 @@ def gen_timestamps(id):
def test_vectorized_udf_check_config(self):
from pyspark.sql.functions import pandas_udf, col
+ import pandas as pd
orig_value = self.spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch", None)
self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 3)
try:
@@ -3798,11 +4032,11 @@ def test_vectorized_udf_check_config(self):
@pandas_udf(returnType=LongType())
def check_records_per_batch(x):
- self.assertTrue(x.size <= 3)
- return x
+ return pd.Series(x.size).repeat(x.size)
- result = df.select(check_records_per_batch(col("id")))
- self.assertEqual(df.collect(), result.collect())
+ result = df.select(check_records_per_batch(col("id"))).collect()
+ for (r,) in result:
+ self.assertTrue(r <= 3)
finally:
if orig_value is None:
self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch")
@@ -3851,6 +4085,33 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self):
finally:
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)
+ def test_nondeterministic_udf(self):
+ # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
+ from pyspark.sql.functions import udf, pandas_udf, col
+
+ @pandas_udf('double')
+ def plus_ten(v):
+ return v + 10
+ random_udf = self.random_udf
+
+ df = self.spark.range(10).withColumn('rand', random_udf(col('id')))
+ result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas()
+
+ self.assertEqual(random_udf.deterministic, False)
+ self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10))
+
+ def test_nondeterministic_udf_in_aggregate(self):
+ from pyspark.sql.functions import pandas_udf, sum
+
+ df = self.spark.range(10)
+ random_udf = self.random_udf
+
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'):
+ df.groupby(df.id).agg(sum(random_udf(df.id))).collect()
+ with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'):
+ df.agg(sum(random_udf(df.id))).collect()
+
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
class GroupbyApplyTests(ReusedSQLTestCase):
@@ -3974,12 +4235,12 @@ def test_wrong_return_type(self):
foo = pandas_udf(
lambda pdf: pdf,
- 'id long, v string',
+ 'id long, v map',
PandasUDFType.GROUP_MAP
)
with QuietTest(self.sc):
- with self.assertRaisesRegexp(Exception, 'Invalid.*type'):
+ with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'):
df.groupby('id').apply(foo).sort('id').toPandas()
def test_wrong_args(self):
@@ -4009,7 +4270,8 @@ def test_wrong_args(self):
def test_unsupported_types(self):
from pyspark.sql.functions import pandas_udf, col, PandasUDFType
schema = StructType(
- [StructField("id", LongType(), True), StructField("dt", DecimalType(), True)])
+ [StructField("id", LongType(), True),
+ StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([(1, None,)], schema=schema)
f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUP_MAP)
with QuietTest(self.sc):
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 78abc32a35a1c..0dc5823f72a3c 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1073,7 +1073,7 @@ def _infer_type(obj):
raise TypeError("not supported type: %s" % type(obj))
-def _infer_schema(row):
+def _infer_schema(row, names=None):
"""Infer the schema from dict/namedtuple/object"""
if isinstance(row, dict):
items = sorted(row.items())
@@ -1084,7 +1084,10 @@ def _infer_schema(row):
elif hasattr(row, "_fields"): # namedtuple
items = zip(row._fields, tuple(row))
else:
- names = ['_%d' % i for i in range(1, len(row) + 1)]
+ if names is None:
+ names = ['_%d' % i for i in range(1, len(row) + 1)]
+ elif len(names) < len(row):
+ names.extend('_%d' % i for i in range(len(names) + 1, len(row) + 1))
items = zip(names, row)
elif hasattr(row, "__dict__"): # object
@@ -1109,19 +1112,27 @@ def _has_nulltype(dt):
return isinstance(dt, NullType)
-def _merge_type(a, b):
+def _merge_type(a, b, name=None):
+ if name is None:
+ new_msg = lambda msg: msg
+ new_name = lambda n: "field %s" % n
+ else:
+ new_msg = lambda msg: "%s: %s" % (name, msg)
+ new_name = lambda n: "field %s in %s" % (n, name)
+
if isinstance(a, NullType):
return b
elif isinstance(b, NullType):
return a
elif type(a) is not type(b):
# TODO: type cast (such as int -> long)
- raise TypeError("Can not merge type %s and %s" % (type(a), type(b)))
+ raise TypeError(new_msg("Can not merge type %s and %s" % (type(a), type(b))))
# same type
if isinstance(a, StructType):
nfs = dict((f.name, f.dataType) for f in b.fields)
- fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType())))
+ fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()),
+ name=new_name(f.name)))
for f in a.fields]
names = set([f.name for f in fields])
for n in nfs:
@@ -1130,11 +1141,12 @@ def _merge_type(a, b):
return StructType(fields)
elif isinstance(a, ArrayType):
- return ArrayType(_merge_type(a.elementType, b.elementType), True)
+ return ArrayType(_merge_type(a.elementType, b.elementType,
+ name='element in array %s' % name), True)
elif isinstance(a, MapType):
- return MapType(_merge_type(a.keyType, b.keyType),
- _merge_type(a.valueType, b.valueType),
+ return MapType(_merge_type(a.keyType, b.keyType, name='key of map %s' % name),
+ _merge_type(a.valueType, b.valueType, name='value of map %s' % name),
True)
else:
return a
@@ -1617,7 +1629,7 @@ def to_arrow_type(dt):
elif type(dt) == DoubleType:
arrow_type = pa.float64()
elif type(dt) == DecimalType:
- arrow_type = pa.decimal(dt.precision, dt.scale)
+ arrow_type = pa.decimal128(dt.precision, dt.scale)
elif type(dt) == StringType:
arrow_type = pa.string()
elif type(dt) == DateType:
@@ -1625,6 +1637,8 @@ def to_arrow_type(dt):
elif type(dt) == TimestampType:
# Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read
arrow_type = pa.timestamp('us', tz='UTC')
+ elif type(dt) == ArrayType:
+ arrow_type = pa.list_(to_arrow_type(dt.elementType))
else:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
return arrow_type
@@ -1642,30 +1656,31 @@ def to_arrow_schema(schema):
def from_arrow_type(at):
""" Convert pyarrow type to Spark data type.
"""
- # TODO: newer pyarrow has is_boolean(at) functions that would be better to check type
- import pyarrow as pa
- if at == pa.bool_():
+ import pyarrow.types as types
+ if types.is_boolean(at):
spark_type = BooleanType()
- elif at == pa.int8():
+ elif types.is_int8(at):
spark_type = ByteType()
- elif at == pa.int16():
+ elif types.is_int16(at):
spark_type = ShortType()
- elif at == pa.int32():
+ elif types.is_int32(at):
spark_type = IntegerType()
- elif at == pa.int64():
+ elif types.is_int64(at):
spark_type = LongType()
- elif at == pa.float32():
+ elif types.is_float32(at):
spark_type = FloatType()
- elif at == pa.float64():
+ elif types.is_float64(at):
spark_type = DoubleType()
- elif type(at) == pa.DecimalType:
+ elif types.is_decimal(at):
spark_type = DecimalType(precision=at.precision, scale=at.scale)
- elif at == pa.string():
+ elif types.is_string(at):
spark_type = StringType()
- elif at == pa.date32():
+ elif types.is_date32(at):
spark_type = DateType()
- elif type(at) == pa.TimestampType:
+ elif types.is_timestamp(at):
spark_type = TimestampType()
+ elif types.is_list(at):
+ spark_type = ArrayType(from_arrow_type(at.value_type))
else:
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
return spark_type
@@ -1679,13 +1694,6 @@ def from_arrow_schema(arrow_schema):
for field in arrow_schema])
-def _old_pandas_exception_message(e):
- """ Create an error message for importing old Pandas.
- """
- msg = "note: Pandas (>=0.19.2) must be installed and available on calling Python process"
- return "%s\n%s" % (_exception_message(e), msg)
-
-
def _check_dataframe_localize_timestamps(pdf, timezone):
"""
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone
@@ -1694,10 +1702,10 @@ def _check_dataframe_localize_timestamps(pdf, timezone):
:param timezone: the timezone to convert. if None then use local timezone
:return pandas.DataFrame where any timezone aware columns have been converted to tz-naive
"""
- try:
- from pandas.api.types import is_datetime64tz_dtype
- except ImportError as e:
- raise ImportError(_old_pandas_exception_message(e))
+ from pyspark.sql.utils import require_minimum_pandas_version
+ require_minimum_pandas_version()
+
+ from pandas.api.types import is_datetime64tz_dtype
tz = timezone or 'tzlocal()'
for column, series in pdf.iteritems():
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
@@ -1715,10 +1723,10 @@ def _check_series_convert_timestamps_internal(s, timezone):
:param timezone: the timezone to convert. if None then use local timezone
:return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone
"""
- try:
- from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
- except ImportError as e:
- raise ImportError(_old_pandas_exception_message(e))
+ from pyspark.sql.utils import require_minimum_pandas_version
+ require_minimum_pandas_version()
+
+ from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64_dtype(s.dtype):
tz = timezone or 'tzlocal()'
@@ -1738,11 +1746,11 @@ def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone):
:param to_timezone: the timezone to convert to. if None then use local timezone
:return pandas.Series where if it is a timestamp, has been converted to tz-naive
"""
- try:
- import pandas as pd
- from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype
- except ImportError as e:
- raise ImportError(_old_pandas_exception_message(e))
+ from pyspark.sql.utils import require_minimum_pandas_version
+ require_minimum_pandas_version()
+
+ import pandas as pd
+ from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype
from_tz = from_timezone or 'tzlocal()'
to_tz = to_timezone or 'tzlocal()'
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index d98ecc3e86a2a..7f7be9d339b9f 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -34,26 +34,31 @@ def _wrap_function(sc, func, returnType):
def _create_udf(f, returnType, evalType):
- if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF:
+
+ if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF or \
+ evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF:
import inspect
+ from pyspark.sql.utils import require_minimum_pyarrow_version
+
+ require_minimum_pyarrow_version()
argspec = inspect.getargspec(f)
- if len(argspec.args) == 0 and argspec.varargs is None:
+
+ if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF and len(argspec.args) == 0 and \
+ argspec.varargs is None:
raise ValueError(
"Invalid function: 0-arg pandas_udfs are not supported. "
"Instead, create a 1-arg pandas_udf and ignore the arg in your function."
)
- elif evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF:
- import inspect
- argspec = inspect.getargspec(f)
- if len(argspec.args) != 1:
+ if evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF and len(argspec.args) != 1:
raise ValueError(
"Invalid function: pandas_udfs with function type GROUP_MAP "
"must take a single arg that is a pandas DataFrame."
)
# Set the name of the UserDefinedFunction object to be the name of function f
- udf_obj = UserDefinedFunction(f, returnType=returnType, name=None, evalType=evalType)
+ udf_obj = UserDefinedFunction(
+ f, returnType=returnType, name=None, evalType=evalType, deterministic=True)
return udf_obj._wrapped()
@@ -64,8 +69,10 @@ class UserDefinedFunction(object):
.. versionadded:: 1.3
"""
def __init__(self, func,
- returnType=StringType(), name=None,
- evalType=PythonEvalType.SQL_BATCHED_UDF):
+ returnType=StringType(),
+ name=None,
+ evalType=PythonEvalType.SQL_BATCHED_UDF,
+ deterministic=True):
if not callable(func):
raise TypeError(
"Invalid function: not a function or callable (__call__ is not defined): "
@@ -89,6 +96,7 @@ def __init__(self, func,
func.__name__ if hasattr(func, '__name__')
else func.__class__.__name__)
self.evalType = evalType
+ self.deterministic = deterministic
@property
def returnType(self):
@@ -126,7 +134,7 @@ def _create_judf(self):
wrapped_func = _wrap_function(sc, self.func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
- self._name, wrapped_func, jdt, self.evalType)
+ self._name, wrapped_func, jdt, self.evalType, self.deterministic)
return judf
def __call__(self, *cols):
@@ -134,6 +142,9 @@ def __call__(self, *cols):
sc = SparkContext._active_spark_context
return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
+ # This function is for improving the online help system in the interactive interpreter.
+ # For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and
+ # argument annotation. (See: SPARK-19161)
def _wrapped(self):
"""
Wrap this udf with a function and attach docstring from func
@@ -158,5 +169,16 @@ def wrapper(*args):
wrapper.func = self.func
wrapper.returnType = self.returnType
wrapper.evalType = self.evalType
-
+ wrapper.deterministic = self.deterministic
+ wrapper.asNondeterministic = functools.wraps(
+ self.asNondeterministic)(lambda: self.asNondeterministic()._wrapped())
return wrapper
+
+ def asNondeterministic(self):
+ """
+ Updates UserDefinedFunction to nondeterministic.
+
+ .. versionadded:: 2.3
+ """
+ self.deterministic = False
+ return self
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index 7bc6a59ad3b26..08c34c6dccc5e 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -110,3 +110,23 @@ def toJArray(gateway, jtype, arr):
for i in range(0, len(arr)):
jarr[i] = arr[i]
return jarr
+
+
+def require_minimum_pandas_version():
+ """ Raise ImportError if minimum version of Pandas is not installed
+ """
+ from distutils.version import LooseVersion
+ import pandas
+ if LooseVersion(pandas.__version__) < LooseVersion('0.19.2'):
+ raise ImportError("Pandas >= 0.19.2 must be installed on calling Python process; "
+ "however, your version was %s." % pandas.__version__)
+
+
+def require_minimum_pyarrow_version():
+ """ Raise ImportError if minimum version of pyarrow is not installed
+ """
+ from distutils.version import LooseVersion
+ import pyarrow
+ if LooseVersion(pyarrow.__version__) < LooseVersion('0.8.0'):
+ raise ImportError("pyarrow >= 0.8.0 must be installed on calling Python process; "
+ "however, your version was %s." % pyarrow.__version__)
diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py
index 5a975d050b0d8..5de448114ece8 100644
--- a/python/pyspark/streaming/flume.py
+++ b/python/pyspark/streaming/flume.py
@@ -20,6 +20,8 @@
from io import BytesIO
else:
from StringIO import StringIO
+import warnings
+
from py4j.protocol import Py4JJavaError
from pyspark.storagelevel import StorageLevel
diff --git a/python/pyspark/version.py b/python/pyspark/version.py
index 12dd53b9d2902..b9c2c4ced71d5 100644
--- a/python/pyspark/version.py
+++ b/python/pyspark/version.py
@@ -16,4 +16,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = "2.3.0.dev0"
+__version__ = "2.4.0.dev0"
diff --git a/python/setup.py b/python/setup.py
index 310670e697a83..251d4526d4dd0 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -201,7 +201,7 @@ def _supports_symlinks():
extras_require={
'ml': ['numpy>=1.7'],
'mllib': ['numpy>=1.7'],
- 'sql': ['pandas>=0.19.2']
+ 'sql': ['pandas>=0.19.2', 'pyarrow>=0.8.0']
},
classifiers=[
'Development Status :: 5 - Production/Stable',
@@ -210,6 +210,7 @@ def _supports_symlinks():
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
+ 'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: Implementation :: CPython',
'Programming Language :: Python :: Implementation :: PyPy']
)
diff --git a/repl/pom.xml b/repl/pom.xml
index 1cb0098d0eca3..6f4a863c48bc7 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.3.0-SNAPSHOT
+ 2.4.0-SNAPSHOT
../pom.xml
diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
new file mode 100644
index 0000000000000..724ce9af49f77
--- /dev/null
+++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.repl
+
+import scala.tools.nsc.interpreter.{ExprTyper, IR}
+
+trait SparkExprTyper extends ExprTyper {
+
+ import repl._
+ import global.{reporter => _, Import => _, _}
+ import naming.freshInternalVarName
+
+ def doInterpret(code: String): IR.Result = {
+ // interpret/interpretSynthetic may change the phase,
+ // which would have unintended effects on types.
+ val savedPhase = phase
+ try interpretSynthetic(code) finally phase = savedPhase
+ }
+
+ override def symbolOfLine(code: String): Symbol = {
+ def asExpr(): Symbol = {
+ val name = freshInternalVarName()
+ // Typing it with a lazy val would give us the right type, but runs
+ // into compiler bugs with things like existentials, so we compile it
+ // behind a def and strip the NullaryMethodType which wraps the expr.
+ val line = "def " + name + " = " + code
+
+ doInterpret(line) match {
+ case IR.Success =>
+ val sym0 = symbolOfTerm(name)
+ // drop NullaryMethodType
+ sym0.cloneSymbol setInfo exitingTyper(sym0.tpe_*.finalResultType)
+ case _ => NoSymbol
+ }
+ }
+
+ def asDefn(): Symbol = {
+ val old = repl.definedSymbolList.toSet
+
+ doInterpret(code) match {
+ case IR.Success =>
+ repl.definedSymbolList filterNot old match {
+ case Nil => NoSymbol
+ case sym :: Nil => sym
+ case syms => NoSymbol.newOverloaded(NoPrefix, syms)
+ }
+ case _ => NoSymbol
+ }
+ }
+
+ def asError(): Symbol = {
+ doInterpret(code)
+ NoSymbol
+ }
+
+ beSilentDuring(asExpr()) orElse beSilentDuring(asDefn()) orElse asError()
+ }
+
+}
diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
index 3ce7cc7c85f74..e69441a475e9a 100644
--- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
+++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
@@ -35,6 +35,10 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter)
def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out)
def this() = this(None, new JPrintWriter(Console.out, true))
+ override def createInterpreter(): Unit = {
+ intp = new SparkILoopInterpreter(settings, out)
+ }
+
val initializationCommands: Seq[String] = Seq(
"""
@transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) {
diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala
new file mode 100644
index 0000000000000..e736607a9a6b9
--- /dev/null
+++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.repl
+
+import scala.collection.mutable
+import scala.tools.nsc.Settings
+import scala.tools.nsc.interpreter._
+
+class SparkILoopInterpreter(settings: Settings, out: JPrintWriter) extends IMain(settings, out) {
+ self =>
+
+ override lazy val memberHandlers = new {
+ val intp: self.type = self
+ } with MemberHandlers {
+ import intp.global._
+
+ override def chooseHandler(member: intp.global.Tree): MemberHandler = member match {
+ case member: Import => new SparkImportHandler(member)
+ case _ => super.chooseHandler(member)
+ }
+
+ class SparkImportHandler(imp: Import) extends ImportHandler(imp: Import) {
+
+ override def targetType: Type = intp.global.rootMirror.getModuleIfDefined("" + expr) match {
+ case NoSymbol => intp.typeOfExpression("" + expr)
+ case sym => sym.tpe
+ }
+
+ private def safeIndexOf(name: Name, s: String): Int = fixIndexOf(name, pos(name, s))
+ private def fixIndexOf(name: Name, idx: Int): Int = if (idx == name.length) -1 else idx
+ private def pos(name: Name, s: String): Int = {
+ var i = name.pos(s.charAt(0), 0)
+ val sLen = s.length()
+ if (sLen == 1) return i
+ while (i + sLen <= name.length) {
+ var j = 1
+ while (s.charAt(j) == name.charAt(i + j)) {
+ j += 1
+ if (j == sLen) return i
+ }
+ i = name.pos(s.charAt(0), i + 1)
+ }
+ name.length
+ }
+
+ private def isFlattenedSymbol(sym: Symbol): Boolean =
+ sym.owner.isPackageClass &&
+ sym.name.containsName(nme.NAME_JOIN_STRING) &&
+ sym.owner.info.member(sym.name.take(
+ safeIndexOf(sym.name, nme.NAME_JOIN_STRING))) != NoSymbol
+
+ private def importableTargetMembers =
+ importableMembers(exitingTyper(targetType)).filterNot(isFlattenedSymbol).toList
+
+ def isIndividualImport(s: ImportSelector): Boolean =
+ s.name != nme.WILDCARD && s.rename != nme.WILDCARD
+ def isWildcardImport(s: ImportSelector): Boolean =
+ s.name == nme.WILDCARD
+
+ // non-wildcard imports
+ private def individualSelectors = selectors filter isIndividualImport
+
+ override val importsWildcard: Boolean = selectors exists isWildcardImport
+
+ lazy val importableSymbolsWithRenames: List[(Symbol, Name)] = {
+ val selectorRenameMap =
+ individualSelectors.flatMap(x => x.name.bothNames zip x.rename.bothNames).toMap
+ importableTargetMembers flatMap (m => selectorRenameMap.get(m.name) map (m -> _))
+ }
+
+ override lazy val individualSymbols: List[Symbol] = importableSymbolsWithRenames map (_._1)
+ override lazy val wildcardSymbols: List[Symbol] =
+ if (importsWildcard) importableTargetMembers else Nil
+
+ }
+
+ }
+
+ object expressionTyper extends {
+ val repl: SparkILoopInterpreter.this.type = self
+ } with SparkExprTyper { }
+
+ override def symbolOfLine(code: String): global.Symbol =
+ expressionTyper.symbolOfLine(code)
+
+ override def typeOfExpression(expr: String, silent: Boolean): global.Type =
+ expressionTyper.typeOfExpression(expr, silent)
+
+
+ import global.Name
+ override def importsCode(wanted: Set[Name], wrapper: Request#Wrapper,
+ definesClass: Boolean, generousImports: Boolean): ComputedImports = {
+
+ import global._
+ import definitions.{ ObjectClass, ScalaPackage, JavaLangPackage, PredefModule }
+ import memberHandlers._
+
+ val header, code, trailingBraces, accessPath = new StringBuilder
+ val currentImps = mutable.HashSet[Name]()
+ // only emit predef import header if name not resolved in history, loosely
+ var predefEscapes = false
+
+ /**
+ * Narrow down the list of requests from which imports
+ * should be taken. Removes requests which cannot contribute
+ * useful imports for the specified set of wanted names.
+ */
+ case class ReqAndHandler(req: Request, handler: MemberHandler)
+
+ def reqsToUse: List[ReqAndHandler] = {
+ /**
+ * Loop through a list of MemberHandlers and select which ones to keep.
+ * 'wanted' is the set of names that need to be imported.
+ */
+ def select(reqs: List[ReqAndHandler], wanted: Set[Name]): List[ReqAndHandler] = {
+ // Single symbol imports might be implicits! See bug #1752. Rather than
+ // try to finesse this, we will mimic all imports for now.
+ def keepHandler(handler: MemberHandler) = handler match {
+ // While defining classes in class based mode - implicits are not needed.
+ case h: ImportHandler if isClassBased && definesClass =>
+ h.importedNames.exists(x => wanted.contains(x))
+ case _: ImportHandler => true
+ case x if generousImports => x.definesImplicit ||
+ (x.definedNames exists (d => wanted.exists(w => d.startsWith(w))))
+ case x => x.definesImplicit ||
+ (x.definedNames exists wanted)
+ }
+
+ reqs match {
+ case Nil =>
+ predefEscapes = wanted contains PredefModule.name ; Nil
+ case rh :: rest if !keepHandler(rh.handler) => select(rest, wanted)
+ case rh :: rest =>
+ import rh.handler._
+ val augment = rh match {
+ case ReqAndHandler(_, _: ImportHandler) => referencedNames
+ case _ => Nil
+ }
+ val newWanted = wanted ++ augment -- definedNames -- importedNames
+ rh :: select(rest, newWanted)
+ }
+ }
+
+ /** Flatten the handlers out and pair each with the original request */
+ select(allReqAndHandlers reverseMap { case (r, h) => ReqAndHandler(r, h) }, wanted).reverse
+ }
+
+ // add code for a new object to hold some imports
+ def addWrapper() {
+ import nme.{ INTERPRETER_IMPORT_WRAPPER => iw }
+ code append (wrapper.prewrap format iw)
+ trailingBraces append wrapper.postwrap
+ accessPath append s".$iw"
+ currentImps.clear()
+ }
+
+ def maybeWrap(names: Name*) = if (names exists currentImps) addWrapper()
+
+ def wrapBeforeAndAfter[T](op: => T): T = {
+ addWrapper()
+ try op finally addWrapper()
+ }
+
+ // imports from Predef are relocated to the template header to allow hiding.
+ def checkHeader(h: ImportHandler) = h.referencedNames contains PredefModule.name
+
+ // loop through previous requests, adding imports for each one
+ wrapBeforeAndAfter {
+ // Reusing a single temporary value when import from a line with multiple definitions.
+ val tempValLines = mutable.Set[Int]()
+ for (ReqAndHandler(req, handler) <- reqsToUse) {
+ val objName = req.lineRep.readPathInstance
+ handler match {
+ case h: ImportHandler if checkHeader(h) =>
+ header.clear()
+ header append f"${h.member}%n"
+ // If the user entered an import, then just use it; add an import wrapping
+ // level if the import might conflict with some other import
+ case x: ImportHandler if x.importsWildcard =>
+ wrapBeforeAndAfter(code append (x.member + "\n"))
+ case x: ImportHandler =>
+ maybeWrap(x.importedNames: _*)
+ code append (x.member + "\n")
+ currentImps ++= x.importedNames
+
+ case x if isClassBased =>
+ for (sym <- x.definedSymbols) {
+ maybeWrap(sym.name)
+ x match {
+ case _: ClassHandler =>
+ code.append(s"import ${objName}${req.accessPath}.`${sym.name}`\n")
+ case _ =>
+ val valName = s"${req.lineRep.packageName}${req.lineRep.readName}"
+ if (!tempValLines.contains(req.lineRep.lineId)) {
+ code.append(s"val $valName: ${objName}.type = $objName\n")
+ tempValLines += req.lineRep.lineId
+ }
+ code.append(s"import ${valName}${req.accessPath}.`${sym.name}`\n")
+ }
+ currentImps += sym.name
+ }
+ // For other requests, import each defined name.
+ // import them explicitly instead of with _, so that
+ // ambiguity errors will not be generated. Also, quote
+ // the name of the variable, so that we don't need to
+ // handle quoting keywords separately.
+ case x =>
+ for (sym <- x.definedSymbols) {
+ maybeWrap(sym.name)
+ code append s"import ${x.path}\n"
+ currentImps += sym.name
+ }
+ }
+ }
+ }
+
+ val computedHeader = if (predefEscapes) header.toString else ""
+ ComputedImports(computedHeader, code.toString, trailingBraces.toString, accessPath.toString)
+ }
+
+ private def allReqAndHandlers =
+ prevRequestList flatMap (req => req.handlers map (req -> _))
+
+}
diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 905b41cdc1594..cdd5cdd841740 100644
--- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -227,4 +227,35 @@ class ReplSuite extends SparkFunSuite {
assertDoesNotContain("error: not found: value sc", output)
}
+ test("spark-shell should find imported types in class constructors and extends clause") {
+ val output = runInterpreter("local",
+ """
+ |import org.apache.spark.Partition
+ |class P(p: Partition)
+ |class P(val index: Int) extends Partition
+ """.stripMargin)
+ assertDoesNotContain("error: not found: type Partition", output)
+ }
+
+ test("spark-shell should shadow val/def definitions correctly") {
+ val output1 = runInterpreter("local",
+ """
+ |def myMethod() = "first definition"
+ |val tmp = myMethod(); val out = tmp
+ |def myMethod() = "second definition"
+ |val tmp = myMethod(); val out = s"$tmp aabbcc"
+ """.stripMargin)
+ assertContains("second definition aabbcc", output1)
+
+ val output2 = runInterpreter("local",
+ """
+ |val a = 1
+ |val b = a; val c = b;
+ |val a = 2
+ |val b = a; val c = b;
+ |s"!!$b!!"
+ """.stripMargin)
+ assertContains("!!2!!", output2)
+ }
+
}
diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml
index a4b18c527c969..2240d0e84eb5c 100644
--- a/resource-managers/kubernetes/core/pom.xml
+++ b/resource-managers/kubernetes/core/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.11
- 2.3.0-SNAPSHOT
+ 2.4.0-SNAPSHOT
../../../pom.xml
@@ -29,7 +29,7 @@
Spark Project Kubernetes
kubernetes
- 2.2.13
+ 3.0.0
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala
index c645b008d736d..fdc6e0a61f73a 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala
@@ -109,7 +109,7 @@ private[k8s] class LoggingPodStatusWatcherImpl(
("namespace", pod.getMetadata.getNamespace()),
("labels", pod.getMetadata.getLabels().asScala.mkString(", ")),
("pod uid", pod.getMetadata.getUid),
- ("creation time", pod.getMetadata.getCreationTimestamp()),
+ ("creation time", pod.getMetadata.getCreationTimestamp.getTime),
// spec details
("service account name", pod.getSpec.getServiceAccountName()),
@@ -117,7 +117,7 @@ private[k8s] class LoggingPodStatusWatcherImpl(
("node name", pod.getSpec.getNodeName()),
// status
- ("start time", pod.getStatus.getStartTime),
+ ("start time", pod.getStatus.getStartTime.getTime),
("container images",
pod.getStatus.getContainerStatuses()
.asScala
@@ -162,7 +162,7 @@ private[k8s] class LoggingPodStatusWatcherImpl(
case running: ContainerStateRunning =>
Seq(
("Container state", "Running"),
- ("Container started at", running.getStartedAt))
+ ("Container started at", running.getStartedAt.getTime))
case waiting: ContainerStateWaiting =>
Seq(
("Container state", "Waiting"),
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala
index 1e97ece88de1e..6c70c6509c208 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala
@@ -487,4 +487,3 @@ private object KubernetesClusterSchedulerBackend {
" Consider boosting spark executor memory overhead."
}
}
-
diff --git a/resource-managers/kubernetes/docker-minimal-bundle/pom.xml b/resource-managers/kubernetes/docker-minimal-bundle/pom.xml
index 85827a17a963f..80951d691c644 100644
--- a/resource-managers/kubernetes/docker-minimal-bundle/pom.xml
+++ b/resource-managers/kubernetes/docker-minimal-bundle/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.3.0-SNAPSHOT
+ 2.4.0-SNAPSHOT
../../../pom.xml
diff --git a/resource-managers/kubernetes/integration-tests-spark-jobs-helpers/pom.xml b/resource-managers/kubernetes/integration-tests-spark-jobs-helpers/pom.xml
index c4fe21604fb44..91a7482546419 100644
--- a/resource-managers/kubernetes/integration-tests-spark-jobs-helpers/pom.xml
+++ b/resource-managers/kubernetes/integration-tests-spark-jobs-helpers/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.11
- 2.3.0-SNAPSHOT
+ 2.4.0-SNAPSHOT
../../../pom.xml
diff --git a/resource-managers/kubernetes/integration-tests-spark-jobs/pom.xml b/resource-managers/kubernetes/integration-tests-spark-jobs/pom.xml
index db5ec86e0dfd8..eab95c5806eb8 100644
--- a/resource-managers/kubernetes/integration-tests-spark-jobs/pom.xml
+++ b/resource-managers/kubernetes/integration-tests-spark-jobs/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.11
- 2.3.0-SNAPSHOT
+ 2.4.0-SNAPSHOT
../../../pom.xml
diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml
index 2f1587e0d6bb4..58b9dd00e5294 100644
--- a/resource-managers/kubernetes/integration-tests/pom.xml
+++ b/resource-managers/kubernetes/integration-tests/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.11
- 2.3.0-SNAPSHOT
+ 2.4.0-SNAPSHOT
../../../pom.xml
diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml
index de8f1c913651d..3995d0afeb5f4 100644
--- a/resource-managers/mesos/pom.xml
+++ b/resource-managers/mesos/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.11
- 2.3.0-SNAPSHOT
+ 2.4.0-SNAPSHOT
../../pom.xml
@@ -29,7 +29,7 @@
Spark Project Mesos
mesos
- 1.3.0
+ 1.4.0
shaded-protobuf
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala
index ff60b88c6d533..68f6921153d89 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala
@@ -77,10 +77,17 @@ private[mesos] class MesosSubmitRequestServlet(
private def buildDriverDescription(request: CreateSubmissionRequest): MesosDriverDescription = {
// Required fields, including the main class because python is not yet supported
val appResource = Option(request.appResource).getOrElse {
- throw new SubmitRestMissingFieldException("Application jar is missing.")
+ throw new SubmitRestMissingFieldException("Application jar 'appResource' is missing.")
}
val mainClass = Option(request.mainClass).getOrElse {
- throw new SubmitRestMissingFieldException("Main class is missing.")
+ throw new SubmitRestMissingFieldException("Main class 'mainClass' is missing.")
+ }
+ val appArgs = Option(request.appArgs).getOrElse {
+ throw new SubmitRestMissingFieldException("Application arguments 'appArgs' are missing.")
+ }
+ val environmentVariables = Option(request.environmentVariables).getOrElse {
+ throw new SubmitRestMissingFieldException("Environment variables 'environmentVariables' " +
+ "are missing.")
}
// Optional fields
@@ -91,8 +98,6 @@ private[mesos] class MesosSubmitRequestServlet(
val superviseDriver = sparkProperties.get("spark.driver.supervise")
val driverMemory = sparkProperties.get("spark.driver.memory")
val driverCores = sparkProperties.get("spark.driver.cores")
- val appArgs = request.appArgs
- val environmentVariables = request.environmentVariables
val name = request.sparkProperties.getOrElse("spark.app.name", mainClass)
// Construct driver description
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
index c41283e4a3e39..d224a7325820a 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
@@ -36,7 +36,6 @@ import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionRes
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.Utils
-
/**
* Tracks the current state of a Mesos Task that runs a Spark driver.
* @param driverDescription Submitted driver description from
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
index c392061fdb358..53f5f61cca486 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
@@ -92,6 +92,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend(
private[this] var stopCalled: Boolean = false
private val launcherBackend = new LauncherBackend() {
+ override protected def conf: SparkConf = sc.conf
+
override protected def onStopRequest(): Unit = {
stopSchedulerBackend()
setState(SparkAppHandle.State.KILLED)
@@ -400,13 +402,20 @@ private[spark] class MesosCoarseGrainedSchedulerBackend(
val offerMem = getResource(offer.getResourcesList, "mem")
val offerCpus = getResource(offer.getResourcesList, "cpus")
val offerPorts = getRangeResource(offer.getResourcesList, "ports")
+ val offerReservationInfo = offer
+ .getResourcesList
+ .asScala
+ .find { r => r.getReservation != null }
val id = offer.getId.getValue
if (tasks.contains(offer.getId)) { // accept
val offerTasks = tasks(offer.getId)
logDebug(s"Accepting offer: $id with attributes: $offerAttributes " +
- s"mem: $offerMem cpu: $offerCpus ports: $offerPorts." +
+ offerReservationInfo.map(resInfo =>
+ s"reservation info: ${resInfo.getReservation.toString}").getOrElse("") +
+ s"mem: $offerMem cpu: $offerCpus ports: $offerPorts " +
+ s"resources: ${offer.getResourcesList.asScala.mkString(",")}." +
s" Launching ${offerTasks.size} Mesos tasks.")
for (task <- offerTasks) {
@@ -416,7 +425,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend(
val ports = getRangeResource(task.getResourcesList, "ports").mkString(",")
logDebug(s"Launching Mesos task: ${taskId.getValue} with mem: $mem cpu: $cpus" +
- s" ports: $ports")
+ s" ports: $ports" + s" on slave with slave id: ${task.getSlaveId.getValue} ")
}
driver.launchTasks(
@@ -431,7 +440,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend(
} else {
declineOffer(
driver,
- offer)
+ offer,
+ Some("Offer was declined due to unmet task launch constraints."))
}
}
}
@@ -513,6 +523,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend(
totalGpusAcquired += taskGPUs
gpusByTaskId(taskId) = taskGPUs
}
+ } else {
+ logDebug(s"Cannot launch a task for offer with id: $offerId on slave " +
+ s"with id: $slaveId. Requirements were not met for this offer.")
}
}
}
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
index 6fcb30af8a733..e75450369ad85 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
@@ -28,7 +28,8 @@ import com.google.common.base.Splitter
import org.apache.mesos.{MesosSchedulerDriver, Protos, Scheduler, SchedulerDriver}
import org.apache.mesos.Protos.{TaskState => MesosTaskState, _}
import org.apache.mesos.Protos.FrameworkInfo.Capability
-import org.apache.mesos.protobuf.{ByteString, GeneratedMessage}
+import org.apache.mesos.Protos.Resource.ReservationInfo
+import org.apache.mesos.protobuf.{ByteString, GeneratedMessageV3}
import org.apache.spark.{SparkConf, SparkContext, SparkException}
import org.apache.spark.TaskState
@@ -36,8 +37,6 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.util.Utils
-
-
/**
* Shared trait for implementing a Mesos Scheduler. This holds common state and helper
* methods and Mesos scheduler will use.
@@ -46,6 +45,8 @@ trait MesosSchedulerUtils extends Logging {
// Lock used to wait for scheduler to be registered
private final val registerLatch = new CountDownLatch(1)
+ private final val ANY_ROLE = "*"
+
/**
* Creates a new MesosSchedulerDriver that communicates to the Mesos master.
*
@@ -175,17 +176,36 @@ trait MesosSchedulerUtils extends Logging {
registerLatch.countDown()
}
- def createResource(name: String, amount: Double, role: Option[String] = None): Resource = {
+ private def setReservationInfo(
+ reservationInfo: Option[ReservationInfo],
+ role: Option[String],
+ builder: Resource.Builder): Unit = {
+ if (!role.contains(ANY_ROLE)) {
+ reservationInfo.foreach { res => builder.setReservation(res) }
+ }
+ }
+
+ def createResource(
+ name: String,
+ amount: Double,
+ role: Option[String] = None,
+ reservationInfo: Option[ReservationInfo] = None): Resource = {
val builder = Resource.newBuilder()
.setName(name)
.setType(Value.Type.SCALAR)
.setScalar(Value.Scalar.newBuilder().setValue(amount).build())
-
role.foreach { r => builder.setRole(r) }
-
+ setReservationInfo(reservationInfo, role, builder)
builder.build()
}
+ private def getReservation(resource: Resource): Option[ReservationInfo] = {
+ if (resource.hasReservation) {
+ Some(resource.getReservation)
+ } else {
+ None
+ }
+ }
/**
* Partition the existing set of resources into two groups, those remaining to be
* scheduled and those requested to be used for a new task.
@@ -203,14 +223,17 @@ trait MesosSchedulerUtils extends Logging {
var requestedResources = new ArrayBuffer[Resource]
val remainingResources = resources.asScala.map {
case r =>
+ val reservation = getReservation(r)
if (remain > 0 &&
r.getType == Value.Type.SCALAR &&
r.getScalar.getValue > 0.0 &&
r.getName == resourceName) {
val usage = Math.min(remain, r.getScalar.getValue)
- requestedResources += createResource(resourceName, usage, Some(r.getRole))
+ requestedResources += createResource(resourceName, usage,
+ Option(r.getRole), reservation)
remain -= usage
- createResource(resourceName, r.getScalar.getValue - usage, Some(r.getRole))
+ createResource(resourceName, r.getScalar.getValue - usage,
+ Option(r.getRole), reservation)
} else {
r
}
@@ -228,16 +251,6 @@ trait MesosSchedulerUtils extends Logging {
(attr.getName, attr.getText.getValue.split(',').toSet)
}
-
- /** Build a Mesos resource protobuf object */
- protected def createResource(resourceName: String, quantity: Double): Protos.Resource = {
- Resource.newBuilder()
- .setName(resourceName)
- .setType(Value.Type.SCALAR)
- .setScalar(Value.Scalar.newBuilder().setValue(quantity).build())
- .build()
- }
-
/**
* Converts the attributes from the resource offer into a Map of name to Attribute Value
* The attribute values are the mesos attribute types and they are
@@ -245,7 +258,8 @@ trait MesosSchedulerUtils extends Logging {
* @param offerAttributes the attributes offered
* @return
*/
- protected def toAttributeMap(offerAttributes: JList[Attribute]): Map[String, GeneratedMessage] = {
+ protected def toAttributeMap(offerAttributes: JList[Attribute])
+ : Map[String, GeneratedMessageV3] = {
offerAttributes.asScala.map { attr =>
val attrValue = attr.getType match {
case Value.Type.SCALAR => attr.getScalar
@@ -266,7 +280,7 @@ trait MesosSchedulerUtils extends Logging {
*/
def matchesAttributeRequirements(
slaveOfferConstraints: Map[String, Set[String]],
- offerAttributes: Map[String, GeneratedMessage]): Boolean = {
+ offerAttributes: Map[String, GeneratedMessageV3]): Boolean = {
slaveOfferConstraints.forall {
// offer has the required attribute and subsumes the required values for that attribute
case (name, requiredValues) =>
@@ -427,10 +441,10 @@ trait MesosSchedulerUtils extends Logging {
// partition port offers
val (resourcesWithoutPorts, portResources) = filterPortResources(offeredResources)
- val portsAndRoles = requestedPorts.
- map(x => (x, findPortAndGetAssignedRangeRole(x, portResources)))
+ val portsAndResourceInfo = requestedPorts.
+ map { x => (x, findPortAndGetAssignedResourceInfo(x, portResources)) }
- val assignedPortResources = createResourcesFromPorts(portsAndRoles)
+ val assignedPortResources = createResourcesFromPorts(portsAndResourceInfo)
// ignore non-assigned port resources, they will be declined implicitly by mesos
// no need for splitting port resources.
@@ -450,16 +464,25 @@ trait MesosSchedulerUtils extends Logging {
managedPortNames.map(conf.getLong(_, 0)).filter( _ != 0)
}
+ private case class RoleResourceInfo(
+ role: String,
+ resInfo: Option[ReservationInfo])
+
/** Creates a mesos resource for a specific port number. */
- private def createResourcesFromPorts(portsAndRoles: List[(Long, String)]) : List[Resource] = {
- portsAndRoles.flatMap{ case (port, role) =>
- createMesosPortResource(List((port, port)), Some(role))}
+ private def createResourcesFromPorts(
+ portsAndResourcesInfo: List[(Long, RoleResourceInfo)])
+ : List[Resource] = {
+ portsAndResourcesInfo.flatMap { case (port, rInfo) =>
+ createMesosPortResource(List((port, port)), Option(rInfo.role), rInfo.resInfo)}
}
/** Helper to create mesos resources for specific port ranges. */
private def createMesosPortResource(
ranges: List[(Long, Long)],
- role: Option[String] = None): List[Resource] = {
+ role: Option[String] = None,
+ reservationInfo: Option[ReservationInfo] = None): List[Resource] = {
+ // for ranges we are going to use (user defined ports fall in there) create mesos resources
+ // for each range there is a role associated with it.
ranges.map { case (rangeStart, rangeEnd) =>
val rangeValue = Value.Range.newBuilder()
.setBegin(rangeStart)
@@ -468,7 +491,8 @@ trait MesosSchedulerUtils extends Logging {
.setName("ports")
.setType(Value.Type.RANGES)
.setRanges(Value.Ranges.newBuilder().addRange(rangeValue))
- role.foreach(r => builder.setRole(r))
+ role.foreach { r => builder.setRole(r) }
+ setReservationInfo(reservationInfo, role, builder)
builder.build()
}
}
@@ -477,19 +501,21 @@ trait MesosSchedulerUtils extends Logging {
* Helper to assign a port to an offered range and get the latter's role
* info to use it later on.
*/
- private def findPortAndGetAssignedRangeRole(port: Long, portResources: List[Resource])
- : String = {
+ private def findPortAndGetAssignedResourceInfo(port: Long, portResources: List[Resource])
+ : RoleResourceInfo = {
val ranges = portResources.
- map(resource =>
- (resource.getRole, resource.getRanges.getRangeList.asScala
- .map(r => (r.getBegin, r.getEnd)).toList))
+ map { resource =>
+ val reservation = getReservation(resource)
+ (RoleResourceInfo(resource.getRole, reservation),
+ resource.getRanges.getRangeList.asScala.map(r => (r.getBegin, r.getEnd)).toList)
+ }
- val rangePortRole = ranges
- .find { case (role, rangeList) => rangeList
+ val rangePortResourceInfo = ranges
+ .find { case (resourceInfo, rangeList) => rangeList
.exists{ case (rangeStart, rangeEnd) => rangeStart <= port & rangeEnd >= port}}
// this is safe since we have previously checked about the ranges (see checkPorts method)
- rangePortRole.map{ case (role, rangeList) => role}.get
+ rangePortResourceInfo.map{ case (resourceInfo, rangeList) => resourceInfo}.get
}
/** Retrieves the port resources from a list of mesos offered resources */
@@ -564,3 +590,4 @@ trait MesosSchedulerUtils extends Logging {
}
}
}
+
diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml
index 43a7ce95bd3de..37e25ceecb883 100644
--- a/resource-managers/yarn/pom.xml
+++ b/resource-managers/yarn/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.11
- 2.3.0-SNAPSHOT
+ 2.4.0-SNAPSHOT
../../pom.xml
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index ca0aa0ea3bc73..4d5e3bb043671 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -56,11 +56,28 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
// TODO: Currently, task to container is computed once (TaskSetManager) - which need not be
// optimal as more containers are available. Might need to handle this better.
- private val sparkConf = new SparkConf()
- private val yarnConf: YarnConfiguration = SparkHadoopUtil.get.newConfiguration(sparkConf)
- .asInstanceOf[YarnConfiguration]
private val isClusterMode = args.userClass != null
+ private val sparkConf = new SparkConf()
+ if (args.propertiesFile != null) {
+ Utils.getPropertiesFromFile(args.propertiesFile).foreach { case (k, v) =>
+ sparkConf.set(k, v)
+ }
+ }
+
+ private val securityMgr = new SecurityManager(sparkConf)
+
+ // Set system properties for each config entry. This covers two use cases:
+ // - The default configuration stored by the SparkHadoopUtil class
+ // - The user application creating a new SparkConf in cluster mode
+ //
+ // Both cases create a new SparkConf object which reads these configs from system properties.
+ sparkConf.getAll.foreach { case (k, v) =>
+ sys.props(k) = v
+ }
+
+ private val yarnConf = new YarnConfiguration(SparkHadoopUtil.newConfiguration(sparkConf))
+
private val ugi = {
val original = UserGroupInformation.getCurrentUser()
@@ -311,7 +328,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
val credentialManager = new YARNHadoopDelegationTokenManager(
sparkConf,
yarnConf,
- conf => YarnSparkHadoopUtil.get.hadoopFSsToAccess(sparkConf, conf))
+ conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf))
val credentialRenewer =
new AMCredentialRenewer(sparkConf, yarnConf, credentialManager)
@@ -323,13 +340,10 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
credentialRenewerThread.join()
}
- // Call this to force generation of secret so it gets populated into the Hadoop UGI.
- val securityMgr = new SecurityManager(sparkConf)
-
if (isClusterMode) {
- runDriver(securityMgr)
+ runDriver()
} else {
- runExecutorLauncher(securityMgr)
+ runExecutorLauncher()
}
} catch {
case e: Exception =>
@@ -410,15 +424,11 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
_sparkConf: SparkConf,
_rpcEnv: RpcEnv,
driverRef: RpcEndpointRef,
- uiAddress: Option[String],
- securityMgr: SecurityManager) = {
+ uiAddress: Option[String]) = {
val appId = client.getAttemptId().getApplicationId().toString()
val attemptId = client.getAttemptId().getAttemptId().toString()
- val historyAddress =
- _sparkConf.get(HISTORY_SERVER_ADDRESS)
- .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) }
- .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" }
- .getOrElse("")
+ val historyAddress = ApplicationMaster
+ .getHistoryServerAddress(_sparkConf, yarnConf, appId, attemptId)
val driverUrl = RpcEndpointAddress(
_sparkConf.get("spark.driver.host"),
@@ -463,7 +473,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
YarnSchedulerBackend.ENDPOINT_NAME)
}
- private def runDriver(securityMgr: SecurityManager): Unit = {
+ private def runDriver(): Unit = {
addAmIpFilter(None)
userClassThread = startUserApplication()
@@ -479,7 +489,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
val driverRef = createSchedulerRef(
sc.getConf.get("spark.driver.host"),
sc.getConf.get("spark.driver.port"))
- registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl), securityMgr)
+ registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl))
registered = true
} else {
// Sanity check; should never happen in normal operation, since sc should only be null
@@ -498,15 +508,14 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
}
}
- private def runExecutorLauncher(securityMgr: SecurityManager): Unit = {
+ private def runExecutorLauncher(): Unit = {
val hostname = Utils.localHostName
val amCores = sparkConf.get(AM_CORES)
rpcEnv = RpcEnv.create("sparkYarnAM", hostname, hostname, -1, sparkConf, securityMgr,
amCores, true)
val driverRef = waitForSparkDriver()
addAmIpFilter(Some(driverRef))
- registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress"),
- securityMgr)
+ registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress"))
registered = true
// In client mode the actor will stop the reporter thread.
@@ -686,6 +695,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) {
// TODO(davies): add R dependencies here
}
+
val mainMethod = userClassLoader.loadClass(args.userClass)
.getMethod("main", classOf[Array[String]])
@@ -809,15 +819,6 @@ object ApplicationMaster extends Logging {
def main(args: Array[String]): Unit = {
SignalUtils.registerLogger(log)
val amArgs = new ApplicationMasterArguments(args)
-
- // Load the properties file with the Spark configuration and set entries as system properties,
- // so that user code run inside the AM also has access to them.
- // Note: we must do this before SparkHadoopUtil instantiated
- if (amArgs.propertiesFile != null) {
- Utils.getPropertiesFromFile(amArgs.propertiesFile).foreach { case (k, v) =>
- sys.props(k) = v
- }
- }
master = new ApplicationMaster(amArgs)
System.exit(master.run())
}
@@ -830,6 +831,16 @@ object ApplicationMaster extends Logging {
master.getAttemptId
}
+ private[spark] def getHistoryServerAddress(
+ sparkConf: SparkConf,
+ yarnConf: YarnConfiguration,
+ appId: String,
+ attemptId: String): String = {
+ sparkConf.get(HISTORY_SERVER_ADDRESS)
+ .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) }
+ .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" }
+ .getOrElse("")
+ }
}
/**
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 99e7d46ca5c96..8cd3cd9746a3a 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -48,7 +48,7 @@ import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException
import org.apache.hadoop.yarn.util.Records
import org.apache.spark.{SecurityManager, SparkConf, SparkException}
-import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.deploy.{SparkApplication, SparkHadoopUtil}
import org.apache.spark.deploy.yarn.config._
import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager
import org.apache.spark.internal.Logging
@@ -58,18 +58,14 @@ import org.apache.spark.util.{CallerContext, Utils}
private[spark] class Client(
val args: ClientArguments,
- val hadoopConf: Configuration,
val sparkConf: SparkConf)
extends Logging {
import Client._
import YarnSparkHadoopUtil._
- def this(clientArgs: ClientArguments, spConf: SparkConf) =
- this(clientArgs, SparkHadoopUtil.get.newConfiguration(spConf), spConf)
-
private val yarnClient = YarnClient.createYarnClient
- private val yarnConf = new YarnConfiguration(hadoopConf)
+ private val hadoopConf = new YarnConfiguration(SparkHadoopUtil.newConfiguration(sparkConf))
private val isClusterMode = sparkConf.get("spark.submit.deployMode", "client") == "cluster"
@@ -104,6 +100,8 @@ private[spark] class Client(
private var amKeytabFileName: String = null
private val launcherBackend = new LauncherBackend() {
+ override protected def conf: SparkConf = sparkConf
+
override def onStopRequest(): Unit = {
if (isClusterMode && appId != null) {
yarnClient.killApplication(appId)
@@ -125,7 +123,7 @@ private[spark] class Client(
private val credentialManager = new YARNHadoopDelegationTokenManager(
sparkConf,
hadoopConf,
- conf => YarnSparkHadoopUtil.get.hadoopFSsToAccess(sparkConf, conf))
+ conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf))
def reportLauncherState(state: SparkAppHandle.State): Unit = {
launcherBackend.setState(state)
@@ -134,8 +132,6 @@ private[spark] class Client(
def stop(): Unit = {
launcherBackend.close()
yarnClient.stop()
- // Unset YARN mode system env variable, to allow switching between cluster types.
- System.clearProperty("SPARK_YARN_MODE")
}
/**
@@ -152,7 +148,7 @@ private[spark] class Client(
// Setup the credentials before doing anything else,
// so we have don't have issues at any point.
setupCredentials()
- yarnClient.init(yarnConf)
+ yarnClient.init(hadoopConf)
yarnClient.start()
logInfo("Requesting a new application from cluster with %d NodeManagers"
@@ -398,7 +394,7 @@ private[spark] class Client(
if (SparkHadoopUtil.get.isProxyUser(currentUser)) {
currentUser.addCredentials(credentials)
}
- logDebug(YarnSparkHadoopUtil.get.dumpTokens(credentials).mkString("\n"))
+ logDebug(SparkHadoopUtil.get.dumpTokens(credentials).mkString("\n"))
}
// If we use principal and keytab to login, also credentials can be renewed some time
@@ -758,12 +754,14 @@ private[spark] class Client(
// Save the YARN configuration into a separate file that will be overlayed on top of the
// cluster's Hadoop conf.
confStream.putNextEntry(new ZipEntry(SPARK_HADOOP_CONF_FILE))
- yarnConf.writeXml(confStream)
+ hadoopConf.writeXml(confStream)
confStream.closeEntry()
- // Save Spark configuration to a file in the archive.
+ // Save Spark configuration to a file in the archive, but filter out the app's secret.
val props = new Properties()
- sparkConf.getAll.foreach { case (k, v) => props.setProperty(k, v) }
+ sparkConf.getAll.foreach { case (k, v) =>
+ props.setProperty(k, v)
+ }
// Override spark.yarn.key to point to the location in distributed cache which will be used
// by AM.
Option(amKeytabFileName).foreach { k => props.setProperty(KEYTAB.key, k) }
@@ -786,8 +784,7 @@ private[spark] class Client(
pySparkArchives: Seq[String]): HashMap[String, String] = {
logInfo("Setting up the launch environment for our AM container")
val env = new HashMap[String, String]()
- populateClasspath(args, yarnConf, sparkConf, env, sparkConf.get(DRIVER_CLASS_PATH))
- env("SPARK_YARN_MODE") = "true"
+ populateClasspath(args, hadoopConf, sparkConf, env, sparkConf.get(DRIVER_CLASS_PATH))
env("SPARK_YARN_STAGING_DIR") = stagingDirPath.toString
env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName()
if (loginFromKeytab) {
@@ -861,6 +858,7 @@ private[spark] class Client(
} else {
Nil
}
+
val launchEnv = setupLaunchEnv(appStagingDirPath, pySparkArchives)
val localResources = prepareLocalResources(appStagingDirPath, pySparkArchives)
@@ -991,7 +989,11 @@ private[spark] class Client(
logDebug("YARN AM launch context:")
logDebug(s" user class: ${Option(args.userClass).getOrElse("N/A")}")
logDebug(" env:")
- launchEnv.foreach { case (k, v) => logDebug(s" $k -> $v") }
+ if (log.isDebugEnabled) {
+ Utils.redact(sparkConf, launchEnv.toSeq).foreach { case (k, v) =>
+ logDebug(s" $k -> $v")
+ }
+ }
logDebug(" resources:")
localResources.foreach { case (k, v) => logDebug(s" $k -> $v")}
logDebug(" command:")
@@ -1185,24 +1187,6 @@ private[spark] class Client(
private object Client extends Logging {
- def main(argStrings: Array[String]) {
- if (!sys.props.contains("SPARK_SUBMIT")) {
- logWarning("WARNING: This client is deprecated and will be removed in a " +
- "future version of Spark. Use ./bin/spark-submit with \"--master yarn\"")
- }
-
- // Set an env variable indicating we are running in YARN mode.
- // Note that any env variable with the SPARK_ prefix gets propagated to all (remote) processes
- System.setProperty("SPARK_YARN_MODE", "true")
- val sparkConf = new SparkConf
- // SparkSubmit would use yarn cache to distribute files & jars in yarn mode,
- // so remove them from sparkConf here for yarn mode.
- sparkConf.remove("spark.jars")
- sparkConf.remove("spark.files")
- val args = new ClientArguments(argStrings)
- new Client(args, sparkConf).run()
- }
-
// Alias for the user jar
val APP_JAR_NAME: String = "__app__.jar"
@@ -1437,15 +1421,20 @@ private object Client extends Logging {
}
/**
- * Return whether the two file systems are the same.
+ * Return whether two URI represent file system are the same
*/
- private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = {
- val srcUri = srcFs.getUri()
- val dstUri = destFs.getUri()
+ private[spark] def compareUri(srcUri: URI, dstUri: URI): Boolean = {
+
if (srcUri.getScheme() == null || srcUri.getScheme() != dstUri.getScheme()) {
return false
}
+ val srcAuthority = srcUri.getAuthority()
+ val dstAuthority = dstUri.getAuthority()
+ if (srcAuthority != null && !srcAuthority.equalsIgnoreCase(dstAuthority)) {
+ return false
+ }
+
var srcHost = srcUri.getHost()
var dstHost = dstUri.getHost()
@@ -1463,6 +1452,17 @@ private object Client extends Logging {
}
Objects.equal(srcHost, dstHost) && srcUri.getPort() == dstUri.getPort()
+
+ }
+
+ /**
+ * Return whether the two file systems are the same.
+ */
+ protected def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = {
+ val srcUri = srcFs.getUri()
+ val dstUri = destFs.getUri()
+
+ compareUri(srcUri, dstUri)
}
/**
@@ -1506,3 +1506,16 @@ private object Client extends Logging {
}
}
+
+private[spark] class YarnClusterApplication extends SparkApplication {
+
+ override def start(args: Array[String], conf: SparkConf): Unit = {
+ // SparkSubmit would use yarn cache to distribute files & jars in yarn mode,
+ // so remove them from sparkConf here for yarn mode.
+ conf.remove("spark.jars")
+ conf.remove("spark.files")
+
+ new Client(new ClientArguments(args), conf).run()
+ }
+
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
index 7052fb347106b..506adb363aa90 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
@@ -41,6 +41,7 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef}
import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason}
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RetrieveLastAllocatedExecutorId
+import org.apache.spark.scheduler.cluster.SchedulerBackendUtils
import org.apache.spark.util.{Clock, SystemClock, ThreadUtils}
/**
@@ -109,7 +110,7 @@ private[yarn] class YarnAllocator(
sparkConf.get(EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS).getOrElse(-1L)
@volatile private var targetNumExecutors =
- YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sparkConf)
+ SchedulerBackendUtils.getInitialTargetExecutorNumber(sparkConf)
private var currentNodeBlacklist = Set.empty[String]
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
index 72f4d273ab53b..c1ae12aabb8cc 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
@@ -92,7 +92,7 @@ private[spark] class YarnRMClient extends Logging {
/** Returns the attempt ID. */
def getAttemptId(): ApplicationAttemptId = {
- YarnSparkHadoopUtil.get.getContainerId.getApplicationAttemptId()
+ YarnSparkHadoopUtil.getContainerId.getApplicationAttemptId()
}
/** Returns the configuration for the AmIpFilter to add to the Spark UI. */
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
index 3d9f99f57bed7..f406fabd61860 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
@@ -17,21 +17,14 @@
package org.apache.spark.deploy.yarn
-import java.nio.charset.StandardCharsets.UTF_8
-import java.util.regex.Matcher
-import java.util.regex.Pattern
+import java.util.regex.{Matcher, Pattern}
import scala.collection.mutable.{HashMap, ListBuffer}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.hadoop.io.Text
-import org.apache.hadoop.mapred.{JobConf, Master}
-import org.apache.hadoop.security.Credentials
-import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.yarn.api.ApplicationConstants
import org.apache.hadoop.yarn.api.records.{ApplicationAccessType, ContainerId, Priority}
-import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.util.ConverterUtils
import org.apache.spark.{SecurityManager, SparkConf, SparkException}
@@ -43,87 +36,10 @@ import org.apache.spark.internal.config._
import org.apache.spark.launcher.YarnCommandBuilderUtils
import org.apache.spark.util.Utils
-
-/**
- * Contains util methods to interact with Hadoop from spark.
- */
-class YarnSparkHadoopUtil extends SparkHadoopUtil {
+object YarnSparkHadoopUtil {
private var credentialUpdater: CredentialUpdater = _
- override def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) {
- dest.addCredentials(source.getCredentials())
- }
-
- // Note that all params which start with SPARK are propagated all the way through, so if in yarn
- // mode, this MUST be set to true.
- override def isYarnMode(): Boolean = { true }
-
- // Return an appropriate (subclass) of Configuration. Creating a config initializes some Hadoop
- // subsystems. Always create a new config, don't reuse yarnConf.
- override def newConfiguration(conf: SparkConf): Configuration = {
- val hadoopConf = new YarnConfiguration(super.newConfiguration(conf))
- hadoopConf.addResource(Client.SPARK_HADOOP_CONF_FILE)
- hadoopConf
- }
-
- // Add any user credentials to the job conf which are necessary for running on a secure Hadoop
- // cluster
- override def addCredentials(conf: JobConf) {
- val jobCreds = conf.getCredentials()
- jobCreds.mergeAll(UserGroupInformation.getCurrentUser().getCredentials())
- }
-
- override def addSecretKeyToUserCredentials(key: String, secret: String) {
- val creds = new Credentials()
- creds.addSecretKey(new Text(key), secret.getBytes(UTF_8))
- addCurrentUserCredentials(creds)
- }
-
- override def getSecretKeyFromUserCredentials(key: String): Array[Byte] = {
- val credentials = getCurrentUserCredentials()
- if (credentials != null) credentials.getSecretKey(new Text(key)) else null
- }
-
- private[spark] override def startCredentialUpdater(sparkConf: SparkConf): Unit = {
- val hadoopConf = newConfiguration(sparkConf)
- val credentialManager = new YARNHadoopDelegationTokenManager(
- sparkConf,
- hadoopConf,
- conf => YarnSparkHadoopUtil.get.hadoopFSsToAccess(sparkConf, conf))
- credentialUpdater = new CredentialUpdater(sparkConf, hadoopConf, credentialManager)
- credentialUpdater.start()
- }
-
- private[spark] override def stopCredentialUpdater(): Unit = {
- if (credentialUpdater != null) {
- credentialUpdater.stop()
- credentialUpdater = null
- }
- }
-
- private[spark] def getContainerId: ContainerId = {
- val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name())
- ConverterUtils.toContainerId(containerIdString)
- }
-
- /** The filesystems for which YARN should fetch delegation tokens. */
- private[spark] def hadoopFSsToAccess(
- sparkConf: SparkConf,
- hadoopConf: Configuration): Set[FileSystem] = {
- val filesystemsToAccess = sparkConf.get(FILESYSTEMS_TO_ACCESS)
- .map(new Path(_).getFileSystem(hadoopConf))
- .toSet
-
- val stagingFS = sparkConf.get(STAGING_DIR)
- .map(new Path(_).getFileSystem(hadoopConf))
- .getOrElse(FileSystem.get(hadoopConf))
-
- filesystemsToAccess + stagingFS
- }
-}
-
-object YarnSparkHadoopUtil {
// Additional memory overhead
// 10% was arrived at experimentally. In the interest of minimizing memory waste while covering
// the common cases. Memory overhead tends to grow with container size.
@@ -133,20 +49,10 @@ object YarnSparkHadoopUtil {
val ANY_HOST = "*"
- val DEFAULT_NUMBER_EXECUTORS = 2
-
// All RM requests are issued with same priority : we do not (yet) have any distinction between
// request types (like map/reduce in hadoop for example)
val RM_REQUEST_PRIORITY = Priority.newInstance(1)
- def get: YarnSparkHadoopUtil = {
- val yarnMode = java.lang.Boolean.parseBoolean(
- System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
- if (!yarnMode) {
- throw new SparkException("YarnSparkHadoopUtil is not available in non-YARN mode!")
- }
- SparkHadoopUtil.get.asInstanceOf[YarnSparkHadoopUtil]
- }
/**
* Add a path variable to the given environment map.
* If the map already contains this key, append the value to the existing value instead.
@@ -280,26 +186,41 @@ object YarnSparkHadoopUtil {
)
}
- /**
- * Getting the initial target number of executors depends on whether dynamic allocation is
- * enabled.
- * If not using dynamic allocation it gets the number of executors requested by the user.
- */
- def getInitialTargetExecutorNumber(
- conf: SparkConf,
- numExecutors: Int = DEFAULT_NUMBER_EXECUTORS): Int = {
- if (Utils.isDynamicAllocationEnabled(conf)) {
- val minNumExecutors = conf.get(DYN_ALLOCATION_MIN_EXECUTORS)
- val initialNumExecutors = Utils.getDynamicAllocationInitialExecutors(conf)
- val maxNumExecutors = conf.get(DYN_ALLOCATION_MAX_EXECUTORS)
- require(initialNumExecutors >= minNumExecutors && initialNumExecutors <= maxNumExecutors,
- s"initial executor number $initialNumExecutors must between min executor number " +
- s"$minNumExecutors and max executor number $maxNumExecutors")
+ def getContainerId: ContainerId = {
+ val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name())
+ ConverterUtils.toContainerId(containerIdString)
+ }
- initialNumExecutors
- } else {
- conf.get(EXECUTOR_INSTANCES).getOrElse(numExecutors)
+ /** The filesystems for which YARN should fetch delegation tokens. */
+ def hadoopFSsToAccess(
+ sparkConf: SparkConf,
+ hadoopConf: Configuration): Set[FileSystem] = {
+ val filesystemsToAccess = sparkConf.get(FILESYSTEMS_TO_ACCESS)
+ .map(new Path(_).getFileSystem(hadoopConf))
+ .toSet
+
+ val stagingFS = sparkConf.get(STAGING_DIR)
+ .map(new Path(_).getFileSystem(hadoopConf))
+ .getOrElse(FileSystem.get(hadoopConf))
+
+ filesystemsToAccess + stagingFS
+ }
+
+ def startCredentialUpdater(sparkConf: SparkConf): Unit = {
+ val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf)
+ val credentialManager = new YARNHadoopDelegationTokenManager(
+ sparkConf,
+ hadoopConf,
+ conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf))
+ credentialUpdater = new CredentialUpdater(sparkConf, hadoopConf, credentialManager)
+ credentialUpdater.start()
+ }
+
+ def stopCredentialUpdater(): Unit = {
+ if (credentialUpdater != null) {
+ credentialUpdater.stop()
+ credentialUpdater = null
}
}
-}
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala
index e1af8ba087d6e..3ba3ae5ab4401 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala
@@ -217,20 +217,12 @@ package object config {
.intConf
.createWithDefault(1)
- private[spark] val DRIVER_MEMORY_OVERHEAD = ConfigBuilder("spark.yarn.driver.memoryOverhead")
- .bytesConf(ByteUnit.MiB)
- .createOptional
-
/* Executor configuration. */
private[spark] val EXECUTOR_CORES = ConfigBuilder("spark.executor.cores")
.intConf
.createWithDefault(1)
- private[spark] val EXECUTOR_MEMORY_OVERHEAD = ConfigBuilder("spark.yarn.executor.memoryOverhead")
- .bytesConf(ByteUnit.MiB)
- .createOptional
-
private[spark] val EXECUTOR_NODE_LABEL_EXPRESSION =
ConfigBuilder("spark.yarn.executor.nodeLabelExpression")
.doc("Node label expression for executors.")
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala
index 6134757a82fdc..eaf2cff111a49 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala
@@ -62,7 +62,7 @@ private[yarn] class AMCredentialRenewer(
private val credentialRenewerThread: ScheduledExecutorService =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("Credential Refresh Thread")
- private val hadoopUtil = YarnSparkHadoopUtil.get
+ private val hadoopUtil = SparkHadoopUtil.get
private val credentialsFile = sparkConf.get(CREDENTIALS_FILE_PATH)
private val daysToKeepFiles = sparkConf.get(CREDENTIALS_FILE_MAX_RETENTION)
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
index d482376d14dd7..0c6206eebe41d 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -52,7 +52,7 @@ private[spark] class YarnClientSchedulerBackend(
logDebug("ClientArguments called with: " + argsArrayBuf.mkString(" "))
val args = new ClientArguments(argsArrayBuf.toArray)
- totalExpectedExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(conf)
+ totalExpectedExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf)
client = new Client(args, conf)
bindToYarn(client.submitApplication(), None)
@@ -66,7 +66,7 @@ private[spark] class YarnClientSchedulerBackend(
// reads the credentials from HDFS, just like the executors and updates its own credentials
// cache.
if (conf.contains("spark.yarn.credentials.file")) {
- YarnSparkHadoopUtil.get.startCredentialUpdater(conf)
+ YarnSparkHadoopUtil.startCredentialUpdater(conf)
}
monitorThread = asyncMonitorApplication()
monitorThread.start()
@@ -153,7 +153,7 @@ private[spark] class YarnClientSchedulerBackend(
client.reportLauncherState(SparkAppHandle.State.FINISHED)
super.stop()
- YarnSparkHadoopUtil.get.stopCredentialUpdater()
+ YarnSparkHadoopUtil.stopCredentialUpdater()
client.stop()
logInfo("Stopped")
}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala
index 4f3d5ebf403e0..62bf9818ee248 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala
@@ -34,14 +34,14 @@ private[spark] class YarnClusterSchedulerBackend(
val attemptId = ApplicationMaster.getAttemptId
bindToYarn(attemptId.getApplicationId(), Some(attemptId))
super.start()
- totalExpectedExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sc.conf)
+ totalExpectedExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(sc.conf)
}
override def getDriverLogUrls: Option[Map[String, String]] = {
var driverLogs: Option[Map[String, String]] = None
try {
val yarnConf = new YarnConfiguration(sc.hadoopConfiguration)
- val containerId = YarnSparkHadoopUtil.get.getContainerId
+ val containerId = YarnSparkHadoopUtil.getContainerId
val httpAddress = System.getenv(Environment.NM_HOST.name()) +
":" + System.getenv(Environment.NM_HTTP_PORT.name())
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index 415a29fd887e8..bb615c36cd97f 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -20,6 +20,7 @@ package org.apache.spark.scheduler.cluster
import java.util.concurrent.atomic.{AtomicBoolean}
import scala.concurrent.{ExecutionContext, Future}
+import scala.concurrent.ExecutionContext.Implicits.global
import scala.util.{Failure, Success}
import scala.util.control.NonFatal
@@ -245,14 +246,7 @@ private[spark] abstract class YarnSchedulerBackend(
Future.successful(RemoveExecutor(executorId, SlaveLost("AM is not yet registered.")))
}
- removeExecutorMessage
- .flatMap { message =>
- driverEndpoint.ask[Boolean](message)
- }(ThreadUtils.sameThread)
- .onFailure {
- case NonFatal(e) => logError(
- s"Error requesting driver to remove executor $executorId after disconnection.", e)
- }(ThreadUtils.sameThread)
+ removeExecutorMessage.foreach { message => driverEndpoint.send(message) }
}
override def receive: PartialFunction[Any, Unit] = {
@@ -265,12 +259,10 @@ private[spark] abstract class YarnSchedulerBackend(
addWebUIFilter(filterName, filterParams, proxyBase)
case r @ RemoveExecutor(executorId, reason) =>
- logWarning(reason.toString)
- driverEndpoint.ask[Boolean](r).onFailure {
- case e =>
- logError("Error requesting driver to remove executor" +
- s" $executorId for reason $reason", e)
- }(ThreadUtils.sameThread)
+ if (!stopped.get) {
+ logWarning(s"Requesting driver to remove executor $executorId for reason $reason")
+ driverEndpoint.send(r)
+ }
}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala
new file mode 100644
index 0000000000000..695a82f3583e6
--- /dev/null
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+
+class ApplicationMasterSuite extends SparkFunSuite {
+
+ test("history url with hadoop and spark substitutions") {
+ val host = "rm.host.com"
+ val port = 18080
+ val sparkConf = new SparkConf()
+
+ sparkConf.set("spark.yarn.historyServer.address",
+ "http://${hadoopconf-yarn.resourcemanager.hostname}:${spark.history.ui.port}")
+ val yarnConf = new YarnConfiguration()
+ yarnConf.set("yarn.resourcemanager.hostname", host)
+ val appId = "application_123_1"
+ val attemptId = appId + "_1"
+
+ val shsAddr = ApplicationMaster
+ .getHistoryServerAddress(sparkConf, yarnConf, appId, attemptId)
+
+ assert(shsAddr === s"http://${host}:${port}/history/${appId}/${attemptId}")
+ }
+}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
index 7cc3075eb766c..dc38e34bf5591 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
@@ -62,18 +62,14 @@ abstract class BaseYarnClusterSuite
protected var hadoopConfDir: File = _
private var logConfDir: File = _
- var oldSystemProperties: Properties = null
-
def newYarnConfig(): YarnConfiguration
override def beforeAll() {
super.beforeAll()
- oldSystemProperties = SerializationUtils.clone(System.getProperties)
tempDir = Utils.createTempDir()
logConfDir = new File(tempDir, "log4j")
logConfDir.mkdir()
- System.setProperty("SPARK_YARN_MODE", "true")
val logConfFile = new File(logConfDir, "log4j.properties")
Files.write(LOG4J_CONF, logConfFile, StandardCharsets.UTF_8)
@@ -124,7 +120,6 @@ abstract class BaseYarnClusterSuite
try {
yarnCluster.stop()
} finally {
- System.setProperties(oldSystemProperties)
super.afterAll()
}
}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
index 6cf68427921fd..7fa597167f3f0 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
@@ -24,7 +24,6 @@ import java.util.Properties
import scala.collection.JavaConverters._
import scala.collection.mutable.{HashMap => MutableHashMap}
-import org.apache.commons.lang3.SerializationUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.MRJobConfig
@@ -36,34 +35,18 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.util.Records
import org.mockito.Matchers.{eq => meq, _}
import org.mockito.Mockito._
-import org.scalatest.{BeforeAndAfterAll, Matchers}
+import org.scalatest.Matchers
import org.apache.spark.{SparkConf, SparkFunSuite, TestUtils}
import org.apache.spark.deploy.yarn.config._
-import org.apache.spark.util.{ResetSystemProperties, SparkConfWithEnv, Utils}
+import org.apache.spark.util.{SparkConfWithEnv, Utils}
-class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll
- with ResetSystemProperties {
+class ClientSuite extends SparkFunSuite with Matchers {
import Client._
var oldSystemProperties: Properties = null
- override def beforeAll(): Unit = {
- super.beforeAll()
- oldSystemProperties = SerializationUtils.clone(System.getProperties)
- System.setProperty("SPARK_YARN_MODE", "true")
- }
-
- override def afterAll(): Unit = {
- try {
- System.setProperties(oldSystemProperties)
- oldSystemProperties = null
- } finally {
- super.afterAll()
- }
- }
-
test("default Yarn application classpath") {
getDefaultYarnApplicationClasspath should be(Fixtures.knownDefYarnAppCP)
}
@@ -185,7 +168,6 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll
}
test("configuration and args propagate through createApplicationSubmissionContext") {
- val conf = new Configuration()
// When parsing tags, duplicates and leading/trailing whitespace should be removed.
// Spaces between non-comma strings should be preserved as single tags. Empty strings may or
// may not be removed depending on the version of Hadoop being used.
@@ -200,7 +182,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll
val getNewApplicationResponse = Records.newRecord(classOf[GetNewApplicationResponse])
val containerLaunchContext = Records.newRecord(classOf[ContainerLaunchContext])
- val client = new Client(args, conf, sparkConf)
+ val client = new Client(args, sparkConf)
client.createApplicationSubmissionContext(
new YarnClientApplication(getNewApplicationResponse, appContext),
containerLaunchContext)
@@ -375,6 +357,39 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll
sparkConf.get(SECONDARY_JARS) should be (Some(Seq(new File(jar2.toURI).getName)))
}
+ private val matching = Seq(
+ ("files URI match test1", "file:///file1", "file:///file2"),
+ ("files URI match test2", "file:///c:file1", "file://c:file2"),
+ ("files URI match test3", "file://host/file1", "file://host/file2"),
+ ("wasb URI match test", "wasb://bucket1@user", "wasb://bucket1@user/"),
+ ("hdfs URI match test", "hdfs:/path1", "hdfs:/path1")
+ )
+
+ matching.foreach { t =>
+ test(t._1) {
+ assert(Client.compareUri(new URI(t._2), new URI(t._3)),
+ s"No match between ${t._2} and ${t._3}")
+ }
+ }
+
+ private val unmatching = Seq(
+ ("files URI unmatch test1", "file:///file1", "file://host/file2"),
+ ("files URI unmatch test2", "file://host/file1", "file:///file2"),
+ ("files URI unmatch test3", "file://host/file1", "file://host2/file2"),
+ ("wasb URI unmatch test1", "wasb://bucket1@user", "wasb://bucket2@user/"),
+ ("wasb URI unmatch test2", "wasb://bucket1@user", "wasb://bucket1@user2/"),
+ ("s3 URI unmatch test", "s3a://user@pass:bucket1/", "s3a://user2@pass2:bucket1/"),
+ ("hdfs URI unmatch test1", "hdfs://namenode1/path1", "hdfs://namenode1:8080/path2"),
+ ("hdfs URI unmatch test2", "hdfs://namenode1:8020/path1", "hdfs://namenode1:8080/path2")
+ )
+
+ unmatching.foreach { t =>
+ test(t._1) {
+ assert(!Client.compareUri(new URI(t._2), new URI(t._3)),
+ s"match between ${t._2} and ${t._3}")
+ }
+ }
+
object Fixtures {
val knownDefYarnAppCP: Seq[String] =
@@ -407,15 +422,14 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll
private def createClient(
sparkConf: SparkConf,
- conf: Configuration = new Configuration(),
args: Array[String] = Array()): Client = {
val clientArgs = new ClientArguments(args)
- spy(new Client(clientArgs, conf, sparkConf))
+ spy(new Client(clientArgs, sparkConf))
}
private def classpath(client: Client): Array[String] = {
val env = new MutableHashMap[String, String]()
- populateClasspath(null, client.hadoopConf, client.sparkConf, env)
+ populateClasspath(null, new Configuration(), client.sparkConf, env)
classpath(env)
}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index 080a20cdf87e4..f941f81f32013 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -132,7 +132,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
"spark.executor.cores" -> "1",
"spark.executor.memory" -> "512m",
"spark.executor.instances" -> "2",
- // Sending some senstive information, which we'll make sure gets redacted
+ // Sending some sensitive information, which we'll make sure gets redacted
"spark.executorEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD,
"spark.yarn.appMasterEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD
))
@@ -152,8 +152,13 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
))
}
- test("run Spark in yarn-cluster mode with using SparkHadoopUtil.conf") {
- testYarnAppUseSparkHadoopUtilConf()
+ test("yarn-cluster should respect conf overrides in SparkHadoopUtil (SPARK-16414)") {
+ val result = File.createTempFile("result", null, tempDir)
+ val finalState = runSpark(false,
+ mainClassName(YarnClusterDriverUseSparkHadoopUtilConf.getClass),
+ appArgs = Seq("key=value", result.getAbsolutePath()),
+ extraConf = Map("spark.hadoop.key" -> "value"))
+ checkResult(finalState, result)
}
test("run Spark in yarn-client mode with additional jar") {
@@ -261,15 +266,6 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
checkResult(finalState, result)
}
- private def testYarnAppUseSparkHadoopUtilConf(): Unit = {
- val result = File.createTempFile("result", null, tempDir)
- val finalState = runSpark(false,
- mainClassName(YarnClusterDriverUseSparkHadoopUtilConf.getClass),
- appArgs = Seq("key=value", result.getAbsolutePath()),
- extraConf = Map("spark.hadoop.key" -> "value"))
- checkResult(finalState, result)
- }
-
private def testWithAddJar(clientMode: Boolean): Unit = {
val originalJar = TestUtils.createJarWithFiles(Map("test.resource" -> "ORIGINAL"), tempDir)
val driverResult = File.createTempFile("driver", null, tempDir)
@@ -518,7 +514,7 @@ private object YarnClusterDriver extends Logging with Matchers {
s"Driver logs contain sensitive info (${SECRET_PASSWORD}): \n${log} "
)
}
- val containerId = YarnSparkHadoopUtil.get.getContainerId
+ val containerId = YarnSparkHadoopUtil.getContainerId
val user = Utils.getCurrentUserName()
assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=-4096"))
}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
index a057618b39950..f21353aa007c8 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
@@ -71,14 +71,10 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging
test("Yarn configuration override") {
val key = "yarn.nodemanager.hostname"
- val default = new YarnConfiguration()
-
val sparkConf = new SparkConf()
.set("spark.hadoop." + key, "someHostName")
- val yarnConf = new YarnSparkHadoopUtil().newConfiguration(sparkConf)
-
- yarnConf.getClass() should be (classOf[YarnConfiguration])
- yarnConf.get(key) should not be default.get(key)
+ val yarnConf = new YarnConfiguration(SparkHadoopUtil.get.newConfiguration(sparkConf))
+ yarnConf.get(key) should be ("someHostName")
}
@@ -145,45 +141,4 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging
}
- test("check different hadoop utils based on env variable") {
- try {
- System.setProperty("SPARK_YARN_MODE", "true")
- assert(SparkHadoopUtil.get.getClass === classOf[YarnSparkHadoopUtil])
- System.setProperty("SPARK_YARN_MODE", "false")
- assert(SparkHadoopUtil.get.getClass === classOf[SparkHadoopUtil])
- } finally {
- System.clearProperty("SPARK_YARN_MODE")
- }
- }
-
-
-
- // This test needs to live here because it depends on isYarnMode returning true, which can only
- // happen in the YARN module.
- test("security manager token generation") {
- try {
- System.setProperty("SPARK_YARN_MODE", "true")
- val initial = SparkHadoopUtil.get
- .getSecretKeyFromUserCredentials(SecurityManager.SECRET_LOOKUP_KEY)
- assert(initial === null || initial.length === 0)
-
- val conf = new SparkConf()
- .set(SecurityManager.SPARK_AUTH_CONF, "true")
- .set(SecurityManager.SPARK_AUTH_SECRET_CONF, "unused")
- val sm = new SecurityManager(conf)
-
- val generated = SparkHadoopUtil.get
- .getSecretKeyFromUserCredentials(SecurityManager.SECRET_LOOKUP_KEY)
- assert(generated != null)
- val genString = new Text(generated).toString()
- assert(genString != "unused")
- assert(sm.getSecretKey() === genString)
- } finally {
- // removeSecretKey() was only added in Hadoop 2.6, so instead we just set the secret
- // to an empty string.
- SparkHadoopUtil.get.addSecretKeyToUserCredentials(SecurityManager.SECRET_LOOKUP_KEY, "")
- System.clearProperty("SPARK_YARN_MODE")
- }
- }
-
}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala
index c918998bde07c..3c7cdc0f1dab8 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala
@@ -31,24 +31,15 @@ class YARNHadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers
override def beforeAll(): Unit = {
super.beforeAll()
-
- System.setProperty("SPARK_YARN_MODE", "true")
-
sparkConf = new SparkConf()
hadoopConf = new Configuration()
}
- override def afterAll(): Unit = {
- super.afterAll()
-
- System.clearProperty("SPARK_YARN_MODE")
- }
-
test("Correctly loads credential providers") {
credentialManager = new YARNHadoopDelegationTokenManager(
sparkConf,
hadoopConf,
- conf => YarnSparkHadoopUtil.get.hadoopFSsToAccess(sparkConf, conf))
+ conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf))
credentialManager.credentialProviders.get("yarn-test") should not be (None)
}
diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index 7bdd3fac773a3..e2fa5754afaee 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -93,7 +93,7 @@ This file is divided into 3 sections:
-
+
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 9e2ced30407d4..7d23637e28342 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.3.0-SNAPSHOT
+ 2.4.0-SNAPSHOT
../../pom.xml
@@ -134,7 +134,7 @@
org.scalatest
scalatest-maven-plugin
- -ea -Xmx4g -Xss4m -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m
+ -ea -Xmx4g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize}
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index 6fe995f650d55..39d5e4ed56628 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -73,18 +73,22 @@ statement
| ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties
| DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase
| createTableHeader ('(' colTypeList ')')? tableProvider
- (OPTIONS options=tablePropertyList)?
- (PARTITIONED BY partitionColumnNames=identifierList)?
- bucketSpec? locationSpec?
- (COMMENT comment=STRING)?
- (TBLPROPERTIES tableProps=tablePropertyList)?
+ ((OPTIONS options=tablePropertyList) |
+ (PARTITIONED BY partitionColumnNames=identifierList) |
+ bucketSpec |
+ locationSpec |
+ (COMMENT comment=STRING) |
+ (TBLPROPERTIES tableProps=tablePropertyList))*
(AS? query)? #createTable
| createTableHeader ('(' columns=colTypeList ')')?
- (COMMENT comment=STRING)?
- (PARTITIONED BY '(' partitionColumns=colTypeList ')')?
- bucketSpec? skewSpec?
- rowFormat? createFileFormat? locationSpec?
- (TBLPROPERTIES tablePropertyList)?
+ ((COMMENT comment=STRING) |
+ (PARTITIONED BY '(' partitionColumns=colTypeList ')') |
+ bucketSpec |
+ skewSpec |
+ rowFormat |
+ createFileFormat |
+ locationSpec |
+ (TBLPROPERTIES tableProps=tablePropertyList))*
(AS? query)? #createHiveTable
| CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier
LIKE source=tableIdentifier locationSpec? #createTableLike
@@ -137,7 +141,7 @@ statement
(LIKE? pattern=STRING)? #showTables
| SHOW TABLE EXTENDED ((FROM | IN) db=identifier)?
LIKE pattern=STRING partitionSpec? #showTable
- | SHOW DATABASES (LIKE pattern=STRING)? #showDatabases
+ | SHOW DATABASES (LIKE? pattern=STRING)? #showDatabases
| SHOW TBLPROPERTIES table=tableIdentifier
('(' key=tablePropertyKey ')')? #showTblProperties
| SHOW COLUMNS (FROM | IN) tableIdentifier
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index 64ab01ca57403..d18542b188f71 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -294,7 +294,7 @@ public void setNullAt(int ordinal) {
assertIndexIsValid(ordinal);
BitSetMethods.set(baseObject, baseOffset + 8, ordinal);
- /* we assume the corrresponding column was already 0 or
+ /* we assume the corresponding column was already 0 or
will be set to 0 later by the caller side */
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java
new file mode 100644
index 0000000000000..f0f66bae245fd
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen;
+
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.types.UTF8String;
+
+/**
+ * A helper class to write {@link UTF8String}s to an internal buffer and build the concatenated
+ * {@link UTF8String} at the end.
+ */
+public class UTF8StringBuilder {
+
+ private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH;
+
+ private byte[] buffer;
+ private int cursor = Platform.BYTE_ARRAY_OFFSET;
+
+ public UTF8StringBuilder() {
+ // Since initial buffer size is 16 in `StringBuilder`, we set the same size here
+ this.buffer = new byte[16];
+ }
+
+ // Grows the buffer by at least `neededSize`
+ private void grow(int neededSize) {
+ if (neededSize > ARRAY_MAX - totalSize()) {
+ throw new UnsupportedOperationException(
+ "Cannot grow internal buffer by size " + neededSize + " because the size after growing " +
+ "exceeds size limitation " + ARRAY_MAX);
+ }
+ final int length = totalSize() + neededSize;
+ if (buffer.length < length) {
+ int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX;
+ final byte[] tmp = new byte[newLength];
+ Platform.copyMemory(
+ buffer,
+ Platform.BYTE_ARRAY_OFFSET,
+ tmp,
+ Platform.BYTE_ARRAY_OFFSET,
+ totalSize());
+ buffer = tmp;
+ }
+ }
+
+ private int totalSize() {
+ return cursor - Platform.BYTE_ARRAY_OFFSET;
+ }
+
+ public void append(UTF8String value) {
+ grow(value.numBytes());
+ value.writeToMemory(buffer, cursor);
+ cursor += value.numBytes();
+ }
+
+ public void append(String value) {
+ append(UTF8String.fromString(value));
+ }
+
+ public UTF8String build() {
+ return UTF8String.fromBytes(buffer, 0, totalSize());
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 65040f1af4b04..9a4bf0075a178 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -63,6 +63,7 @@ object ScalaReflection extends ScalaReflection {
private def dataTypeFor(tpe: `Type`): DataType = cleanUpReflectionObjects {
tpe.dealias match {
+ case t if t <:< definitions.NullTpe => NullType
case t if t <:< definitions.IntTpe => IntegerType
case t if t <:< definitions.LongTpe => LongType
case t if t <:< definitions.DoubleTpe => DoubleType
@@ -712,6 +713,9 @@ object ScalaReflection extends ScalaReflection {
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): Schema = cleanUpReflectionObjects {
tpe.dealias match {
+ // this must be the first case, since all objects in scala are instances of Null, therefore
+ // Null type would wrongly match the first of them, which is Option as of now
+ case t if t <:< definitions.NullTpe => Schema(NullType, nullable = true)
case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) =>
val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
Schema(udt, nullable = true)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala
index 57f7a80bedc6c..6d587abd8fd4d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala
@@ -31,7 +31,7 @@ class TableAlreadyExistsException(db: String, table: String)
extends AnalysisException(s"Table or view '$table' already exists in database '$db'")
class TempTableAlreadyExistsException(table: String)
- extends AnalysisException(s"Temporary table '$table' already exists")
+ extends AnalysisException(s"Temporary view '$table' already exists")
class PartitionAlreadyExistsException(db: String, table: String, spec: TablePartitionSpec)
extends AnalysisException(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index e5c93b5f0e059..35b35110e491f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.analysis
-import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.AnalysisException
@@ -53,6 +52,7 @@ object SimpleAnalyzer extends Analyzer(
/**
* Provides a way to keep state during the analysis, this enables us to decouple the concerns
* of analysis environment from the catalog.
+ * The state that is kept here is per-query.
*
* Note this is thread local.
*
@@ -71,6 +71,8 @@ object AnalysisContext {
}
def get: AnalysisContext = value.get()
+ def reset(): Unit = value.remove()
+
private def set(context: AnalysisContext): Unit = value.set(context)
def withAnalysisContext[A](database: Option[String])(f: => A): A = {
@@ -96,6 +98,17 @@ class Analyzer(
this(catalog, conf, conf.optimizerMaxIterations)
}
+ override def execute(plan: LogicalPlan): LogicalPlan = {
+ AnalysisContext.reset()
+ try {
+ executeSameContext(plan)
+ } finally {
+ AnalysisContext.reset()
+ }
+ }
+
+ private def executeSameContext(plan: LogicalPlan): LogicalPlan = super.execute(plan)
+
def resolver: Resolver = conf.resolver
protected val fixedPoint = FixedPoint(maxIterations)
@@ -151,7 +164,7 @@ class Analyzer(
TimeWindowing ::
ResolveInlineTables(conf) ::
ResolveTimeZone(conf) ::
- TypeCoercion.typeCoercionRules ++
+ TypeCoercion.typeCoercionRules(conf) ++
extendedResolutionRules : _*),
Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
Batch("View", Once,
@@ -165,18 +178,19 @@ class Analyzer(
Batch("Subquery", Once,
UpdateOuterReferences),
Batch("Cleanup", fixedPoint,
- CleanupAliases)
+ CleanupAliases,
+ EliminateBarriers)
)
/**
* Analyze cte definitions and substitute child plan with analyzed cte definitions.
*/
object CTESubstitution extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case With(child, relations) =>
substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) {
case (resolved, (name, relation)) =>
- resolved :+ name -> execute(substituteCTE(relation, resolved))
+ resolved :+ name -> executeSameContext(substituteCTE(relation, resolved))
})
case other => other
}
@@ -200,7 +214,7 @@ class Analyzer(
* Substitute child plan with WindowSpecDefinitions.
*/
object WindowsSubstitution extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
// Lookup WindowSpecDefinitions. This rule works with unresolved children.
case WithWindowDefinition(windowDefinitions, child) =>
child.transform {
@@ -221,28 +235,26 @@ class Analyzer(
*/
object ResolveAliases extends Rule[LogicalPlan] {
private def assignAliases(exprs: Seq[NamedExpression]) = {
- exprs.zipWithIndex.map {
- case (expr, i) =>
- expr.transformUp { case u @ UnresolvedAlias(child, optGenAliasFunc) =>
- child match {
- case ne: NamedExpression => ne
- case go @ GeneratorOuter(g: Generator) if g.resolved => MultiAlias(go, Nil)
- case e if !e.resolved => u
- case g: Generator => MultiAlias(g, Nil)
- case c @ Cast(ne: NamedExpression, _, _) => Alias(c, ne.name)()
- case e: ExtractValue => Alias(e, toPrettySQL(e))()
- case e if optGenAliasFunc.isDefined =>
- Alias(child, optGenAliasFunc.get.apply(e))()
- case e => Alias(e, toPrettySQL(e))()
- }
+ exprs.map(_.transformUp { case u @ UnresolvedAlias(child, optGenAliasFunc) =>
+ child match {
+ case ne: NamedExpression => ne
+ case go @ GeneratorOuter(g: Generator) if g.resolved => MultiAlias(go, Nil)
+ case e if !e.resolved => u
+ case g: Generator => MultiAlias(g, Nil)
+ case c @ Cast(ne: NamedExpression, _, _) => Alias(c, ne.name)()
+ case e: ExtractValue => Alias(e, toPrettySQL(e))()
+ case e if optGenAliasFunc.isDefined =>
+ Alias(child, optGenAliasFunc.get.apply(e))()
+ case e => Alias(e, toPrettySQL(e))()
}
- }.asInstanceOf[Seq[NamedExpression]]
+ }
+ ).asInstanceOf[Seq[NamedExpression]]
}
private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) =
exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined)
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) =>
Aggregate(groups, assignAliases(aggs), child)
@@ -602,7 +614,7 @@ class Analyzer(
"avoid errors. Increase the value of spark.sql.view.maxNestedViewDepth to work " +
"aroud this.")
}
- execute(child)
+ executeSameContext(child)
}
view.copy(child = newChild)
case p @ SubqueryAlias(_, view: View) =>
@@ -611,7 +623,7 @@ class Analyzer(
case _ => plan
}
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved =>
EliminateSubqueryAliases(lookupTableFromCatalog(u)) match {
case v: View =>
@@ -672,6 +684,12 @@ class Analyzer(
s"between $left and $right")
right.collect {
+ // For `AnalysisBarrier`, recursively de-duplicate its child.
+ case oldVersion: AnalysisBarrier
+ if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
+ val newVersion = dedupRight(left, oldVersion.child)
+ (oldVersion, AnalysisBarrier(newVersion))
+
// Handle base relations that might appear more than once.
case oldVersion: MultiInstanceRelation
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
@@ -692,7 +710,7 @@ class Analyzer(
(oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))
case oldVersion: Generate
- if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty =>
+ if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
val newOutput = oldVersion.generatorOutput.map(_.newInstance())
(oldVersion, oldVersion.copy(generatorOutput = newOutput))
@@ -712,7 +730,7 @@ class Analyzer(
right
case Some((oldRelation, newRelation)) =>
val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output))
- val newRight = right transformUp {
+ right transformUp {
case r if r == oldRelation => newRelation
} transformUp {
case other => other transformExpressions {
@@ -722,7 +740,6 @@ class Analyzer(
s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites))
}
}
- newRight
}
}
@@ -799,7 +816,7 @@ class Analyzer(
case _ => e.mapChildren(resolve(_, q))
}
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case p: LogicalPlan if !p.childrenResolved => p
// If the projection list contains Stars, expand it.
@@ -957,7 +974,8 @@ class Analyzer(
protected[sql] def resolveExpression(
expr: Expression,
plan: LogicalPlan,
- throws: Boolean = false) = {
+ throws: Boolean = false): Expression = {
+ if (expr.resolved) return expr
// Resolve expression in one round.
// If throws == false or the desired attribute doesn't exist
// (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one.
@@ -993,7 +1011,7 @@ class Analyzer(
* have no effect on the results.
*/
object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case p if !p.childrenResolved => p
// Replace the index with the related attribute for ORDER BY,
// which is a 1-base position of the projection list.
@@ -1049,7 +1067,7 @@ class Analyzer(
}}
}
- override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case agg @ Aggregate(groups, aggs, child)
if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) &&
groups.exists(!_.resolved) =>
@@ -1073,102 +1091,79 @@ class Analyzer(
* The HAVING clause could also used a grouping columns that is not presented in the SELECT.
*/
object ResolveMissingReferences extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
// Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
+ case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa
case sa @ Sort(_, _, child: Aggregate) => sa
case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
- try {
- val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder])
- val requiredAttrs = AttributeSet(newOrder).filter(_.resolved)
- val missingAttrs = requiredAttrs -- child.outputSet
- if (missingAttrs.nonEmpty) {
- // Add missing attributes and then project them away after the sort.
- Project(child.output,
- Sort(newOrder, s.global, addMissingAttr(child, missingAttrs)))
- } else if (newOrder != order) {
- s.copy(order = newOrder)
- } else {
- s
- }
- } catch {
- // Attempting to resolve it might fail. When this happens, return the original plan.
- // Users will see an AnalysisException for resolution failure of missing attributes
- // in Sort
- case ae: AnalysisException => s
+ val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, child)
+ val ordering = newOrder.map(_.asInstanceOf[SortOrder])
+ if (child.output == newChild.output) {
+ s.copy(order = ordering)
+ } else {
+ // Add missing attributes and then project them away.
+ val newSort = s.copy(order = ordering, child = newChild)
+ Project(child.output, newSort)
}
case f @ Filter(cond, child) if !f.resolved && child.resolved =>
- try {
- val newCond = resolveExpressionRecursively(cond, child)
- val requiredAttrs = newCond.references.filter(_.resolved)
- val missingAttrs = requiredAttrs -- child.outputSet
- if (missingAttrs.nonEmpty) {
- // Add missing attributes and then project them away.
- Project(child.output,
- Filter(newCond, addMissingAttr(child, missingAttrs)))
- } else if (newCond != cond) {
- f.copy(condition = newCond)
- } else {
- f
- }
- } catch {
- // Attempting to resolve it might fail. When this happens, return the original plan.
- // Users will see an AnalysisException for resolution failure of missing attributes
- case ae: AnalysisException => f
+ val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), child)
+ if (child.output == newChild.output) {
+ f.copy(condition = newCond.head)
+ } else {
+ // Add missing attributes and then project them away.
+ val newFilter = Filter(newCond.head, newChild)
+ Project(child.output, newFilter)
}
}
- /**
- * Add the missing attributes into projectList of Project/Window or aggregateExpressions of
- * Aggregate.
- */
- private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = {
- if (missingAttrs.isEmpty) {
- return plan
- }
- plan match {
- case p: Project =>
- val missing = missingAttrs -- p.child.outputSet
- Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, missing))
- case a: Aggregate =>
- // all the missing attributes should be grouping expressions
- // TODO: push down AggregateExpression
- missingAttrs.foreach { attr =>
- if (!a.groupingExpressions.exists(_.semanticEquals(attr))) {
- throw new AnalysisException(s"Can't add $attr to ${a.simpleString}")
- }
- }
- val newAggregateExpressions = a.aggregateExpressions ++ missingAttrs
- a.copy(aggregateExpressions = newAggregateExpressions)
- case g: Generate =>
- // If join is false, we will convert it to true for getting from the child the missing
- // attributes that its child might have or could have.
- val missing = missingAttrs -- g.child.outputSet
- g.copy(join = true, child = addMissingAttr(g.child, missing))
- case d: Distinct =>
- throw new AnalysisException(s"Can't add $missingAttrs to $d")
- case u: UnaryNode =>
- u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil)
- case other =>
- throw new AnalysisException(s"Can't add $missingAttrs to $other")
- }
- }
-
- /**
- * Resolve the expression on a specified logical plan and it's child (recursively), until
- * the expression is resolved or meet a non-unary node or Subquery.
- */
- @tailrec
- private def resolveExpressionRecursively(expr: Expression, plan: LogicalPlan): Expression = {
- val resolved = resolveExpression(expr, plan)
- if (resolved.resolved) {
- resolved
+ private def resolveExprsAndAddMissingAttrs(
+ exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = {
+ if (exprs.forall(_.resolved)) {
+ // All given expressions are resolved, no need to continue anymore.
+ (exprs, plan)
} else {
plan match {
- case u: UnaryNode if !u.isInstanceOf[SubqueryAlias] =>
- resolveExpressionRecursively(resolved, u.child)
- case other => resolved
+ // For `AnalysisBarrier`, recursively resolve expressions and add missing attributes via
+ // its child.
+ case barrier: AnalysisBarrier =>
+ val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(exprs, barrier.child)
+ (newExprs, AnalysisBarrier(newChild))
+
+ case p: Project =>
+ val maybeResolvedExprs = exprs.map(resolveExpression(_, p))
+ val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child)
+ val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs)
+ (newExprs, Project(p.projectList ++ missingAttrs, newChild))
+
+ case a @ Aggregate(groupExprs, aggExprs, child) =>
+ val maybeResolvedExprs = exprs.map(resolveExpression(_, a))
+ val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child)
+ val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs)
+ if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) {
+ // All the missing attributes are grouping expressions, valid case.
+ (newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild))
+ } else {
+ // Need to add non-grouping attributes, invalid case.
+ (exprs, a)
+ }
+
+ case g: Generate =>
+ val maybeResolvedExprs = exprs.map(resolveExpression(_, g))
+ val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, g.child)
+ (newExprs, g.copy(unrequiredChildIndex = Nil, child = newChild))
+
+ // For `Distinct` and `SubqueryAlias`, we can't recursively resolve and add attributes
+ // via its children.
+ case u: UnaryNode if !u.isInstanceOf[Distinct] && !u.isInstanceOf[SubqueryAlias] =>
+ val maybeResolvedExprs = exprs.map(resolveExpression(_, u))
+ val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, u.child)
+ (newExprs, u.withNewChildren(Seq(newChild)))
+
+ // For other operators, we can't recursively resolve and add attributes via its children.
+ case other =>
+ (exprs.map(resolveExpression(_, other)), other)
}
}
}
@@ -1197,7 +1192,7 @@ class Analyzer(
* Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s.
*/
object ResolveFunctions extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case q: LogicalPlan =>
q transformExpressions {
case u if !u.childrenResolved => u // Skip until children are resolved.
@@ -1288,7 +1283,7 @@ class Analyzer(
do {
// Try to resolve the subquery plan using the regular analyzer.
previous = current
- current = execute(current)
+ current = executeSameContext(current)
// Use the outer references to resolve the subquery plan if it isn't resolved yet.
val i = plans.iterator
@@ -1334,7 +1329,7 @@ class Analyzer(
/**
* Resolve and rewrite all subqueries in an operator tree..
*/
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
// In case of HAVING (a filter after an aggregate) we use both the aggregate and
// its child for resolution.
case f @ Filter(_, a: Aggregate) if f.childrenResolved =>
@@ -1350,7 +1345,7 @@ class Analyzer(
*/
object ResolveSubqueryColumnAliases extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case u @ UnresolvedSubqueryColumnAliases(columnNames, child) if child.resolved =>
// Resolves output attributes if a query has alias names in its subquery:
// e.g., SELECT * FROM (SELECT 1 AS a, 1 AS b) t(col1, col2)
@@ -1373,7 +1368,7 @@ class Analyzer(
* Turns projections that contain aggregate expressions into aggregations.
*/
object GlobalAggregates extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case Project(projectList, child) if containsAggregates(projectList) =>
Aggregate(Nil, projectList, child)
}
@@ -1399,19 +1394,19 @@ class Analyzer(
* underlying aggregate operator and then projected away after the original operator.
*/
object ResolveAggregateFunctions extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
- case filter @ Filter(havingCondition,
- aggregate @ Aggregate(grouping, originalAggExprs, child))
- if aggregate.resolved =>
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
+ case Filter(cond, AnalysisBarrier(agg: Aggregate)) =>
+ apply(Filter(cond, agg)).mapChildren(AnalysisBarrier)
+ case f @ Filter(cond, agg @ Aggregate(grouping, originalAggExprs, child)) if agg.resolved =>
// Try resolving the condition of the filter as though it is in the aggregate clause
try {
val aggregatedCondition =
Aggregate(
grouping,
- Alias(havingCondition, "havingCondition")() :: Nil,
+ Alias(cond, "havingCondition")() :: Nil,
child)
- val resolvedOperator = execute(aggregatedCondition)
+ val resolvedOperator = executeSameContext(aggregatedCondition)
def resolvedAggregateFilter =
resolvedOperator
.asInstanceOf[Aggregate]
@@ -1430,7 +1425,7 @@ class Analyzer(
// Grouping functions are handled in the rule [[ResolveGroupingAnalytics]].
case e: Expression if grouping.exists(_.semanticEquals(e)) &&
!ResolveGroupingAnalytics.hasGroupingFunction(e) &&
- !aggregate.output.exists(_.semanticEquals(e)) =>
+ !agg.output.exists(_.semanticEquals(e)) =>
e match {
case ne: NamedExpression =>
aggregateExpressions += ne
@@ -1444,21 +1439,23 @@ class Analyzer(
// Push the aggregate expressions into the aggregate (if any).
if (aggregateExpressions.nonEmpty) {
- Project(aggregate.output,
+ Project(agg.output,
Filter(transformedAggregateFilter,
- aggregate.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions)))
+ agg.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions)))
} else {
- filter
+ f
}
} else {
- filter
+ f
}
} catch {
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
// just return the original plan.
- case ae: AnalysisException => filter
+ case ae: AnalysisException => f
}
+ case Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) =>
+ apply(Sort(sortOrder, global, aggregate)).mapChildren(AnalysisBarrier)
case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved =>
// Try resolving the ordering as though it is in the aggregate clause.
@@ -1467,7 +1464,8 @@ class Analyzer(
val aliasedOrdering =
unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")())
val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)
- val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate]
+ val resolvedAggregate: Aggregate =
+ executeSameContext(aggregatedOrdering).asInstanceOf[Aggregate]
val resolvedAliasedOrdering: Seq[Alias] =
resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]]
@@ -1571,7 +1569,7 @@ class Analyzer(
}
}
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case Project(projectList, _) if projectList.exists(hasNestedGenerator) =>
val nestedGenerator = projectList.find(hasNestedGenerator).get
throw new AnalysisException("Generators are not supported when it's nested in " +
@@ -1595,7 +1593,7 @@ class Analyzer(
resolvedGenerator =
Generate(
generator,
- join = projectList.size > 1, // Only join if there are other expressions in SELECT.
+ unrequiredChildIndex = Nil,
outer = outer,
qualifier = None,
generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names),
@@ -1629,7 +1627,7 @@ class Analyzer(
* that wrap the [[Generator]].
*/
object ResolveGenerate extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case g: Generate if !g.child.resolved || !g.generator.resolved => g
case g: Generate if !g.resolved =>
g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name)))
@@ -1946,7 +1944,7 @@ class Analyzer(
* put them into an inner Project and finally project them away at the outer Project.
*/
object PullOutNondeterministic extends Rule[LogicalPlan] {
- override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case p if !p.resolved => p // Skip unresolved nodes.
case p: Project => p
case f: Filter => f
@@ -1991,7 +1989,7 @@ class Analyzer(
* and we should return null if the input is null.
*/
object HandleNullInputsForUDF extends Rule[LogicalPlan] {
- override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case p if !p.resolved => p // Skip unresolved nodes.
case p => p transformExpressionsUp {
@@ -2056,7 +2054,7 @@ class Analyzer(
* Then apply a Project on a normal Join to eliminate natural or using join.
*/
object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] {
- override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case j @ Join(left, right, UsingJoin(joinType, usingCols), condition)
if left.resolved && right.resolved && j.duplicateResolved =>
commonNaturalJoinProcessing(left, right, joinType, usingCols, None)
@@ -2121,7 +2119,7 @@ class Analyzer(
* to the given input attributes.
*/
object ResolveDeserializer extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case p if !p.childrenResolved => p
case p if p.resolved => p
@@ -2207,7 +2205,7 @@ class Analyzer(
* constructed is an inner class.
*/
object ResolveNewInstance extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case p if !p.childrenResolved => p
case p if p.resolved => p
@@ -2241,7 +2239,7 @@ class Analyzer(
"type of the field in the target object")
}
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case p if !p.childrenResolved => p
case p if p.resolved => p
@@ -2300,7 +2298,7 @@ object CleanupAliases extends Rule[LogicalPlan] {
case other => trimAliases(other)
}
- override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case Project(projectList, child) =>
val cleanedProjectList =
projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
@@ -2329,6 +2327,13 @@ object CleanupAliases extends Rule[LogicalPlan] {
}
}
+/** Remove the barrier nodes of analysis */
+object EliminateBarriers extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
+ case AnalysisBarrier(child) => child
+ }
+}
+
/**
* Ignore event time watermark in batch query, which is only supported in Structured Streaming.
* TODO: add this rule into analyzer rule list.
@@ -2379,7 +2384,7 @@ object TimeWindowing extends Rule[LogicalPlan] {
* @return the logical plan that will generate the time windows using the Expand operator, with
* the Filter operator for correctness and Project for usability.
*/
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case p: LogicalPlan if p.children.size == 1 =>
val child = p.children.head
val windowExpressions =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index b5e8bdd79869e..bbcec5627bd49 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -78,8 +78,6 @@ trait CheckAnalysis extends PredicateHelper {
// We transform up and order the rules so as to catch the first possible failure instead
// of the result of cascading resolution failures.
plan.foreachUp {
- case p if p.analyzed => // Skip already analyzed sub-plans
-
case u: UnresolvedRelation =>
u.failAnalysis(s"Table or view not found: ${u.tableIdentifier}")
@@ -353,8 +351,6 @@ trait CheckAnalysis extends PredicateHelper {
case o if !o.resolved => failAnalysis(s"unresolved operator ${o.simpleString}")
case _ =>
}
-
- plan.foreach(_.setAnalyzed())
}
/**
@@ -612,8 +608,8 @@ trait CheckAnalysis extends PredicateHelper {
// allows to have correlation under it
// but must not host any outer references.
// Note:
- // Generator with join=false is treated as Category 4.
- case g: Generate if g.join =>
+ // Generator with requiredChildOutput.isEmpty is treated as Category 4.
+ case g: Generate if g.requiredChildOutput.nonEmpty =>
failOnInvalidOuterReference(g)
// Category 4: Any other operators not in the above 3 categories
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
index fd2ac78b25dbd..a8100b9b24aac 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
@@ -58,7 +58,7 @@ import org.apache.spark.sql.types._
* - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE
*/
// scalastyle:on
-object DecimalPrecision extends Rule[LogicalPlan] {
+object DecimalPrecision extends TypeCoercionRule {
import scala.math.{max, min}
private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
@@ -78,7 +78,7 @@ object DecimalPrecision extends Rule[LogicalPlan] {
PromotePrecision(Cast(e, dataType))
}
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformUp {
// fix decimal precision for expressions
case q => q.transformExpressionsUp(
decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 11538bd31b4fd..5ddb39822617d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -392,6 +392,7 @@ object FunctionRegistry {
expression[ToUnixTimestamp]("to_unix_timestamp"),
expression[ToUTCTimestamp]("to_utc_timestamp"),
expression[TruncDate]("trunc"),
+ expression[TruncTimestamp]("date_trunc"),
expression[UnixTimestamp]("unix_timestamp"),
expression[DayOfWeek]("dayofweek"),
expression[WeekOfYear]("weekofyear"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
index 7358f9ee36921..a214e59302cd9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
@@ -103,7 +103,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
})
)
- override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) =>
val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match {
case Some(tvf) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala
index 072dc954879ca..7a0aa08289efa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import scala.util.control.NonFatal
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, AttributeSet, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._
@@ -238,6 +238,8 @@ object StreamingJoinHelper extends PredicateHelper with Logging {
collect(child, !negate)
case CheckOverflow(child, _) =>
collect(child, negate)
+ case PromotePrecision(child) =>
+ collect(child, negate)
case Cast(child, dataType, _) =>
dataType match {
case _: NumericType | _: TimestampType => collect(child, negate)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala
index 860d20f897690..f9fd0df9e4010 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala
@@ -33,7 +33,7 @@ class SubstituteUnresolvedOrdinals(conf: SQLConf) extends Rule[LogicalPlan] {
case _ => false
}
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) =>
val newOrders = s.order.map {
case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 28be955e08a0d..e8669c4637d06 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -22,10 +22,12 @@ import javax.annotation.Nullable
import scala.annotation.tailrec
import scala.collection.mutable
+import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -44,19 +46,19 @@ import org.apache.spark.sql.types._
*/
object TypeCoercion {
- val typeCoercionRules =
- PropagateTypes ::
- InConversion ::
+ def typeCoercionRules(conf: SQLConf): List[Rule[LogicalPlan]] =
+ InConversion ::
WidenSetOperationTypes ::
PromoteStrings ::
DecimalPrecision ::
BooleanEquality ::
FunctionArgumentConversion ::
+ ConcatCoercion(conf) ::
+ EltCoercion(conf) ::
CaseWhenCoercion ::
IfCoercion ::
StackCoercion ::
Division ::
- PropagateTypes ::
ImplicitTypeCasts ::
DateTimeOperations ::
WindowFrameCoercion ::
@@ -220,38 +222,6 @@ object TypeCoercion {
private def haveSameType(exprs: Seq[Expression]): Boolean =
exprs.map(_.dataType).distinct.length == 1
- /**
- * Applies any changes to [[AttributeReference]] data types that are made by other rules to
- * instances higher in the query tree.
- */
- object PropagateTypes extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
-
- // No propagation required for leaf nodes.
- case q: LogicalPlan if q.children.isEmpty => q
-
- // Don't propagate types from unresolved children.
- case q: LogicalPlan if !q.childrenResolved => q
-
- case q: LogicalPlan =>
- val inputMap = q.inputSet.toSeq.map(a => (a.exprId, a)).toMap
- q transformExpressions {
- case a: AttributeReference =>
- inputMap.get(a.exprId) match {
- // This can happen when an Attribute reference is born in a non-leaf node, for
- // example due to a call to an external script like in the Transform operator.
- // TODO: Perhaps those should actually be aliases?
- case None => a
- // Leave the same if the dataTypes match.
- case Some(newType) if a.dataType == newType.dataType => a
- case Some(newType) =>
- logDebug(s"Promoting $a to $newType in ${q.simpleString}")
- newType
- }
- }
- }
- }
-
/**
* Widens numeric types and converts strings to numbers when appropriate.
*
@@ -280,9 +250,7 @@ object TypeCoercion {
*/
object WidenSetOperationTypes extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
- case p if p.analyzed => p
-
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case s @ SetOperation(left, right) if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
@@ -345,7 +313,7 @@ object TypeCoercion {
/**
* Promotes strings that appear in arithmetic expressions.
*/
- object PromoteStrings extends Rule[LogicalPlan] {
+ object PromoteStrings extends TypeCoercionRule {
private def castExpr(expr: Expression, targetType: DataType): Expression = {
(expr.dataType, targetType) match {
case (NullType, dt) => Literal.create(null, targetType)
@@ -354,13 +322,16 @@ object TypeCoercion {
}
}
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
+ override protected def coerceTypes(
+ plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
- case a @ BinaryArithmetic(left @ StringType(), right) =>
+ case a @ BinaryArithmetic(left @ StringType(), right)
+ if right.dataType != CalendarIntervalType =>
a.makeCopy(Array(Cast(left, DoubleType), right))
- case a @ BinaryArithmetic(left, right @ StringType()) =>
+ case a @ BinaryArithmetic(left, right @ StringType())
+ if left.dataType != CalendarIntervalType =>
a.makeCopy(Array(left, Cast(right, DoubleType)))
// For equality between string and timestamp we cast the string to a timestamp
@@ -403,7 +374,7 @@ object TypeCoercion {
* operator type is found the original expression will be returned and an
* Analysis Exception will be raised at the type checking phase.
*/
- object InConversion extends Rule[LogicalPlan] {
+ object InConversion extends TypeCoercionRule {
private def flattenExpr(expr: Expression): Seq[Expression] = {
expr match {
// Multi columns in IN clause is represented as a CreateNamedStruct.
@@ -413,7 +384,8 @@ object TypeCoercion {
}
}
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
+ override protected def coerceTypes(
+ plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
@@ -472,7 +444,7 @@ object TypeCoercion {
private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE)
private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO)
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
@@ -512,8 +484,9 @@ object TypeCoercion {
/**
* This ensure that the types for various functions are as expected.
*/
- object FunctionArgumentConversion extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
+ object FunctionArgumentConversion extends TypeCoercionRule {
+ override protected def coerceTypes(
+ plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
@@ -602,8 +575,9 @@ object TypeCoercion {
* Hive only performs integral division with the DIV operator. The arguments to / are always
* converted to fractional types.
*/
- object Division extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
+ object Division extends TypeCoercionRule {
+ override protected def coerceTypes(
+ plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who has not been resolved yet,
// as this is an extra rule which should be applied at last.
case e if !e.childrenResolved => e
@@ -624,8 +598,9 @@ object TypeCoercion {
/**
* Coerces the type of different branches of a CASE WHEN statement to a common type.
*/
- object CaseWhenCoercion extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
+ object CaseWhenCoercion extends TypeCoercionRule {
+ override protected def coerceTypes(
+ plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual =>
val maybeCommonType = findWiderCommonType(c.valueTypes)
maybeCommonType.map { commonType =>
@@ -654,8 +629,9 @@ object TypeCoercion {
/**
* Coerces the type of different branches of If statement to a common type.
*/
- object IfCoercion extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
+ object IfCoercion extends TypeCoercionRule {
+ override protected def coerceTypes(
+ plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case e if !e.childrenResolved => e
// Find tightest common type for If, if the true value and false value have different types.
case i @ If(pred, left, right) if left.dataType != right.dataType =>
@@ -674,8 +650,8 @@ object TypeCoercion {
/**
* Coerces NullTypes in the Stack expression to the column types of the corresponding positions.
*/
- object StackCoercion extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ object StackCoercion extends TypeCoercionRule {
+ override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case s @ Stack(children) if s.childrenResolved && s.hasFoldableNumRows =>
Stack(children.zipWithIndex.map {
// The first child is the number of rows for stack.
@@ -687,6 +663,56 @@ object TypeCoercion {
}
}
+ /**
+ * Coerces the types of [[Concat]] children to expected ones.
+ *
+ * If `spark.sql.function.concatBinaryAsString` is false and all children types are binary,
+ * the expected types are binary. Otherwise, the expected ones are strings.
+ */
+ case class ConcatCoercion(conf: SQLConf) extends TypeCoercionRule {
+
+ override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p =>
+ p transformExpressionsUp {
+ // Skip nodes if unresolved or empty children
+ case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c
+ case c @ Concat(children) if conf.concatBinaryAsString ||
+ !children.map(_.dataType).forall(_ == BinaryType) =>
+ val newChildren = c.children.map { e =>
+ ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e)
+ }
+ c.copy(children = newChildren)
+ }
+ }
+ }
+
+ /**
+ * Coerces the types of [[Elt]] children to expected ones.
+ *
+ * If `spark.sql.function.eltOutputAsString` is false and all children types are binary,
+ * the expected types are binary. Otherwise, the expected ones are strings.
+ */
+ case class EltCoercion(conf: SQLConf) extends TypeCoercionRule {
+
+ override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p =>
+ p transformExpressionsUp {
+ // Skip nodes if unresolved or not enough children
+ case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c
+ case c @ Elt(children) =>
+ val index = children.head
+ val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index)
+ val newInputs = if (conf.eltOutputAsString ||
+ !children.tail.map(_.dataType).forall(_ == BinaryType)) {
+ children.tail.map { e =>
+ ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e)
+ }
+ } else {
+ children.tail
+ }
+ c.copy(children = newIndex +: newInputs)
+ }
+ }
+ }
+
/**
* Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType
* to TimeAdd/TimeSub
@@ -695,7 +721,7 @@ object TypeCoercion {
private val acceptedTypes = Seq(DateType, TimestampType, StringType)
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
@@ -711,8 +737,9 @@ object TypeCoercion {
/**
* Casts types according to the expected input types for [[Expression]]s.
*/
- object ImplicitTypeCasts extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
+ object ImplicitTypeCasts extends TypeCoercionRule {
+ override protected def coerceTypes(
+ plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
@@ -828,8 +855,9 @@ object TypeCoercion {
/**
* Cast WindowFrame boundaries to the type they operate upon.
*/
- object WindowFrameCoercion extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
+ object WindowFrameCoercion extends TypeCoercionRule {
+ override protected def coerceTypes(
+ plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case s @ WindowSpecDefinition(_, Seq(order), SpecifiedWindowFrame(RangeFrame, lower, upper))
if order.resolved =>
s.copy(frameSpecification = SpecifiedWindowFrame(
@@ -850,3 +878,46 @@ object TypeCoercion {
}
}
}
+
+trait TypeCoercionRule extends Rule[LogicalPlan] with Logging {
+ /**
+ * Applies any changes to [[AttributeReference]] data types that are made by the transform method
+ * to instances higher in the query tree.
+ */
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val newPlan = coerceTypes(plan)
+ if (plan.fastEquals(newPlan)) {
+ plan
+ } else {
+ propagateTypes(newPlan)
+ }
+ }
+
+ protected def coerceTypes(plan: LogicalPlan): LogicalPlan
+
+ private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ // No propagation required for leaf nodes.
+ case q: LogicalPlan if q.children.isEmpty => q
+
+ // Don't propagate types from unresolved children.
+ case q: LogicalPlan if !q.childrenResolved => q
+
+ case q: LogicalPlan =>
+ val inputMap = q.inputSet.toSeq.map(a => (a.exprId, a)).toMap
+ q transformExpressions {
+ case a: AttributeReference =>
+ inputMap.get(a.exprId) match {
+ // This can happen when an Attribute reference is born in a non-leaf node, for
+ // example due to a call to an external script like in the Transform operator.
+ // TODO: Perhaps those should actually be aliases?
+ case None => a
+ // Leave the same if the dataTypes match.
+ case Some(newType) if a.dataType == newType.dataType => a
+ case Some(newType) =>
+ logDebug(
+ s"Promoting $a from ${a.dataType} to ${newType.dataType} in ${q.simpleString}")
+ newType
+ }
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index 04502d04d9509..b55043c270644 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, MonotonicallyIncreasingID}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, CurrentDate, CurrentTimestamp, MonotonicallyIncreasingID}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans._
@@ -339,6 +339,29 @@ object UnsupportedOperationChecker {
}
}
+ def checkForContinuous(plan: LogicalPlan, outputMode: OutputMode): Unit = {
+ checkForStreaming(plan, outputMode)
+
+ plan.foreachUp { implicit subPlan =>
+ subPlan match {
+ case (_: Project | _: Filter | _: MapElements | _: MapPartitions |
+ _: DeserializeToObject | _: SerializeFromObject) =>
+ case node if node.nodeName == "StreamingRelationV2" =>
+ case node =>
+ throwError(s"Continuous processing does not support ${node.nodeName} operations.")
+ }
+
+ subPlan.expressions.foreach { e =>
+ if (e.collectLeaves().exists {
+ case (_: CurrentTimestamp | _: CurrentDate) => true
+ case _ => false
+ }) {
+ throwError(s"Continuous processing does not support current time operations.")
+ }
+ }
+ }
+ }
+
private def throwErrorIf(
condition: Boolean,
msg: String)(implicit operator: LogicalPlan): Unit = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala
index a27aa845bf0ae..af1f9165b0044 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala
@@ -38,7 +38,7 @@ case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] {
}
override def apply(plan: LogicalPlan): LogicalPlan =
- plan.resolveExpressions(transformTimeZoneExprs)
+ plan.transformAllExpressions(transformTimeZoneExprs)
def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala
index ea46dd7282401..20216087b0158 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala
@@ -38,7 +38,7 @@ import org.apache.spark.sql.internal.SQLConf
* view resolution, in this way, we are able to get the correct view column ordering and
* omit the extra columns that we don't require);
* 1.2. Else set the child output attributes to `queryOutput`.
- * 2. Map the `queryQutput` to view output by index, if the corresponding attributes don't match,
+ * 2. Map the `queryOutput` to view output by index, if the corresponding attributes don't match,
* try to up cast and alias the attribute in `queryOutput` to the attribute in the view output.
* 3. Add a Project over the child, with the new output generated by the previous steps.
* If the view output doesn't have the same number of columns neither with the child output, nor
@@ -48,7 +48,7 @@ import org.apache.spark.sql.internal.SQLConf
* completely resolved during the batch of Resolution.
*/
case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport {
- override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case v @ View(desc, output, child) if child.resolved && output != child.output =>
val resolver = conf.resolver
val queryColumnNames = desc.viewQueryColumnNames
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 7c100afcd738f..59cb26d5e6c36 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -359,12 +359,12 @@ package object dsl {
def generate(
generator: Generator,
- join: Boolean = false,
+ unrequiredChildIndex: Seq[Int] = Nil,
outer: Boolean = false,
alias: Option[String] = None,
outputNames: Seq[String] = Nil): LogicalPlan =
- Generate(generator, join = join, outer = outer, alias,
- outputNames.map(UnresolvedAttribute(_)), logicalPlan)
+ Generate(generator, unrequiredChildIndex, outer,
+ alias, outputNames.map(UnresolvedAttribute(_)), logicalPlan)
def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan =
InsertIntoTable(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
index 65e497afc12cd..d848ba18356d3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
@@ -31,7 +31,7 @@ package org.apache.spark.sql.catalyst.expressions
* - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`.
* - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`.
*/
-object Canonicalize extends {
+object Canonicalize {
def execute(e: Expression): Expression = {
expressionReorder(ignoreNamesTypes(e))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 12baddf1bf7ac..a95ebe301b9d1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -181,7 +181,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(
- s"cannot cast ${child.dataType} to $dataType")
+ s"cannot cast ${child.dataType.simpleString} to ${dataType.simpleString}")
}
}
@@ -206,6 +206,84 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d)))
case TimestampType => buildCast[Long](_,
t => UTF8String.fromString(DateTimeUtils.timestampToString(t, timeZone)))
+ case ArrayType(et, _) =>
+ buildCast[ArrayData](_, array => {
+ val builder = new UTF8StringBuilder
+ builder.append("[")
+ if (array.numElements > 0) {
+ val toUTF8String = castToString(et)
+ if (!array.isNullAt(0)) {
+ builder.append(toUTF8String(array.get(0, et)).asInstanceOf[UTF8String])
+ }
+ var i = 1
+ while (i < array.numElements) {
+ builder.append(",")
+ if (!array.isNullAt(i)) {
+ builder.append(" ")
+ builder.append(toUTF8String(array.get(i, et)).asInstanceOf[UTF8String])
+ }
+ i += 1
+ }
+ }
+ builder.append("]")
+ builder.build()
+ })
+ case MapType(kt, vt, _) =>
+ buildCast[MapData](_, map => {
+ val builder = new UTF8StringBuilder
+ builder.append("[")
+ if (map.numElements > 0) {
+ val keyArray = map.keyArray()
+ val valueArray = map.valueArray()
+ val keyToUTF8String = castToString(kt)
+ val valueToUTF8String = castToString(vt)
+ builder.append(keyToUTF8String(keyArray.get(0, kt)).asInstanceOf[UTF8String])
+ builder.append(" ->")
+ if (!valueArray.isNullAt(0)) {
+ builder.append(" ")
+ builder.append(valueToUTF8String(valueArray.get(0, vt)).asInstanceOf[UTF8String])
+ }
+ var i = 1
+ while (i < map.numElements) {
+ builder.append(", ")
+ builder.append(keyToUTF8String(keyArray.get(i, kt)).asInstanceOf[UTF8String])
+ builder.append(" ->")
+ if (!valueArray.isNullAt(i)) {
+ builder.append(" ")
+ builder.append(valueToUTF8String(valueArray.get(i, vt))
+ .asInstanceOf[UTF8String])
+ }
+ i += 1
+ }
+ }
+ builder.append("]")
+ builder.build()
+ })
+ case StructType(fields) =>
+ buildCast[InternalRow](_, row => {
+ val builder = new UTF8StringBuilder
+ builder.append("[")
+ if (row.numFields > 0) {
+ val st = fields.map(_.dataType)
+ val toUTF8StringFuncs = st.map(castToString)
+ if (!row.isNullAt(0)) {
+ builder.append(toUTF8StringFuncs(0)(row.get(0, st(0))).asInstanceOf[UTF8String])
+ }
+ var i = 1
+ while (i < row.numFields) {
+ builder.append(",")
+ if (!row.isNullAt(i)) {
+ builder.append(" ")
+ builder.append(toUTF8StringFuncs(i)(row.get(i, st(i))).asInstanceOf[UTF8String])
+ }
+ i += 1
+ }
+ }
+ builder.append("]")
+ builder.build()
+ })
+ case udt: UserDefinedType[_] =>
+ buildCast[Any](_, o => UTF8String.fromString(udt.deserialize(o).toString))
case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString))
}
@@ -548,8 +626,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast))
}
- // three function arguments are: child.primitive, result.primitive and result.isNull
- // it returns the code snippets to be put in null safe evaluation region
+ // The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull`
+ // in parameter list, because the returned code will be put in null safe evaluation region.
private[this] type CastFunction = (String, String, String) => String
private[this] def nullSafeCastFunction(
@@ -584,19 +662,136 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
throw new SparkException(s"Cannot cast $from to $to.")
}
- // Since we need to cast child expressions recursively inside ComplexTypes, such as Map's
+ // Since we need to cast input expressions recursively inside ComplexTypes, such as Map's
// Key and Value, Struct's field, we need to name out all the variable names involved in a cast.
- private[this] def castCode(ctx: CodegenContext, childPrim: String, childNull: String,
- resultPrim: String, resultNull: String, resultType: DataType, cast: CastFunction): String = {
+ private[this] def castCode(ctx: CodegenContext, input: String, inputIsNull: String,
+ result: String, resultIsNull: String, resultType: DataType, cast: CastFunction): String = {
s"""
- boolean $resultNull = $childNull;
- ${ctx.javaType(resultType)} $resultPrim = ${ctx.defaultValue(resultType)};
- if (!$childNull) {
- ${cast(childPrim, resultPrim, resultNull)}
+ boolean $resultIsNull = $inputIsNull;
+ ${ctx.javaType(resultType)} $result = ${ctx.defaultValue(resultType)};
+ if (!$inputIsNull) {
+ ${cast(input, result, resultIsNull)}
}
"""
}
+ private def writeArrayToStringBuilder(
+ et: DataType,
+ array: String,
+ buffer: String,
+ ctx: CodegenContext): String = {
+ val elementToStringCode = castToStringCode(et, ctx)
+ val funcName = ctx.freshName("elementToString")
+ val elementToStringFunc = ctx.addNewFunction(funcName,
+ s"""
+ |private UTF8String $funcName(${ctx.javaType(et)} element) {
+ | UTF8String elementStr = null;
+ | ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)}
+ | return elementStr;
+ |}
+ """.stripMargin)
+
+ val loopIndex = ctx.freshName("loopIndex")
+ s"""
+ |$buffer.append("[");
+ |if ($array.numElements() > 0) {
+ | if (!$array.isNullAt(0)) {
+ | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, "0")}));
+ | }
+ | for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) {
+ | $buffer.append(",");
+ | if (!$array.isNullAt($loopIndex)) {
+ | $buffer.append(" ");
+ | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, loopIndex)}));
+ | }
+ | }
+ |}
+ |$buffer.append("]");
+ """.stripMargin
+ }
+
+ private def writeMapToStringBuilder(
+ kt: DataType,
+ vt: DataType,
+ map: String,
+ buffer: String,
+ ctx: CodegenContext): String = {
+
+ def dataToStringFunc(func: String, dataType: DataType) = {
+ val funcName = ctx.freshName(func)
+ val dataToStringCode = castToStringCode(dataType, ctx)
+ ctx.addNewFunction(funcName,
+ s"""
+ |private UTF8String $funcName(${ctx.javaType(dataType)} data) {
+ | UTF8String dataStr = null;
+ | ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)}
+ | return dataStr;
+ |}
+ """.stripMargin)
+ }
+
+ val keyToStringFunc = dataToStringFunc("keyToString", kt)
+ val valueToStringFunc = dataToStringFunc("valueToString", vt)
+ val loopIndex = ctx.freshName("loopIndex")
+ s"""
+ |$buffer.append("[");
+ |if ($map.numElements() > 0) {
+ | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, "0")}));
+ | $buffer.append(" ->");
+ | if (!$map.valueArray().isNullAt(0)) {
+ | $buffer.append(" ");
+ | $buffer.append($valueToStringFunc(${ctx.getValue(s"$map.valueArray()", vt, "0")}));
+ | }
+ | for (int $loopIndex = 1; $loopIndex < $map.numElements(); $loopIndex++) {
+ | $buffer.append(", ");
+ | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, loopIndex)}));
+ | $buffer.append(" ->");
+ | if (!$map.valueArray().isNullAt($loopIndex)) {
+ | $buffer.append(" ");
+ | $buffer.append($valueToStringFunc(
+ | ${ctx.getValue(s"$map.valueArray()", vt, loopIndex)}));
+ | }
+ | }
+ |}
+ |$buffer.append("]");
+ """.stripMargin
+ }
+
+ private def writeStructToStringBuilder(
+ st: Seq[DataType],
+ row: String,
+ buffer: String,
+ ctx: CodegenContext): String = {
+ val structToStringCode = st.zipWithIndex.map { case (ft, i) =>
+ val fieldToStringCode = castToStringCode(ft, ctx)
+ val field = ctx.freshName("field")
+ val fieldStr = ctx.freshName("fieldStr")
+ s"""
+ |${if (i != 0) s"""$buffer.append(",");""" else ""}
+ |if (!$row.isNullAt($i)) {
+ | ${if (i != 0) s"""$buffer.append(" ");""" else ""}
+ |
+ | // Append $i field into the string buffer
+ | ${ctx.javaType(ft)} $field = ${ctx.getValue(row, ft, s"$i")};
+ | UTF8String $fieldStr = null;
+ | ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)}
+ | $buffer.append($fieldStr);
+ |}
+ """.stripMargin
+ }
+
+ val writeStructCode = ctx.splitExpressions(
+ expressions = structToStringCode,
+ funcName = "fieldToString",
+ arguments = ("InternalRow", row) :: (classOf[UTF8StringBuilder].getName, buffer) :: Nil)
+
+ s"""
+ |$buffer.append("[");
+ |$writeStructCode
+ |$buffer.append("]");
+ """.stripMargin
+ }
+
private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = {
from match {
case BinaryType =>
@@ -605,9 +800,49 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString(
org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c));"""
case TimestampType =>
- val tz = ctx.addReferenceMinorObj(timeZone)
+ val tz = ctx.addReferenceObj("timeZone", timeZone)
(c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString(
org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));"""
+ case ArrayType(et, _) =>
+ (c, evPrim, evNull) => {
+ val buffer = ctx.freshName("buffer")
+ val bufferClass = classOf[UTF8StringBuilder].getName
+ val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx)
+ s"""
+ |$bufferClass $buffer = new $bufferClass();
+ |$writeArrayElemCode;
+ |$evPrim = $buffer.build();
+ """.stripMargin
+ }
+ case MapType(kt, vt, _) =>
+ (c, evPrim, evNull) => {
+ val buffer = ctx.freshName("buffer")
+ val bufferClass = classOf[UTF8StringBuilder].getName
+ val writeMapElemCode = writeMapToStringBuilder(kt, vt, c, buffer, ctx)
+ s"""
+ |$bufferClass $buffer = new $bufferClass();
+ |$writeMapElemCode;
+ |$evPrim = $buffer.build();
+ """.stripMargin
+ }
+ case StructType(fields) =>
+ (c, evPrim, evNull) => {
+ val row = ctx.freshName("row")
+ val buffer = ctx.freshName("buffer")
+ val bufferClass = classOf[UTF8StringBuilder].getName
+ val writeStructCode = writeStructToStringBuilder(fields.map(_.dataType), row, buffer, ctx)
+ s"""
+ |InternalRow $row = $c;
+ |$bufferClass $buffer = new $bufferClass();
+ |$writeStructCode
+ |$evPrim = $buffer.build();
+ """.stripMargin
+ }
+ case udt: UserDefinedType[_] =>
+ val udtRef = ctx.addReferenceObj("udt", udt)
+ (c, evPrim, evNull) => {
+ s"$evPrim = UTF8String.fromString($udtRef.deserialize($c).toString());"
+ }
case _ =>
(c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));"
}
@@ -633,7 +868,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
}
"""
case TimestampType =>
- val tz = ctx.addReferenceMinorObj(timeZone)
+ val tz = ctx.addReferenceObj("timeZone", timeZone)
(c, evPrim, evNull) =>
s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L, $tz);"
case _ =>
@@ -713,7 +948,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
from: DataType,
ctx: CodegenContext): CastFunction = from match {
case StringType =>
- val tz = ctx.addReferenceMinorObj(timeZone)
+ val tz = ctx.addReferenceObj("timeZone", timeZone)
val longOpt = ctx.freshName("longOpt")
(c, evPrim, evNull) =>
s"""
@@ -730,7 +965,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
case _: IntegralType =>
(c, evPrim, evNull) => s"$evPrim = ${longToTimeStampCode(c)};"
case DateType =>
- val tz = ctx.addReferenceMinorObj(timeZone)
+ val tz = ctx.addReferenceObj("timeZone", timeZone)
(c, evPrim, evNull) =>
s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c, $tz) * 1000;"
case DecimalType() =>
@@ -799,16 +1034,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType =>
- val wrapper = ctx.freshName("wrapper")
- ctx.addMutableState("UTF8String.IntWrapper", wrapper,
- s"$wrapper = new UTF8String.IntWrapper();")
+ val wrapper = ctx.freshName("intWrapper")
(c, evPrim, evNull) =>
s"""
+ UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
if ($c.toByte($wrapper)) {
$evPrim = (byte) $wrapper.value;
} else {
$evNull = true;
}
+ $wrapper = null;
"""
case BooleanType =>
(c, evPrim, evNull) => s"$evPrim = $c ? (byte) 1 : (byte) 0;"
@@ -826,16 +1061,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
from: DataType,
ctx: CodegenContext): CastFunction = from match {
case StringType =>
- val wrapper = ctx.freshName("wrapper")
- ctx.addMutableState("UTF8String.IntWrapper", wrapper,
- s"$wrapper = new UTF8String.IntWrapper();")
+ val wrapper = ctx.freshName("intWrapper")
(c, evPrim, evNull) =>
s"""
+ UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
if ($c.toShort($wrapper)) {
$evPrim = (short) $wrapper.value;
} else {
$evNull = true;
}
+ $wrapper = null;
"""
case BooleanType =>
(c, evPrim, evNull) => s"$evPrim = $c ? (short) 1 : (short) 0;"
@@ -851,16 +1086,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType =>
- val wrapper = ctx.freshName("wrapper")
- ctx.addMutableState("UTF8String.IntWrapper", wrapper,
- s"$wrapper = new UTF8String.IntWrapper();")
+ val wrapper = ctx.freshName("intWrapper")
(c, evPrim, evNull) =>
s"""
+ UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
if ($c.toInt($wrapper)) {
$evPrim = $wrapper.value;
} else {
$evNull = true;
}
+ $wrapper = null;
"""
case BooleanType =>
(c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
@@ -876,17 +1111,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType =>
- val wrapper = ctx.freshName("wrapper")
- ctx.addMutableState("UTF8String.LongWrapper", wrapper,
- s"$wrapper = new UTF8String.LongWrapper();")
+ val wrapper = ctx.freshName("longWrapper")
(c, evPrim, evNull) =>
s"""
+ UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper();
if ($c.toLong($wrapper)) {
$evPrim = $wrapper.value;
} else {
$evNull = true;
}
+ $wrapper = null;
"""
case BooleanType =>
(c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;"
@@ -1014,8 +1249,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx)
}
val rowClass = classOf[GenericInternalRow].getName
- val result = ctx.freshName("result")
- val tmpRow = ctx.freshName("tmpRow")
+ val tmpResult = ctx.freshName("tmpResult")
+ val tmpInput = ctx.freshName("tmpInput")
val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) =>
val fromFieldPrim = ctx.freshName("ffp")
@@ -1024,37 +1259,33 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val toFieldNull = ctx.freshName("tfn")
val fromType = ctx.javaType(from.fields(i).dataType)
s"""
- boolean $fromFieldNull = $tmpRow.isNullAt($i);
+ boolean $fromFieldNull = $tmpInput.isNullAt($i);
if ($fromFieldNull) {
- $result.setNullAt($i);
+ $tmpResult.setNullAt($i);
} else {
$fromType $fromFieldPrim =
- ${ctx.getValue(tmpRow, from.fields(i).dataType, i.toString)};
+ ${ctx.getValue(tmpInput, from.fields(i).dataType, i.toString)};
${castCode(ctx, fromFieldPrim,
fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)}
if ($toFieldNull) {
- $result.setNullAt($i);
+ $tmpResult.setNullAt($i);
} else {
- ${ctx.setColumn(result, to.fields(i).dataType, i, toFieldPrim)};
+ ${ctx.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)};
}
}
"""
}
- val fieldsEvalCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
- ctx.splitExpressions(
- expressions = fieldsEvalCode,
- funcName = "castStruct",
- arguments = ("InternalRow", tmpRow) :: (rowClass, result) :: Nil)
- } else {
- fieldsEvalCode.mkString("\n")
- }
+ val fieldsEvalCodes = ctx.splitExpressions(
+ expressions = fieldsEvalCode,
+ funcName = "castStruct",
+ arguments = ("InternalRow", tmpInput) :: (rowClass, tmpResult) :: Nil)
- (c, evPrim, evNull) =>
+ (input, result, resultIsNull) =>
s"""
- final $rowClass $result = new $rowClass(${fieldsCasts.length});
- final InternalRow $tmpRow = $c;
+ final $rowClass $tmpResult = new $rowClass(${fieldsCasts.length});
+ final InternalRow $tmpInput = $input;
$fieldsEvalCodes
- $evPrim = $result;
+ $result = $tmpResult;
"""
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 743782a6453e9..4568714933095 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -119,8 +119,7 @@ abstract class Expression extends TreeNode[Expression] {
// TODO: support whole stage codegen too
if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") {
- val globalIsNull = ctx.freshName("globalIsNull")
- ctx.addMutableState(ctx.JAVA_BOOLEAN, globalIsNull)
+ val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull")
val localIsNull = eval.isNull
eval.isNull = globalIsNull
s"$globalIsNull = $localIsNull;"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
index 821d784a01342..11fb579dfa88c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
@@ -65,10 +65,9 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val countTerm = ctx.freshName("count")
- val partitionMaskTerm = ctx.freshName("partitionMask")
- ctx.addMutableState(ctx.JAVA_LONG, countTerm)
- ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm)
+ val countTerm = ctx.addMutableState(ctx.JAVA_LONG, "count")
+ val partitionMaskTerm = "partitionMask"
+ ctx.addImmutableStateIfNotExists(ctx.JAVA_LONG, partitionMaskTerm)
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 179853032035e..388ef42883ad3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -76,11 +76,6 @@ case class ScalaUDF(
}.foreach(println)
*/
-
- // Accessors used in genCode
- def userDefinedFunc(): AnyRef = function
- def getChildren(): Seq[Expression] = children
-
private[this] val f = children.size match {
case 0 =>
val func = function.asInstanceOf[() => Any]
@@ -981,50 +976,19 @@ case class ScalaUDF(
}
// scalastyle:on line.size.limit
-
- // Generate codes used to convert the arguments to Scala type for user-defined functions
- private[this] def genCodeForConverter(ctx: CodegenContext, index: Int): String = {
- val converterClassName = classOf[Any => Any].getName
- val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$"
- val expressionClassName = classOf[Expression].getName
- val scalaUDFClassName = classOf[ScalaUDF].getName
-
- val converterTerm = ctx.freshName("converter")
- val expressionIdx = ctx.references.size - 1
- ctx.addMutableState(converterClassName, converterTerm,
- s"$converterTerm = ($converterClassName)$typeConvertersClassName" +
- s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" +
- s"references[$expressionIdx]).getChildren().apply($index))).dataType());")
- converterTerm
- }
-
override def doGenCode(
ctx: CodegenContext,
ev: ExprCode): ExprCode = {
-
- val scalaUDF = ctx.addReferenceObj("scalaUDF", this)
val converterClassName = classOf[Any => Any].getName
- val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$"
-
- // Generate codes used to convert the returned value of user-defined functions to Catalyst type
- val catalystConverterTerm = ctx.freshName("catalystConverter")
- ctx.addMutableState(converterClassName, catalystConverterTerm,
- s"$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" +
- s".createToCatalystConverter($scalaUDF.dataType());")
+ // The type converters for inputs and the result.
+ val converters: Array[Any => Any] = children.map { c =>
+ CatalystTypeConverters.createToScalaConverter(c.dataType)
+ }.toArray :+ CatalystTypeConverters.createToCatalystConverter(dataType)
+ val convertersTerm = ctx.addReferenceObj("converters", converters, s"$converterClassName[]")
+ val errorMsgTerm = ctx.addReferenceObj("errMsg", udfErrorMessage)
val resultTerm = ctx.freshName("result")
- // This must be called before children expressions' codegen
- // because ctx.references is used in genCodeForConverter
- val converterTerms = children.indices.map(genCodeForConverter(ctx, _))
-
- // Initialize user-defined function
- val funcClassName = s"scala.Function${children.size}"
-
- val funcTerm = ctx.freshName("udf")
- ctx.addMutableState(funcClassName, funcTerm,
- s"$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();")
-
// codegen for children expressions
val evals = children.map(_.genCode(ctx))
@@ -1032,38 +996,42 @@ case class ScalaUDF(
// We need to get the boxedType of dataType's javaType here. Because for the dataType
// such as IntegerType, its javaType is `int` and the returned type of user-defined
// function is Object. Trying to convert an Object to `int` will cause casting exception.
- val evalCode = evals.map(_.code).mkString
- val (converters, funcArguments) = converterTerms.zipWithIndex.map { case (converter, i) =>
- val eval = evals(i)
+ val evalCode = evals.map(_.code).mkString("\n")
+ val (funcArgs, initArgs) = evals.zipWithIndex.map { case (eval, i) =>
val argTerm = ctx.freshName("arg")
- val convert = s"Object $argTerm = ${eval.isNull} ? null : $converter.apply(${eval.value});"
- (convert, argTerm)
+ val convert = s"$convertersTerm[$i].apply(${eval.value})"
+ val initArg = s"Object $argTerm = ${eval.isNull} ? null : $convert;"
+ (argTerm, initArg)
}.unzip
- val getFuncResult = s"$funcTerm.apply(${funcArguments.mkString(", ")})"
+ val udf = ctx.addReferenceObj("udf", function, s"scala.Function${children.length}")
+ val getFuncResult = s"$udf.apply(${funcArgs.mkString(", ")})"
+ val resultConverter = s"$convertersTerm[${children.length}]"
val callFunc =
s"""
- ${ctx.boxedType(dataType)} $resultTerm = null;
- try {
- $resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult);
- } catch (Exception e) {
- throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e);
- }
- """
-
- ev.copy(code = s"""
- $evalCode
- ${converters.mkString("\n")}
- $callFunc
+ |${ctx.boxedType(dataType)} $resultTerm = null;
+ |try {
+ | $resultTerm = (${ctx.boxedType(dataType)})$resultConverter.apply($getFuncResult);
+ |} catch (Exception e) {
+ | throw new org.apache.spark.SparkException($errorMsgTerm, e);
+ |}
+ """.stripMargin
- boolean ${ev.isNull} = $resultTerm == null;
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- if (!${ev.isNull}) {
- ${ev.value} = $resultTerm;
- }""")
+ ev.copy(code =
+ s"""
+ |$evalCode
+ |${initArgs.mkString("\n")}
+ |$callFunc
+ |
+ |boolean ${ev.isNull} = $resultTerm == null;
+ |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ |if (!${ev.isNull}) {
+ | ${ev.value} = $resultTerm;
+ |}
+ """.stripMargin)
}
- private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType)
+ private[this] val resultConverter = CatalystTypeConverters.createToCatalystConverter(dataType)
lazy val udfErrorMessage = {
val funcCls = function.getClass.getSimpleName
@@ -1079,6 +1047,6 @@ case class ScalaUDF(
throw new SparkException(udfErrorMessage, e)
}
- converter(result)
+ resultConverter(result)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
index 4fa18d6b3209b..a160b9b275290 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
@@ -43,8 +43,8 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
override protected def evalInternal(input: InternalRow): Int = partitionId
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val idTerm = ctx.freshName("partitionId")
- ctx.addMutableState(ctx.JAVA_INT, idTerm)
+ val idTerm = "partitionId"
+ ctx.addImmutableStateIfNotExists(ctx.JAVA_INT, idTerm)
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
index 7facb9dad9a76..a45854a3b5146 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
@@ -132,7 +132,7 @@ case class ApproximatePercentile(
case TimestampType => value.asInstanceOf[Long].toDouble
case n: NumericType => n.numeric.toDouble(value.asInstanceOf[n.InternalType])
case other: DataType =>
- throw new UnsupportedOperationException(s"Unexpected data type $other")
+ throw new UnsupportedOperationException(s"Unexpected data type ${other.simpleString}")
}
buffer.add(doubleValue)
}
@@ -157,7 +157,7 @@ case class ApproximatePercentile(
case DoubleType => doubleResult
case _: DecimalType => doubleResult.map(Decimal(_))
case other: DataType =>
- throw new UnsupportedOperationException(s"Unexpected data type $other")
+ throw new UnsupportedOperationException(s"Unexpected data type ${other.simpleString}")
}
if (result.length == 0) {
null
@@ -296,8 +296,8 @@ object ApproximatePercentile {
Ints.BYTES + Doubles.BYTES + Longs.BYTES +
// length of summary.sampled
Ints.BYTES +
- // summary.sampled, Array[Stat(value: Double, g: Int, delta: Int)]
- summaries.sampled.length * (Doubles.BYTES + Ints.BYTES + Ints.BYTES)
+ // summary.sampled, Array[Stat(value: Double, g: Long, delta: Long)]
+ summaries.sampled.length * (Doubles.BYTES + Longs.BYTES + Longs.BYTES)
}
final def serialize(obj: PercentileDigest): Array[Byte] = {
@@ -312,8 +312,8 @@ object ApproximatePercentile {
while (i < summary.sampled.length) {
val stat = summary.sampled(i)
buffer.putDouble(stat.value)
- buffer.putInt(stat.g)
- buffer.putInt(stat.delta)
+ buffer.putLong(stat.g)
+ buffer.putLong(stat.delta)
i += 1
}
buffer.array()
@@ -330,8 +330,8 @@ object ApproximatePercentile {
var i = 0
while (i < sampledLength) {
val value = buffer.getDouble()
- val g = buffer.getInt()
- val delta = buffer.getInt()
+ val g = buffer.getLong()
+ val delta = buffer.getLong()
sampled(i) = Stats(value, g, delta)
i += 1
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index d98f7b3d8efe6..8bb14598a6d7b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -602,23 +602,36 @@ case class Least(children: Seq[Expression]) extends Expression {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evalChildren = children.map(_.genCode(ctx))
- ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
- ctx.addMutableState(ctx.javaType(dataType), ev.value)
- def updateEval(eval: ExprCode): String = {
+ ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
+ val evals = evalChildren.map(eval =>
s"""
- ${eval.code}
- if (!${eval.isNull} && (${ev.isNull} ||
- ${ctx.genGreater(dataType, ev.value, eval.value)})) {
- ${ev.isNull} = false;
- ${ev.value} = ${eval.value};
- }
- """
- }
- val codes = ctx.splitExpressions(evalChildren.map(updateEval))
- ev.copy(code = s"""
- ${ev.isNull} = true;
- ${ev.value} = ${ctx.defaultValue(dataType)};
- $codes""")
+ |${eval.code}
+ |if (!${eval.isNull} && (${ev.isNull} ||
+ | ${ctx.genGreater(dataType, ev.value, eval.value)})) {
+ | ${ev.isNull} = false;
+ | ${ev.value} = ${eval.value};
+ |}
+ """.stripMargin
+ )
+
+ val resultType = ctx.javaType(dataType)
+ val codes = ctx.splitExpressionsWithCurrentInputs(
+ expressions = evals,
+ funcName = "least",
+ extraArguments = Seq(resultType -> ev.value),
+ returnType = resultType,
+ makeSplitFunction = body =>
+ s"""
+ |$body
+ |return ${ev.value};
+ """.stripMargin,
+ foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
+ ev.copy(code =
+ s"""
+ |${ev.isNull} = true;
+ |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ |$codes
+ """.stripMargin)
}
}
@@ -668,22 +681,35 @@ case class Greatest(children: Seq[Expression]) extends Expression {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evalChildren = children.map(_.genCode(ctx))
- ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
- ctx.addMutableState(ctx.javaType(dataType), ev.value)
- def updateEval(eval: ExprCode): String = {
+ ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
+ val evals = evalChildren.map(eval =>
s"""
- ${eval.code}
- if (!${eval.isNull} && (${ev.isNull} ||
- ${ctx.genGreater(dataType, eval.value, ev.value)})) {
- ${ev.isNull} = false;
- ${ev.value} = ${eval.value};
- }
- """
- }
- val codes = ctx.splitExpressions(evalChildren.map(updateEval))
- ev.copy(code = s"""
- ${ev.isNull} = true;
- ${ev.value} = ${ctx.defaultValue(dataType)};
- $codes""")
+ |${eval.code}
+ |if (!${eval.isNull} && (${ev.isNull} ||
+ | ${ctx.genGreater(dataType, eval.value, ev.value)})) {
+ | ${ev.isNull} = false;
+ | ${ev.value} = ${eval.value};
+ |}
+ """.stripMargin
+ )
+
+ val resultType = ctx.javaType(dataType)
+ val codes = ctx.splitExpressionsWithCurrentInputs(
+ expressions = evals,
+ funcName = "greatest",
+ extraArguments = Seq(resultType -> ev.value),
+ returnType = resultType,
+ makeSplitFunction = body =>
+ s"""
+ |$body
+ |return ${ev.value};
+ """.stripMargin,
+ foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
+ ev.copy(code =
+ s"""
+ |${ev.isNull} = true;
+ |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ |$codes
+ """.stripMargin)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 668c816b3fd8d..2c714c228e6c9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -29,7 +29,7 @@ import scala.util.control.NonFatal
import com.google.common.cache.{CacheBuilder, CacheLoader}
import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException}
import org.codehaus.commons.compiler.CompileException
-import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, JaninoRuntimeException, SimpleCompiler}
+import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, InternalCompilerException, SimpleCompiler}
import org.codehaus.janino.util.ClassFile
import org.apache.spark.{SparkEnv, TaskContext, TaskKilledException}
@@ -109,28 +109,14 @@ class CodegenContext {
*
* Returns the code to access it.
*
- * This is for minor objects not to store the object into field but refer it from the references
- * field at the time of use because number of fields in class is limited so we should reduce it.
+ * This does not to store the object into field but refer it from the references field at the
+ * time of use because number of fields in class is limited so we should reduce it.
*/
- def addReferenceMinorObj(obj: Any, className: String = null): String = {
+ def addReferenceObj(objName: String, obj: Any, className: String = null): String = {
val idx = references.length
references += obj
val clsName = Option(className).getOrElse(obj.getClass.getName)
- s"(($clsName) references[$idx])"
- }
-
- /**
- * Add an object to `references`, create a class member to access it.
- *
- * Returns the name of class member.
- */
- def addReferenceObj(name: String, obj: Any, className: String = null): String = {
- val term = freshName(name)
- val idx = references.length
- references += obj
- val clsName = Option(className).getOrElse(obj.getClass.getName)
- addMutableState(clsName, term, s"$term = ($clsName) references[$idx];")
- term
+ s"(($clsName) references[$idx] /* $objName */)"
}
/**
@@ -142,7 +128,7 @@ class CodegenContext {
* `currentVars` to null, or set `currentVars(i)` to null for certain columns, before calling
* `Expression.genCode`.
*/
- final var INPUT_ROW = "i"
+ var INPUT_ROW = "i"
/**
* Holding a list of generated columns as input of current operator, will be used by
@@ -151,22 +137,83 @@ class CodegenContext {
var currentVars: Seq[ExprCode] = null
/**
- * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a
- * 3-tuple: java type, variable name, code to init it.
- * As an example, ("int", "count", "count = 0;") will produce code:
+ * Holding expressions' inlined mutable states like `MonotonicallyIncreasingID.count` as a
+ * 2-tuple: java type, variable name.
+ * As an example, ("int", "count") will produce code:
* {{{
* private int count;
* }}}
- * as a member variable, and add
- * {{{
- * count = 0;
- * }}}
- * to the constructor.
+ * as a member variable
*
* They will be kept as member variables in generated classes like `SpecificProjection`.
+ *
+ * Exposed for tests only.
+ */
+ private[catalyst] val inlinedMutableStates: mutable.ArrayBuffer[(String, String)] =
+ mutable.ArrayBuffer.empty[(String, String)]
+
+ /**
+ * The mapping between mutable state types and corrseponding compacted arrays.
+ * The keys are java type string. The values are [[MutableStateArrays]] which encapsulates
+ * the compacted arrays for the mutable states with the same java type.
+ *
+ * Exposed for tests only.
+ */
+ private[catalyst] val arrayCompactedMutableStates: mutable.Map[String, MutableStateArrays] =
+ mutable.Map.empty[String, MutableStateArrays]
+
+ // An array holds the code that will initialize each state
+ // Exposed for tests only.
+ private[catalyst] val mutableStateInitCode: mutable.ArrayBuffer[String] =
+ mutable.ArrayBuffer.empty[String]
+
+ // Tracks the names of all the mutable states.
+ private val mutableStateNames: mutable.HashSet[String] = mutable.HashSet.empty
+
+ /**
+ * This class holds a set of names of mutableStateArrays that is used for compacting mutable
+ * states for a certain type, and holds the next available slot of the current compacted array.
*/
- val mutableStates: mutable.ArrayBuffer[(String, String, String)] =
- mutable.ArrayBuffer.empty[(String, String, String)]
+ class MutableStateArrays {
+ val arrayNames = mutable.ListBuffer.empty[String]
+ createNewArray()
+
+ private[this] var currentIndex = 0
+
+ private def createNewArray() = {
+ val newArrayName = freshName("mutableStateArray")
+ mutableStateNames += newArrayName
+ arrayNames.append(newArrayName)
+ }
+
+ def getCurrentIndex: Int = currentIndex
+
+ /**
+ * Returns the reference of next available slot in current compacted array. The size of each
+ * compacted array is controlled by the constant `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`.
+ * Once reaching the threshold, new compacted array is created.
+ */
+ def getNextSlot(): String = {
+ if (currentIndex < CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT) {
+ val res = s"${arrayNames.last}[$currentIndex]"
+ currentIndex += 1
+ res
+ } else {
+ createNewArray()
+ currentIndex = 1
+ s"${arrayNames.last}[0]"
+ }
+ }
+
+ }
+
+ /**
+ * A map containing the mutable states which have been defined so far using
+ * `addImmutableStateIfNotExists`. Each entry contains the name of the mutable state as key and
+ * its Java type and init code as value.
+ */
+ private val immutableStates: mutable.Map[String, (String, String)] =
+ mutable.Map.empty[String, (String, String)]
/**
* Add a mutable state as a field to the generated class. c.f. the comments above.
@@ -177,11 +224,85 @@ class CodegenContext {
* the list of default imports available.
* Also, generic type arguments are accepted but ignored.
* @param variableName Name of the field.
- * @param initCode The statement(s) to put into the init() method to initialize this field.
+ * @param initFunc Function includes statement(s) to put into the init() method to initialize
+ * this field. The argument is the name of the mutable state variable.
* If left blank, the field will be default-initialized.
+ * @param forceInline whether the declaration and initialization code may be inlined rather than
+ * compacted. Please set `true` into forceInline for one of the followings:
+ * 1. use the original name of the status
+ * 2. expect to non-frequently generate the status
+ * (e.g. not much sort operators in one stage)
+ * @param useFreshName If this is false and the mutable state ends up inlining in the outer
+ * class, the name is not changed
+ * @return the name of the mutable state variable, which is the original name or fresh name if
+ * the variable is inlined to the outer class, or an array access if the variable is to
+ * be stored in an array of variables of the same type.
+ * A variable will be inlined into the outer class when one of the following conditions
+ * are satisfied:
+ * 1. forceInline is true
+ * 2. its type is primitive type and the total number of the inlined mutable variables
+ * is less than `CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD`
+ * 3. its type is multi-dimensional array
+ * When a variable is compacted into an array, the max size of the array for compaction
+ * is given by `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`.
*/
- def addMutableState(javaType: String, variableName: String, initCode: String = ""): Unit = {
- mutableStates += ((javaType, variableName, initCode))
+ def addMutableState(
+ javaType: String,
+ variableName: String,
+ initFunc: String => String = _ => "",
+ forceInline: Boolean = false,
+ useFreshName: Boolean = true): String = {
+
+ // want to put a primitive type variable at outerClass for performance
+ val canInlinePrimitive = isPrimitiveType(javaType) &&
+ (inlinedMutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD)
+ if (forceInline || canInlinePrimitive || javaType.contains("[][]")) {
+ val varName = if (useFreshName) freshName(variableName) else variableName
+ val initCode = initFunc(varName)
+ inlinedMutableStates += ((javaType, varName))
+ mutableStateInitCode += initCode
+ mutableStateNames += varName
+ varName
+ } else {
+ val arrays = arrayCompactedMutableStates.getOrElseUpdate(javaType, new MutableStateArrays)
+ val element = arrays.getNextSlot()
+
+ val initCode = initFunc(element)
+ mutableStateInitCode += initCode
+ element
+ }
+ }
+
+ /**
+ * Add an immutable state as a field to the generated class only if it does not exist yet a field
+ * with that name. This helps reducing the number of the generated class' fields, since the same
+ * variable can be reused by many functions.
+ *
+ * Even though the added variables are not declared as final, they should never be reassigned in
+ * the generated code to prevent errors and unexpected behaviors.
+ *
+ * Internally, this method calls `addMutableState`.
+ *
+ * @param javaType Java type of the field.
+ * @param variableName Name of the field.
+ * @param initFunc Function includes statement(s) to put into the init() method to initialize
+ * this field. The argument is the name of the mutable state variable.
+ */
+ def addImmutableStateIfNotExists(
+ javaType: String,
+ variableName: String,
+ initFunc: String => String = _ => ""): Unit = {
+ val existingImmutableState = immutableStates.get(variableName)
+ if (existingImmutableState.isEmpty) {
+ addMutableState(javaType, variableName, initFunc, useFreshName = false, forceInline = true)
+ immutableStates(variableName) = (javaType, initFunc(variableName))
+ } else {
+ val (prevJavaType, prevInitCode) = existingImmutableState.get
+ assert(prevJavaType == javaType, s"$variableName has already been defined with type " +
+ s"$prevJavaType and now it is tried to define again with type $javaType.")
+ assert(prevInitCode == initFunc(variableName), s"$variableName has already been defined " +
+ s"with different initialization statements.")
+ }
}
/**
@@ -190,8 +311,7 @@ class CodegenContext {
* data types like: UTF8String, ArrayData, MapData & InternalRow.
*/
def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = {
- val value = freshName(variableName)
- addMutableState(javaType(dataType), value, "")
+ val value = addMutableState(javaType(dataType), variableName)
val code = dataType match {
case StringType => s"$value = $initCode.clone();"
case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();"
@@ -203,15 +323,37 @@ class CodegenContext {
def declareMutableStates(): String = {
// It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in
// `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones.
- mutableStates.distinct.map { case (javaType, variableName, _) =>
+ val inlinedStates = inlinedMutableStates.distinct.map { case (javaType, variableName) =>
s"private $javaType $variableName;"
- }.mkString("\n")
+ }
+
+ val arrayStates = arrayCompactedMutableStates.flatMap { case (javaType, mutableStateArrays) =>
+ val numArrays = mutableStateArrays.arrayNames.size
+ mutableStateArrays.arrayNames.zipWithIndex.map { case (arrayName, index) =>
+ val length = if (index + 1 == numArrays) {
+ mutableStateArrays.getCurrentIndex
+ } else {
+ CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT
+ }
+ if (javaType.contains("[]")) {
+ // initializer had an one-dimensional array variable
+ val baseType = javaType.substring(0, javaType.length - 2)
+ s"private $javaType[] $arrayName = new $baseType[$length][];"
+ } else {
+ // initializer had a scalar variable
+ s"private $javaType[] $arrayName = new $javaType[$length];"
+ }
+ }
+ }
+
+ (inlinedStates ++ arrayStates).mkString("\n")
}
def initMutableStates(): String = {
// It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in
// `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones.
- val initCodes = mutableStates.distinct.map(_._3 + "\n")
+ val initCodes = mutableStateInitCode.distinct.map(_ + "\n")
+
// The generated initialization code may exceed 64kb function size limit in JVM if there are too
// many mutable states, so split it into multiple functions.
splitExpressions(expressions = initCodes, funcName = "init", arguments = Nil)
@@ -781,23 +923,45 @@ class CodegenContext {
* beyond 1000kb, we declare a private, inner sub-class, and the function is inlined to it
* instead, because classes have a constant pool limit of 65,536 named values.
*
- * Note that we will extract the current inputs of this context and pass them to the generated
- * functions. The input is `INPUT_ROW` for normal codegen path, and `currentVars` for whole
- * stage codegen path. Whole stage codegen path is not supported yet.
+ * Note that different from `splitExpressions`, we will extract the current inputs of this
+ * context and pass them to the generated functions. The input is `INPUT_ROW` for normal codegen
+ * path, and `currentVars` for whole stage codegen path. Whole stage codegen path is not
+ * supported yet.
*
* @param expressions the codes to evaluate expressions.
+ * @param funcName the split function name base.
+ * @param extraArguments the list of (type, name) of the arguments of the split function,
+ * except for the current inputs like `ctx.INPUT_ROW`.
+ * @param returnType the return type of the split function.
+ * @param makeSplitFunction makes split function body, e.g. add preparation or cleanup.
+ * @param foldFunctions folds the split function calls.
*/
- def splitExpressions(expressions: Seq[String]): String = {
+ def splitExpressionsWithCurrentInputs(
+ expressions: Seq[String],
+ funcName: String = "apply",
+ extraArguments: Seq[(String, String)] = Nil,
+ returnType: String = "void",
+ makeSplitFunction: String => String = identity,
+ foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): String = {
// TODO: support whole stage codegen
if (INPUT_ROW == null || currentVars != null) {
- return expressions.mkString("\n")
+ expressions.mkString("\n")
+ } else {
+ splitExpressions(
+ expressions,
+ funcName,
+ ("InternalRow", INPUT_ROW) +: extraArguments,
+ returnType,
+ makeSplitFunction,
+ foldFunctions)
}
- splitExpressions(expressions, funcName = "apply", arguments = ("InternalRow", INPUT_ROW) :: Nil)
}
/**
* Splits the generated code of expressions into multiple functions, because function has
- * 64kb code size limit in JVM
+ * 64kb code size limit in JVM. If the class to which the function would be inlined would grow
+ * beyond 1000kb, we declare a private, inner sub-class, and the function is inlined to it
+ * instead, because classes have a constant pool limit of 65,536 named values.
*
* @param expressions the codes to evaluate expressions.
* @param funcName the split function name base.
@@ -819,6 +983,15 @@ class CodegenContext {
// inline execution if only one block
blocks.head
} else {
+ if (Utils.isTesting) {
+ // Passing global variables to the split method is dangerous, as any mutating to it is
+ // ignored and may lead to unexpected behavior.
+ arguments.foreach { case (_, name) =>
+ assert(!mutableStateNames.contains(name),
+ s"split function argument $name cannot be a global variable.")
+ }
+ }
+
val func = freshName(funcName)
val argString = arguments.map { case (t, name) => s"$t $name" }.mkString(", ")
val functions = blocks.zipWithIndex.map { case (body, i) =>
@@ -854,7 +1027,7 @@ class CodegenContext {
*
* @param expressions the codes to evaluate expressions.
*/
- def buildCodeBlocks(expressions: Seq[String]): Seq[String] = {
+ private def buildCodeBlocks(expressions: Seq[String]): Seq[String] = {
val blocks = new ArrayBuffer[String]()
val blockBuilder = new StringBuilder()
var length = 0
@@ -1003,9 +1176,9 @@ class CodegenContext {
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
commonExprs.foreach { e =>
val expr = e.head
- val fnName = freshName("evalExpr")
- val isNull = s"${fnName}IsNull"
- val value = s"${fnName}Value"
+ val fnName = freshName("subExpr")
+ val isNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull")
+ val value = addMutableState(javaType(expr.dataType), "subExprValue")
// Generate the code for this expression tree and wrap it in a function.
val eval = expr.genCode(this)
@@ -1031,9 +1204,6 @@ class CodegenContext {
// 2. Less code.
// Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with
// at least two nodes) as the cost of doing it is expected to be low.
- addMutableState(JAVA_BOOLEAN, isNull, s"$isNull = false;")
- addMutableState(javaType(expr.dataType), value,
- s"$value = ${defaultValue(expr.dataType)};")
subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
val state = SubExprEliminationState(isNull, value)
@@ -1157,6 +1327,15 @@ object CodeGenerator extends Logging {
// class.
val GENERATED_CLASS_SIZE_THRESHOLD = 1000000
+ // This is the threshold for the number of global variables, whose types are primitive type or
+ // complex type (e.g. more than one-dimensional array), that will be placed at the outer class
+ val OUTER_CLASS_VARIABLES_THRESHOLD = 10000
+
+ // This is the maximum number of array elements to keep global variables in one Java array
+ // 32767 is the maximum integer value that does not require a constant pool entry in a Java
+ // bytecode instruction
+ val MUTABLESTATEARRAY_SIZE_LIMIT = 32768
+
/**
* Compile the Java source code into a Java class, using Janino.
*
@@ -1218,12 +1397,12 @@ object CodeGenerator extends Logging {
evaluator.cook("generated.java", code.body)
updateAndGetCompilationStats(evaluator)
} catch {
- case e: JaninoRuntimeException =>
+ case e: InternalCompilerException =>
val msg = s"failed to compile: $e"
logError(msg, e)
val maxLines = SQLConf.get.loggingMaxLinesForCodegen
logInfo(s"\n${CodeFormatter.format(code, maxLines)}")
- throw new JaninoRuntimeException(msg, e)
+ throw new InternalCompilerException(msg, e)
case e: CompileException =>
val msg = s"failed to compile: $e"
logError(msg, e)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 5fdbda51b4ad1..b53c0087e7e2d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -57,42 +57,38 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
case _ => true
}.unzip
val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination)
- val projectionCodes = exprVals.zip(index).map {
+
+ // 4-tuples: (code for projection, isNull variable name, value variable name, column index)
+ val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map {
case (ev, i) =>
val e = expressions(i)
+ val value = ctx.addMutableState(ctx.javaType(e.dataType), "value")
if (e.nullable) {
- val isNull = s"isNull_$i"
- val value = s"value_$i"
- ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull, s"$isNull = true;")
- ctx.addMutableState(ctx.javaType(e.dataType), value,
- s"$value = ${ctx.defaultValue(e.dataType)};")
- s"""
- ${ev.code}
- $isNull = ${ev.isNull};
- $value = ${ev.value};
- """
+ val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "isNull")
+ (s"""
+ |${ev.code}
+ |$isNull = ${ev.isNull};
+ |$value = ${ev.value};
+ """.stripMargin, isNull, value, i)
} else {
- val value = s"value_$i"
- ctx.addMutableState(ctx.javaType(e.dataType), value,
- s"$value = ${ctx.defaultValue(e.dataType)};")
- s"""
- ${ev.code}
- $value = ${ev.value};
- """
+ (s"""
+ |${ev.code}
+ |$value = ${ev.value};
+ """.stripMargin, ev.isNull, value, i)
}
}
// Evaluate all the subexpressions.
val evalSubexpr = ctx.subexprFunctions.mkString("\n")
- val updates = validExpr.zip(index).map {
- case (e, i) =>
- val ev = ExprCode("", s"isNull_$i", s"value_$i")
+ val updates = validExpr.zip(projectionCodes).map {
+ case (e, (_, isNull, value, i)) =>
+ val ev = ExprCode("", isNull, value)
ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
}
- val allProjections = ctx.splitExpressions(projectionCodes)
- val allUpdates = ctx.splitExpressions(updates)
+ val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1))
+ val allUpdates = ctx.splitExpressionsWithCurrentInputs(updates)
val codeBody = s"""
public java.lang.Object generate(Object[] references) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index 5d35cce1a91cb..3dcbb518ba42a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -49,8 +49,6 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val tmpInput = ctx.freshName("tmpInput")
val output = ctx.freshName("safeRow")
val values = ctx.freshName("values")
- // These expressions could be split into multiple functions
- ctx.addMutableState("Object[]", values, s"$values = null;")
val rowClass = classOf[GenericInternalRow].getName
@@ -66,15 +64,15 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val allFields = ctx.splitExpressions(
expressions = fieldWriters,
funcName = "writeFields",
- arguments = Seq("InternalRow" -> tmpInput)
+ arguments = Seq("InternalRow" -> tmpInput, "Object[]" -> values)
)
- val code = s"""
- final InternalRow $tmpInput = $input;
- $values = new Object[${schema.length}];
- $allFields
- final InternalRow $output = new $rowClass($values);
- $values = null;
- """
+ val code =
+ s"""
+ |final InternalRow $tmpInput = $input;
+ |final Object[] $values = new Object[${schema.length}];
+ |$allFields
+ |final InternalRow $output = new $rowClass($values);
+ """.stripMargin
ExprCode(code, "false", output)
}
@@ -159,7 +157,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
}
"""
}
- val allExpressions = ctx.splitExpressions(expressionCodes)
+ val allExpressions = ctx.splitExpressionsWithCurrentInputs(expressionCodes)
val codeBody = s"""
public java.lang.Object generate(Object[] references) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index b022457865d50..36ffa8dcdd2b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -73,9 +73,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
bufferHolder: String,
isTopLevel: Boolean = false): String = {
val rowWriterClass = classOf[UnsafeRowWriter].getName
- val rowWriter = ctx.freshName("rowWriter")
- ctx.addMutableState(rowWriterClass, rowWriter,
- s"$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});")
+ val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter",
+ v => s"$v = new $rowWriterClass($bufferHolder, ${inputs.length});")
val resetWriter = if (isTopLevel) {
// For top level row writer, it always writes to the beginning of the global buffer holder,
@@ -186,9 +185,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
val tmpInput = ctx.freshName("tmpInput")
val arrayWriterClass = classOf[UnsafeArrayWriter].getName
- val arrayWriter = ctx.freshName("arrayWriter")
- ctx.addMutableState(arrayWriterClass, arrayWriter,
- s"$arrayWriter = new $arrayWriterClass();")
+ val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter",
+ v => s"$v = new $arrayWriterClass();")
val numElements = ctx.freshName("numElements")
val index = ctx.freshName("index")
@@ -318,13 +316,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _ => true
}
- val result = ctx.freshName("result")
- ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});")
+ val result = ctx.addMutableState("UnsafeRow", "result",
+ v => s"$v = new UnsafeRow(${expressions.length});")
- val holder = ctx.freshName("holder")
val holderClass = classOf[BufferHolder].getName
- ctx.addMutableState(holderClass, holder,
- s"$holder = new $holderClass($result, ${numVarLenFields * 32});")
+ val holder = ctx.addMutableState(holderClass, "holder",
+ v => s"$v = new $holderClass($result, ${numVarLenFields * 32});")
val resetBufferHolder = if (numVarLenFields == 0) {
""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
index be5f5a73b5d47..febf7b0c96c2a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
@@ -70,7 +70,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
// --------------------- copy bitset from row 1 and row 2 --------------------------- //
val copyBitset = Seq.tabulate(outputBitsetWords) { i =>
- val bits = if (bitset1Remainder > 0) {
+ val bits = if (bitset1Remainder > 0 && bitset2Words != 0) {
if (i < bitset1Words - 1) {
s"$getLong(obj1, offset1 + ${i * 8})"
} else if (i == bitset1Words - 1) {
@@ -152,7 +152,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
} else {
// Number of bytes to increase for the offset. Note that since in UnsafeRow we store the
// offset in the upper 32 bit of the words, we can just shift the offset to the left by
- // 32 and increment that amount in place.
+ // 32 and increment that amount in place. However, we need to handle the important special
+ // case of a null field, in which case the offset should be zero and should not have a
+ // shift added to it.
val shift =
if (i < schema1.size) {
s"${(outputBitsetWords - bitset1Words + schema2.size) * 8}L"
@@ -160,14 +162,55 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
s"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1)"
}
val cursor = offset + outputBitsetWords * 8 + i * 8
- s"$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32));\n"
+ // UnsafeRow is a little underspecified, so in what follows we'll treat UnsafeRowWriter's
+ // output as a de-facto specification for the internal layout of data.
+ //
+ // Null-valued fields will always have a data offset of 0 because
+ // UnsafeRowWriter.setNullAt(ordinal) sets the null bit and stores 0 to in field's
+ // position in the fixed-length section of the row. As a result, we must NOT add
+ // `shift` to the offset for null fields.
+ //
+ // We could perform a null-check here by inspecting the null-tracking bitmap, but doing
+ // so could be expensive and will add significant bloat to the generated code. Instead,
+ // we'll rely on the invariant "stored offset == 0 for variable-length data type implies
+ // that the field's value is null."
+ //
+ // To establish that this invariant holds, we'll prove that a non-null field can never
+ // have a stored offset of 0. There are two cases to consider:
+ //
+ // 1. The non-null field's data is of non-zero length: reading this field's value
+ // must read data from the variable-length section of the row, so the stored offset
+ // will actually be used in address calculation and must be correct. The offsets
+ // count bytes from the start of the UnsafeRow so these offsets will always be
+ // non-zero because the storage of the offsets themselves takes up space at the
+ // start of the row.
+ // 2. The non-null field's data is of zero length (i.e. its data is empty). In this
+ // case, we have to worry about the possibility that an arbitrary offset value was
+ // stored because we never actually read any bytes using this offset and therefore
+ // would not crash if it was incorrect. The variable-sized data writing paths in
+ // UnsafeRowWriter unconditionally calls setOffsetAndSize(ordinal, numBytes) with
+ // no special handling for the case where `numBytes == 0`. Internally,
+ // setOffsetAndSize computes the offset without taking the size into account. Thus
+ // the stored offset is the same non-zero offset that would be used if the field's
+ // dataSize was non-zero (and in (1) above we've shown that case behaves as we
+ // expect).
+ //
+ // Thus it is safe to perform `existingOffset != 0` checks here in the place of
+ // more expensive null-bit checks.
+ s"""
+ |existingOffset = $getLong(buf, $cursor);
+ |if (existingOffset != 0) {
+ | $putLong(buf, $cursor, existingOffset + ($shift << 32));
+ |}
+ """.stripMargin
}
}
val updateOffsets = ctx.splitExpressions(
expressions = updateOffset,
funcName = "copyBitsetFunc",
- arguments = ("long", "numBytesVariableRow1") :: Nil)
+ arguments = ("long", "numBytesVariableRow1") :: Nil,
+ makeSplitFunction = (s: String) => "long existingOffset;\n" + s)
// ------------------------ Finally, put everything together --------------------------- //
val codeBody = s"""
@@ -200,6 +243,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
| $copyFixedLengthRow2
| $copyVariableLengthRow1
| $copyVariableLengthRow2
+ | long existingOffset;
| $updateOffsets
|
| out.pointTo(buf, sizeInBytes);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 57a7f2e207738..3dc2ee03a86e3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -63,7 +63,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
val (preprocess, assigns, postprocess, arrayData) =
GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false)
ev.copy(
- code = preprocess + ctx.splitExpressions(assigns) + postprocess,
+ code = preprocess + assigns + postprocess,
value = arrayData,
isNull = "false")
}
@@ -77,24 +77,22 @@ private [sql] object GenArrayData {
*
* @param ctx a [[CodegenContext]]
* @param elementType data type of underlying array elements
- * @param elementsCode a set of [[ExprCode]] for each element of an underlying array
+ * @param elementsCode concatenated set of [[ExprCode]] for each element of an underlying array
* @param isMapKey if true, throw an exception when the element is null
- * @return (code pre-assignments, assignments to each array elements, code post-assignments,
- * arrayData name)
+ * @return (code pre-assignments, concatenated assignments to each array elements,
+ * code post-assignments, arrayData name)
*/
def genCodeToCreateArrayData(
ctx: CodegenContext,
elementType: DataType,
elementsCode: Seq[ExprCode],
- isMapKey: Boolean): (String, Seq[String], String, String) = {
- val arrayName = ctx.freshName("array")
+ isMapKey: Boolean): (String, String, String, String) = {
val arrayDataName = ctx.freshName("arrayData")
val numElements = elementsCode.length
if (!ctx.isPrimitiveType(elementType)) {
+ val arrayName = ctx.freshName("arrayObject")
val genericArrayClass = classOf[GenericArrayData].getName
- ctx.addMutableState("Object[]", arrayName,
- s"$arrayName = new Object[$numElements];")
val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
val isNullAssignment = if (!isMapKey) {
@@ -110,17 +108,21 @@ private [sql] object GenArrayData {
}
"""
}
+ val assignmentString = ctx.splitExpressionsWithCurrentInputs(
+ expressions = assignments,
+ funcName = "apply",
+ extraArguments = ("Object[]", arrayDataName) :: Nil)
- ("",
- assignments,
+ (s"Object[] $arrayName = new Object[$numElements];",
+ assignmentString,
s"final ArrayData $arrayDataName = new $genericArrayClass($arrayName);",
arrayDataName)
} else {
+ val arrayName = ctx.freshName("array")
val unsafeArraySizeInBytes =
UnsafeArrayData.calculateHeaderPortionInBytes(numElements) +
ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements)
val baseOffset = Platform.BYTE_ARRAY_OFFSET
- ctx.addMutableState("UnsafeArrayData", arrayDataName)
val primitiveValueTypeName = ctx.primitiveTypeName(elementType)
val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
@@ -137,14 +139,18 @@ private [sql] object GenArrayData {
}
"""
}
+ val assignmentString = ctx.splitExpressionsWithCurrentInputs(
+ expressions = assignments,
+ funcName = "apply",
+ extraArguments = ("UnsafeArrayData", arrayDataName) :: Nil)
(s"""
byte[] $arrayName = new byte[$unsafeArraySizeInBytes];
- $arrayDataName = new UnsafeArrayData();
+ UnsafeArrayData $arrayDataName = new UnsafeArrayData();
Platform.putLong($arrayName, $baseOffset, $numElements);
$arrayDataName.pointTo($arrayName, $baseOffset, $unsafeArraySizeInBytes);
""",
- assignments,
+ assignmentString,
"",
arrayDataName)
}
@@ -216,10 +222,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
s"""
final boolean ${ev.isNull} = false;
$preprocessKeyData
- ${ctx.splitExpressions(assignKeys)}
+ $assignKeys
$postprocessKeyData
$preprocessValueData
- ${ctx.splitExpressions(assignValues)}
+ $assignValues
$postprocessValueData
final MapData ${ev.value} = new $mapClass($keyArrayData, $valueArrayData);
"""
@@ -350,22 +356,25 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val rowClass = classOf[GenericInternalRow].getName
val values = ctx.freshName("values")
- ctx.addMutableState("Object[]", values, s"$values = null;")
- val valuesCode = ctx.splitExpressions(
- valExprs.zipWithIndex.map { case (e, i) =>
- val eval = e.genCode(ctx)
- s"""
- ${eval.code}
- if (${eval.isNull}) {
- $values[$i] = null;
- } else {
- $values[$i] = ${eval.value};
- }"""
- })
+ val valCodes = valExprs.zipWithIndex.map { case (e, i) =>
+ val eval = e.genCode(ctx)
+ s"""
+ |${eval.code}
+ |if (${eval.isNull}) {
+ | $values[$i] = null;
+ |} else {
+ | $values[$i] = ${eval.value};
+ |}
+ """.stripMargin
+ }
+ val valuesCode = ctx.splitExpressionsWithCurrentInputs(
+ expressions = valCodes,
+ funcName = "createNamedStruct",
+ extraArguments = "Object[]" -> values :: Nil)
ev.copy(code =
s"""
- |$values = new Object[${valExprs.size}];
+ |Object[] $values = new Object[${valExprs.size}];
|$valuesCode
|final InternalRow ${ev.value} = new $rowClass($values);
|$values = null;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 43e643178c899..b444c3a7be92a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -40,7 +40,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
override def checkInputDataTypes(): TypeCheckResult = {
if (predicate.dataType != BooleanType) {
TypeCheckResult.TypeCheckFailure(
- s"type of predicate expression in If should be boolean, not ${predicate.dataType}")
+ "type of predicate expression in If should be boolean, " +
+ s"not ${predicate.dataType.simpleString}")
} else if (!trueValue.dataType.sameType(falseValue.dataType)) {
TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " +
s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).")
@@ -180,13 +181,17 @@ case class CaseWhen(
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- // This variable represents whether the first successful condition is met or not.
- // It is initialized to `false` and it is set to `true` when the first condition which
- // evaluates to `true` is met and therefore is not needed to go on anymore on the computation
- // of the following conditions.
- val conditionMet = ctx.freshName("caseWhenConditionMet")
- ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
- ctx.addMutableState(ctx.javaType(dataType), ev.value)
+ // This variable holds the state of the result:
+ // -1 means the condition is not met yet and the result is unknown.
+ val NOT_MATCHED = -1
+ // 0 means the condition is met and result is not null.
+ val HAS_NONNULL = 0
+ // 1 means the condition is met and result is null.
+ val HAS_NULL = 1
+ // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`,
+ // We won't go on anymore on the computation.
+ val resultState = ctx.freshName("caseWhenResultState")
+ ev.value = ctx.addMutableState(ctx.javaType(dataType), ev.value)
// these blocks are meant to be inside a
// do {
@@ -200,9 +205,8 @@ case class CaseWhen(
|${cond.code}
|if (!${cond.isNull} && ${cond.value}) {
| ${res.code}
- | ${ev.isNull} = ${res.isNull};
+ | $resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL);
| ${ev.value} = ${res.value};
- | $conditionMet = true;
| continue;
|}
""".stripMargin
@@ -212,65 +216,61 @@ case class CaseWhen(
val res = elseExpr.genCode(ctx)
s"""
|${res.code}
- |${ev.isNull} = ${res.isNull};
+ |$resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL);
|${ev.value} = ${res.value};
""".stripMargin
}
val allConditions = cases ++ elseCode
- val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
- allConditions.mkString("\n")
- } else {
- // This generates code like:
- // conditionMet = caseWhen_1(i);
- // if(conditionMet) {
- // continue;
- // }
- // conditionMet = caseWhen_2(i);
- // if(conditionMet) {
- // continue;
- // }
- // ...
- // and the declared methods are:
- // private boolean caseWhen_1234() {
- // boolean conditionMet = false;
- // do {
- // // here the evaluation of the conditions
- // } while (false);
- // return conditionMet;
- // }
- ctx.splitExpressions(allConditions, "caseWhen",
- ("InternalRow", ctx.INPUT_ROW) :: Nil,
- returnType = ctx.JAVA_BOOLEAN,
- makeSplitFunction = {
- func =>
- s"""
- ${ctx.JAVA_BOOLEAN} $conditionMet = false;
- do {
- $func
- } while (false);
- return $conditionMet;
- """
- },
- foldFunctions = { funcCalls =>
- funcCalls.map { funcCall =>
- s"""
- $conditionMet = $funcCall;
- if ($conditionMet) {
- continue;
- }"""
- }.mkString
- })
- }
+ // This generates code like:
+ // caseWhenResultState = caseWhen_1(i);
+ // if(caseWhenResultState != -1) {
+ // continue;
+ // }
+ // caseWhenResultState = caseWhen_2(i);
+ // if(caseWhenResultState != -1) {
+ // continue;
+ // }
+ // ...
+ // and the declared methods are:
+ // private byte caseWhen_1234() {
+ // byte caseWhenResultState = -1;
+ // do {
+ // // here the evaluation of the conditions
+ // } while (false);
+ // return caseWhenResultState;
+ // }
+ val codes = ctx.splitExpressionsWithCurrentInputs(
+ expressions = allConditions,
+ funcName = "caseWhen",
+ returnType = ctx.JAVA_BYTE,
+ makeSplitFunction = func =>
+ s"""
+ |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
+ |do {
+ | $func
+ |} while (false);
+ |return $resultState;
+ """.stripMargin,
+ foldFunctions = _.map { funcCall =>
+ s"""
+ |$resultState = $funcCall;
+ |if ($resultState != $NOT_MATCHED) {
+ | continue;
+ |}
+ """.stripMargin
+ }.mkString)
- ev.copy(code = s"""
- ${ev.isNull} = true;
- ${ev.value} = ${ctx.defaultValue(dataType)};
- ${ctx.JAVA_BOOLEAN} $conditionMet = false;
- do {
- $code
- } while (false);""")
+ ev.copy(code =
+ s"""
+ |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
+ |do {
+ | $codes
+ |} while (false);
+ |// TRUE if any condition is met and the result is null, or no any condition is met.
+ |final boolean ${ev.isNull} = ($resultState != $HAS_NONNULL);
+ """.stripMargin)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index eaf8788888211..424871f2047e9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -23,6 +23,8 @@ import java.util.{Calendar, TimeZone}
import scala.util.control.NonFatal
+import org.apache.commons.lang3.StringEscapeUtils
+
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
@@ -226,7 +228,7 @@ case class Hour(child: Expression, timeZoneId: Option[String] = None)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val tz = ctx.addReferenceMinorObj(timeZone)
+ val tz = ctx.addReferenceObj("timeZone", timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, c => s"$dtu.getHours($c, $tz)")
}
@@ -257,7 +259,7 @@ case class Minute(child: Expression, timeZoneId: Option[String] = None)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val tz = ctx.addReferenceMinorObj(timeZone)
+ val tz = ctx.addReferenceObj("timeZone", timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, c => s"$dtu.getMinutes($c, $tz)")
}
@@ -288,7 +290,7 @@ case class Second(child: Expression, timeZoneId: Option[String] = None)
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val tz = ctx.addReferenceMinorObj(timeZone)
+ val tz = ctx.addReferenceObj("timeZone", timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, c => s"$dtu.getSeconds($c, $tz)")
}
@@ -442,9 +444,10 @@ case class DayOfWeek(child: Expression) extends UnaryExpression with ImplicitCas
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, time => {
val cal = classOf[Calendar].getName
- val c = ctx.freshName("cal")
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
- ctx.addMutableState(cal, c, s"""$c = $cal.getInstance($dtu.getTimeZone("UTC"));""")
+ val c = "calDayOfWeek"
+ ctx.addImmutableStateIfNotExists(cal, c,
+ v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""")
s"""
$c.setTimeInMillis($time * 1000L * 3600L * 24L);
${ev.value} = $c.get($cal.DAY_OF_WEEK);
@@ -484,18 +487,18 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, time => {
val cal = classOf[Calendar].getName
- val c = ctx.freshName("cal")
+ val c = "calWeekOfYear"
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
- ctx.addMutableState(cal, c,
+ ctx.addImmutableStateIfNotExists(cal, c, v =>
s"""
- $c = $cal.getInstance($dtu.getTimeZone("UTC"));
- $c.setFirstDayOfWeek($cal.MONDAY);
- $c.setMinimalDaysInFirstWeek(4);
- """)
+ |$v = $cal.getInstance($dtu.getTimeZone("UTC"));
+ |$v.setFirstDayOfWeek($cal.MONDAY);
+ |$v.setMinimalDaysInFirstWeek(4);
+ """.stripMargin)
s"""
- $c.setTimeInMillis($time * 1000L * 3600L * 24L);
- ${ev.value} = $c.get($cal.WEEK_OF_YEAR);
- """
+ |$c.setTimeInMillis($time * 1000L * 3600L * 24L);
+ |${ev.value} = $c.get($cal.WEEK_OF_YEAR);
+ """.stripMargin
})
}
}
@@ -529,7 +532,7 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
- val tz = ctx.addReferenceMinorObj(timeZone)
+ val tz = ctx.addReferenceObj("timeZone", timeZone)
defineCodeGen(ctx, ev, (timestamp, format) => {
s"""UTF8String.fromString($dtu.newDateFormat($format.toString(), $tz)
.format(new java.util.Date($timestamp / 1000)))"""
@@ -691,7 +694,7 @@ abstract class UnixTime
}""")
}
case StringType =>
- val tz = ctx.addReferenceMinorObj(timeZone)
+ val tz = ctx.addReferenceObj("timeZone", timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
nullSafeCodeGen(ctx, ev, (string, format) => {
s"""
@@ -715,7 +718,7 @@ abstract class UnixTime
${ev.value} = ${eval1.value} / 1000000L;
}""")
case DateType =>
- val tz = ctx.addReferenceMinorObj(timeZone)
+ val tz = ctx.addReferenceObj("timeZone", timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
val eval1 = left.genCode(ctx)
ev.copy(code = s"""
@@ -827,7 +830,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
}""")
}
} else {
- val tz = ctx.addReferenceMinorObj(timeZone)
+ val tz = ctx.addReferenceObj("timeZone", timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
nullSafeCodeGen(ctx, ev, (seconds, f) => {
s"""
@@ -969,7 +972,7 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val tz = ctx.addReferenceMinorObj(timeZone)
+ val tz = ctx.addReferenceObj("timeZone", timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (sd, i) => {
s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds, $tz)"""
@@ -1007,19 +1010,21 @@ case class FromUTCTimestamp(left: Expression, right: Expression)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
if (right.foldable) {
- val tz = right.eval()
+ val tz = right.eval().asInstanceOf[UTF8String]
if (tz == null) {
ev.copy(code = s"""
|boolean ${ev.isNull} = true;
|long ${ev.value} = 0;
""".stripMargin)
} else {
- val tzTerm = ctx.freshName("tz")
- val utcTerm = ctx.freshName("utc")
val tzClass = classOf[TimeZone].getName
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
- ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""")
- ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""")
+ val escapedTz = StringEscapeUtils.escapeJava(tz.toString)
+ val tzTerm = ctx.addMutableState(tzClass, "tz",
+ v => s"""$v = $dtu.getTimeZone("$escapedTz");""")
+ val utcTerm = "tzUTC"
+ ctx.addImmutableStateIfNotExists(tzClass, utcTerm,
+ v => s"""$v = $dtu.getTimeZone("UTC");""")
val eval = left.genCode(ctx)
ev.copy(code = s"""
|${eval.code}
@@ -1065,7 +1070,7 @@ case class TimeSub(start: Expression, interval: Expression, timeZoneId: Option[S
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val tz = ctx.addReferenceMinorObj(timeZone)
+ val tz = ctx.addReferenceObj("timeZone", timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (sd, i) => {
s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds, $tz)"""
@@ -1143,7 +1148,7 @@ case class MonthsBetween(date1: Expression, date2: Expression, timeZoneId: Optio
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val tz = ctx.addReferenceMinorObj(timeZone)
+ val tz = ctx.addReferenceObj("timeZone", timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (l, r) => {
s"""$dtu.monthsBetween($l, $r, $tz)"""
@@ -1183,19 +1188,21 @@ case class ToUTCTimestamp(left: Expression, right: Expression)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
if (right.foldable) {
- val tz = right.eval()
+ val tz = right.eval().asInstanceOf[UTF8String]
if (tz == null) {
ev.copy(code = s"""
|boolean ${ev.isNull} = true;
|long ${ev.value} = 0;
""".stripMargin)
} else {
- val tzTerm = ctx.freshName("tz")
- val utcTerm = ctx.freshName("utc")
val tzClass = classOf[TimeZone].getName
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
- ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""")
- ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""")
+ val escapedTz = StringEscapeUtils.escapeJava(tz.toString)
+ val tzTerm = ctx.addMutableState(tzClass, "tz",
+ v => s"""$v = $dtu.getTimeZone("$escapedTz");""")
+ val utcTerm = "tzUTC"
+ ctx.addImmutableStateIfNotExists(tzClass, utcTerm,
+ v => s"""$v = $dtu.getTimeZone("UTC");""")
val eval = left.genCode(ctx)
ev.copy(code = s"""
|${eval.code}
@@ -1295,80 +1302,79 @@ case class ParseToTimestamp(left: Expression, format: Option[Expression], child:
override def dataType: DataType = TimestampType
}
-/**
- * Returns date truncated to the unit specified by the format.
- */
-// scalastyle:off line.size.limit
-@ExpressionDescription(
- usage = "_FUNC_(date, fmt) - Returns `date` with the time portion of the day truncated to the unit specified by the format model `fmt`.",
- examples = """
- Examples:
- > SELECT _FUNC_('2009-02-12', 'MM');
- 2009-02-01
- > SELECT _FUNC_('2015-10-27', 'YEAR');
- 2015-01-01
- """,
- since = "1.5.0")
-// scalastyle:on line.size.limit
-case class TruncDate(date: Expression, format: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
- override def left: Expression = date
- override def right: Expression = format
-
- override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
- override def dataType: DataType = DateType
+trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes {
+ val instant: Expression
+ val format: Expression
override def nullable: Boolean = true
- override def prettyName: String = "trunc"
private lazy val truncLevel: Int =
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
- override def eval(input: InternalRow): Any = {
+ /**
+ * @param input internalRow (time)
+ * @param maxLevel Maximum level that can be used for truncation (e.g MONTH for Date input)
+ * @param truncFunc function: (time, level) => time
+ */
+ protected def evalHelper(input: InternalRow, maxLevel: Int)(
+ truncFunc: (Any, Int) => Any): Any = {
val level = if (format.foldable) {
truncLevel
} else {
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
}
- if (level == -1) {
- // unknown format
+ if (level == DateTimeUtils.TRUNC_INVALID || level > maxLevel) {
+ // unknown format or too large level
null
} else {
- val d = date.eval(input)
- if (d == null) {
+ val t = instant.eval(input)
+ if (t == null) {
null
} else {
- DateTimeUtils.truncDate(d.asInstanceOf[Int], level)
+ truncFunc(t, level)
}
}
}
- override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ protected def codeGenHelper(
+ ctx: CodegenContext,
+ ev: ExprCode,
+ maxLevel: Int,
+ orderReversed: Boolean = false)(
+ truncFunc: (String, String) => String)
+ : ExprCode = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
if (format.foldable) {
- if (truncLevel == -1) {
+ if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) {
ev.copy(code = s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""")
} else {
- val d = date.genCode(ctx)
+ val t = instant.genCode(ctx)
+ val truncFuncStr = truncFunc(t.value, truncLevel.toString)
ev.copy(code = s"""
- ${d.code}
- boolean ${ev.isNull} = ${d.isNull};
+ ${t.code}
+ boolean ${ev.isNull} = ${t.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
- ${ev.value} = $dtu.truncDate(${d.value}, $truncLevel);
+ ${ev.value} = $dtu.$truncFuncStr;
}""")
}
} else {
- nullSafeCodeGen(ctx, ev, (dateVal, fmt) => {
+ nullSafeCodeGen(ctx, ev, (left, right) => {
val form = ctx.freshName("form")
+ val (dateVal, fmt) = if (orderReversed) {
+ (right, left)
+ } else {
+ (left, right)
+ }
+ val truncFuncStr = truncFunc(dateVal, form)
s"""
int $form = $dtu.parseTruncLevel($fmt);
- if ($form == -1) {
+ if ($form == -1 || $form > $maxLevel) {
${ev.isNull} = true;
} else {
- ${ev.value} = $dtu.truncDate($dateVal, $form);
+ ${ev.value} = $dtu.$truncFuncStr
}
"""
})
@@ -1376,6 +1382,101 @@ case class TruncDate(date: Expression, format: Expression)
}
}
+/**
+ * Returns date truncated to the unit specified by the format.
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(date, fmt) - Returns `date` with the time portion of the day truncated to the unit specified by the format model `fmt`.
+ `fmt` should be one of ["year", "yyyy", "yy", "mon", "month", "mm"]
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_('2009-02-12', 'MM');
+ 2009-02-01
+ > SELECT _FUNC_('2015-10-27', 'YEAR');
+ 2015-01-01
+ """,
+ since = "1.5.0")
+// scalastyle:on line.size.limit
+case class TruncDate(date: Expression, format: Expression)
+ extends TruncInstant {
+ override def left: Expression = date
+ override def right: Expression = format
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
+ override def dataType: DataType = DateType
+ override def prettyName: String = "trunc"
+ override val instant = date
+
+ override def eval(input: InternalRow): Any = {
+ evalHelper(input, maxLevel = DateTimeUtils.TRUNC_TO_MONTH) { (d: Any, level: Int) =>
+ DateTimeUtils.truncDate(d.asInstanceOf[Int], level)
+ }
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ codeGenHelper(ctx, ev, maxLevel = DateTimeUtils.TRUNC_TO_MONTH) { (date: String, fmt: String) =>
+ s"truncDate($date, $fmt);"
+ }
+ }
+}
+
+/**
+ * Returns timestamp truncated to the unit specified by the format.
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(fmt, ts) - Returns timestamp `ts` truncated to the unit specified by the format model `fmt`.
+ `fmt` should be one of ["YEAR", "YYYY", "YY", "MON", "MONTH", "MM", "DAY", "DD", "HOUR", "MINUTE", "SECOND", "WEEK", "QUARTER"]
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_('2015-03-05T09:32:05.359', 'YEAR');
+ 2015-01-01T00:00:00
+ > SELECT _FUNC_('2015-03-05T09:32:05.359', 'MM');
+ 2015-03-01T00:00:00
+ > SELECT _FUNC_('2015-03-05T09:32:05.359', 'DD');
+ 2015-03-05T00:00:00
+ > SELECT _FUNC_('2015-03-05T09:32:05.359', 'HOUR');
+ 2015-03-05T09:00:00
+ """,
+ since = "2.3.0")
+// scalastyle:on line.size.limit
+case class TruncTimestamp(
+ format: Expression,
+ timestamp: Expression,
+ timeZoneId: Option[String] = None)
+ extends TruncInstant with TimeZoneAwareExpression {
+ override def left: Expression = format
+ override def right: Expression = timestamp
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(StringType, TimestampType)
+ override def dataType: TimestampType = TimestampType
+ override def prettyName: String = "date_trunc"
+ override val instant = timestamp
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+ copy(timeZoneId = Option(timeZoneId))
+
+ def this(format: Expression, timestamp: Expression) = this(format, timestamp, None)
+
+ override def eval(input: InternalRow): Any = {
+ evalHelper(input, maxLevel = DateTimeUtils.TRUNC_TO_SECOND) { (t: Any, level: Int) =>
+ DateTimeUtils.truncTimestamp(t.asInstanceOf[Long], level, timeZone)
+ }
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val tz = ctx.addReferenceObj("timeZone", timeZone)
+ codeGenHelper(ctx, ev, maxLevel = DateTimeUtils.TRUNC_TO_SECOND, true) {
+ (date: String, fmt: String) =>
+ s"truncTimestamp($date, $fmt, $tz);"
+ }
+ }
+}
+
/**
* Returns the number of days from startDate to endDate.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index 752dea23e1f7a..db1579ba28671 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -70,10 +70,12 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
case class PromotePrecision(child: Expression) extends UnaryExpression {
override def dataType: DataType = child.dataType
override def eval(input: InternalRow): Any = child.eval(input)
+ /** Just a simple pass-through for code generation. */
override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx)
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("")
override def prettyName: String = "promote_precision"
override def sql: String = child.sql
+ override lazy val canonicalized: Expression = child.canonicalized
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index f1aa130669266..4f4d49166e88c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -155,8 +155,8 @@ case class Stack(children: Seq[Expression]) extends Generator {
val j = (i - 1) % numFields
if (children(i).dataType != elementSchema.fields(j).dataType) {
return TypeCheckResult.TypeCheckFailure(
- s"Argument ${j + 1} (${elementSchema.fields(j).dataType}) != " +
- s"Argument $i (${children(i).dataType})")
+ s"Argument ${j + 1} (${elementSchema.fields(j).dataType.simpleString}) != " +
+ s"Argument $i (${children(i).dataType.simpleString})")
}
}
TypeCheckResult.TypeCheckSuccess
@@ -199,11 +199,11 @@ case class Stack(children: Seq[Expression]) extends Generator {
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// Rows - we write these into an array.
- val rowData = ctx.freshName("rows")
- ctx.addMutableState("InternalRow[]", rowData, s"$rowData = new InternalRow[$numRows];")
+ val rowData = ctx.addMutableState("InternalRow[]", "rows",
+ v => s"$v = new InternalRow[$numRows];")
val values = children.tail
val dataTypes = values.take(numFields).map(_.dataType)
- val code = ctx.splitExpressions(Seq.tabulate(numRows) { row =>
+ val code = ctx.splitExpressionsWithCurrentInputs(Seq.tabulate(numRows) { row =>
val fields = Seq.tabulate(numFields) { col =>
val index = row * numFields + col
if (index < values.length) values(index) else Literal(null, dataTypes(col))
@@ -214,11 +214,11 @@ case class Stack(children: Seq[Expression]) extends Generator {
// Create the collection.
val wrapperClass = classOf[mutable.WrappedArray[_]].getName
- ctx.addMutableState(
- s"$wrapperClass",
- ev.value,
- s"${ev.value} = $wrapperClass$$.MODULE$$.make($rowData);")
- ev.copy(code = code, isNull = "false")
+ ev.copy(code =
+ s"""
+ |$code
+ |$wrapperClass ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData);
+ """.stripMargin, isNull = "false")
}
}
@@ -249,7 +249,8 @@ abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(
- s"input to function explode should be array or map type, not ${child.dataType}")
+ "input to function explode should be array or map type, " +
+ s"not ${child.dataType.simpleString}")
}
// hive-compatible default alias for explode function ("col" for array, "key", "value" for map)
@@ -378,7 +379,8 @@ case class Inline(child: Expression) extends UnaryExpression with CollectionGene
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(
- s"input to function $prettyName should be array of struct type, not ${child.dataType}")
+ s"input to function $prettyName should be array of struct type, " +
+ s"not ${child.dataType.simpleString}")
}
override def elementSchema: StructType = child.dataType match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index c3289b8299933..055ebf6c0da54 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -270,17 +270,32 @@ abstract class HashExpression[E] extends Expression {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
ev.isNull = "false"
- val childrenHash = ctx.splitExpressions(children.map { child =>
+
+ val childrenHash = children.map { child =>
val childGen = child.genCode(ctx)
childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) {
computeHash(childGen.value, child.dataType, ev.value, ctx)
}
- })
+ }
- ctx.addMutableState(ctx.javaType(dataType), ev.value)
- ev.copy(code = s"""
- ${ev.value} = $seed;
- $childrenHash""")
+ val hashResultType = ctx.javaType(dataType)
+ val codes = ctx.splitExpressionsWithCurrentInputs(
+ expressions = childrenHash,
+ funcName = "computeHash",
+ extraArguments = Seq(hashResultType -> ev.value),
+ returnType = hashResultType,
+ makeSplitFunction = body =>
+ s"""
+ |$body
+ |return ${ev.value};
+ """.stripMargin,
+ foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
+
+ ev.copy(code =
+ s"""
+ |$hashResultType ${ev.value} = $seed;
+ |$codes
+ """.stripMargin)
}
protected def nullSafeElementHash(
@@ -389,13 +404,21 @@ abstract class HashExpression[E] extends Expression {
input: String,
result: String,
fields: Array[StructField]): String = {
- val hashes = fields.zipWithIndex.map { case (field, index) =>
+ val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx)
}
+ val hashResultType = ctx.javaType(dataType)
ctx.splitExpressions(
- expressions = hashes,
- funcName = "getHash",
- arguments = Seq("InternalRow" -> input))
+ expressions = fieldsHash,
+ funcName = "computeHashForStruct",
+ arguments = Seq("InternalRow" -> input, hashResultType -> result),
+ returnType = hashResultType,
+ makeSplitFunction = body =>
+ s"""
+ |$body
+ |return $result;
+ """.stripMargin,
+ foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n"))
}
@tailrec
@@ -610,25 +633,41 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
ev.isNull = "false"
+
val childHash = ctx.freshName("childHash")
- val childrenHash = ctx.splitExpressions(children.map { child =>
+ val childrenHash = children.map { child =>
val childGen = child.genCode(ctx)
val codeToComputeHash = ctx.nullSafeExec(child.nullable, childGen.isNull) {
computeHash(childGen.value, child.dataType, childHash, ctx)
}
s"""
|${childGen.code}
+ |$childHash = 0;
|$codeToComputeHash
|${ev.value} = (31 * ${ev.value}) + $childHash;
- |$childHash = 0;
""".stripMargin
- })
+ }
- ctx.addMutableState(ctx.javaType(dataType), ev.value)
- ctx.addMutableState(ctx.JAVA_INT, childHash, s"$childHash = 0;")
- ev.copy(code = s"""
- ${ev.value} = $seed;
- $childrenHash""")
+ val codes = ctx.splitExpressionsWithCurrentInputs(
+ expressions = childrenHash,
+ funcName = "computeHash",
+ extraArguments = Seq(ctx.JAVA_INT -> ev.value),
+ returnType = ctx.JAVA_INT,
+ makeSplitFunction = body =>
+ s"""
+ |${ctx.JAVA_INT} $childHash = 0;
+ |$body
+ |return ${ev.value};
+ """.stripMargin,
+ foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
+
+
+ ev.copy(code =
+ s"""
+ |${ctx.JAVA_INT} ${ev.value} = $seed;
+ |${ctx.JAVA_INT} $childHash = 0;
+ |$codes
+ """.stripMargin)
}
override def eval(input: InternalRow = null): Int = {
@@ -730,23 +769,29 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
input: String,
result: String,
fields: Array[StructField]): String = {
- val localResult = ctx.freshName("localResult")
val childResult = ctx.freshName("childResult")
- fields.zipWithIndex.map { case (field, index) =>
+ val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
+ val computeFieldHash = nullSafeElementHash(
+ input, index.toString, field.nullable, field.dataType, childResult, ctx)
s"""
- $childResult = 0;
- ${nullSafeElementHash(input, index.toString, field.nullable, field.dataType,
- childResult, ctx)}
- $localResult = (31 * $localResult) + $childResult;
- """
- }.mkString(
- s"""
- int $localResult = 0;
- int $childResult = 0;
- """,
- "",
- s"$result = (31 * $result) + $localResult;"
- )
+ |$childResult = 0;
+ |$computeFieldHash
+ |$result = (31 * $result) + $childResult;
+ """.stripMargin
+ }
+
+ s"${ctx.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions(
+ expressions = fieldsHash,
+ funcName = "computeHashForStruct",
+ arguments = Seq("InternalRow" -> input, ctx.JAVA_INT -> result),
+ returnType = ctx.JAVA_INT,
+ makeSplitFunction = body =>
+ s"""
+ |${ctx.JAVA_INT} $childResult = 0;
+ |$body
+ |return $result;
+ """.stripMargin,
+ foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n"))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index eaeaf08c37b4e..383203a209833 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -290,7 +290,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
case FloatType =>
val v = value.asInstanceOf[Float]
if (v.isNaN || v.isInfinite) {
- val boxedValue = ctx.addReferenceMinorObj(v)
+ val boxedValue = ctx.addReferenceObj("boxedValue", v)
val code = s"final $javaType ${ev.value} = ($javaType) $boxedValue;"
ev.copy(code = code)
} else {
@@ -299,7 +299,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
case DoubleType =>
val v = value.asInstanceOf[Double]
if (v.isNaN || v.isInfinite) {
- val boxedValue = ctx.addReferenceMinorObj(v)
+ val boxedValue = ctx.addReferenceObj("boxedValue", v)
val code = s"final $javaType ${ev.value} = ($javaType) $boxedValue;"
ev.copy(code = code)
} else {
@@ -309,8 +309,9 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
ev.copy(code = "", value = s"($javaType)$value")
case TimestampType | LongType =>
ev.copy(code = "", value = s"${value}L")
- case other =>
- ev.copy(code = "", value = ctx.addReferenceMinorObj(value, ctx.javaType(dataType)))
+ case _ =>
+ ev.copy(code = "", value = ctx.addReferenceObj("literal", value,
+ ctx.javaType(dataType)))
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index b86e271fe2958..4b9006ab5b423 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -81,7 +81,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa
// Use unnamed reference that doesn't create a local field here to reduce the number of fields
// because errMsgField is used only when the value is null or false.
- val errMsgField = ctx.addReferenceMinorObj(errMsg)
+ val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
ExprCode(code = s"""${eval.code}
|if (${eval.isNull} || !${eval.value}) {
| throw new RuntimeException($errMsgField);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 173e171910b69..470d5da041ea5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -72,26 +72,52 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
- ctx.addMutableState(ctx.javaType(dataType), ev.value)
+ ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
+ // all the evals are meant to be in a do { ... } while (false); loop
val evals = children.map { e =>
val eval = e.genCode(ctx)
s"""
- if (${ev.isNull}) {
- ${eval.code}
- if (!${eval.isNull}) {
- ${ev.isNull} = false;
- ${ev.value} = ${eval.value};
- }
- }
- """
+ |${eval.code}
+ |if (!${eval.isNull}) {
+ | ${ev.isNull} = false;
+ | ${ev.value} = ${eval.value};
+ | continue;
+ |}
+ """.stripMargin
}
- ev.copy(code = s"""
- ${ev.isNull} = true;
- ${ev.value} = ${ctx.defaultValue(dataType)};
- ${ctx.splitExpressions(evals)}""")
+ val resultType = ctx.javaType(dataType)
+ val codes = ctx.splitExpressionsWithCurrentInputs(
+ expressions = evals,
+ funcName = "coalesce",
+ returnType = resultType,
+ makeSplitFunction = func =>
+ s"""
+ |$resultType ${ev.value} = ${ctx.defaultValue(dataType)};
+ |do {
+ | $func
+ |} while (false);
+ |return ${ev.value};
+ """.stripMargin,
+ foldFunctions = _.map { funcCall =>
+ s"""
+ |${ev.value} = $funcCall;
+ |if (!${ev.isNull}) {
+ | continue;
+ |}
+ """.stripMargin
+ }.mkString)
+
+
+ ev.copy(code =
+ s"""
+ |${ev.isNull} = true;
+ |$resultType ${ev.value} = ${ctx.defaultValue(dataType)};
+ |do {
+ | $codes
+ |} while (false);
+ """.stripMargin)
}
}
@@ -358,53 +384,63 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val nonnull = ctx.freshName("nonnull")
+ // all evals are meant to be inside a do { ... } while (false); loop
val evals = children.map { e =>
val eval = e.genCode(ctx)
e.dataType match {
case DoubleType | FloatType =>
s"""
- if ($nonnull < $n) {
- ${eval.code}
- if (!${eval.isNull} && !Double.isNaN(${eval.value})) {
- $nonnull += 1;
- }
- }
- """
+ |if ($nonnull < $n) {
+ | ${eval.code}
+ | if (!${eval.isNull} && !Double.isNaN(${eval.value})) {
+ | $nonnull += 1;
+ | }
+ |} else {
+ | continue;
+ |}
+ """.stripMargin
case _ =>
s"""
- if ($nonnull < $n) {
- ${eval.code}
- if (!${eval.isNull}) {
- $nonnull += 1;
- }
- }
- """
+ |if ($nonnull < $n) {
+ | ${eval.code}
+ | if (!${eval.isNull}) {
+ | $nonnull += 1;
+ | }
+ |} else {
+ | continue;
+ |}
+ """.stripMargin
}
}
- val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
- evals.mkString("\n")
- } else {
- ctx.splitExpressions(
- expressions = evals,
- funcName = "atLeastNNonNulls",
- arguments = ("InternalRow", ctx.INPUT_ROW) :: ("int", nonnull) :: Nil,
- returnType = "int",
- makeSplitFunction = { body =>
- s"""
- $body
- return $nonnull;
- """
- },
- foldFunctions = { funcCalls =>
- funcCalls.map(funcCall => s"$nonnull = $funcCall;").mkString("\n")
- }
- )
- }
-
- ev.copy(code = s"""
- int $nonnull = 0;
- $code
- boolean ${ev.value} = $nonnull >= $n;""", isNull = "false")
+ val codes = ctx.splitExpressionsWithCurrentInputs(
+ expressions = evals,
+ funcName = "atLeastNNonNulls",
+ extraArguments = (ctx.JAVA_INT, nonnull) :: Nil,
+ returnType = ctx.JAVA_INT,
+ makeSplitFunction = body =>
+ s"""
+ |do {
+ | $body
+ |} while (false);
+ |return $nonnull;
+ """.stripMargin,
+ foldFunctions = _.map { funcCall =>
+ s"""
+ |$nonnull = $funcCall;
+ |if ($nonnull >= $n) {
+ | continue;
+ |}
+ """.stripMargin
+ }.mkString)
+
+ ev.copy(code =
+ s"""
+ |${ctx.JAVA_INT} $nonnull = 0;
+ |do {
+ | $codes
+ |} while (false);
+ |${ctx.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n;
+ """.stripMargin, isNull = "false")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index e2bc79d98b33d..64da9bb9cdec1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -51,7 +51,7 @@ trait InvokeLike extends Expression with NonSQLExpression {
*
* - generate codes for argument.
* - use ctx.splitExpressions() to not exceed 64kb JVM limit while preparing arguments.
- * - avoid some of nullabilty checking which are not needed because the expression is not
+ * - avoid some of nullability checking which are not needed because the expression is not
* nullable.
* - when needNullCheck == true, short circuit if we found one of arguments is null because
* preparing rest of arguments can be skipped in the case.
@@ -62,15 +62,13 @@ trait InvokeLike extends Expression with NonSQLExpression {
def prepareArguments(ctx: CodegenContext): (String, String, String) = {
val resultIsNull = if (needNullCheck) {
- val resultIsNull = ctx.freshName("resultIsNull")
- ctx.addMutableState(ctx.JAVA_BOOLEAN, resultIsNull)
+ val resultIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "resultIsNull")
resultIsNull
} else {
"false"
}
val argValues = arguments.map { e =>
- val argValue = ctx.freshName("argValue")
- ctx.addMutableState(ctx.javaType(e.dataType), argValue)
+ val argValue = ctx.addMutableState(ctx.javaType(e.dataType), "argValue")
argValue
}
@@ -101,7 +99,7 @@ trait InvokeLike extends Expression with NonSQLExpression {
"""
}
}
- val argCode = ctx.splitExpressions(argCodes)
+ val argCode = ctx.splitExpressionsWithCurrentInputs(argCodes)
(argCode, argValues.mkString(", "), resultIsNull)
}
@@ -195,7 +193,8 @@ case class StaticInvoke(
* @param targetObject An expression that will return the object to call the method on.
* @param functionName The name of the method to call.
* @param dataType The expected return type of the function.
- * @param arguments An optional list of expressions, whos evaluation will be passed to the function.
+ * @param arguments An optional list of expressions, whose evaluation will be passed to the
+ * function.
* @param propagateNull When true, and any of the arguments is null, null will be returned instead
* of calling the function.
* @param returnNullable When false, indicating the invoked method will always return
@@ -548,7 +547,7 @@ case class MapObjects private(
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val elementJavaType = ctx.javaType(loopVarDataType)
- ctx.addMutableState(elementJavaType, loopValue)
+ ctx.addMutableState(elementJavaType, loopValue, forceInline = true, useFreshName = false)
val genInputData = inputData.genCode(ctx)
val genFunction = lambdaFunction.genCode(ctx)
val dataLength = ctx.freshName("dataLength")
@@ -644,7 +643,7 @@ case class MapObjects private(
}
val loopNullCheck = if (loopIsNull != "false") {
- ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull)
+ ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false)
inputDataType match {
case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
case _ => s"$loopIsNull = $loopValue == null;"
@@ -808,10 +807,11 @@ case class CatalystToExternalMap private(
val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType]
val keyElementJavaType = ctx.javaType(mapType.keyType)
- ctx.addMutableState(keyElementJavaType, keyLoopValue)
+ ctx.addMutableState(keyElementJavaType, keyLoopValue, forceInline = true, useFreshName = false)
val genKeyFunction = keyLambdaFunction.genCode(ctx)
val valueElementJavaType = ctx.javaType(mapType.valueType)
- ctx.addMutableState(valueElementJavaType, valueLoopValue)
+ ctx.addMutableState(valueElementJavaType, valueLoopValue, forceInline = true,
+ useFreshName = false)
val genValueFunction = valueLambdaFunction.genCode(ctx)
val genInputData = inputData.genCode(ctx)
val dataLength = ctx.freshName("dataLength")
@@ -844,7 +844,8 @@ case class CatalystToExternalMap private(
val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction)
val valueLoopNullCheck = if (valueLoopIsNull != "false") {
- ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull)
+ ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true,
+ useFreshName = false)
s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);"
} else {
""
@@ -994,8 +995,8 @@ case class ExternalMapToCatalyst private(
val keyElementJavaType = ctx.javaType(keyType)
val valueElementJavaType = ctx.javaType(valueType)
- ctx.addMutableState(keyElementJavaType, key)
- ctx.addMutableState(valueElementJavaType, value)
+ ctx.addMutableState(keyElementJavaType, key, forceInline = true, useFreshName = false)
+ ctx.addMutableState(valueElementJavaType, value, forceInline = true, useFreshName = false)
val (defineEntries, defineKeyValue) = child.dataType match {
case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) =>
@@ -1031,14 +1032,14 @@ case class ExternalMapToCatalyst private(
}
val keyNullCheck = if (keyIsNull != "false") {
- ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull)
+ ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false)
s"$keyIsNull = $key == null;"
} else {
""
}
val valueNullCheck = if (valueIsNull != "false") {
- ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull)
+ ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false)
s"$valueIsNull = $value == null;"
} else {
""
@@ -1106,27 +1107,31 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val rowClass = classOf[GenericRowWithSchema].getName
val values = ctx.freshName("values")
- ctx.addMutableState("Object[]", values)
val childrenCodes = children.zipWithIndex.map { case (e, i) =>
val eval = e.genCode(ctx)
- eval.code + s"""
- if (${eval.isNull}) {
- $values[$i] = null;
- } else {
- $values[$i] = ${eval.value};
- }
- """
+ s"""
+ |${eval.code}
+ |if (${eval.isNull}) {
+ | $values[$i] = null;
+ |} else {
+ | $values[$i] = ${eval.value};
+ |}
+ """.stripMargin
}
- val childrenCode = ctx.splitExpressions(childrenCodes)
+ val childrenCode = ctx.splitExpressionsWithCurrentInputs(
+ expressions = childrenCodes,
+ funcName = "createExternalRow",
+ extraArguments = "Object[]" -> values :: Nil)
val schemaField = ctx.addReferenceObj("schema", schema)
- val code = s"""
- $values = new Object[${children.size}];
- $childrenCode
- final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField);
- """
+ val code =
+ s"""
+ |Object[] $values = new Object[${children.size}];
+ |$childrenCode
+ |final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField);
+ """.stripMargin
ev.copy(code = code, isNull = "false")
}
}
@@ -1144,25 +1149,28 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// Code to initialize the serializer.
- val serializer = ctx.freshName("serializer")
- val (serializerClass, serializerInstanceClass) = {
+ val (serializer, serializerClass, serializerInstanceClass) = {
if (kryo) {
- (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName)
+ ("kryoSerializer",
+ classOf[KryoSerializer].getName,
+ classOf[KryoSerializerInstance].getName)
} else {
- (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName)
+ ("javaSerializer",
+ classOf[JavaSerializer].getName,
+ classOf[JavaSerializerInstance].getName)
}
}
// try conf from env, otherwise create a new one
val env = s"${classOf[SparkEnv].getName}.get()"
val sparkConf = s"new ${classOf[SparkConf].getName}()"
- val serializerInit = s"""
- if ($env == null) {
- $serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();
- } else {
- $serializer = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance();
- }
- """
- ctx.addMutableState(serializerInstanceClass, serializer, serializerInit)
+ ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v =>
+ s"""
+ |if ($env == null) {
+ | $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();
+ |} else {
+ | $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance();
+ |}
+ """.stripMargin)
// Code to serialize.
val input = child.genCode(ctx)
@@ -1190,25 +1198,28 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// Code to initialize the serializer.
- val serializer = ctx.freshName("serializer")
- val (serializerClass, serializerInstanceClass) = {
+ val (serializer, serializerClass, serializerInstanceClass) = {
if (kryo) {
- (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName)
+ ("kryoSerializer",
+ classOf[KryoSerializer].getName,
+ classOf[KryoSerializerInstance].getName)
} else {
- (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName)
+ ("javaSerializer",
+ classOf[JavaSerializer].getName,
+ classOf[JavaSerializerInstance].getName)
}
}
// try conf from env, otherwise create a new one
val env = s"${classOf[SparkEnv].getName}.get()"
val sparkConf = s"new ${classOf[SparkConf].getName}()"
- val serializerInit = s"""
- if ($env == null) {
- $serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();
- } else {
- $serializer = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance();
- }
- """
- ctx.addMutableState(serializerInstanceClass, serializer, serializerInit)
+ ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v =>
+ s"""
+ |if ($env == null) {
+ | $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();
+ |} else {
+ | $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance();
+ |}
+ """.stripMargin)
// Code to deserialize.
val input = child.genCode(ctx)
@@ -1244,25 +1255,28 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
val javaBeanInstance = ctx.freshName("javaBean")
val beanInstanceJavaType = ctx.javaType(beanInstance.dataType)
- ctx.addMutableState(beanInstanceJavaType, javaBeanInstance)
val initialize = setters.map {
case (setterMethod, fieldValue) =>
val fieldGen = fieldValue.genCode(ctx)
s"""
- ${fieldGen.code}
- ${javaBeanInstance}.$setterMethod(${fieldGen.value});
- """
+ |${fieldGen.code}
+ |$javaBeanInstance.$setterMethod(${fieldGen.value});
+ """.stripMargin
}
- val initializeCode = ctx.splitExpressions(initialize.toSeq)
+ val initializeCode = ctx.splitExpressionsWithCurrentInputs(
+ expressions = initialize.toSeq,
+ funcName = "initializeJavaBean",
+ extraArguments = beanInstanceJavaType -> javaBeanInstance :: Nil)
- val code = s"""
- ${instanceGen.code}
- ${javaBeanInstance} = ${instanceGen.value};
- if (!${instanceGen.isNull}) {
- $initializeCode
- }
- """
+ val code =
+ s"""
+ |${instanceGen.code}
+ |$beanInstanceJavaType $javaBeanInstance = ${instanceGen.value};
+ |if (!${instanceGen.isNull}) {
+ | $initializeCode
+ |}
+ """.stripMargin
ev.copy(code = code, isNull = instanceGen.isNull, value = instanceGen.value)
}
}
@@ -1303,7 +1317,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil)
// Use unnamed reference that doesn't create a local field here to reduce the number of fields
// because errMsgField is used only when the value is null.
- val errMsgField = ctx.addReferenceMinorObj(errMsg)
+ val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
val code = s"""
${childGen.code}
@@ -1340,7 +1354,7 @@ case class GetExternalRowField(
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// Use unnamed reference that doesn't create a local field here to reduce the number of fields
// because errMsgField is used only when the field is null.
- val errMsgField = ctx.addReferenceMinorObj(errMsg)
+ val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
val row = child.genCode(ctx)
val code = s"""
${row.code}
@@ -1380,7 +1394,7 @@ case class ValidateExternalType(child: Expression, expected: DataType)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// Use unnamed reference that doesn't create a local field here to reduce the number of fields
// because errMsgField is used only when the type doesn't match.
- val errMsgField = ctx.addReferenceMinorObj(errMsg)
+ val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
val input = child.genCode(ctx)
val obj = input.value
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index eb7475354b104..b469f5cb7586a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -195,7 +195,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
}
case _ =>
TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " +
- s"${value.dataType} != ${mismatchOpt.get.dataType}")
+ s"${value.dataType.simpleString} != ${mismatchOpt.get.dataType.simpleString}")
}
} else {
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
@@ -234,38 +234,66 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val javaDataType = ctx.javaType(value.dataType)
val valueGen = value.genCode(ctx)
val listGen = list.map(_.genCode(ctx))
- ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.value)
- ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
+ // inTmpResult has 3 possible values:
+ // -1 means no matches found and there is at least one value in the list evaluated to null
+ val HAS_NULL = -1
+ // 0 means no matches found and all values in the list are not null
+ val NOT_MATCHED = 0
+ // 1 means one value in the list is matched
+ val MATCHED = 1
+ val tmpResult = ctx.freshName("inTmpResult")
val valueArg = ctx.freshName("valueArg")
+ // All the blocks are meant to be inside a do { ... } while (false); loop.
+ // The evaluation of variables can be stopped when we find a matching value.
val listCode = listGen.map(x =>
s"""
- if (!${ev.value}) {
- ${x.code}
- if (${x.isNull}) {
- ${ev.isNull} = true;
- } else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) {
- ${ev.isNull} = false;
- ${ev.value} = true;
- }
- }
- """)
- val listCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
- val args = ("InternalRow", ctx.INPUT_ROW) :: (ctx.javaType(value.dataType), valueArg) :: Nil
- ctx.splitExpressions(expressions = listCode, funcName = "valueIn", arguments = args)
- } else {
- listCode.mkString("\n")
- }
- ev.copy(code = s"""
- ${valueGen.code}
- ${ev.value} = false;
- ${ev.isNull} = ${valueGen.isNull};
- if (!${ev.isNull}) {
- ${ctx.javaType(value.dataType)} $valueArg = ${valueGen.value};
- $listCodes
- }
- """)
+ |${x.code}
+ |if (${x.isNull}) {
+ | $tmpResult = $HAS_NULL; // ${ev.isNull} = true;
+ |} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) {
+ | $tmpResult = $MATCHED; // ${ev.isNull} = false; ${ev.value} = true;
+ | continue;
+ |}
+ """.stripMargin)
+
+ val codes = ctx.splitExpressionsWithCurrentInputs(
+ expressions = listCode,
+ funcName = "valueIn",
+ extraArguments = (javaDataType, valueArg) :: (ctx.JAVA_BYTE, tmpResult) :: Nil,
+ returnType = ctx.JAVA_BYTE,
+ makeSplitFunction = body =>
+ s"""
+ |do {
+ | $body
+ |} while (false);
+ |return $tmpResult;
+ """.stripMargin,
+ foldFunctions = _.map { funcCall =>
+ s"""
+ |$tmpResult = $funcCall;
+ |if ($tmpResult == $MATCHED) {
+ | continue;
+ |}
+ """.stripMargin
+ }.mkString("\n"))
+
+ ev.copy(code =
+ s"""
+ |${valueGen.code}
+ |byte $tmpResult = $HAS_NULL;
+ |if (!${valueGen.isNull}) {
+ | $tmpResult = $NOT_MATCHED;
+ | $javaDataType $valueArg = ${valueGen.value};
+ | do {
+ | $codes
+ | } while (false);
+ |}
+ |final boolean ${ev.isNull} = ($tmpResult == $HAS_NULL);
+ |final boolean ${ev.value} = ($tmpResult == $MATCHED);
+ """.stripMargin)
}
override def sql: String = {
@@ -300,7 +328,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
}
}
- @transient private[this] lazy val set = child.dataType match {
+ @transient lazy val set: Set[Any] = child.dataType match {
case _: AtomicType => hset
case _: NullType => hset
case _ =>
@@ -308,34 +336,24 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ hset
}
- def getSet(): Set[Any] = set
-
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val setName = classOf[Set[Any]].getName
- val InSetName = classOf[InSet].getName
+ val setTerm = ctx.addReferenceObj("set", set)
val childGen = child.genCode(ctx)
- ctx.references += this
- val setTerm = ctx.freshName("set")
- val setNull = if (hasNull) {
- s"""
- |if (!${ev.value}) {
- | ${ev.isNull} = true;
- |}
- """.stripMargin
+ val setIsNull = if (hasNull) {
+ s"${ev.isNull} = !${ev.value};"
} else {
""
}
- ctx.addMutableState(setName, setTerm,
- s"$setTerm = (($InSetName)references[${ctx.references.size - 1}]).getSet();")
- ev.copy(code = s"""
- ${childGen.code}
- boolean ${ev.isNull} = ${childGen.isNull};
- boolean ${ev.value} = false;
- if (!${ev.isNull}) {
- ${ev.value} = $setTerm.contains(${childGen.value});
- $setNull
- }
- """)
+ ev.copy(code =
+ s"""
+ |${childGen.code}
+ |${ctx.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull};
+ |${ctx.JAVA_BOOLEAN} ${ev.value} = false;
+ |if (!${ev.isNull}) {
+ | ${ev.value} = $setTerm.contains(${childGen.value});
+ | $setIsNull
+ |}
+ """.stripMargin)
}
override def sql: String = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index b4aefe6cff73e..8bc936fcbfc31 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -77,9 +77,8 @@ case class Rand(child: Expression) extends RDG {
override protected def evalInternal(input: InternalRow): Double = rng.nextDouble()
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val rngTerm = ctx.freshName("rng")
val className = classOf[XORShiftRandom].getName
- ctx.addMutableState(className, rngTerm)
+ val rngTerm = ctx.addMutableState(className, "rng")
ctx.addPartitionInitializationStatement(
s"$rngTerm = new $className(${seed}L + partitionIndex);")
ev.copy(code = s"""
@@ -112,9 +111,8 @@ case class Randn(child: Expression) extends RDG {
override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val rngTerm = ctx.freshName("rng")
val className = classOf[XORShiftRandom].getName
- ctx.addMutableState(className, rngTerm)
+ val rngTerm = ctx.addMutableState(className, "rng")
ctx.addPartitionInitializationStatement(
s"$rngTerm = new $className(${seed}L + partitionIndex);")
ev.copy(code = s"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index d0d663f63f5db..f3e8f6de58975 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -112,15 +112,14 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val patternClass = classOf[Pattern].getName
val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex"
- val pattern = ctx.freshName("pattern")
if (right.foldable) {
val rVal = right.eval()
if (rVal != null) {
val regexStr =
StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString()))
- ctx.addMutableState(patternClass, pattern,
- s"""$pattern = ${patternClass}.compile("$regexStr");""")
+ val pattern = ctx.addMutableState(patternClass, "patternLike",
+ v => s"""$v = $patternClass.compile("$regexStr");""")
// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
val eval = left.genCode(ctx)
@@ -139,12 +138,13 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
""")
}
} else {
+ val pattern = ctx.freshName("pattern")
val rightStr = ctx.freshName("rightStr")
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
- String $rightStr = ${eval2}.toString();
- ${patternClass} $pattern = ${patternClass}.compile($escapeFunc($rightStr));
- ${ev.value} = $pattern.matcher(${eval1}.toString()).matches();
+ String $rightStr = $eval2.toString();
+ $patternClass $pattern = $patternClass.compile($escapeFunc($rightStr));
+ ${ev.value} = $pattern.matcher($eval1.toString()).matches();
"""
})
}
@@ -187,15 +187,14 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val patternClass = classOf[Pattern].getName
- val pattern = ctx.freshName("pattern")
if (right.foldable) {
val rVal = right.eval()
if (rVal != null) {
val regexStr =
StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString())
- ctx.addMutableState(patternClass, pattern,
- s"""$pattern = ${patternClass}.compile("$regexStr");""")
+ val pattern = ctx.addMutableState(patternClass, "patternRLike",
+ v => s"""$v = $patternClass.compile("$regexStr");""")
// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
val eval = left.genCode(ctx)
@@ -215,11 +214,12 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
}
} else {
val rightStr = ctx.freshName("rightStr")
+ val pattern = ctx.freshName("pattern")
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
- String $rightStr = ${eval2}.toString();
- ${patternClass} $pattern = ${patternClass}.compile($rightStr);
- ${ev.value} = $pattern.matcher(${eval1}.toString()).find(0);
+ String $rightStr = $eval2.toString();
+ $patternClass $pattern = $patternClass.compile($rightStr);
+ ${ev.value} = $pattern.matcher($eval1.toString()).find(0);
"""
})
}
@@ -316,26 +316,17 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
override def prettyName: String = "regexp_replace"
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val termLastRegex = ctx.freshName("lastRegex")
- val termPattern = ctx.freshName("pattern")
-
- val termLastReplacement = ctx.freshName("lastReplacement")
- val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8")
-
- val termResult = ctx.freshName("result")
+ val termResult = ctx.freshName("termResult")
val classNamePattern = classOf[Pattern].getCanonicalName
val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName
val matcher = ctx.freshName("matcher")
- ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;")
- ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;")
- ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;")
- ctx.addMutableState("UTF8String",
- termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;")
- ctx.addMutableState(classNameStringBuffer,
- termResult, s"${termResult} = new $classNameStringBuffer();")
+ val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex")
+ val termPattern = ctx.addMutableState(classNamePattern, "pattern")
+ val termLastReplacement = ctx.addMutableState("String", "lastReplacement")
+ val termLastReplacementInUTF8 = ctx.addMutableState("UTF8String", "lastReplacementInUTF8")
val setEvNotNull = if (nullable) {
s"${ev.isNull} = false;"
@@ -345,24 +336,25 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => {
s"""
- if (!$regexp.equals(${termLastRegex})) {
+ if (!$regexp.equals($termLastRegex)) {
// regex value changed
- ${termLastRegex} = $regexp.clone();
- ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
+ $termLastRegex = $regexp.clone();
+ $termPattern = $classNamePattern.compile($termLastRegex.toString());
}
- if (!$rep.equals(${termLastReplacementInUTF8})) {
+ if (!$rep.equals($termLastReplacementInUTF8)) {
// replacement string changed
- ${termLastReplacementInUTF8} = $rep.clone();
- ${termLastReplacement} = ${termLastReplacementInUTF8}.toString();
+ $termLastReplacementInUTF8 = $rep.clone();
+ $termLastReplacement = $termLastReplacementInUTF8.toString();
}
- ${termResult}.delete(0, ${termResult}.length());
- java.util.regex.Matcher ${matcher} = ${termPattern}.matcher($subject.toString());
+ $classNameStringBuffer $termResult = new $classNameStringBuffer();
+ java.util.regex.Matcher $matcher = $termPattern.matcher($subject.toString());
- while (${matcher}.find()) {
- ${matcher}.appendReplacement(${termResult}, ${termLastReplacement});
+ while ($matcher.find()) {
+ $matcher.appendReplacement($termResult, $termLastReplacement);
}
- ${matcher}.appendTail(${termResult});
- ${ev.value} = UTF8String.fromString(${termResult}.toString());
+ $matcher.appendTail($termResult);
+ ${ev.value} = UTF8String.fromString($termResult.toString());
+ $termResult = null;
$setEvNotNull
"""
})
@@ -416,14 +408,12 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
override def prettyName: String = "regexp_extract"
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val termLastRegex = ctx.freshName("lastRegex")
- val termPattern = ctx.freshName("pattern")
val classNamePattern = classOf[Pattern].getCanonicalName
val matcher = ctx.freshName("matcher")
val matchResult = ctx.freshName("matchResult")
- ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;")
- ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;")
+ val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex")
+ val termPattern = ctx.addMutableState(classNamePattern, "pattern")
val setEvNotNull = if (nullable) {
s"${ev.isNull} = false;"
@@ -433,19 +423,19 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => {
s"""
- if (!$regexp.equals(${termLastRegex})) {
+ if (!$regexp.equals($termLastRegex)) {
// regex value changed
- ${termLastRegex} = $regexp.clone();
- ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
+ $termLastRegex = $regexp.clone();
+ $termPattern = $classNamePattern.compile($termLastRegex.toString());
}
- java.util.regex.Matcher ${matcher} =
- ${termPattern}.matcher($subject.toString());
- if (${matcher}.find()) {
- java.util.regex.MatchResult ${matchResult} = ${matcher}.toMatchResult();
- if (${matchResult}.group($idx) == null) {
+ java.util.regex.Matcher $matcher =
+ $termPattern.matcher($subject.toString());
+ if ($matcher.find()) {
+ java.util.regex.MatchResult $matchResult = $matcher.toMatchResult();
+ if ($matchResult.group($idx) == null) {
${ev.value} = UTF8String.EMPTY_UTF8;
} else {
- ${ev.value} = UTF8String.fromString(${matchResult}.group($idx));
+ ${ev.value} = UTF8String.fromString($matchResult.group($idx));
}
$setEvNotNull
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index ee5cf925d3cef..e004bfc6af473 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -24,11 +24,10 @@ import java.util.regex.Pattern
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
@@ -38,7 +37,8 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
/**
- * An expression that concatenates multiple input strings into a single string.
+ * An expression that concatenates multiple inputs into a single output.
+ * If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.
* If any input is null, concat returns null.
*/
@ExpressionDescription(
@@ -48,17 +48,37 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
> SELECT _FUNC_('Spark', 'SQL');
SparkSQL
""")
-case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes {
+case class Concat(children: Seq[Expression]) extends Expression {
- override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType)
- override def dataType: DataType = StringType
+ private lazy val isBinaryMode: Boolean = dataType == BinaryType
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (children.isEmpty) {
+ TypeCheckResult.TypeCheckSuccess
+ } else {
+ val childTypes = children.map(_.dataType)
+ if (childTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) {
+ return TypeCheckResult.TypeCheckFailure(
+ s"input to function $prettyName should have StringType or BinaryType, but it's " +
+ childTypes.map(_.simpleString).mkString("[", ", ", "]"))
+ }
+ TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName")
+ }
+ }
+
+ override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType)
override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
override def eval(input: InternalRow): Any = {
- val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
- UTF8String.concat(inputs : _*)
+ if (isBinaryMode) {
+ val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
+ ByteArray.concat(inputs: _*)
+ } else {
+ val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
+ UTF8String.concat(inputs : _*)
+ }
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -73,21 +93,27 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas
}
"""
}
- val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
- ctx.splitExpressions(
- expressions = inputs,
- funcName = "valueConcat",
- arguments = ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil)
+
+ val (concatenator, initCode) = if (isBinaryMode) {
+ (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];")
} else {
- inputs.mkString("\n")
+ ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];")
}
+ val codes = ctx.splitExpressionsWithCurrentInputs(
+ expressions = inputs,
+ funcName = "valueConcat",
+ extraArguments = (s"${ctx.javaType(dataType)}[]", args) :: Nil)
ev.copy(s"""
- UTF8String[] $args = new UTF8String[${evals.length}];
+ $initCode
$codes
- UTF8String ${ev.value} = UTF8String.concat($args);
+ ${ctx.javaType(dataType)} ${ev.value} = $concatenator.concat($args);
boolean ${ev.isNull} = ${ev.value} == null;
""")
}
+
+ override def toString: String = s"concat(${children.mkString(", ")})"
+
+ override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
}
@@ -156,14 +182,10 @@ case class ConcatWs(children: Seq[Expression])
""
}
}
- val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
- ctx.splitExpressions(
+ val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = inputs,
funcName = "valueConcatWs",
- arguments = ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil)
- } else {
- inputs.mkString("\n")
- }
+ extraArguments = ("UTF8String[]", args) :: Nil)
ev.copy(s"""
UTF8String[] $args = new UTF8String[$numArgs];
${separator.code}
@@ -208,31 +230,32 @@ case class ConcatWs(children: Seq[Expression])
}
}.unzip
- val codes = ctx.splitExpressions(evals.map(_.code))
- val varargCounts = ctx.splitExpressions(
+ val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code))
+
+ val varargCounts = ctx.splitExpressionsWithCurrentInputs(
expressions = varargCount,
funcName = "varargCountsConcatWs",
- arguments = ("InternalRow", ctx.INPUT_ROW) :: Nil,
returnType = "int",
makeSplitFunction = body =>
s"""
- int $varargNum = 0;
- $body
- return $varargNum;
- """,
- foldFunctions = _.mkString(s"$varargNum += ", s";\n$varargNum += ", ";"))
- val varargBuilds = ctx.splitExpressions(
+ |int $varargNum = 0;
+ |$body
+ |return $varargNum;
+ """.stripMargin,
+ foldFunctions = _.map(funcCall => s"$varargNum += $funcCall;").mkString("\n"))
+
+ val varargBuilds = ctx.splitExpressionsWithCurrentInputs(
expressions = varargBuild,
funcName = "varargBuildsConcatWs",
- arguments =
- ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String []", array) :: ("int", idxInVararg) :: Nil,
+ extraArguments = ("UTF8String []", array) :: ("int", idxInVararg) :: Nil,
returnType = "int",
makeSplitFunction = body =>
s"""
- $body
- return $idxInVararg;
- """,
- foldFunctions = _.mkString(s"$idxInVararg = ", s";\n$idxInVararg = ", ";"))
+ |$body
+ |return $idxInVararg;
+ """.stripMargin,
+ foldFunctions = _.map(funcCall => s"$idxInVararg = $funcCall;").mkString("\n"))
+
ev.copy(
s"""
$codes
@@ -248,33 +271,45 @@ case class ConcatWs(children: Seq[Expression])
}
}
+/**
+ * An expression that returns the `n`-th input in given inputs.
+ * If all inputs are binary, `elt` returns an output as binary. Otherwise, it returns as string.
+ * If any input is null, `elt` returns null.
+ */
// scalastyle:off line.size.limit
@ExpressionDescription(
- usage = "_FUNC_(n, str1, str2, ...) - Returns the `n`-th string, e.g., returns `str2` when `n` is 2.",
+ usage = "_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.",
examples = """
Examples:
> SELECT _FUNC_(1, 'scala', 'java');
scala
""")
// scalastyle:on line.size.limit
-case class Elt(children: Seq[Expression])
- extends Expression with ImplicitCastInputTypes {
+case class Elt(children: Seq[Expression]) extends Expression {
private lazy val indexExpr = children.head
- private lazy val stringExprs = children.tail.toArray
+ private lazy val inputExprs = children.tail.toArray
/** This expression is always nullable because it returns null if index is out of range. */
override def nullable: Boolean = true
- override def dataType: DataType = StringType
-
- override def inputTypes: Seq[DataType] = IntegerType +: Seq.fill(children.size - 1)(StringType)
+ override def dataType: DataType = inputExprs.map(_.dataType).headOption.getOrElse(StringType)
override def checkInputDataTypes(): TypeCheckResult = {
if (children.size < 2) {
TypeCheckResult.TypeCheckFailure("elt function requires at least two arguments")
} else {
- super[ImplicitCastInputTypes].checkInputDataTypes()
+ val (indexType, inputTypes) = (indexExpr.dataType, inputExprs.map(_.dataType))
+ if (indexType != IntegerType) {
+ return TypeCheckResult.TypeCheckFailure(s"first input to function $prettyName should " +
+ s"have IntegerType, but it's $indexType")
+ }
+ if (inputTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) {
+ return TypeCheckResult.TypeCheckFailure(
+ s"input to function $prettyName should have StringType or BinaryType, but it's " +
+ inputTypes.map(_.simpleString).mkString("[", ", ", "]"))
+ }
+ TypeUtils.checkForSameTypeInputExpr(inputTypes, s"function $prettyName")
}
}
@@ -284,65 +319,67 @@ case class Elt(children: Seq[Expression])
null
} else {
val index = indexObj.asInstanceOf[Int]
- if (index <= 0 || index > stringExprs.length) {
+ if (index <= 0 || index > inputExprs.length) {
null
} else {
- stringExprs(index - 1).eval(input)
+ inputExprs(index - 1).eval(input)
}
}
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val index = indexExpr.genCode(ctx)
- val strings = stringExprs.map(_.genCode(ctx))
+ val inputs = inputExprs.map(_.genCode(ctx))
val indexVal = ctx.freshName("index")
- val stringVal = ctx.freshName("stringVal")
- val assignStringValue = strings.zipWithIndex.map { case (eval, index) =>
- s"""
- case ${index + 1}:
- ${eval.code}
- $stringVal = ${eval.isNull} ? null : ${eval.value};
- break;
- """
- }
+ val indexMatched = ctx.freshName("eltIndexMatched")
- val cases = ctx.buildCodeBlocks(assignStringValue)
- val codes = if (cases.length == 1) {
+ val inputVal = ctx.addMutableState(ctx.javaType(dataType), "inputVal")
+
+ val assignInputValue = inputs.zipWithIndex.map { case (eval, index) =>
s"""
- UTF8String $stringVal = null;
- switch ($indexVal) {
- ${cases.head}
- }
- """
- } else {
- var prevFunc = "null"
- for (c <- cases.reverse) {
- val funcName = ctx.freshName("eltFunc")
- val funcBody = s"""
- private UTF8String $funcName(InternalRow ${ctx.INPUT_ROW}, int $indexVal) {
- UTF8String $stringVal = null;
- switch ($indexVal) {
- $c
- default:
- return $prevFunc;
- }
- return $stringVal;
- }
- """
- val fullFuncName = ctx.addNewFunction(funcName, funcBody)
- prevFunc = s"$fullFuncName(${ctx.INPUT_ROW}, $indexVal)"
- }
- s"UTF8String $stringVal = $prevFunc;"
+ |if ($indexVal == ${index + 1}) {
+ | ${eval.code}
+ | $inputVal = ${eval.isNull} ? null : ${eval.value};
+ | $indexMatched = true;
+ | continue;
+ |}
+ """.stripMargin
}
+ val codes = ctx.splitExpressionsWithCurrentInputs(
+ expressions = assignInputValue,
+ funcName = "eltFunc",
+ extraArguments = ("int", indexVal) :: Nil,
+ returnType = ctx.JAVA_BOOLEAN,
+ makeSplitFunction = body =>
+ s"""
+ |${ctx.JAVA_BOOLEAN} $indexMatched = false;
+ |do {
+ | $body
+ |} while (false);
+ |return $indexMatched;
+ """.stripMargin,
+ foldFunctions = _.map { funcCall =>
+ s"""
+ |$indexMatched = $funcCall;
+ |if ($indexMatched) {
+ | continue;
+ |}
+ """.stripMargin
+ }.mkString)
+
ev.copy(
s"""
- ${index.code}
- final int $indexVal = ${index.value};
- $codes
- UTF8String ${ev.value} = $stringVal;
- final boolean ${ev.isNull} = ${ev.value} == null;
- """)
+ |${index.code}
+ |final int $indexVal = ${index.value};
+ |${ctx.JAVA_BOOLEAN} $indexMatched = false;
+ |$inputVal = null;
+ |do {
+ | $codes
+ |} while (false);
+ |final ${ctx.javaType(dataType)} ${ev.value} = $inputVal;
+ |final boolean ${ev.isNull} = ${ev.value} == null;
+ """.stripMargin)
}
}
@@ -536,14 +573,11 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val termLastMatching = ctx.freshName("lastMatching")
- val termLastReplace = ctx.freshName("lastReplace")
- val termDict = ctx.freshName("dict")
val classNameDict = classOf[JMap[Character, Character]].getCanonicalName
- ctx.addMutableState("UTF8String", termLastMatching, s"$termLastMatching = null;")
- ctx.addMutableState("UTF8String", termLastReplace, s"$termLastReplace = null;")
- ctx.addMutableState(classNameDict, termDict, s"$termDict = null;")
+ val termLastMatching = ctx.addMutableState("UTF8String", "lastMatching")
+ val termLastReplace = ctx.addMutableState("UTF8String", "lastReplace")
+ val termDict = ctx.addMutableState(classNameDict, "dict")
nullSafeCodeGen(ctx, ev, (src, matching, replace) => {
val check = if (matchingExpr.foldable && replaceExpr.foldable) {
@@ -1388,14 +1422,10 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
$argList[$index] = $value;
"""
}
- val argListCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
- ctx.splitExpressions(
- expressions = argListCode,
- funcName = "valueFormatString",
- arguments = ("InternalRow", ctx.INPUT_ROW) :: ("Object[]", argList) :: Nil)
- } else {
- argListCode.mkString("\n")
- }
+ val argListCodes = ctx.splitExpressionsWithCurrentInputs(
+ expressions = argListCode,
+ funcName = "valueFormatString",
+ extraArguments = ("Object[]", argList) :: Nil)
val form = ctx.freshName("formatter")
val formatter = classOf[java.util.Formatter].getName
@@ -2073,15 +2103,12 @@ case class FormatNumber(x: Expression, d: Expression)
// SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.')
// as a decimal separator.
val usLocale = "US"
- val lastDValue = ctx.freshName("lastDValue")
- val pattern = ctx.freshName("pattern")
- val numberFormat = ctx.freshName("numberFormat")
val i = ctx.freshName("i")
val dFormat = ctx.freshName("dFormat")
- ctx.addMutableState(ctx.JAVA_INT, lastDValue, s"$lastDValue = -100;")
- ctx.addMutableState(sb, pattern, s"$pattern = new $sb();")
- ctx.addMutableState(df, numberFormat,
- s"""$numberFormat = new $df("", new $dfs($l.$usLocale));""")
+ val lastDValue = ctx.addMutableState(ctx.JAVA_INT, "lastDValue", v => s"$v = -100;")
+ val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();")
+ val numberFormat = ctx.addMutableState(df, "numberFormat",
+ v => s"""$v = new $df("", new $dfs($l.$usLocale));""")
s"""
if ($d >= 0) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
index e11e3a105f597..dd13d9a3bba51 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
@@ -70,9 +70,9 @@ case class WindowSpecDefinition(
case f: SpecifiedWindowFrame if f.frameType == RangeFrame && f.isValueBound &&
!isValidFrameType(f.valueBoundary.head.dataType) =>
TypeCheckFailure(
- s"The data type '${orderSpec.head.dataType}' used in the order specification does " +
- s"not match the data type '${f.valueBoundary.head.dataType}' which is used in the " +
- "range frame.")
+ s"The data type '${orderSpec.head.dataType.simpleString}' used in the order " +
+ "specification does not match the data type " +
+ s"'${f.valueBoundary.head.dataType.simpleString}' which is used in the range frame.")
case _ => TypeCheckSuccess
}
}
@@ -251,8 +251,8 @@ case class SpecifiedWindowFrame(
TypeCheckFailure(s"Window frame $location bound '$e' is not a literal.")
case e: Expression if !frameType.inputType.acceptsType(e.dataType) =>
TypeCheckFailure(
- s"The data type of the $location bound '${e.dataType} does not match " +
- s"the expected data type '${frameType.inputType}'.")
+ s"The data type of the $location bound '${e.dataType.simpleString}' does not match " +
+ s"the expected data type '${frameType.inputType.simpleString}'.")
case _ => TypeCheckSuccess
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 0d961bf2e6e5e..c794ba8619322 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -47,7 +47,62 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
protected def fixedPoint = FixedPoint(SQLConf.get.optimizerMaxIterations)
def batches: Seq[Batch] = {
- Batch("Eliminate Distinct", Once, EliminateDistinct) ::
+ val operatorOptimizationRuleSet =
+ Seq(
+ // Operator push down
+ PushProjectionThroughUnion,
+ ReorderJoin,
+ EliminateOuterJoin,
+ PushPredicateThroughJoin,
+ PushDownPredicate,
+ LimitPushDown,
+ ColumnPruning,
+ InferFiltersFromConstraints,
+ // Operator combine
+ CollapseRepartition,
+ CollapseProject,
+ CollapseWindow,
+ CombineFilters,
+ CombineLimits,
+ CombineUnions,
+ // Constant folding and strength reduction
+ NullPropagation,
+ ConstantPropagation,
+ FoldablePropagation,
+ OptimizeIn,
+ ConstantFolding,
+ ReorderAssociativeOperator,
+ LikeSimplification,
+ BooleanSimplification,
+ SimplifyConditionals,
+ RemoveDispensableExpressions,
+ SimplifyBinaryComparison,
+ PruneFilters,
+ EliminateSorts,
+ SimplifyCasts,
+ SimplifyCaseConversionExpressions,
+ RewriteCorrelatedScalarSubquery,
+ EliminateSerialization,
+ RemoveRedundantAliases,
+ RemoveRedundantProject,
+ SimplifyCreateStructOps,
+ SimplifyCreateArrayOps,
+ SimplifyCreateMapOps,
+ CombineConcats) ++
+ extendedOperatorOptimizationRules
+
+ val operatorOptimizationBatch: Seq[Batch] = {
+ val rulesWithoutInferFiltersFromConstraints =
+ operatorOptimizationRuleSet.filterNot(_ == InferFiltersFromConstraints)
+ Batch("Operator Optimization before Inferring Filters", fixedPoint,
+ rulesWithoutInferFiltersFromConstraints: _*) ::
+ Batch("Infer Filters", Once,
+ InferFiltersFromConstraints) ::
+ Batch("Operator Optimization after Inferring Filters", fixedPoint,
+ rulesWithoutInferFiltersFromConstraints: _*) :: Nil
+ }
+
+ (Batch("Eliminate Distinct", Once, EliminateDistinct) ::
// Technically some of the rules in Finish Analysis are not optimizer rules and belong more
// in the analyzer, because they are needed for correctness (e.g. ComputeCurrentTime).
// However, because we also use the analyzer to canonicalized queries (for view definition),
@@ -81,66 +136,26 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
ReplaceDistinctWithAggregate) ::
Batch("Aggregate", fixedPoint,
RemoveLiteralFromGroupExpressions,
- RemoveRepetitionFromGroupExpressions) ::
- Batch("Operator Optimizations", fixedPoint, Seq(
- // Operator push down
- PushProjectionThroughUnion,
- ReorderJoin,
- EliminateOuterJoin,
- InferFiltersFromConstraints,
- BooleanSimplification,
- PushPredicateThroughJoin,
- PushDownPredicate,
- LimitPushDown,
- ColumnPruning,
- // Operator combine
- CollapseRepartition,
- CollapseProject,
- CollapseWindow,
- CombineFilters,
- CombineLimits,
- CombineUnions,
- // Constant folding and strength reduction
- NullPropagation,
- ConstantPropagation,
- FoldablePropagation,
- OptimizeIn,
- ConstantFolding,
- ReorderAssociativeOperator,
- LikeSimplification,
- BooleanSimplification,
- SimplifyConditionals,
- RemoveDispensableExpressions,
- SimplifyBinaryComparison,
- PruneFilters,
- EliminateSorts,
- SimplifyCasts,
- SimplifyCaseConversionExpressions,
- RewriteCorrelatedScalarSubquery,
- EliminateSerialization,
- RemoveRedundantAliases,
- RemoveRedundantProject,
- SimplifyCreateStructOps,
- SimplifyCreateArrayOps,
- SimplifyCreateMapOps,
- CombineConcats) ++
- extendedOperatorOptimizationRules: _*) ::
+ RemoveRepetitionFromGroupExpressions) :: Nil ++
+ operatorOptimizationBatch) :+
Batch("Join Reorder", Once,
- CostBasedJoinReorder) ::
+ CostBasedJoinReorder) :+
Batch("Decimal Optimizations", fixedPoint,
- DecimalAggregates) ::
+ DecimalAggregates) :+
Batch("Object Expressions Optimization", fixedPoint,
EliminateMapObjects,
- CombineTypedFilters) ::
+ CombineTypedFilters) :+
Batch("LocalRelation", fixedPoint,
ConvertToLocalRelation,
- PropagateEmptyRelation) ::
+ PropagateEmptyRelation) :+
// The following batch should be executed after batch "Join Reorder" and "LocalRelation".
Batch("Check Cartesian Products", Once,
- CheckCartesianProducts) ::
+ CheckCartesianProducts) :+
Batch("RewriteSubquery", Once,
RewritePredicateSubquery,
- CollapseProject) :: Nil
+ ColumnPruning,
+ CollapseProject,
+ RemoveRedundantProject)
}
/**
@@ -441,12 +456,15 @@ object ColumnPruning extends Rule[LogicalPlan] {
f.copy(child = prunedChild(child, f.references))
case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty =>
e.copy(child = prunedChild(child, e.references))
- case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty =>
- g.copy(child = prunedChild(g.child, g.references))
- // Turn off `join` for Generate if no column from it's child is used
- case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) =>
- p.copy(child = g.copy(join = false))
+ // prune unrequired references
+ case p @ Project(_, g: Generate) if p.references != g.outputSet =>
+ val requiredAttrs = p.references -- g.producedAttributes ++ g.generator.references
+ val newChild = prunedChild(g.child, requiredAttrs)
+ val unrequired = g.generator.references -- p.references
+ val unrequiredIndices = newChild.output.zipWithIndex.filter(t => unrequired.contains(t._1))
+ .map(_._2)
+ p.copy(child = g.copy(child = newChild, unrequiredChildIndex = unrequiredIndices))
// Eliminate unneeded attributes from right side of a Left Existence Join.
case j @ Join(_, right, LeftExistence(_), _) =>
@@ -692,7 +710,9 @@ object CombineUnions extends Rule[LogicalPlan] {
*/
object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case Filter(fc, nf @ Filter(nc, grandChild)) =>
+ // The query execution/optimization does not guarantee the expressions are evaluated in order.
+ // We only can combine them if and only if both are deterministic.
+ case Filter(fc, nf @ Filter(nc, grandChild)) if fc.deterministic && nc.deterministic =>
(ExpressionSet(splitConjunctivePredicates(fc)) --
ExpressionSet(splitConjunctivePredicates(nc))).reduceOption(And) match {
case Some(ac) =>
@@ -775,7 +795,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild))
case filter @ Filter(condition, aggregate: Aggregate)
- if aggregate.aggregateExpressions.forall(_.deterministic) =>
+ if aggregate.aggregateExpressions.forall(_.deterministic)
+ && aggregate.groupingExpressions.nonEmpty =>
// Find all the aliased expressions in the aggregate list that don't include any actual
// AggregateExpression, and create a map from the alias to the expression
val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect {
@@ -785,15 +806,15 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
// For each filter, expand the alias and check if the filter can be evaluated using
// attributes produced by the aggregate operator's child operator.
- val (candidates, containingNonDeterministic) =
- splitConjunctivePredicates(condition).span(_.deterministic)
+ val (candidates, nonDeterministic) =
+ splitConjunctivePredicates(condition).partition(_.deterministic)
val (pushDown, rest) = candidates.partition { cond =>
val replaced = replaceAlias(cond, aliasMap)
cond.references.nonEmpty && replaced.references.subsetOf(aggregate.child.outputSet)
}
- val stayUp = rest ++ containingNonDeterministic
+ val stayUp = rest ++ nonDeterministic
if (pushDown.nonEmpty) {
val pushDownPredicate = pushDown.reduce(And)
@@ -815,14 +836,14 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) =>
val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references))
- val (candidates, containingNonDeterministic) =
- splitConjunctivePredicates(condition).span(_.deterministic)
+ val (candidates, nonDeterministic) =
+ splitConjunctivePredicates(condition).partition(_.deterministic)
val (pushDown, rest) = candidates.partition { cond =>
cond.references.subsetOf(partitionAttrs)
}
- val stayUp = rest ++ containingNonDeterministic
+ val stayUp = rest ++ nonDeterministic
if (pushDown.nonEmpty) {
val pushDownPredicate = pushDown.reduce(And)
@@ -834,7 +855,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
case filter @ Filter(condition, union: Union) =>
// Union could change the rows, so non-deterministic predicate can't be pushed down
- val (pushDown, stayUp) = splitConjunctivePredicates(condition).span(_.deterministic)
+ val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition(_.deterministic)
if (pushDown.nonEmpty) {
val pushDownCond = pushDown.reduceLeft(And)
@@ -858,13 +879,9 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
}
case filter @ Filter(condition, watermark: EventTimeWatermark) =>
- // We can only push deterministic predicates which don't reference the watermark attribute.
- // We could in theory span() only on determinism and pull out deterministic predicates
- // on the watermark separately. But it seems unnecessary and a bit confusing to not simply
- // use the prefix as we do for nondeterminism in other cases.
-
- val (pushDown, stayUp) = splitConjunctivePredicates(condition).span(
- p => p.deterministic && !p.references.contains(watermark.eventTime))
+ val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { p =>
+ p.deterministic && !p.references.contains(watermark.eventTime)
+ }
if (pushDown.nonEmpty) {
val pushDownPredicate = pushDown.reduceLeft(And)
@@ -905,14 +922,14 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
// come from grandchild.
// TODO: non-deterministic predicates could be pushed through some operators that do not change
// the rows.
- val (candidates, containingNonDeterministic) =
- splitConjunctivePredicates(filter.condition).span(_.deterministic)
+ val (candidates, nonDeterministic) =
+ splitConjunctivePredicates(filter.condition).partition(_.deterministic)
val (pushDown, rest) = candidates.partition { cond =>
cond.references.subsetOf(grandchild.outputSet)
}
- val stayUp = rest ++ containingNonDeterministic
+ val stayUp = rest ++ nonDeterministic
if (pushDown.nonEmpty) {
val newChild = insertFilter(pushDown.reduceLeft(And))
@@ -955,23 +972,19 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
/**
* Splits join condition expressions or filter predicates (on a given join's output) into three
* categories based on the attributes required to evaluate them. Note that we explicitly exclude
- * on-deterministic (i.e., stateful) condition expressions in canEvaluateInLeft or
+ * non-deterministic (i.e., stateful) condition expressions in canEvaluateInLeft or
* canEvaluateInRight to prevent pushing these predicates on either side of the join.
*
* @return (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth)
*/
private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = {
- // Note: In order to ensure correctness, it's important to not change the relative ordering of
- // any deterministic expression that follows a non-deterministic expression. To achieve this,
- // we only consider pushing down those expressions that precede the first non-deterministic
- // expression in the condition.
- val (pushDownCandidates, containingNonDeterministic) = condition.span(_.deterministic)
+ val (pushDownCandidates, nonDeterministic) = condition.partition(_.deterministic)
val (leftEvaluateCondition, rest) =
pushDownCandidates.partition(_.references.subsetOf(left.outputSet))
val (rightEvaluateCondition, commonCondition) =
rest.partition(expr => expr.references.subsetOf(right.outputSet))
- (leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ containingNonDeterministic)
+ (leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ nonDeterministic)
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
@@ -1209,7 +1222,13 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] {
Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId)
}
}
- Aggregate(keys, aggCols, child)
+ // SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping
+ // aggregations by checking the number of grouping keys. The key difference here is that a
+ // global aggregation always returns at least one row even if there are no input rows. Here
+ // we append a literal when the grouping key list is empty so that the result aggregate
+ // operator is properly treated as a grouping aggregation.
+ val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys
+ Aggregate(nonemptyKeys, aggCols, child)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
index 52fbb4df2f58e..a6e5aa6daca65 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
@@ -41,6 +41,10 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {
private def empty(plan: LogicalPlan) =
LocalRelation(plan.output, data = Seq.empty, isStreaming = plan.isStreaming)
+ // Construct a project list from plan's output, while the value is always NULL.
+ private def nullValueProjectList(plan: LogicalPlan): Seq[NamedExpression] =
+ plan.output.map{ a => Alias(Literal(null), a.name)(a.exprId) }
+
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case p: Union if p.children.forall(isEmptyLocalRelation) =>
empty(p)
@@ -49,16 +53,28 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {
// as stateful streaming joins need to perform other state management operations other than
// just processing the input data.
case p @ Join(_, _, joinType, _)
- if !p.children.exists(_.isStreaming) && p.children.exists(isEmptyLocalRelation) =>
- joinType match {
- case _: InnerLike => empty(p)
- // Intersect is handled as LeftSemi by `ReplaceIntersectWithSemiJoin` rule.
- // Except is handled as LeftAnti by `ReplaceExceptWithAntiJoin` rule.
- case LeftOuter | LeftSemi | LeftAnti if isEmptyLocalRelation(p.left) => empty(p)
- case RightOuter if isEmptyLocalRelation(p.right) => empty(p)
- case FullOuter if p.children.forall(isEmptyLocalRelation) => empty(p)
- case _ => p
- }
+ if !p.children.exists(_.isStreaming) =>
+ val isLeftEmpty = isEmptyLocalRelation(p.left)
+ val isRightEmpty = isEmptyLocalRelation(p.right)
+ if (isLeftEmpty || isRightEmpty) {
+ joinType match {
+ case _: InnerLike => empty(p)
+ // Intersect is handled as LeftSemi by `ReplaceIntersectWithSemiJoin` rule.
+ // Except is handled as LeftAnti by `ReplaceExceptWithAntiJoin` rule.
+ case LeftOuter | LeftSemi | LeftAnti if isLeftEmpty => empty(p)
+ case LeftSemi if isRightEmpty => empty(p)
+ case LeftAnti if isRightEmpty => p.left
+ case FullOuter if isLeftEmpty && isRightEmpty => empty(p)
+ case LeftOuter | FullOuter if isRightEmpty =>
+ Project(p.left.output ++ nullValueProjectList(p.right), p.left)
+ case RightOuter if isRightEmpty => empty(p)
+ case RightOuter | FullOuter if isLeftEmpty =>
+ Project(nullValueProjectList(p.left) ++ p.right.output, p.right)
+ case _ => p
+ }
+ } else {
+ p
+ }
case p: UnaryNode if p.children.nonEmpty && p.children.forall(isEmptyLocalRelation) => p match {
case _: Project => empty(p)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 785e815b41185..1c0b7bd806801 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -21,6 +21,7 @@ import scala.collection.immutable.HashSet
import scala.collection.mutable.{ArrayBuffer, Stack}
import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -64,49 +65,89 @@ object ConstantFolding extends Rule[LogicalPlan] {
* }}}
*
* Approach used:
- * - Start from AND operator as the root
- * - Get all the children conjunctive predicates which are EqualTo / EqualNullSafe such that they
- * don't have a `NOT` or `OR` operator in them
* - Populate a mapping of attribute => constant value by looking at all the equals predicates
* - Using this mapping, replace occurrence of the attributes with the corresponding constant values
* in the AND node.
*/
object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper {
- private def containsNonConjunctionPredicates(expression: Expression): Boolean = expression.find {
- case _: Not | _: Or => true
- case _ => false
- }.isDefined
-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case f: Filter => f transformExpressionsUp {
- case and: And =>
- val conjunctivePredicates =
- splitConjunctivePredicates(and)
- .filter(expr => expr.isInstanceOf[EqualTo] || expr.isInstanceOf[EqualNullSafe])
- .filterNot(expr => containsNonConjunctionPredicates(expr))
-
- val equalityPredicates = conjunctivePredicates.collect {
- case e @ EqualTo(left: AttributeReference, right: Literal) => ((left, right), e)
- case e @ EqualTo(left: Literal, right: AttributeReference) => ((right, left), e)
- case e @ EqualNullSafe(left: AttributeReference, right: Literal) => ((left, right), e)
- case e @ EqualNullSafe(left: Literal, right: AttributeReference) => ((right, left), e)
- }
+ case f: Filter =>
+ val (newCondition, _) = traverse(f.condition, replaceChildren = true)
+ if (newCondition.isDefined) {
+ f.copy(condition = newCondition.get)
+ } else {
+ f
+ }
+ }
- val constantsMap = AttributeMap(equalityPredicates.map(_._1))
- val predicates = equalityPredicates.map(_._2).toSet
+ type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)]
- def replaceConstants(expression: Expression) = expression transform {
- case a: AttributeReference =>
- constantsMap.get(a) match {
- case Some(literal) => literal
- case None => a
- }
+ /**
+ * Traverse a condition as a tree and replace attributes with constant values.
+ * - On matching [[And]], recursively traverse each children and get propagated mappings.
+ * If the current node is not child of another [[And]], replace all occurrences of the
+ * attributes with the corresponding constant values.
+ * - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], propagate the mapping
+ * of attribute => constant.
+ * - On matching [[Or]] or [[Not]], recursively traverse each children, propagate empty mapping.
+ * - Otherwise, stop traversal and propagate empty mapping.
+ * @param condition condition to be traversed
+ * @param replaceChildren whether to replace attributes with constant values in children
+ * @return A tuple including:
+ * 1. Option[Expression]: optional changed condition after traversal
+ * 2. EqualityPredicates: propagated mapping of attribute => constant
+ */
+ private def traverse(condition: Expression, replaceChildren: Boolean)
+ : (Option[Expression], EqualityPredicates) =
+ condition match {
+ case e @ EqualTo(left: AttributeReference, right: Literal) => (None, Seq(((left, right), e)))
+ case e @ EqualTo(left: Literal, right: AttributeReference) => (None, Seq(((right, left), e)))
+ case e @ EqualNullSafe(left: AttributeReference, right: Literal) =>
+ (None, Seq(((left, right), e)))
+ case e @ EqualNullSafe(left: Literal, right: AttributeReference) =>
+ (None, Seq(((right, left), e)))
+ case a: And =>
+ val (newLeft, equalityPredicatesLeft) = traverse(a.left, replaceChildren = false)
+ val (newRight, equalityPredicatesRight) = traverse(a.right, replaceChildren = false)
+ val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight
+ val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
+ Some(And(replaceConstants(newLeft.getOrElse(a.left), equalityPredicates),
+ replaceConstants(newRight.getOrElse(a.right), equalityPredicates)))
+ } else {
+ if (newLeft.isDefined || newRight.isDefined) {
+ Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right)))
+ } else {
+ None
+ }
}
-
- and transform {
- case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants(e)
- case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants(e)
+ (newSelf, equalityPredicates)
+ case o: Or =>
+ // Ignore the EqualityPredicates from children since they are only propagated through And.
+ val (newLeft, _) = traverse(o.left, replaceChildren = true)
+ val (newRight, _) = traverse(o.right, replaceChildren = true)
+ val newSelf = if (newLeft.isDefined || newRight.isDefined) {
+ Some(Or(left = newLeft.getOrElse(o.left), right = newRight.getOrElse((o.right))))
+ } else {
+ None
}
+ (newSelf, Seq.empty)
+ case n: Not =>
+ // Ignore the EqualityPredicates from children since they are only propagated through And.
+ val (newChild, _) = traverse(n.child, replaceChildren = true)
+ (newChild.map(Not), Seq.empty)
+ case _ => (None, Seq.empty)
+ }
+
+ private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates)
+ : Expression = {
+ val constantsMap = AttributeMap(equalityPredicates.map(_._1))
+ val predicates = equalityPredicates.map(_._2).toSet
+ def replaceConstants0(expression: Expression) = expression transform {
+ case a: AttributeReference => constantsMap.getOrElse(a, a)
+ }
+ condition transform {
+ case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e)
+ case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e)
}
}
}
@@ -465,18 +506,21 @@ object NullPropagation extends Rule[LogicalPlan] {
/**
- * Propagate foldable expressions:
* Replace attributes with aliases of the original foldable expressions if possible.
- * Other optimizations will take advantage of the propagated foldable expressions.
- *
+ * Other optimizations will take advantage of the propagated foldable expressions. For example,
+ * this rule can optimize
* {{{
* SELECT 1.0 x, 'abc' y, Now() z ORDER BY x, y, 3
- * ==> SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now()
* }}}
+ * to
+ * {{{
+ * SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now()
+ * }}}
+ * and other rules can further optimize it and remove the ORDER BY operator.
*/
object FoldablePropagation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
- val foldableMap = AttributeMap(plan.flatMap {
+ var foldableMap = AttributeMap(plan.flatMap {
case Project(projectList, _) => projectList.collect {
case a: Alias if a.child.foldable => (a.toAttribute, a)
}
@@ -489,38 +533,44 @@ object FoldablePropagation extends Rule[LogicalPlan] {
if (foldableMap.isEmpty) {
plan
} else {
- var stop = false
CleanupAliases(plan.transformUp {
- // A leaf node should not stop the folding process (note that we are traversing up the
- // tree, starting at the leaf nodes); so we are allowing it.
- case l: LeafNode =>
- l
-
// We can only propagate foldables for a subset of unary nodes.
- case u: UnaryNode if !stop && canPropagateFoldables(u) =>
+ case u: UnaryNode if foldableMap.nonEmpty && canPropagateFoldables(u) =>
u.transformExpressions(replaceFoldable)
- // Allow inner joins. We do not allow outer join, although its output attributes are
- // derived from its children, they are actually different attributes: the output of outer
- // join is not always picked from its children, but can also be null.
+ // Join derives the output attributes from its child while they are actually not the
+ // same attributes. For example, the output of outer join is not always picked from its
+ // children, but can also be null. We should exclude these miss-derived attributes when
+ // propagating the foldable expressions.
// TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes
// of outer join.
- case j @ Join(_, _, Inner, _) if !stop =>
- j.transformExpressions(replaceFoldable)
-
- // We can fold the projections an expand holds. However expand changes the output columns
- // and often reuses the underlying attributes; so we cannot assume that a column is still
- // foldable after the expand has been applied.
- // TODO(hvanhovell): Expand should use new attributes as the output attributes.
- case expand: Expand if !stop =>
- val newExpand = expand.copy(projections = expand.projections.map { projection =>
+ case j @ Join(left, right, joinType, _) if foldableMap.nonEmpty =>
+ val newJoin = j.transformExpressions(replaceFoldable)
+ val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match {
+ case _: InnerLike | LeftExistence(_) => Nil
+ case LeftOuter => right.output
+ case RightOuter => left.output
+ case FullOuter => left.output ++ right.output
+ })
+ foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
+ case (attr, _) => missDerivedAttrsSet.contains(attr)
+ }.toSeq)
+ newJoin
+
+ // We can not replace the attributes in `Expand.output`. If there are other non-leaf
+ // operators that have the `output` field, we should put them here too.
+ case expand: Expand if foldableMap.nonEmpty =>
+ expand.copy(projections = expand.projections.map { projection =>
projection.map(_.transform(replaceFoldable))
})
- stop = true
- newExpand
- case other =>
- stop = true
+ // For other plans, they are not safe to apply foldable propagation, and they should not
+ // propagate foldable expressions from children.
+ case other if foldableMap.nonEmpty =>
+ val childrenOutputSet = AttributeSet(other.children.flatMap(_.output))
+ foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
+ case (attr, _) => childrenOutputSet.contains(attr)
+ }.toSeq)
other
})
}
@@ -574,7 +624,6 @@ object SimplifyCasts extends Rule[LogicalPlan] {
object RemoveDispensableExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case UnaryPositive(child) => child
- case PromotePrecision(child) => child
}
}
@@ -606,6 +655,12 @@ object CombineConcats extends Rule[LogicalPlan] {
stack.pop() match {
case Concat(children) =>
stack.pushAll(children.reverse)
+ // If `spark.sql.function.concatBinaryAsString` is false, nested `Concat` exprs possibly
+ // have `Concat`s with binary output. Since `TypeCoercion` casts them into strings,
+ // we need to handle the case to combine all nested `Concat`s.
+ case c @ Cast(Concat(children), StringType, _) =>
+ val newChildren = children.map { e => c.copy(child = e) }
+ stack.pushAll(newChildren.reverse)
case child =>
flattened += child
}
@@ -613,8 +668,14 @@ object CombineConcats extends Rule[LogicalPlan] {
Concat(flattened)
}
+ private def hasNestedConcats(concat: Concat): Boolean = concat.children.exists {
+ case c: Concat => true
+ case c @ Cast(Concat(children), StringType, _) => true
+ case _ => false
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsDown {
- case concat: Concat if concat.children.exists(_.isInstanceOf[Concat]) =>
+ case concat: Concat if hasNestedConcats(concat) =>
flattenConcats(concat)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index 64b28565eb27c..2673bea648d09 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -270,7 +270,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
/**
* Pull up the correlated predicates and rewrite all subqueries in an operator tree..
*/
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case f @ Filter(_, a: Aggregate) =>
rewriteSubQueries(f, Seq(a, a.child))
// Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 7651d11ee65a8..bdc357d54a878 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -623,7 +623,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
val expressions = expressionList(ctx.expression)
Generate(
UnresolvedGenerator(visitFunctionName(ctx.qualifiedName), expressions),
- join = true,
+ unrequiredChildIndex = Nil,
outer = ctx.OUTER != null,
Some(ctx.tblName.getText.toLowerCase),
ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
index 9b127f91648e6..89347f4b1f7bf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
@@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.catalyst.parser
+import java.util
+
import scala.collection.mutable.StringBuilder
import org.antlr.v4.runtime.{ParserRuleContext, Token}
@@ -39,6 +41,13 @@ object ParserUtils {
throw new ParseException(s"Operation not allowed: $message", ctx)
}
+ def checkDuplicateClauses[T](
+ nodes: util.List[T], clauseName: String, ctx: ParserRuleContext): Unit = {
+ if (nodes.size() > 1) {
+ throw new ParseException(s"Found duplicate clauses: $clauseName", ctx)
+ }
+ }
+
/** Check if duplicate keys exist in a set of key-value pairs. */
def checkDuplicateKeys[T](keyPairs: Seq[(String, T)], ctx: ParserRuleContext): Unit = {
keyPairs.groupBy(_._1).filter(_._2.size > 1).foreach { case (key, _) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala
index 06196b5afb031..7a927e1e083b5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala
@@ -38,7 +38,7 @@ object EventTimeWatermark {
case class EventTimeWatermark(
eventTime: Attribute,
delay: CalendarInterval,
- child: LogicalPlan) extends LogicalPlan {
+ child: LogicalPlan) extends UnaryNode {
// Update the metadata on the eventTime column to include the desired delay.
override val output: Seq[Attribute] = child.output.map { a =>
@@ -60,6 +60,4 @@ case class EventTimeWatermark(
a
}
}
-
- override val children: Seq[LogicalPlan] = child :: Nil
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 14188829db2af..ff2a0ec588567 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -33,58 +33,9 @@ abstract class LogicalPlan
with QueryPlanConstraints
with Logging {
- private var _analyzed: Boolean = false
-
- /**
- * Marks this plan as already analyzed. This should only be called by [[CheckAnalysis]].
- */
- private[catalyst] def setAnalyzed(): Unit = { _analyzed = true }
-
- /**
- * Returns true if this node and its children have already been gone through analysis and
- * verification. Note that this is only an optimization used to avoid analyzing trees that
- * have already been analyzed, and can be reset by transformations.
- */
- def analyzed: Boolean = _analyzed
-
/** Returns true if this subtree has data from a streaming data source. */
def isStreaming: Boolean = children.exists(_.isStreaming == true)
- /**
- * Returns a copy of this node where `rule` has been recursively applied first to all of its
- * children and then itself (post-order). When `rule` does not apply to a given node, it is left
- * unchanged. This function is similar to `transformUp`, but skips sub-trees that have already
- * been marked as analyzed.
- *
- * @param rule the function use to transform this nodes children
- */
- def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = {
- if (!analyzed) {
- val afterRuleOnChildren = mapChildren(_.resolveOperators(rule))
- if (this fastEquals afterRuleOnChildren) {
- CurrentOrigin.withOrigin(origin) {
- rule.applyOrElse(this, identity[LogicalPlan])
- }
- } else {
- CurrentOrigin.withOrigin(origin) {
- rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan])
- }
- }
- } else {
- this
- }
- }
-
- /**
- * Recursively transforms the expressions of a tree, skipping nodes that have already
- * been analyzed.
- */
- def resolveExpressions(r: PartialFunction[Expression, Expression]): LogicalPlan = {
- this resolveOperators {
- case p => p.transformExpressions(r)
- }
- }
-
override def verboseStringWithSuffix: String = {
super.verboseString + statsCache.map(", " + _.toString).getOrElse("")
}
@@ -296,6 +247,8 @@ abstract class UnaryNode extends LogicalPlan {
protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = {
var allConstraints = child.constraints.asInstanceOf[Set[Expression]]
projectList.foreach {
+ case a @ Alias(l: Literal, _) =>
+ allConstraints += EqualTo(a.toAttribute, l)
case a @ Alias(e, _) =>
// For every alias in `projectList`, replace the reference in constraints by its attribute.
allConstraints ++= allConstraints.map(_ transform {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala
index b0f611fd38dea..9c0a30a47f839 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala
@@ -98,7 +98,7 @@ trait QueryPlanConstraints { self: LogicalPlan =>
// we may avoid producing recursive constraints.
private lazy val aliasMap: AttributeMap[Expression] = AttributeMap(
expressions.collect {
- case a: Alias => (a.toAttribute, a.child)
+ case a: Alias if !a.child.isInstanceOf[Literal] => (a.toAttribute, a.child)
} ++ children.flatMap(_.asInstanceOf[QueryPlanConstraints].aliasMap))
// Note: the explicit cast is necessary, since Scala compiler fails to infer the type.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index c2750c3079814..a4fca790dd086 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning,
+ RangePartitioning, RoundRobinPartitioning}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.RandomSampler
@@ -72,8 +73,13 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
* their output.
*
* @param generator the generator expression
- * @param join when true, each output row is implicitly joined with the input tuple that produced
- * it.
+ * @param unrequiredChildIndex this paramter starts as Nil and gets filled by the Optimizer.
+ * It's used as an optimization for omitting data generation that will
+ * be discarded next by a projection.
+ * A common use case is when we explode(array(..)) and are interested
+ * only in the exploded data and not in the original array. before this
+ * optimization the array got duplicated for each of its elements,
+ * causing O(n^^2) memory consumption. (see [SPARK-21657])
* @param outer when true, each input row will be output at least once, even if the output of the
* given `generator` is empty.
* @param qualifier Qualifier for the attributes of generator(UDTF)
@@ -82,15 +88,17 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
*/
case class Generate(
generator: Generator,
- join: Boolean,
+ unrequiredChildIndex: Seq[Int],
outer: Boolean,
qualifier: Option[String],
generatorOutput: Seq[Attribute],
child: LogicalPlan)
extends UnaryNode {
- /** The set of all attributes produced by this node. */
- def generatedSet: AttributeSet = AttributeSet(generatorOutput)
+ lazy val requiredChildOutput: Seq[Attribute] = {
+ val unrequiredSet = unrequiredChildIndex.toSet
+ child.output.zipWithIndex.filterNot(t => unrequiredSet.contains(t._2)).map(_._1)
+ }
override lazy val resolved: Boolean = {
generator.resolved &&
@@ -113,9 +121,7 @@ case class Generate(
nullableOutput
}
- def output: Seq[Attribute] = {
- if (join) child.output ++ qualifiedGeneratorOutput else qualifiedGeneratorOutput
- }
+ def output: Seq[Attribute] = requiredChildOutput ++ qualifiedGeneratorOutput
}
case class Filter(condition: Expression, child: LogicalPlan)
@@ -838,6 +844,27 @@ case class RepartitionByExpression(
require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.")
+ val partitioning: Partitioning = {
+ val (sortOrder, nonSortOrder) = partitionExpressions.partition(_.isInstanceOf[SortOrder])
+
+ require(sortOrder.isEmpty || nonSortOrder.isEmpty,
+ s"${getClass.getSimpleName} expects that either all its `partitionExpressions` are of type " +
+ "`SortOrder`, which means `RangePartitioning`, or none of them are `SortOrder`, which " +
+ "means `HashPartitioning`. In this case we have:" +
+ s"""
+ |SortOrder: $sortOrder
+ |NonSortOrder: $nonSortOrder
+ """.stripMargin)
+
+ if (sortOrder.nonEmpty) {
+ RangePartitioning(sortOrder.map(_.asInstanceOf[SortOrder]), numPartitions)
+ } else if (nonSortOrder.nonEmpty) {
+ HashPartitioning(nonSortOrder, numPartitions)
+ } else {
+ RoundRobinPartitioning(numPartitions)
+ }
+ }
+
override def maxRows: Option[Long] = child.maxRows
override def shuffle: Boolean = true
}
@@ -861,3 +888,23 @@ case class Deduplicate(
override def output: Seq[Attribute] = child.output
}
+
+/**
+ * A logical plan for setting a barrier of analysis.
+ *
+ * The SQL Analyzer goes through a whole query plan even most part of it is analyzed. This
+ * increases the time spent on query analysis for long pipelines in ML, especially.
+ *
+ * This logical plan wraps an analyzed logical plan to prevent it from analysis again. The barrier
+ * is applied to the analyzed logical plan in Dataset. It won't change the output of wrapped
+ * logical plan and just acts as a wrapper to hide it from analyzer. New operations on the dataset
+ * will be put on the barrier, so only the new nodes created will be analyzed.
+ *
+ * This analysis barrier will be removed at the end of analysis stage.
+ */
+case class AnalysisBarrier(child: LogicalPlan) extends LeafNode {
+ override protected def innerChildren: Seq[LogicalPlan] = Seq(child)
+ override def output: Seq[Attribute] = child.output
+ override def isStreaming: Boolean = child.isStreaming
+ override def doCanonicalize(): LogicalPlan = child.canonicalized
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
index 9c34a9b7aa756..d793f77413d18 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
@@ -17,10 +17,11 @@
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
+import scala.collection.mutable.ArrayBuffer
import scala.math.BigDecimal.RoundingMode
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
-import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.{DecimalType, _}
@@ -88,30 +89,296 @@ object EstimationUtils {
}
/**
- * For simplicity we use Decimal to unify operations for data types whose min/max values can be
+ * For simplicity we use Double to unify operations for data types whose min/max values can be
* represented as numbers, e.g. Boolean can be represented as 0 (false) or 1 (true).
* The two methods below are the contract of conversion.
*/
- def toDecimal(value: Any, dataType: DataType): Decimal = {
+ def toDouble(value: Any, dataType: DataType): Double = {
dataType match {
- case _: NumericType | DateType | TimestampType => Decimal(value.toString)
- case BooleanType => if (value.asInstanceOf[Boolean]) Decimal(1) else Decimal(0)
+ case _: NumericType | DateType | TimestampType => value.toString.toDouble
+ case BooleanType => if (value.asInstanceOf[Boolean]) 1 else 0
}
}
- def fromDecimal(dec: Decimal, dataType: DataType): Any = {
+ def fromDouble(double: Double, dataType: DataType): Any = {
dataType match {
- case BooleanType => dec.toLong == 1
- case DateType => dec.toInt
- case TimestampType => dec.toLong
- case ByteType => dec.toByte
- case ShortType => dec.toShort
- case IntegerType => dec.toInt
- case LongType => dec.toLong
- case FloatType => dec.toFloat
- case DoubleType => dec.toDouble
- case _: DecimalType => dec
+ case BooleanType => double.toInt == 1
+ case DateType => double.toInt
+ case TimestampType => double.toLong
+ case ByteType => double.toByte
+ case ShortType => double.toShort
+ case IntegerType => double.toInt
+ case LongType => double.toLong
+ case FloatType => double.toFloat
+ case DoubleType => double
+ case _: DecimalType => Decimal(double)
}
}
+ /**
+ * Returns the index of the first bin into which the given value falls for a specified
+ * numeric equi-height histogram.
+ */
+ private def findFirstBinForValue(value: Double, bins: Array[HistogramBin]): Int = {
+ var i = 0
+ while ((i < bins.length) && (value > bins(i).hi)) {
+ i += 1
+ }
+ i
+ }
+
+ /**
+ * Returns the index of the last bin into which the given value falls for a specified
+ * numeric equi-height histogram.
+ */
+ private def findLastBinForValue(value: Double, bins: Array[HistogramBin]): Int = {
+ var i = bins.length - 1
+ while ((i >= 0) && (value < bins(i).lo)) {
+ i -= 1
+ }
+ i
+ }
+
+ /**
+ * Returns the possibility of the given histogram bin holding values within the given range
+ * [lowerBound, upperBound].
+ */
+ private def binHoldingRangePossibility(
+ upperBound: Double,
+ lowerBound: Double,
+ bin: HistogramBin): Double = {
+ assert(bin.lo <= lowerBound && lowerBound <= upperBound && upperBound <= bin.hi)
+ if (bin.hi == bin.lo) {
+ // the entire bin is covered in the range
+ 1.0
+ } else if (upperBound == lowerBound) {
+ // set percentage to 1/NDV
+ 1.0 / bin.ndv.toDouble
+ } else {
+ // Use proration since the range falls inside this bin.
+ math.min((upperBound - lowerBound) / (bin.hi - bin.lo), 1.0)
+ }
+ }
+
+ /**
+ * Returns the number of histogram bins holding values within the given range
+ * [lowerBound, upperBound].
+ *
+ * Note that the returned value is double type, because the range boundaries usually occupy a
+ * portion of a bin. An extreme case is [value, value] which is generated by equal predicate
+ * `col = value`, we can get higher accuracy by allowing returning portion of histogram bins.
+ *
+ * @param upperBound the highest value of the given range
+ * @param upperBoundInclusive whether the upperBound is included in the range
+ * @param lowerBound the lowest value of the given range
+ * @param lowerBoundInclusive whether the lowerBound is included in the range
+ * @param bins an array of bins for a given numeric equi-height histogram
+ */
+ def numBinsHoldingRange(
+ upperBound: Double,
+ upperBoundInclusive: Boolean,
+ lowerBound: Double,
+ lowerBoundInclusive: Boolean,
+ bins: Array[HistogramBin]): Double = {
+ assert(bins.head.lo <= lowerBound && lowerBound <= upperBound && upperBound <= bins.last.hi,
+ "Given range does not fit in the given histogram.")
+ assert(upperBound != lowerBound || upperBoundInclusive || lowerBoundInclusive,
+ s"'$lowerBound < value < $upperBound' is an invalid range.")
+
+ val upperBinIndex = if (upperBoundInclusive) {
+ findLastBinForValue(upperBound, bins)
+ } else {
+ findFirstBinForValue(upperBound, bins)
+ }
+ val lowerBinIndex = if (lowerBoundInclusive) {
+ findFirstBinForValue(lowerBound, bins)
+ } else {
+ findLastBinForValue(lowerBound, bins)
+ }
+ assert(lowerBinIndex <= upperBinIndex, "Invalid histogram data.")
+
+
+ if (lowerBinIndex == upperBinIndex) {
+ binHoldingRangePossibility(upperBound, lowerBound, bins(lowerBinIndex))
+ } else {
+ // Computes the occupied portion of bins of the upperBound and lowerBound.
+ val lowerBin = bins(lowerBinIndex)
+ val lowerPart = binHoldingRangePossibility(lowerBin.hi, lowerBound, lowerBin)
+
+ val higherBin = bins(upperBinIndex)
+ val higherPart = binHoldingRangePossibility(upperBound, higherBin.lo, higherBin)
+
+ // The total number of bins is lowerPart + higherPart + bins between them
+ lowerPart + higherPart + upperBinIndex - lowerBinIndex - 1
+ }
+ }
+
+ /**
+ * Returns overlapped ranges between two histograms, in the given value range
+ * [lowerBound, upperBound].
+ */
+ def getOverlappedRanges(
+ leftHistogram: Histogram,
+ rightHistogram: Histogram,
+ lowerBound: Double,
+ upperBound: Double): Seq[OverlappedRange] = {
+ val overlappedRanges = new ArrayBuffer[OverlappedRange]()
+ // Only bins whose range intersect [lowerBound, upperBound] have join possibility.
+ val leftBins = leftHistogram.bins
+ .filter(b => b.lo <= upperBound && b.hi >= lowerBound)
+ val rightBins = rightHistogram.bins
+ .filter(b => b.lo <= upperBound && b.hi >= lowerBound)
+
+ leftBins.foreach { lb =>
+ rightBins.foreach { rb =>
+ val (left, leftHeight) = trimBin(lb, leftHistogram.height, lowerBound, upperBound)
+ val (right, rightHeight) = trimBin(rb, rightHistogram.height, lowerBound, upperBound)
+ // Only collect overlapped ranges.
+ if (left.lo <= right.hi && left.hi >= right.lo) {
+ // Collect overlapped ranges.
+ val range = if (right.lo >= left.lo && right.hi >= left.hi) {
+ // Case1: the left bin is "smaller" than the right bin
+ // left.lo right.lo left.hi right.hi
+ // --------+------------------+------------+----------------+------->
+ if (left.hi == right.lo) {
+ // The overlapped range has only one value.
+ OverlappedRange(
+ lo = right.lo,
+ hi = right.lo,
+ leftNdv = 1,
+ rightNdv = 1,
+ leftNumRows = leftHeight / left.ndv,
+ rightNumRows = rightHeight / right.ndv
+ )
+ } else {
+ val leftRatio = (left.hi - right.lo) / (left.hi - left.lo)
+ val rightRatio = (left.hi - right.lo) / (right.hi - right.lo)
+ OverlappedRange(
+ lo = right.lo,
+ hi = left.hi,
+ leftNdv = left.ndv * leftRatio,
+ rightNdv = right.ndv * rightRatio,
+ leftNumRows = leftHeight * leftRatio,
+ rightNumRows = rightHeight * rightRatio
+ )
+ }
+ } else if (right.lo <= left.lo && right.hi <= left.hi) {
+ // Case2: the left bin is "larger" than the right bin
+ // right.lo left.lo right.hi left.hi
+ // --------+------------------+------------+----------------+------->
+ if (right.hi == left.lo) {
+ // The overlapped range has only one value.
+ OverlappedRange(
+ lo = right.hi,
+ hi = right.hi,
+ leftNdv = 1,
+ rightNdv = 1,
+ leftNumRows = leftHeight / left.ndv,
+ rightNumRows = rightHeight / right.ndv
+ )
+ } else {
+ val leftRatio = (right.hi - left.lo) / (left.hi - left.lo)
+ val rightRatio = (right.hi - left.lo) / (right.hi - right.lo)
+ OverlappedRange(
+ lo = left.lo,
+ hi = right.hi,
+ leftNdv = left.ndv * leftRatio,
+ rightNdv = right.ndv * rightRatio,
+ leftNumRows = leftHeight * leftRatio,
+ rightNumRows = rightHeight * rightRatio
+ )
+ }
+ } else if (right.lo >= left.lo && right.hi <= left.hi) {
+ // Case3: the left bin contains the right bin
+ // left.lo right.lo right.hi left.hi
+ // --------+------------------+------------+----------------+------->
+ val leftRatio = (right.hi - right.lo) / (left.hi - left.lo)
+ OverlappedRange(
+ lo = right.lo,
+ hi = right.hi,
+ leftNdv = left.ndv * leftRatio,
+ rightNdv = right.ndv,
+ leftNumRows = leftHeight * leftRatio,
+ rightNumRows = rightHeight
+ )
+ } else {
+ assert(right.lo <= left.lo && right.hi >= left.hi)
+ // Case4: the right bin contains the left bin
+ // right.lo left.lo left.hi right.hi
+ // --------+------------------+------------+----------------+------->
+ val rightRatio = (left.hi - left.lo) / (right.hi - right.lo)
+ OverlappedRange(
+ lo = left.lo,
+ hi = left.hi,
+ leftNdv = left.ndv,
+ rightNdv = right.ndv * rightRatio,
+ leftNumRows = leftHeight,
+ rightNumRows = rightHeight * rightRatio
+ )
+ }
+ overlappedRanges += range
+ }
+ }
+ }
+ overlappedRanges
+ }
+
+ /**
+ * Given an original bin and a value range [lowerBound, upperBound], returns the trimmed part
+ * of the bin in that range and its number of rows.
+ * @param bin the input histogram bin.
+ * @param height the number of rows of the given histogram bin inside an equi-height histogram.
+ * @param lowerBound lower bound of the given range.
+ * @param upperBound upper bound of the given range.
+ * @return trimmed part of the given bin and its number of rows.
+ */
+ def trimBin(bin: HistogramBin, height: Double, lowerBound: Double, upperBound: Double)
+ : (HistogramBin, Double) = {
+ val (lo, hi) = if (bin.lo <= lowerBound && bin.hi >= upperBound) {
+ // bin.lo lowerBound upperBound bin.hi
+ // --------+------------------+------------+-------------+------->
+ (lowerBound, upperBound)
+ } else if (bin.lo <= lowerBound && bin.hi >= lowerBound) {
+ // bin.lo lowerBound bin.hi upperBound
+ // --------+------------------+------------+-------------+------->
+ (lowerBound, bin.hi)
+ } else if (bin.lo <= upperBound && bin.hi >= upperBound) {
+ // lowerBound bin.lo upperBound bin.hi
+ // --------+------------------+------------+-------------+------->
+ (bin.lo, upperBound)
+ } else {
+ // lowerBound bin.lo bin.hi upperBound
+ // --------+------------------+------------+-------------+------->
+ assert(bin.lo >= lowerBound && bin.hi <= upperBound)
+ (bin.lo, bin.hi)
+ }
+
+ if (hi == lo) {
+ // Note that bin.hi == bin.lo also falls into this branch.
+ (HistogramBin(lo, hi, 1), height / bin.ndv)
+ } else {
+ assert(bin.hi != bin.lo)
+ val ratio = (hi - lo) / (bin.hi - bin.lo)
+ (HistogramBin(lo, hi, math.ceil(bin.ndv * ratio).toLong), height * ratio)
+ }
+ }
+
+ /**
+ * A join between two equi-height histograms may produce multiple overlapped ranges.
+ * Each overlapped range is produced by a part of one bin in the left histogram and a part of
+ * one bin in the right histogram.
+ * @param lo lower bound of this overlapped range.
+ * @param hi higher bound of this overlapped range.
+ * @param leftNdv ndv in the left part.
+ * @param rightNdv ndv in the right part.
+ * @param leftNumRows number of rows in the left part.
+ * @param rightNumRows number of rows in the right part.
+ */
+ case class OverlappedRange(
+ lo: Double,
+ hi: Double,
+ leftNdv: Double,
+ rightNdv: Double,
+ leftNumRows: Double,
+ rightNumRows: Double)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
index 74820eb97d081..4cc32de2d32d7 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
@@ -23,7 +23,7 @@ import scala.collection.mutable
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
-import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
import org.apache.spark.sql.types._
@@ -31,7 +31,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
private val childStats = plan.child.stats
- private val colStatsMap = new ColumnStatsMap(childStats.attributeStats)
+ private val colStatsMap = ColumnStatsMap(childStats.attributeStats)
/**
* Returns an option of Statistics for a Filter logical plan node.
@@ -47,7 +47,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
// Estimate selectivity of this filter predicate, and update column stats if needed.
// For not-supported condition, set filter selectivity to a conservative estimate 100%
- val filterSelectivity = calculateFilterSelectivity(plan.condition).getOrElse(BigDecimal(1))
+ val filterSelectivity = calculateFilterSelectivity(plan.condition).getOrElse(1.0)
val filteredRowCount: BigInt = ceil(BigDecimal(childStats.rowCount.get) * filterSelectivity)
val newColStats = if (filteredRowCount == 0) {
@@ -79,17 +79,16 @@ case class FilterEstimation(plan: Filter) extends Logging {
* @return an optional double value to show the percentage of rows meeting a given condition.
* It returns None if the condition is not supported.
*/
- def calculateFilterSelectivity(condition: Expression, update: Boolean = true)
- : Option[BigDecimal] = {
+ def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = {
condition match {
case And(cond1, cond2) =>
- val percent1 = calculateFilterSelectivity(cond1, update).getOrElse(BigDecimal(1))
- val percent2 = calculateFilterSelectivity(cond2, update).getOrElse(BigDecimal(1))
+ val percent1 = calculateFilterSelectivity(cond1, update).getOrElse(1.0)
+ val percent2 = calculateFilterSelectivity(cond2, update).getOrElse(1.0)
Some(percent1 * percent2)
case Or(cond1, cond2) =>
- val percent1 = calculateFilterSelectivity(cond1, update = false).getOrElse(BigDecimal(1))
- val percent2 = calculateFilterSelectivity(cond2, update = false).getOrElse(BigDecimal(1))
+ val percent1 = calculateFilterSelectivity(cond1, update = false).getOrElse(1.0)
+ val percent2 = calculateFilterSelectivity(cond2, update = false).getOrElse(1.0)
Some(percent1 + percent2 - (percent1 * percent2))
// Not-operator pushdown
@@ -131,7 +130,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
* @return an optional double value to show the percentage of rows meeting a given condition.
* It returns None if the condition is not supported.
*/
- def calculateSingleCondition(condition: Expression, update: Boolean): Option[BigDecimal] = {
+ def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = {
condition match {
case l: Literal =>
evaluateLiteral(l)
@@ -225,17 +224,17 @@ case class FilterEstimation(plan: Filter) extends Logging {
def evaluateNullCheck(
attr: Attribute,
isNull: Boolean,
- update: Boolean): Option[BigDecimal] = {
+ update: Boolean): Option[Double] = {
if (!colStatsMap.contains(attr)) {
logDebug("[CBO] No statistics for " + attr)
return None
}
val colStat = colStatsMap(attr)
val rowCountValue = childStats.rowCount.get
- val nullPercent: BigDecimal = if (rowCountValue == 0) {
+ val nullPercent: Double = if (rowCountValue == 0) {
0
} else {
- BigDecimal(colStat.nullCount) / BigDecimal(rowCountValue)
+ (BigDecimal(colStat.nullCount) / BigDecimal(rowCountValue)).toDouble
}
if (update) {
@@ -265,13 +264,13 @@ case class FilterEstimation(plan: Filter) extends Logging {
* @param update a boolean flag to specify if we need to update ColumnStat of a given column
* for subsequent conditions
* @return an optional double value to show the percentage of rows meeting a given condition
- * It returns None if no statistics exists for a given column or wrong value.
+ * It returns None if no statistics exists for a given column or wrong value.
*/
def evaluateBinary(
op: BinaryComparison,
attr: Attribute,
literal: Literal,
- update: Boolean): Option[BigDecimal] = {
+ update: Boolean): Option[Double] = {
if (!colStatsMap.contains(attr)) {
logDebug("[CBO] No statistics for " + attr)
return None
@@ -305,13 +304,12 @@ case class FilterEstimation(plan: Filter) extends Logging {
def evaluateEquality(
attr: Attribute,
literal: Literal,
- update: Boolean): Option[BigDecimal] = {
+ update: Boolean): Option[Double] = {
if (!colStatsMap.contains(attr)) {
logDebug("[CBO] No statistics for " + attr)
return None
}
val colStat = colStatsMap(attr)
- val ndv = colStat.distinctCount
// decide if the value is in [min, max] of the column.
// We currently don't store min/max for binary/string type.
@@ -332,11 +330,16 @@ case class FilterEstimation(plan: Filter) extends Logging {
colStatsMap.update(attr, newStats)
}
- Some(1.0 / BigDecimal(ndv))
- } else {
+ if (colStat.histogram.isEmpty) {
+ // returns 1/ndv if there is no histogram
+ Some(1.0 / colStat.distinctCount.toDouble)
+ } else {
+ Some(computeEqualityPossibilityByHistogram(literal, colStat))
+ }
+
+ } else { // not in interval
Some(0.0)
}
-
}
/**
@@ -349,7 +352,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
* @param literal a literal value (or constant)
* @return an optional double value to show the percentage of rows meeting a given condition
*/
- def evaluateLiteral(literal: Literal): Option[BigDecimal] = {
+ def evaluateLiteral(literal: Literal): Option[Double] = {
literal match {
case Literal(null, _) => Some(0.0)
case FalseLiteral => Some(0.0)
@@ -374,7 +377,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
def evaluateInSet(
attr: Attribute,
hSet: Set[Any],
- update: Boolean): Option[BigDecimal] = {
+ update: Boolean): Option[Double] = {
if (!colStatsMap.contains(attr)) {
logDebug("[CBO] No statistics for " + attr)
return None
@@ -398,8 +401,8 @@ case class FilterEstimation(plan: Filter) extends Logging {
return Some(0.0)
}
- val newMax = validQuerySet.maxBy(EstimationUtils.toDecimal(_, dataType))
- val newMin = validQuerySet.minBy(EstimationUtils.toDecimal(_, dataType))
+ val newMax = validQuerySet.maxBy(EstimationUtils.toDouble(_, dataType))
+ val newMin = validQuerySet.minBy(EstimationUtils.toDouble(_, dataType))
// newNdv should not be greater than the old ndv. For example, column has only 2 values
// 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5.
newNdv = ndv.min(BigInt(validQuerySet.size))
@@ -420,7 +423,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
// return the filter selectivity. Without advanced statistics such as histograms,
// we have to assume uniform distribution.
- Some((BigDecimal(newNdv) / BigDecimal(ndv)).min(1.0))
+ Some(math.min(newNdv.toDouble / ndv.toDouble, 1.0))
}
/**
@@ -438,21 +441,17 @@ case class FilterEstimation(plan: Filter) extends Logging {
op: BinaryComparison,
attr: Attribute,
literal: Literal,
- update: Boolean): Option[BigDecimal] = {
+ update: Boolean): Option[Double] = {
val colStat = colStatsMap(attr)
val statsInterval =
ValueInterval(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericValueInterval]
- val max = statsInterval.max.toBigDecimal
- val min = statsInterval.min.toBigDecimal
- val ndv = BigDecimal(colStat.distinctCount)
+ val max = statsInterval.max
+ val min = statsInterval.min
+ val ndv = colStat.distinctCount.toDouble
// determine the overlapping degree between predicate interval and column's interval
- val numericLiteral = if (literal.dataType == BooleanType) {
- if (literal.value.asInstanceOf[Boolean]) BigDecimal(1) else BigDecimal(0)
- } else {
- BigDecimal(literal.value.toString)
- }
+ val numericLiteral = EstimationUtils.toDouble(literal.value, literal.dataType)
val (noOverlap: Boolean, completeOverlap: Boolean) = op match {
case _: LessThan =>
(numericLiteral <= min, numericLiteral > max)
@@ -464,63 +463,65 @@ case class FilterEstimation(plan: Filter) extends Logging {
(numericLiteral > max, numericLiteral <= min)
}
- var percent = BigDecimal(1)
+ var percent = 1.0
if (noOverlap) {
percent = 0.0
} else if (completeOverlap) {
percent = 1.0
} else {
// This is the partial overlap case:
- // Without advanced statistics like histogram, we assume uniform data distribution.
- // We just prorate the adjusted range over the initial range to compute filter selectivity.
- assert(max > min)
- percent = op match {
- case _: LessThan =>
- if (numericLiteral == max) {
- // If the literal value is right on the boundary, we can minus the part of the
- // boundary value (1/ndv).
- 1.0 - 1.0 / ndv
- } else {
- (numericLiteral - min) / (max - min)
- }
- case _: LessThanOrEqual =>
- if (numericLiteral == min) {
- // The boundary value is the only satisfying value.
- 1.0 / ndv
- } else {
- (numericLiteral - min) / (max - min)
- }
- case _: GreaterThan =>
- if (numericLiteral == min) {
- 1.0 - 1.0 / ndv
- } else {
- (max - numericLiteral) / (max - min)
- }
- case _: GreaterThanOrEqual =>
- if (numericLiteral == max) {
- 1.0 / ndv
- } else {
- (max - numericLiteral) / (max - min)
- }
+
+ if (colStat.histogram.isEmpty) {
+ // Without advanced statistics like histogram, we assume uniform data distribution.
+ // We just prorate the adjusted range over the initial range to compute filter selectivity.
+ assert(max > min)
+ percent = op match {
+ case _: LessThan =>
+ if (numericLiteral == max) {
+ // If the literal value is right on the boundary, we can minus the part of the
+ // boundary value (1/ndv).
+ 1.0 - 1.0 / ndv
+ } else {
+ (numericLiteral - min) / (max - min)
+ }
+ case _: LessThanOrEqual =>
+ if (numericLiteral == min) {
+ // The boundary value is the only satisfying value.
+ 1.0 / ndv
+ } else {
+ (numericLiteral - min) / (max - min)
+ }
+ case _: GreaterThan =>
+ if (numericLiteral == min) {
+ 1.0 - 1.0 / ndv
+ } else {
+ (max - numericLiteral) / (max - min)
+ }
+ case _: GreaterThanOrEqual =>
+ if (numericLiteral == max) {
+ 1.0 / ndv
+ } else {
+ (max - numericLiteral) / (max - min)
+ }
+ }
+ } else {
+ percent = computeComparisonPossibilityByHistogram(op, literal, colStat)
}
if (update) {
val newValue = Some(literal.value)
var newMax = colStat.max
var newMin = colStat.min
- var newNdv = ceil(ndv * percent)
- if (newNdv < 1) newNdv = 1
op match {
case _: GreaterThan | _: GreaterThanOrEqual =>
- // If new ndv is 1, then new max must be equal to new min.
- newMin = if (newNdv == 1) newMax else newValue
+ newMin = newValue
case _: LessThan | _: LessThanOrEqual =>
- newMax = if (newNdv == 1) newMin else newValue
+ newMax = newValue
}
- val newStats =
- colStat.copy(distinctCount = newNdv, min = newMin, max = newMax, nullCount = 0)
+ val newStats = colStat.copy(distinctCount = ceil(ndv * percent),
+ min = newMin, max = newMax, nullCount = 0)
colStatsMap.update(attr, newStats)
}
@@ -529,6 +530,93 @@ case class FilterEstimation(plan: Filter) extends Logging {
Some(percent)
}
+ /**
+ * Computes the possibility of an equality predicate using histogram.
+ */
+ private def computeEqualityPossibilityByHistogram(
+ literal: Literal, colStat: ColumnStat): Double = {
+ val datum = EstimationUtils.toDouble(literal.value, literal.dataType)
+ val histogram = colStat.histogram.get
+
+ // find bins where column's current min and max locate. Note that a column's [min, max]
+ // range may change due to another condition applied earlier.
+ val min = EstimationUtils.toDouble(colStat.min.get, literal.dataType)
+ val max = EstimationUtils.toDouble(colStat.max.get, literal.dataType)
+
+ // compute how many bins the column's current valid range [min, max] occupies.
+ val numBinsHoldingEntireRange = EstimationUtils.numBinsHoldingRange(
+ upperBound = max,
+ upperBoundInclusive = true,
+ lowerBound = min,
+ lowerBoundInclusive = true,
+ histogram.bins)
+
+ val numBinsHoldingDatum = EstimationUtils.numBinsHoldingRange(
+ upperBound = datum,
+ upperBoundInclusive = true,
+ lowerBound = datum,
+ lowerBoundInclusive = true,
+ histogram.bins)
+
+ numBinsHoldingDatum / numBinsHoldingEntireRange
+ }
+
+ /**
+ * Computes the possibility of a comparison predicate using histogram.
+ */
+ private def computeComparisonPossibilityByHistogram(
+ op: BinaryComparison, literal: Literal, colStat: ColumnStat): Double = {
+ val datum = EstimationUtils.toDouble(literal.value, literal.dataType)
+ val histogram = colStat.histogram.get
+
+ // find bins where column's current min and max locate. Note that a column's [min, max]
+ // range may change due to another condition applied earlier.
+ val min = EstimationUtils.toDouble(colStat.min.get, literal.dataType)
+ val max = EstimationUtils.toDouble(colStat.max.get, literal.dataType)
+
+ // compute how many bins the column's current valid range [min, max] occupies.
+ val numBinsHoldingEntireRange = EstimationUtils.numBinsHoldingRange(
+ max, upperBoundInclusive = true, min, lowerBoundInclusive = true, histogram.bins)
+
+ val numBinsHoldingRange = op match {
+ // LessThan and LessThanOrEqual share the same logic, the only difference is whether to
+ // include the upperBound in the range.
+ case _: LessThan =>
+ EstimationUtils.numBinsHoldingRange(
+ upperBound = datum,
+ upperBoundInclusive = false,
+ lowerBound = min,
+ lowerBoundInclusive = true,
+ histogram.bins)
+ case _: LessThanOrEqual =>
+ EstimationUtils.numBinsHoldingRange(
+ upperBound = datum,
+ upperBoundInclusive = true,
+ lowerBound = min,
+ lowerBoundInclusive = true,
+ histogram.bins)
+
+ // GreaterThan and GreaterThanOrEqual share the same logic, the only difference is whether to
+ // include the lowerBound in the range.
+ case _: GreaterThan =>
+ EstimationUtils.numBinsHoldingRange(
+ upperBound = max,
+ upperBoundInclusive = true,
+ lowerBound = datum,
+ lowerBoundInclusive = false,
+ histogram.bins)
+ case _: GreaterThanOrEqual =>
+ EstimationUtils.numBinsHoldingRange(
+ upperBound = max,
+ upperBoundInclusive = true,
+ lowerBound = datum,
+ lowerBoundInclusive = true,
+ histogram.bins)
+ }
+
+ numBinsHoldingRange / numBinsHoldingEntireRange
+ }
+
/**
* Returns a percentage of rows meeting a binary comparison expression containing two columns.
* In SQL queries, we also see predicate expressions involving two columns
@@ -547,7 +635,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
op: BinaryComparison,
attrLeft: Attribute,
attrRight: Attribute,
- update: Boolean): Option[BigDecimal] = {
+ update: Boolean): Option[Double] = {
if (!colStatsMap.contains(attrLeft)) {
logDebug("[CBO] No statistics for " + attrLeft)
@@ -630,7 +718,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
)
}
- var percent = BigDecimal(1)
+ var percent = 1.0
if (noOverlap) {
percent = 0.0
} else if (completeOverlap) {
@@ -644,11 +732,9 @@ case class FilterEstimation(plan: Filter) extends Logging {
// Need to adjust new min/max after the filter condition is applied
val ndvLeft = BigDecimal(colStatLeft.distinctCount)
- var newNdvLeft = ceil(ndvLeft * percent)
- if (newNdvLeft < 1) newNdvLeft = 1
+ val newNdvLeft = ceil(ndvLeft * percent)
val ndvRight = BigDecimal(colStatRight.distinctCount)
- var newNdvRight = ceil(ndvRight * percent)
- if (newNdvRight < 1) newNdvRight = 1
+ val newNdvRight = ceil(ndvRight * percent)
var newMaxLeft = colStatLeft.max
var newMinLeft = colStatLeft.min
@@ -784,11 +870,16 @@ case class ColumnStatsMap(originalMap: AttributeMap[ColumnStat]) {
def outputColumnStats(rowsBeforeFilter: BigInt, rowsAfterFilter: BigInt)
: AttributeMap[ColumnStat] = {
val newColumnStats = originalMap.map { case (attr, oriColStat) =>
- // Update ndv based on the overall filter selectivity: scale down ndv if the number of rows
- // decreases; otherwise keep it unchanged.
- val newNdv = EstimationUtils.updateNdv(oldNumRows = rowsBeforeFilter,
- newNumRows = rowsAfterFilter, oldNdv = oriColStat.distinctCount)
val colStat = updatedMap.get(attr.exprId).map(_._2).getOrElse(oriColStat)
+ val newNdv = if (colStat.distinctCount > 1) {
+ // Update ndv based on the overall filter selectivity: scale down ndv if the number of rows
+ // decreases; otherwise keep it unchanged.
+ EstimationUtils.updateNdv(oldNumRows = rowsBeforeFilter,
+ newNumRows = rowsAfterFilter, oldNdv = oriColStat.distinctCount)
+ } else {
+ // no need to scale down since it is already down to 1 (for skewed distribution case)
+ colStat.distinctCount
+ }
attr -> colStat.copy(distinctCount = newNdv)
}
AttributeMap(newColumnStats.toSeq)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
index b073108c26ee5..f0294a4246703 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
@@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, Join, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
@@ -191,8 +191,19 @@ case class JoinEstimation(join: Join) extends Logging {
val rInterval = ValueInterval(rightKeyStat.min, rightKeyStat.max, rightKey.dataType)
if (ValueInterval.isIntersected(lInterval, rInterval)) {
val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType)
- val (card, joinStat) = computeByNdv(leftKey, rightKey, newMin, newMax)
- keyStatsAfterJoin += (leftKey -> joinStat, rightKey -> joinStat)
+ val (card, joinStat) = (leftKeyStat.histogram, rightKeyStat.histogram) match {
+ case (Some(l: Histogram), Some(r: Histogram)) =>
+ computeByHistogram(leftKey, rightKey, l, r, newMin, newMax)
+ case _ =>
+ computeByNdv(leftKey, rightKey, newMin, newMax)
+ }
+ keyStatsAfterJoin += (
+ // Histograms are propagated as unchanged. During future estimation, they should be
+ // truncated by the updated max/min. In this way, only pointers of the histograms are
+ // propagated and thus reduce memory consumption.
+ leftKey -> joinStat.copy(histogram = leftKeyStat.histogram),
+ rightKey -> joinStat.copy(histogram = rightKeyStat.histogram)
+ )
// Return cardinality estimated from the most selective join keys.
if (card < joinCard) joinCard = card
} else {
@@ -225,6 +236,43 @@ case class JoinEstimation(join: Join) extends Logging {
(ceil(card), newStats)
}
+ /** Compute join cardinality using equi-height histograms. */
+ private def computeByHistogram(
+ leftKey: AttributeReference,
+ rightKey: AttributeReference,
+ leftHistogram: Histogram,
+ rightHistogram: Histogram,
+ newMin: Option[Any],
+ newMax: Option[Any]): (BigInt, ColumnStat) = {
+ val overlappedRanges = getOverlappedRanges(
+ leftHistogram = leftHistogram,
+ rightHistogram = rightHistogram,
+ // Only numeric values have equi-height histograms.
+ lowerBound = newMin.get.toString.toDouble,
+ upperBound = newMax.get.toString.toDouble)
+
+ var card: BigDecimal = 0
+ var totalNdv: Double = 0
+ for (i <- overlappedRanges.indices) {
+ val range = overlappedRanges(i)
+ if (i == 0 || range.hi != overlappedRanges(i - 1).hi) {
+ // If range.hi == overlappedRanges(i - 1).hi, that means the current range has only one
+ // value, and this value is already counted in the previous range. So there is no need to
+ // count it in this range.
+ totalNdv += math.min(range.leftNdv, range.rightNdv)
+ }
+ // Apply the formula in this overlapped range.
+ card += range.leftNumRows * range.rightNumRows / math.max(range.leftNdv, range.rightNdv)
+ }
+
+ val leftKeyStat = leftStats.attributeStats(leftKey)
+ val rightKeyStat = rightStats.attributeStats(rightKey)
+ val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen)
+ val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2
+ val newStats = ColumnStat(ceil(totalNdv), newMin, newMax, 0, newAvgLen, newMaxLen)
+ (ceil(card), newStats)
+ }
+
/**
* Propagate or update column stats for output attributes.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala
index 0caaf796a3b68..f46b4ed764e27 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala
@@ -26,10 +26,10 @@ trait ValueInterval {
def contains(l: Literal): Boolean
}
-/** For simplicity we use decimal to unify operations of numeric intervals. */
-case class NumericValueInterval(min: Decimal, max: Decimal) extends ValueInterval {
+/** For simplicity we use double to unify operations of numeric intervals. */
+case class NumericValueInterval(min: Double, max: Double) extends ValueInterval {
override def contains(l: Literal): Boolean = {
- val lit = EstimationUtils.toDecimal(l.value, l.dataType)
+ val lit = EstimationUtils.toDouble(l.value, l.dataType)
min <= lit && max >= lit
}
}
@@ -56,8 +56,8 @@ object ValueInterval {
case _ if min.isEmpty || max.isEmpty => new NullValueInterval()
case _ =>
NumericValueInterval(
- min = EstimationUtils.toDecimal(min.get, dataType),
- max = EstimationUtils.toDecimal(max.get, dataType))
+ min = EstimationUtils.toDouble(min.get, dataType),
+ max = EstimationUtils.toDouble(max.get, dataType))
}
def isIntersected(r1: ValueInterval, r2: ValueInterval): Boolean = (r1, r2) match {
@@ -84,8 +84,8 @@ object ValueInterval {
// Choose the maximum of two min values, and the minimum of two max values.
val newMin = if (n1.min <= n2.min) n2.min else n1.min
val newMax = if (n1.max <= n2.max) n1.max else n2.max
- (Some(EstimationUtils.fromDecimal(newMin, dt)),
- Some(EstimationUtils.fromDecimal(newMax, dt)))
+ (Some(EstimationUtils.fromDouble(newMin, dt)),
+ Some(EstimationUtils.fromDouble(newMax, dt)))
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index e57c842ce2a36..0189bd73c56bf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -30,18 +30,43 @@ import org.apache.spark.sql.types.{DataType, IntegerType}
* - Intra-partition ordering of data: In this case the distribution describes guarantees made
* about how tuples are distributed within a single partition.
*/
-sealed trait Distribution
+sealed trait Distribution {
+ /**
+ * The required number of partitions for this distribution. If it's None, then any number of
+ * partitions is allowed for this distribution.
+ */
+ def requiredNumPartitions: Option[Int]
+
+ /**
+ * Creates a default partitioning for this distribution, which can satisfy this distribution while
+ * matching the given number of partitions.
+ */
+ def createPartitioning(numPartitions: Int): Partitioning
+}
/**
* Represents a distribution where no promises are made about co-location of data.
*/
-case object UnspecifiedDistribution extends Distribution
+case object UnspecifiedDistribution extends Distribution {
+ override def requiredNumPartitions: Option[Int] = None
+
+ override def createPartitioning(numPartitions: Int): Partitioning = {
+ throw new IllegalStateException("UnspecifiedDistribution does not have default partitioning.")
+ }
+}
/**
* Represents a distribution that only has a single partition and all tuples of the dataset
* are co-located.
*/
-case object AllTuples extends Distribution
+case object AllTuples extends Distribution {
+ override def requiredNumPartitions: Option[Int] = Some(1)
+
+ override def createPartitioning(numPartitions: Int): Partitioning = {
+ assert(numPartitions == 1, "The default partitioning of AllTuples can only have 1 partition.")
+ SinglePartition
+ }
+}
/**
* Represents data where tuples that share the same values for the `clustering`
@@ -51,12 +76,41 @@ case object AllTuples extends Distribution
*/
case class ClusteredDistribution(
clustering: Seq[Expression],
- numPartitions: Option[Int] = None) extends Distribution {
+ requiredNumPartitions: Option[Int] = None) extends Distribution {
require(
clustering != Nil,
"The clustering expressions of a ClusteredDistribution should not be Nil. " +
"An AllTuples should be used to represent a distribution that only has " +
"a single partition.")
+
+ override def createPartitioning(numPartitions: Int): Partitioning = {
+ assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions,
+ s"This ClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " +
+ s"the actual number of partitions is $numPartitions.")
+ HashPartitioning(clustering, numPartitions)
+ }
+}
+
+/**
+ * Represents data where tuples have been clustered according to the hash of the given
+ * `expressions`. The hash function is defined as `HashPartitioning.partitionIdExpression`, so only
+ * [[HashPartitioning]] can satisfy this distribution.
+ *
+ * This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the
+ * number of partitions, this distribution strictly requires which partition the tuple should be in.
+ */
+case class HashClusteredDistribution(expressions: Seq[Expression]) extends Distribution {
+ require(
+ expressions != Nil,
+ "The expressions for hash of a HashPartitionedDistribution should not be Nil. " +
+ "An AllTuples should be used to represent a distribution that only has " +
+ "a single partition.")
+
+ override def requiredNumPartitions: Option[Int] = None
+
+ override def createPartitioning(numPartitions: Int): Partitioning = {
+ HashPartitioning(expressions, numPartitions)
+ }
}
/**
@@ -73,46 +127,31 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
"An AllTuples should be used to represent a distribution that only has " +
"a single partition.")
- // TODO: This is not really valid...
- def clustering: Set[Expression] = ordering.map(_.child).toSet
+ override def requiredNumPartitions: Option[Int] = None
+
+ override def createPartitioning(numPartitions: Int): Partitioning = {
+ RangePartitioning(ordering, numPartitions)
+ }
}
/**
* Represents data where tuples are broadcasted to every node. It is quite common that the
* entire set of tuples is transformed into different data structure.
*/
-case class BroadcastDistribution(mode: BroadcastMode) extends Distribution
+case class BroadcastDistribution(mode: BroadcastMode) extends Distribution {
+ override def requiredNumPartitions: Option[Int] = Some(1)
+
+ override def createPartitioning(numPartitions: Int): Partitioning = {
+ assert(numPartitions == 1,
+ "The default partitioning of BroadcastDistribution can only have 1 partition.")
+ BroadcastPartitioning(mode)
+ }
+}
/**
- * Describes how an operator's output is split across partitions. The `compatibleWith`,
- * `guarantees`, and `satisfies` methods describe relationships between child partitionings,
- * target partitionings, and [[Distribution]]s. These relations are described more precisely in
- * their individual method docs, but at a high level:
- *
- * - `satisfies` is a relationship between partitionings and distributions.
- * - `compatibleWith` is relationships between an operator's child output partitionings.
- * - `guarantees` is a relationship between a child's existing output partitioning and a target
- * output partitioning.
- *
- * Diagrammatically:
- *
- * +--------------+
- * | Distribution |
- * +--------------+
- * ^
- * |
- * satisfies
- * |
- * +--------------+ +--------------+
- * | Child | | Target |
- * +----| Partitioning |----guarantees--->| Partitioning |
- * | +--------------+ +--------------+
- * | ^
- * | |
- * | compatibleWith
- * | |
- * +------------+
- *
+ * Describes how an operator's output is split across partitions. It has 2 major properties:
+ * 1. number of partitions.
+ * 2. if it can satisfy a given distribution.
*/
sealed trait Partitioning {
/** Returns the number of partitions that the data is split across */
@@ -123,113 +162,35 @@ sealed trait Partitioning {
* to satisfy the partitioning scheme mandated by the `required` [[Distribution]],
* i.e. the current dataset does not need to be re-partitioned for the `required`
* Distribution (it is possible that tuples within a partition need to be reorganized).
- */
- def satisfies(required: Distribution): Boolean
-
- /**
- * Returns true iff we can say that the partitioning scheme of this [[Partitioning]]
- * guarantees the same partitioning scheme described by `other`.
- *
- * Compatibility of partitionings is only checked for operators that have multiple children
- * and that require a specific child output [[Distribution]], such as joins.
- *
- * Intuitively, partitionings are compatible if they route the same partitioning key to the same
- * partition. For instance, two hash partitionings are only compatible if they produce the same
- * number of output partitionings and hash records according to the same hash function and
- * same partitioning key schema.
- *
- * Put another way, two partitionings are compatible with each other if they satisfy all of the
- * same distribution guarantees.
- */
- def compatibleWith(other: Partitioning): Boolean
-
- /**
- * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] guarantees
- * the same partitioning scheme described by `other`. If a `A.guarantees(B)`, then repartitioning
- * the child's output according to `B` will be unnecessary. `guarantees` is used as a performance
- * optimization to allow the exchange planner to avoid redundant repartitionings. By default,
- * a partitioning only guarantees partitionings that are equal to itself (i.e. the same number
- * of partitions, same strategy (range or hash), etc).
- *
- * In order to enable more aggressive optimization, this strict equality check can be relaxed.
- * For example, say that the planner needs to repartition all of an operator's children so that
- * they satisfy the [[AllTuples]] distribution. One way to do this is to repartition all children
- * to have the [[SinglePartition]] partitioning. If one of the operator's children already happens
- * to be hash-partitioned with a single partition then we do not need to re-shuffle this child;
- * this repartitioning can be avoided if a single-partition [[HashPartitioning]] `guarantees`
- * [[SinglePartition]].
- *
- * The SinglePartition example given above is not particularly interesting; guarantees' real
- * value occurs for more advanced partitioning strategies. SPARK-7871 will introduce a notion
- * of null-safe partitionings, under which partitionings can specify whether rows whose
- * partitioning keys contain null values will be grouped into the same partition or whether they
- * will have an unknown / random distribution. If a partitioning does not require nulls to be
- * clustered then a partitioning which _does_ cluster nulls will guarantee the null clustered
- * partitioning. The converse is not true, however: a partitioning which clusters nulls cannot
- * be guaranteed by one which does not cluster them. Thus, in general `guarantees` is not a
- * symmetric relation.
*
- * Another way to think about `guarantees`: if `A.guarantees(B)`, then any partitioning of rows
- * produced by `A` could have also been produced by `B`.
+ * By default a [[Partitioning]] can satisfy [[UnspecifiedDistribution]], and [[AllTuples]] if
+ * the [[Partitioning]] only have one partition. Implementations can overwrite this method with
+ * special logic.
*/
- def guarantees(other: Partitioning): Boolean = this == other
-}
-
-object Partitioning {
- def allCompatible(partitionings: Seq[Partitioning]): Boolean = {
- // Note: this assumes transitivity
- partitionings.sliding(2).map {
- case Seq(a) => true
- case Seq(a, b) =>
- if (a.numPartitions != b.numPartitions) {
- assert(!a.compatibleWith(b) && !b.compatibleWith(a))
- false
- } else {
- a.compatibleWith(b) && b.compatibleWith(a)
- }
- }.forall(_ == true)
- }
-}
-
-case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
- override def satisfies(required: Distribution): Boolean = required match {
+ def satisfies(required: Distribution): Boolean = required match {
case UnspecifiedDistribution => true
+ case AllTuples => numPartitions == 1
case _ => false
}
-
- override def compatibleWith(other: Partitioning): Boolean = false
-
- override def guarantees(other: Partitioning): Boolean = false
}
+case class UnknownPartitioning(numPartitions: Int) extends Partitioning
+
/**
* Represents a partitioning where rows are distributed evenly across output partitions
* by starting from a random target partition number and distributing rows in a round-robin
* fashion. This partitioning is used when implementing the DataFrame.repartition() operator.
*/
-case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning {
- override def satisfies(required: Distribution): Boolean = required match {
- case UnspecifiedDistribution => true
- case _ => false
- }
-
- override def compatibleWith(other: Partitioning): Boolean = false
-
- override def guarantees(other: Partitioning): Boolean = false
-}
+case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning
case object SinglePartition extends Partitioning {
val numPartitions = 1
override def satisfies(required: Distribution): Boolean = required match {
case _: BroadcastDistribution => false
- case ClusteredDistribution(_, desiredPartitions) => desiredPartitions.forall(_ == 1)
+ case ClusteredDistribution(_, Some(requiredNumPartitions)) => requiredNumPartitions == 1
case _ => true
}
-
- override def compatibleWith(other: Partitioning): Boolean = other.numPartitions == 1
-
- override def guarantees(other: Partitioning): Boolean = other.numPartitions == 1
}
/**
@@ -244,22 +205,19 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
override def nullable: Boolean = false
override def dataType: DataType = IntegerType
- override def satisfies(required: Distribution): Boolean = required match {
- case UnspecifiedDistribution => true
- case ClusteredDistribution(requiredClustering, desiredPartitions) =>
- expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) &&
- desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true
- case _ => false
- }
-
- override def compatibleWith(other: Partitioning): Boolean = other match {
- case o: HashPartitioning => this.semanticEquals(o)
- case _ => false
- }
-
- override def guarantees(other: Partitioning): Boolean = other match {
- case o: HashPartitioning => this.semanticEquals(o)
- case _ => false
+ override def satisfies(required: Distribution): Boolean = {
+ super.satisfies(required) || {
+ required match {
+ case h: HashClusteredDistribution =>
+ expressions.length == h.expressions.length && expressions.zip(h.expressions).forall {
+ case (l, r) => l.semanticEquals(r)
+ }
+ case ClusteredDistribution(requiredClustering, requiredNumPartitions) =>
+ expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) &&
+ (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions)
+ case _ => false
+ }
+ }
}
/**
@@ -288,25 +246,18 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
override def nullable: Boolean = false
override def dataType: DataType = IntegerType
- override def satisfies(required: Distribution): Boolean = required match {
- case UnspecifiedDistribution => true
- case OrderedDistribution(requiredOrdering) =>
- val minSize = Seq(requiredOrdering.size, ordering.size).min
- requiredOrdering.take(minSize) == ordering.take(minSize)
- case ClusteredDistribution(requiredClustering, desiredPartitions) =>
- ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) &&
- desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true
- case _ => false
- }
-
- override def compatibleWith(other: Partitioning): Boolean = other match {
- case o: RangePartitioning => this.semanticEquals(o)
- case _ => false
- }
-
- override def guarantees(other: Partitioning): Boolean = other match {
- case o: RangePartitioning => this.semanticEquals(o)
- case _ => false
+ override def satisfies(required: Distribution): Boolean = {
+ super.satisfies(required) || {
+ required match {
+ case OrderedDistribution(requiredOrdering) =>
+ val minSize = Seq(requiredOrdering.size, ordering.size).min
+ requiredOrdering.take(minSize) == ordering.take(minSize)
+ case ClusteredDistribution(requiredClustering, requiredNumPartitions) =>
+ ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) &&
+ (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions)
+ case _ => false
+ }
+ }
}
}
@@ -347,20 +298,6 @@ case class PartitioningCollection(partitionings: Seq[Partitioning])
override def satisfies(required: Distribution): Boolean =
partitionings.exists(_.satisfies(required))
- /**
- * Returns true if any `partitioning` of this collection is compatible with
- * the given [[Partitioning]].
- */
- override def compatibleWith(other: Partitioning): Boolean =
- partitionings.exists(_.compatibleWith(other))
-
- /**
- * Returns true if any `partitioning` of this collection guarantees
- * the given [[Partitioning]].
- */
- override def guarantees(other: Partitioning): Boolean =
- partitionings.exists(_.guarantees(other))
-
override def toString: String = {
partitionings.map(_.toString).mkString("(", " or ", ")")
}
@@ -377,9 +314,4 @@ case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning {
case BroadcastDistribution(m) if m == mode => true
case _ => false
}
-
- override def compatibleWith(other: Partitioning): Boolean = other match {
- case BroadcastPartitioning(m) if m == mode => true
- case _ => false
- }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index 746c3e8950f7b..fa69b8af62c85 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -45,7 +45,8 @@ object DateTimeUtils {
// it's 2440587.5, rounding up to compatible with Hive
final val JULIAN_DAY_OF_EPOCH = 2440588
final val SECONDS_PER_DAY = 60 * 60 * 24L
- final val MICROS_PER_SECOND = 1000L * 1000L
+ final val MICROS_PER_MILLIS = 1000L
+ final val MICROS_PER_SECOND = MICROS_PER_MILLIS * MILLIS_PER_SECOND
final val MILLIS_PER_SECOND = 1000L
final val NANOS_PER_SECOND = MICROS_PER_SECOND * 1000L
final val MICROS_PER_DAY = MICROS_PER_SECOND * SECONDS_PER_DAY
@@ -61,6 +62,7 @@ object DateTimeUtils {
final val YearZero = -17999
final val toYearZero = to2001 + 7304850
final val TimeZoneGMT = TimeZone.getTimeZone("GMT")
+ final val TimeZoneUTC = TimeZone.getTimeZone("UTC")
final val MonthOf31Days = Set(1, 3, 5, 7, 8, 10, 12)
val TIMEZONE_OPTION = "timeZone"
@@ -908,6 +910,15 @@ object DateTimeUtils {
math.round(diff * 1e8) / 1e8
}
+ // Thursday = 0 since 1970/Jan/01 => Thursday
+ private val SUNDAY = 3
+ private val MONDAY = 4
+ private val TUESDAY = 5
+ private val WEDNESDAY = 6
+ private val THURSDAY = 0
+ private val FRIDAY = 1
+ private val SATURDAY = 2
+
/*
* Returns day of week from String. Starting from Thursday, marked as 0.
* (Because 1970-01-01 is Thursday).
@@ -915,13 +926,13 @@ object DateTimeUtils {
def getDayOfWeekFromString(string: UTF8String): Int = {
val dowString = string.toString.toUpperCase(Locale.ROOT)
dowString match {
- case "SU" | "SUN" | "SUNDAY" => 3
- case "MO" | "MON" | "MONDAY" => 4
- case "TU" | "TUE" | "TUESDAY" => 5
- case "WE" | "WED" | "WEDNESDAY" => 6
- case "TH" | "THU" | "THURSDAY" => 0
- case "FR" | "FRI" | "FRIDAY" => 1
- case "SA" | "SAT" | "SATURDAY" => 2
+ case "SU" | "SUN" | "SUNDAY" => SUNDAY
+ case "MO" | "MON" | "MONDAY" => MONDAY
+ case "TU" | "TUE" | "TUESDAY" => TUESDAY
+ case "WE" | "WED" | "WEDNESDAY" => WEDNESDAY
+ case "TH" | "THU" | "THURSDAY" => THURSDAY
+ case "FR" | "FRI" | "FRIDAY" => FRIDAY
+ case "SA" | "SAT" | "SATURDAY" => SATURDAY
case _ => -1
}
}
@@ -943,9 +954,16 @@ object DateTimeUtils {
date + daysToMonthEnd
}
- private val TRUNC_TO_YEAR = 1
- private val TRUNC_TO_MONTH = 2
- private val TRUNC_INVALID = -1
+ // Visible for testing.
+ private[sql] val TRUNC_TO_YEAR = 1
+ private[sql] val TRUNC_TO_MONTH = 2
+ private[sql] val TRUNC_TO_QUARTER = 3
+ private[sql] val TRUNC_TO_WEEK = 4
+ private[sql] val TRUNC_TO_DAY = 5
+ private[sql] val TRUNC_TO_HOUR = 6
+ private[sql] val TRUNC_TO_MINUTE = 7
+ private[sql] val TRUNC_TO_SECOND = 8
+ private[sql] val TRUNC_INVALID = -1
/**
* Returns the trunc date from original date and trunc level.
@@ -963,7 +981,62 @@ object DateTimeUtils {
}
/**
- * Returns the truncate level, could be TRUNC_YEAR, TRUNC_MONTH, or TRUNC_INVALID,
+ * Returns the trunc date time from original date time and trunc level.
+ * Trunc level should be generated using `parseTruncLevel()`, should be between 1 and 8
+ */
+ def truncTimestamp(t: SQLTimestamp, level: Int, timeZone: TimeZone): SQLTimestamp = {
+ var millis = t / MICROS_PER_MILLIS
+ val truncated = level match {
+ case TRUNC_TO_YEAR =>
+ val dDays = millisToDays(millis, timeZone)
+ daysToMillis(truncDate(dDays, level), timeZone)
+ case TRUNC_TO_MONTH =>
+ val dDays = millisToDays(millis, timeZone)
+ daysToMillis(truncDate(dDays, level), timeZone)
+ case TRUNC_TO_DAY =>
+ val offset = timeZone.getOffset(millis)
+ millis += offset
+ millis - millis % (MILLIS_PER_SECOND * SECONDS_PER_DAY) - offset
+ case TRUNC_TO_HOUR =>
+ val offset = timeZone.getOffset(millis)
+ millis += offset
+ millis - millis % (60 * 60 * MILLIS_PER_SECOND) - offset
+ case TRUNC_TO_MINUTE =>
+ millis - millis % (60 * MILLIS_PER_SECOND)
+ case TRUNC_TO_SECOND =>
+ millis - millis % MILLIS_PER_SECOND
+ case TRUNC_TO_WEEK =>
+ val dDays = millisToDays(millis, timeZone)
+ val prevMonday = getNextDateForDayOfWeek(dDays - 7, MONDAY)
+ daysToMillis(prevMonday, timeZone)
+ case TRUNC_TO_QUARTER =>
+ val dDays = millisToDays(millis, timeZone)
+ millis = daysToMillis(truncDate(dDays, TRUNC_TO_MONTH), timeZone)
+ val cal = Calendar.getInstance()
+ cal.setTimeInMillis(millis)
+ val quarter = getQuarter(dDays)
+ val month = quarter match {
+ case 1 => Calendar.JANUARY
+ case 2 => Calendar.APRIL
+ case 3 => Calendar.JULY
+ case 4 => Calendar.OCTOBER
+ }
+ cal.set(Calendar.MONTH, month)
+ cal.getTimeInMillis()
+ case _ =>
+ // caller make sure that this should never be reached
+ sys.error(s"Invalid trunc level: $level")
+ }
+ truncated * MICROS_PER_MILLIS
+ }
+
+ def truncTimestamp(d: SQLTimestamp, level: Int): SQLTimestamp = {
+ truncTimestamp(d, level, defaultTimeZone())
+ }
+
+ /**
+ * Returns the truncate level, could be TRUNC_YEAR, TRUNC_MONTH, TRUNC_TO_DAY, TRUNC_TO_HOUR,
+ * TRUNC_TO_MINUTE, TRUNC_TO_SECOND, TRUNC_TO_WEEK, TRUNC_TO_QUARTER or TRUNC_INVALID,
* TRUNC_INVALID means unsupported truncate level.
*/
def parseTruncLevel(format: UTF8String): Int = {
@@ -973,6 +1046,12 @@ object DateTimeUtils {
format.toString.toUpperCase(Locale.ROOT) match {
case "YEAR" | "YYYY" | "YY" => TRUNC_TO_YEAR
case "MON" | "MONTH" | "MM" => TRUNC_TO_MONTH
+ case "DAY" | "DD" => TRUNC_TO_DAY
+ case "HOUR" => TRUNC_TO_HOUR
+ case "MINUTE" => TRUNC_TO_MINUTE
+ case "SECOND" => TRUNC_TO_SECOND
+ case "WEEK" => TRUNC_TO_WEEK
+ case "QUARTER" => TRUNC_TO_QUARTER
case _ => TRUNC_INVALID
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala
index eb7941cf9e6af..b013add9c9778 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala
@@ -105,7 +105,7 @@ class QuantileSummaries(
if (newSamples.isEmpty || (sampleIdx == sampled.length && opsIdx == sorted.length - 1)) {
0
} else {
- math.floor(2 * relativeError * currentCount).toInt
+ math.floor(2 * relativeError * currentCount).toLong
}
val tuple = Stats(currentSample, 1, delta)
@@ -192,10 +192,10 @@ class QuantileSummaries(
}
// Target rank
- val rank = math.ceil(quantile * count).toInt
+ val rank = math.ceil(quantile * count).toLong
val targetError = relativeError * count
// Minimum rank at current sample
- var minRank = 0
+ var minRank = 0L
var i = 0
while (i < sampled.length - 1) {
val curSample = sampled(i)
@@ -235,7 +235,7 @@ object QuantileSummaries {
* @param g the minimum rank jump from the previous value's minimum rank
* @param delta the maximum span of the rank.
*/
- case class Stats(value: Double, g: Int, delta: Int)
+ case class Stats(value: Double, g: Long, delta: Long)
private def compressImmut(
currentSamples: IndexedSeq[Stats],
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 551cad70b3919..39387349755a5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -23,14 +23,17 @@ import java.util.concurrent.atomic.AtomicReference
import scala.collection.JavaConverters._
import scala.collection.immutable
+import scala.util.matching.Regex
import org.apache.hadoop.fs.Path
+import org.apache.spark.{SparkContext, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator
+import org.apache.spark.util.Utils
////////////////////////////////////////////////////////////////////////////////////////////////////
// This file defines the configuration options for Spark SQL.
@@ -69,7 +72,7 @@ object SQLConf {
* Default config. Only used when there is no active SparkSession for the thread.
* See [[get]] for more information.
*/
- private val fallbackConf = new ThreadLocal[SQLConf] {
+ private lazy val fallbackConf = new ThreadLocal[SQLConf] {
override def initialValue: SQLConf = new SQLConf
}
@@ -246,7 +249,7 @@ object SQLConf {
val CONSTRAINT_PROPAGATION_ENABLED = buildConf("spark.sql.constraintPropagation.enabled")
.internal()
.doc("When true, the query optimizer will infer and propagate data constraints in the query " +
- "plan to optimize them. Constraint propagation can sometimes be computationally expensive" +
+ "plan to optimize them. Constraint propagation can sometimes be computationally expensive " +
"for certain kinds of query plans (such as those with a large number of predicates and " +
"aliases) which might negatively impact overall runtime.")
.booleanConf
@@ -260,6 +263,15 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val FILE_COMRESSION_FACTOR = buildConf("spark.sql.sources.fileCompressionFactor")
+ .internal()
+ .doc("When estimating the output data size of a table scan, multiply the file size with this " +
+ "factor as the estimated data size, in case the data is compressed in the file and lead to" +
+ " a heavily underestimated result.")
+ .doubleConf
+ .checkValue(_ > 0, "the value of fileDataSizeFactor must be larger than 0")
+ .createWithDefault(1.0)
+
val PARQUET_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.parquet.mergeSchema")
.doc("When true, the Parquet data source merges schemas collected from all data files, " +
"otherwise the schema is picked from the summary file or a random data file " +
@@ -291,6 +303,13 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val PARQUET_INT96_TIMESTAMP_CONVERSION = buildConf("spark.sql.parquet.int96TimestampConversion")
+ .doc("This controls whether timestamp adjustments should be applied to INT96 data when " +
+ "converting to timestamps, for data written by Impala. This is necessary because Impala " +
+ "stores INT96 data with a different timezone offset than Hive & Spark.")
+ .booleanConf
+ .createWithDefault(false)
+
object ParquetOutputTimestampType extends Enumeration {
val INT96, TIMESTAMP_MICROS, TIMESTAMP_MILLIS = Value
}
@@ -321,11 +340,14 @@ object SQLConf {
.createWithDefault(false)
val PARQUET_COMPRESSION = buildConf("spark.sql.parquet.compression.codec")
- .doc("Sets the compression codec use when writing Parquet files. Acceptable values include: " +
- "uncompressed, snappy, gzip, lzo.")
+ .doc("Sets the compression codec used when writing Parquet files. If either `compression` or " +
+ "`parquet.compression` is specified in the table-specific options/properties, the " +
+ "precedence would be `compression`, `parquet.compression`, " +
+ "`spark.sql.parquet.compression.codec`. Acceptable values include: none, uncompressed, " +
+ "snappy, gzip, lzo.")
.stringConf
.transform(_.toLowerCase(Locale.ROOT))
- .checkValues(Set("uncompressed", "snappy", "gzip", "lzo"))
+ .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo"))
.createWithDefault("snappy")
val PARQUET_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.filterPushdown")
@@ -334,8 +356,8 @@ object SQLConf {
.createWithDefault(true)
val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat")
- .doc("Whether to follow Parquet's format specification when converting Parquet schema to " +
- "Spark SQL schema and vice versa.")
+ .doc("Whether to be compatible with the legacy Parquet format adopted by Spark 1.4 and prior " +
+ "versions, when converting Parquet schema to Spark SQL schema and vice versa.")
.booleanConf
.createWithDefault(false)
@@ -367,13 +389,35 @@ object SQLConf {
.createWithDefault(true)
val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec")
- .doc("Sets the compression codec use when writing ORC files. Acceptable values include: " +
- "none, uncompressed, snappy, zlib, lzo.")
+ .doc("Sets the compression codec used when writing ORC files. If either `compression` or " +
+ "`orc.compress` is specified in the table-specific options/properties, the precedence " +
+ "would be `compression`, `orc.compress`, `spark.sql.orc.compression.codec`." +
+ "Acceptable values include: none, uncompressed, snappy, zlib, lzo.")
.stringConf
.transform(_.toLowerCase(Locale.ROOT))
.checkValues(Set("none", "uncompressed", "snappy", "zlib", "lzo"))
.createWithDefault("snappy")
+ val ORC_IMPLEMENTATION = buildConf("spark.sql.orc.impl")
+ .doc("When native, use the native version of ORC support instead of the ORC library in Hive " +
+ "1.2.1. It is 'hive' by default prior to Spark 2.3.")
+ .internal()
+ .stringConf
+ .checkValues(Set("hive", "native"))
+ .createWithDefault("native")
+
+ val ORC_VECTORIZED_READER_ENABLED = buildConf("spark.sql.orc.enableVectorizedReader")
+ .doc("Enables vectorized orc decoding.")
+ .booleanConf
+ .createWithDefault(true)
+
+ val ORC_COPY_BATCH_TO_SPARK = buildConf("spark.sql.orc.copyBatchToSpark")
+ .doc("Whether or not to copy the ORC columnar batch to Spark columnar batch in the " +
+ "vectorized ORC reader.")
+ .internal()
+ .booleanConf
+ .createWithDefault(false)
+
val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown")
.doc("When true, enable filter pushdown for ORC files.")
.booleanConf
@@ -1031,6 +1075,60 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val SQL_STRING_REDACTION_PATTERN =
+ ConfigBuilder("spark.sql.redaction.string.regex")
+ .doc("Regex to decide which parts of strings produced by Spark contain sensitive " +
+ "information. When this regex matches a string part, that string part is replaced by a " +
+ "dummy value. This is currently used to redact the output of SQL explain commands. " +
+ "When this conf is not set, the value from `spark.redaction.string.regex` is used.")
+ .fallbackConf(org.apache.spark.internal.config.STRING_REDACTION_PATTERN)
+
+ val CONCAT_BINARY_AS_STRING = buildConf("spark.sql.function.concatBinaryAsString")
+ .doc("When this option is set to false and all inputs are binary, `functions.concat` returns " +
+ "an output as binary. Otherwise, it returns as a string. ")
+ .booleanConf
+ .createWithDefault(false)
+
+ val ELT_OUTPUT_AS_STRING = buildConf("spark.sql.function.eltOutputAsString")
+ .doc("When this option is set to false and all inputs are binary, `elt` returns " +
+ "an output as binary. Otherwise, it returns as a string. ")
+ .booleanConf
+ .createWithDefault(false)
+
+ val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE =
+ buildConf("spark.sql.streaming.continuous.executorQueueSize")
+ .internal()
+ .doc("The size (measured in number of rows) of the queue used in continuous execution to" +
+ " buffer the results of a ContinuousDataReader.")
+ .intConf
+ .createWithDefault(1024)
+
+ val CONTINUOUS_STREAMING_EXECUTOR_POLL_INTERVAL_MS =
+ buildConf("spark.sql.streaming.continuous.executorPollIntervalMs")
+ .internal()
+ .doc("The interval at which continuous execution readers will poll to check whether" +
+ " the epoch has advanced on the driver.")
+ .timeConf(TimeUnit.MILLISECONDS)
+ .createWithDefault(100)
+
+ object PartitionOverwriteMode extends Enumeration {
+ val STATIC, DYNAMIC = Value
+ }
+
+ val PARTITION_OVERWRITE_MODE =
+ buildConf("spark.sql.sources.partitionOverwriteMode")
+ .doc("When INSERT OVERWRITE a partitioned data source table, we currently support 2 modes: " +
+ "static and dynamic. In static mode, Spark deletes all the partitions that match the " +
+ "partition specification(e.g. PARTITION(a=1,b)) in the INSERT statement, before " +
+ "overwriting. In dynamic mode, Spark doesn't delete partitions ahead, and only overwrite " +
+ "those partitions that have data written into it at runtime. By default we use static " +
+ "mode to keep the same behavior of Spark prior to 2.3. Note that this config doesn't " +
+ "affect Hive serde tables, as they are always overwritten with dynamic mode.")
+ .stringConf
+ .transform(_.toUpperCase(Locale.ROOT))
+ .checkValues(PartitionOverwriteMode.values.map(_.toString))
+ .createWithDefault(PartitionOverwriteMode.STATIC.toString)
+
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
@@ -1052,6 +1150,12 @@ object SQLConf {
class SQLConf extends Serializable with Logging {
import SQLConf._
+ if (Utils.isTesting && SparkEnv.get != null) {
+ // assert that we're only accessing it on the driver.
+ assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER,
+ "SQLConf should only be created and accessed on the driver.")
+ }
+
/** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */
@transient protected[spark] val settings = java.util.Collections.synchronizedMap(
new java.util.HashMap[String, String]())
@@ -1111,6 +1215,8 @@ class SQLConf extends Serializable with Logging {
def orcCompressionCodec: String = getConf(ORC_COMPRESSION)
+ def orcVectorizedReaderEnabled: Boolean = getConf(ORC_VECTORIZED_READER_ENABLED)
+
def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION)
def parquetVectorizedReaderEnabled: Boolean = getConf(PARQUET_VECTORIZED_READER_ENABLED)
@@ -1171,6 +1277,10 @@ class SQLConf extends Serializable with Logging {
def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS)
+ def fileCompressionFactor: Double = getConf(FILE_COMRESSION_FACTOR)
+
+ def stringRedationPattern: Option[Regex] = SQL_STRING_REDACTION_PATTERN.readFrom(reader)
+
/**
* Returns the [[Resolver]] for the current configuration, which can be used to determine if two
* identifiers are equal.
@@ -1211,6 +1321,8 @@ class SQLConf extends Serializable with Logging {
def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP)
+ def isParquetINT96TimestampConversion: Boolean = getConf(PARQUET_INT96_TIMESTAMP_CONVERSION)
+
def isParquetINT64AsTimestampMillis: Boolean = getConf(PARQUET_INT64_AS_TIMESTAMP_MILLIS)
def parquetOutputTimestampType: ParquetOutputTimestampType.Value = {
@@ -1352,6 +1464,18 @@ class SQLConf extends Serializable with Logging {
def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER)
+ def continuousStreamingExecutorQueueSize: Int = getConf(CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE)
+
+ def continuousStreamingExecutorPollIntervalMs: Long =
+ getConf(CONTINUOUS_STREAMING_EXECUTOR_POLL_INTERVAL_MS)
+
+ def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING)
+
+ def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING)
+
+ def partitionOverwriteMode: PartitionOverwriteMode.Value =
+ PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE))
+
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
@@ -1385,7 +1509,7 @@ class SQLConf extends Serializable with Logging {
Option(settings.get(key)).
orElse {
// Try to use the default value
- Option(sqlConfEntries.get(key)).map(_.defaultValueString)
+ Option(sqlConfEntries.get(key)).map { e => e.stringConverter(e.readFrom(reader)) }
}.
getOrElse(throw new NoSuchElementException(key))
}
@@ -1423,14 +1547,21 @@ class SQLConf extends Serializable with Logging {
* not set yet, return `defaultValue`.
*/
def getConfString(key: String, defaultValue: String): String = {
- if (defaultValue != null && defaultValue != "") {
+ if (defaultValue != null && defaultValue != ConfigEntry.UNDEFINED) {
val entry = sqlConfEntries.get(key)
if (entry != null) {
// Only verify configs in the SQLConf object
entry.valueConverter(defaultValue)
}
}
- Option(settings.get(key)).getOrElse(defaultValue)
+ Option(settings.get(key)).getOrElse {
+ // If the key is not set, need to check whether the config entry is registered and is
+ // a fallback conf, so that we can check its parent.
+ sqlConfEntries.get(key) match {
+ case e: FallbackConfigEntry[_] => getConfString(e.fallback.key, defaultValue)
+ case _ => defaultValue
+ }
+ }
}
/**
@@ -1446,7 +1577,8 @@ class SQLConf extends Serializable with Logging {
*/
def getAllDefinedConfs: Seq[(String, String, String)] = sqlConfEntries.synchronized {
sqlConfEntries.values.asScala.filter(_.isPublic).map { entry =>
- (entry.key, getConfString(entry.key, entry.defaultValueString), entry.doc)
+ val displayValue = Option(getConfString(entry.key, null)).getOrElse(entry.defaultValueString)
+ (entry.key, displayValue, entry.doc)
}.toSeq
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
index af2bb44332dd4..a7f594837d3cf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
@@ -99,4 +99,11 @@ object StaticSQLConf {
.stringConf
.toSequence
.createOptional
+
+ val UI_RETAINED_EXECUTIONS =
+ buildStaticConf("spark.sql.ui.retainedExecutions")
+ .doc("Number of executions to retain in the Spark UI.")
+ .intConf
+ .createWithDefault(1000)
+
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala
deleted file mode 100644
index 5b802ccc637dd..0000000000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala
+++ /dev/null
@@ -1,55 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal}
-import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning}
-
-class PartitioningSuite extends SparkFunSuite {
- test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") {
- val expressions = Seq(Literal(2), Literal(3))
- // Consider two HashPartitionings that have the same _set_ of hash expressions but which are
- // created with different orderings of those expressions:
- val partitioningA = HashPartitioning(expressions, 100)
- val partitioningB = HashPartitioning(expressions.reverse, 100)
- // These partitionings are not considered equal:
- assert(partitioningA != partitioningB)
- // However, they both satisfy the same clustered distribution:
- val distribution = ClusteredDistribution(expressions)
- assert(partitioningA.satisfies(distribution))
- assert(partitioningB.satisfies(distribution))
- // These partitionings compute different hashcodes for the same input row:
- def computeHashCode(partitioning: HashPartitioning): Int = {
- val hashExprProj = new InterpretedMutableProjection(partitioning.expressions, Seq.empty)
- hashExprProj.apply(InternalRow.empty).hashCode()
- }
- assert(computeHashCode(partitioningA) != computeHashCode(partitioningB))
- // Thus, these partitionings are incompatible:
- assert(!partitioningA.compatibleWith(partitioningB))
- assert(!partitioningB.compatibleWith(partitioningA))
- assert(!partitioningA.guarantees(partitioningB))
- assert(!partitioningB.guarantees(partitioningA))
-
- // Just to be sure that we haven't cheated by having these methods always return false,
- // check that identical partitionings are still compatible with and guarantee each other:
- assert(partitioningA === partitioningA)
- assert(partitioningA.guarantees(partitioningA))
- assert(partitioningA.compatibleWith(partitioningA))
- }
-}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 23e866cdf4917..8c3db48a01f12 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -356,4 +356,13 @@ class ScalaReflectionSuite extends SparkFunSuite {
assert(deserializerFor[Int].isInstanceOf[AssertNotNull])
assert(!deserializerFor[String].isInstanceOf[AssertNotNull])
}
+
+ test("SPARK-23025: schemaFor should support Null type") {
+ val schema = schemaFor[(Int, Null)]
+ assert(schema === Schema(
+ StructType(Seq(
+ StructField("_1", IntegerType, nullable = false),
+ StructField("_2", NullType, nullable = true))),
+ nullable = true))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index e56a5d6368318..f4514205d3ae0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Cross, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning,
+ RangePartitioning, RoundRobinPartitioning}
import org.apache.spark.sql.types._
@@ -514,4 +516,45 @@ class AnalysisSuite extends AnalysisTest with Matchers {
Seq("Number of column aliases does not match number of columns. " +
"Number of column aliases: 5; number of columns: 4."))
}
+
+ test("SPARK-22614 RepartitionByExpression partitioning") {
+ def checkPartitioning[T <: Partitioning](numPartitions: Int, exprs: Expression*): Unit = {
+ val partitioning = RepartitionByExpression(exprs, testRelation2, numPartitions).partitioning
+ assert(partitioning.isInstanceOf[T])
+ }
+
+ checkPartitioning[HashPartitioning](numPartitions = 10, exprs = Literal(20))
+ checkPartitioning[HashPartitioning](numPartitions = 10, exprs = 'a.attr, 'b.attr)
+
+ checkPartitioning[RangePartitioning](numPartitions = 10,
+ exprs = SortOrder(Literal(10), Ascending))
+ checkPartitioning[RangePartitioning](numPartitions = 10,
+ exprs = SortOrder('a.attr, Ascending), SortOrder('b.attr, Descending))
+
+ checkPartitioning[RoundRobinPartitioning](numPartitions = 10, exprs = Seq.empty: _*)
+
+ intercept[IllegalArgumentException] {
+ checkPartitioning(numPartitions = 0, exprs = Literal(20))
+ }
+ intercept[IllegalArgumentException] {
+ checkPartitioning(numPartitions = -1, exprs = Literal(20))
+ }
+ intercept[IllegalArgumentException] {
+ checkPartitioning(numPartitions = 10, exprs = SortOrder('a.attr, Ascending), 'b.attr)
+ }
+ }
+
+ test("SPARK-20392: analysis barrier") {
+ // [[AnalysisBarrier]] will be removed after analysis
+ checkAnalysis(
+ Project(Seq(UnresolvedAttribute("tbl.a")),
+ AnalysisBarrier(SubqueryAlias("tbl", testRelation))),
+ Project(testRelation.output, SubqueryAlias("tbl", testRelation)))
+
+ // Verify we won't go through a plan wrapped in a barrier.
+ // Since we wrap an unresolved plan and analyzer won't go through it. It remains unresolved.
+ val barrier = AnalysisBarrier(Project(Seq(UnresolvedAttribute("tbl.b")),
+ SubqueryAlias("tbl", testRelation)))
+ assertAnalysisError(barrier, Seq("cannot resolve '`tbl.b`'"))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index 5dcd653e9b341..52a7ebdafd7c7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -869,6 +869,114 @@ class TypeCoercionSuite extends AnalysisTest {
Literal.create(null, IntegerType), Literal.create(null, StringType))))
}
+ test("type coercion for Concat") {
+ val rule = TypeCoercion.ConcatCoercion(conf)
+
+ ruleTest(rule,
+ Concat(Seq(Literal("ab"), Literal("cde"))),
+ Concat(Seq(Literal("ab"), Literal("cde"))))
+ ruleTest(rule,
+ Concat(Seq(Literal(null), Literal("abc"))),
+ Concat(Seq(Cast(Literal(null), StringType), Literal("abc"))))
+ ruleTest(rule,
+ Concat(Seq(Literal(1), Literal("234"))),
+ Concat(Seq(Cast(Literal(1), StringType), Literal("234"))))
+ ruleTest(rule,
+ Concat(Seq(Literal("1"), Literal("234".getBytes()))),
+ Concat(Seq(Literal("1"), Cast(Literal("234".getBytes()), StringType))))
+ ruleTest(rule,
+ Concat(Seq(Literal(1L), Literal(2.toByte), Literal(0.1))),
+ Concat(Seq(Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType),
+ Cast(Literal(0.1), StringType))))
+ ruleTest(rule,
+ Concat(Seq(Literal(true), Literal(0.1f), Literal(3.toShort))),
+ Concat(Seq(Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType),
+ Cast(Literal(3.toShort), StringType))))
+ ruleTest(rule,
+ Concat(Seq(Literal(1L), Literal(0.1))),
+ Concat(Seq(Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType))))
+ ruleTest(rule,
+ Concat(Seq(Literal(Decimal(10)))),
+ Concat(Seq(Cast(Literal(Decimal(10)), StringType))))
+ ruleTest(rule,
+ Concat(Seq(Literal(BigDecimal.valueOf(10)))),
+ Concat(Seq(Cast(Literal(BigDecimal.valueOf(10)), StringType))))
+ ruleTest(rule,
+ Concat(Seq(Literal(java.math.BigDecimal.valueOf(10)))),
+ Concat(Seq(Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType))))
+ ruleTest(rule,
+ Concat(Seq(Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))),
+ Concat(Seq(Cast(Literal(new java.sql.Date(0)), StringType),
+ Cast(Literal(new Timestamp(0)), StringType))))
+
+ withSQLConf("spark.sql.function.concatBinaryAsString" -> "true") {
+ ruleTest(rule,
+ Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))),
+ Concat(Seq(Cast(Literal("123".getBytes), StringType),
+ Cast(Literal("456".getBytes), StringType))))
+ }
+
+ withSQLConf("spark.sql.function.concatBinaryAsString" -> "false") {
+ ruleTest(rule,
+ Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))),
+ Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))))
+ }
+ }
+
+ test("type coercion for Elt") {
+ val rule = TypeCoercion.EltCoercion(conf)
+
+ ruleTest(rule,
+ Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))),
+ Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))))
+ ruleTest(rule,
+ Elt(Seq(Literal(1.toShort), Literal("ab"), Literal("cde"))),
+ Elt(Seq(Cast(Literal(1.toShort), IntegerType), Literal("ab"), Literal("cde"))))
+ ruleTest(rule,
+ Elt(Seq(Literal(2), Literal(null), Literal("abc"))),
+ Elt(Seq(Literal(2), Cast(Literal(null), StringType), Literal("abc"))))
+ ruleTest(rule,
+ Elt(Seq(Literal(2), Literal(1), Literal("234"))),
+ Elt(Seq(Literal(2), Cast(Literal(1), StringType), Literal("234"))))
+ ruleTest(rule,
+ Elt(Seq(Literal(3), Literal(1L), Literal(2.toByte), Literal(0.1))),
+ Elt(Seq(Literal(3), Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType),
+ Cast(Literal(0.1), StringType))))
+ ruleTest(rule,
+ Elt(Seq(Literal(2), Literal(true), Literal(0.1f), Literal(3.toShort))),
+ Elt(Seq(Literal(2), Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType),
+ Cast(Literal(3.toShort), StringType))))
+ ruleTest(rule,
+ Elt(Seq(Literal(1), Literal(1L), Literal(0.1))),
+ Elt(Seq(Literal(1), Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType))))
+ ruleTest(rule,
+ Elt(Seq(Literal(1), Literal(Decimal(10)))),
+ Elt(Seq(Literal(1), Cast(Literal(Decimal(10)), StringType))))
+ ruleTest(rule,
+ Elt(Seq(Literal(1), Literal(BigDecimal.valueOf(10)))),
+ Elt(Seq(Literal(1), Cast(Literal(BigDecimal.valueOf(10)), StringType))))
+ ruleTest(rule,
+ Elt(Seq(Literal(1), Literal(java.math.BigDecimal.valueOf(10)))),
+ Elt(Seq(Literal(1), Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType))))
+ ruleTest(rule,
+ Elt(Seq(Literal(2), Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))),
+ Elt(Seq(Literal(2), Cast(Literal(new java.sql.Date(0)), StringType),
+ Cast(Literal(new Timestamp(0)), StringType))))
+
+ withSQLConf("spark.sql.function.eltOutputAsString" -> "true") {
+ ruleTest(rule,
+ Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))),
+ Elt(Seq(Literal(1), Cast(Literal("123".getBytes), StringType),
+ Cast(Literal("456".getBytes), StringType))))
+ }
+
+ withSQLConf("spark.sql.function.eltOutputAsString" -> "false") {
+ ruleTest(rule,
+ Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))),
+ Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))))
+ }
+ }
+
test("BooleanEquality type cast") {
val be = TypeCoercion.BooleanEquality
// Use something more than a literal to avoid triggering the simplification rules.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
index 95c87ffa20cb7..6abab0073cca3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
@@ -279,7 +279,7 @@ abstract class SessionCatalogSuite extends AnalysisTest {
}
}
- test("create temp table") {
+ test("create temp view") {
withBasicCatalog { catalog =>
val tempTable1 = Range(1, 10, 1, 10)
val tempTable2 = Range(1, 20, 2, 10)
@@ -288,11 +288,11 @@ abstract class SessionCatalogSuite extends AnalysisTest {
assert(catalog.getTempView("tbl1") == Option(tempTable1))
assert(catalog.getTempView("tbl2") == Option(tempTable2))
assert(catalog.getTempView("tbl3").isEmpty)
- // Temporary table already exists
+ // Temporary view already exists
intercept[TempTableAlreadyExistsException] {
catalog.createTempView("tbl1", tempTable1, overrideIfExists = false)
}
- // Temporary table already exists but we override it
+ // Temporary view already exists but we override it
catalog.createTempView("tbl1", tempTable2, overrideIfExists = true)
assert(catalog.getTempView("tbl1") == Option(tempTable2))
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index fb759eba6a9e2..6edb4348f8309 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types._
class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -343,4 +344,14 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Least(inputsExpr), "s" * 1, EmptyRow)
checkEvaluation(Greatest(inputsExpr), "s" * N, EmptyRow)
}
+
+ test("SPARK-22704: Least and greatest use less global variables") {
+ val ctx1 = new CodegenContext()
+ Least(Seq(Literal(1), Literal(1))).genCode(ctx1)
+ assert(ctx1.inlinedMutableStates.size == 1)
+
+ val ctx2 = new CodegenContext()
+ Greatest(Seq(Literal(1), Literal(1))).genCode(ctx2)
+ assert(ctx2.inlinedMutableStates.size == 1)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index 7837d6529d12b..5b25bdf907c3a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -23,6 +23,7 @@ import java.util.{Calendar, Locale, TimeZone}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT
@@ -845,4 +846,80 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
val outputOuter = Row.fromSeq((1 to N).map(_ => outputInner))
checkEvaluation(cast(Literal.create(inputOuter, fromOuter), toOuter), outputOuter)
}
+
+ test("SPARK-22570: Cast should not create a lot of global variables") {
+ val ctx = new CodegenContext
+ cast("1", IntegerType).genCode(ctx)
+ cast("2", LongType).genCode(ctx)
+ assert(ctx.inlinedMutableStates.length == 0)
+ }
+
+ test("SPARK-22825 Cast array to string") {
+ val ret1 = cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType)
+ checkEvaluation(ret1, "[1, 2, 3, 4, 5]")
+ val ret2 = cast(Literal.create(Array("ab", "cde", "f")), StringType)
+ checkEvaluation(ret2, "[ab, cde, f]")
+ val ret3 = cast(Literal.create(Array("ab", null, "c")), StringType)
+ checkEvaluation(ret3, "[ab,, c]")
+ val ret4 = cast(Literal.create(Array("ab".getBytes, "cde".getBytes, "f".getBytes)), StringType)
+ checkEvaluation(ret4, "[ab, cde, f]")
+ val ret5 = cast(
+ Literal.create(Array("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf)),
+ StringType)
+ checkEvaluation(ret5, "[2014-12-03, 2014-12-04, 2014-12-06]")
+ val ret6 = cast(
+ Literal.create(Array("2014-12-03 13:01:00", "2014-12-04 15:05:00").map(Timestamp.valueOf)),
+ StringType)
+ checkEvaluation(ret6, "[2014-12-03 13:01:00, 2014-12-04 15:05:00]")
+ val ret7 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType)
+ checkEvaluation(ret7, "[[1, 2, 3], [4, 5]]")
+ val ret8 = cast(
+ Literal.create(Array(Array(Array("a"), Array("b", "c")), Array(Array("d")))),
+ StringType)
+ checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]")
+ }
+
+ test("SPARK-22973 Cast map to string") {
+ val ret1 = cast(Literal.create(Map(1 -> "a", 2 -> "b", 3 -> "c")), StringType)
+ checkEvaluation(ret1, "[1 -> a, 2 -> b, 3 -> c]")
+ val ret2 = cast(
+ Literal.create(Map("1" -> "a".getBytes, "2" -> null, "3" -> "c".getBytes)),
+ StringType)
+ checkEvaluation(ret2, "[1 -> a, 2 ->, 3 -> c]")
+ val ret3 = cast(
+ Literal.create(Map(
+ 1 -> Date.valueOf("2014-12-03"),
+ 2 -> Date.valueOf("2014-12-04"),
+ 3 -> Date.valueOf("2014-12-05"))),
+ StringType)
+ checkEvaluation(ret3, "[1 -> 2014-12-03, 2 -> 2014-12-04, 3 -> 2014-12-05]")
+ val ret4 = cast(
+ Literal.create(Map(
+ 1 -> Timestamp.valueOf("2014-12-03 13:01:00"),
+ 2 -> Timestamp.valueOf("2014-12-04 15:05:00"))),
+ StringType)
+ checkEvaluation(ret4, "[1 -> 2014-12-03 13:01:00, 2 -> 2014-12-04 15:05:00]")
+ val ret5 = cast(
+ Literal.create(Map(
+ 1 -> Array(1, 2, 3),
+ 2 -> Array(4, 5, 6))),
+ StringType)
+ checkEvaluation(ret5, "[1 -> [1, 2, 3], 2 -> [4, 5, 6]]")
+ }
+
+ test("SPARK-22981 Cast struct to string") {
+ val ret1 = cast(Literal.create((1, "a", 0.1)), StringType)
+ checkEvaluation(ret1, "[1, a, 0.1]")
+ val ret2 = cast(Literal.create(Tuple3[Int, String, String](1, null, "a")), StringType)
+ checkEvaluation(ret2, "[1,, a]")
+ val ret3 = cast(Literal.create(
+ (Date.valueOf("2014-12-03"), Timestamp.valueOf("2014-12-03 15:05:00"))), StringType)
+ checkEvaluation(ret3, "[2014-12-03, 2014-12-03 15:05:00]")
+ val ret4 = cast(Literal.create(((1, "a"), 5, 0.1)), StringType)
+ checkEvaluation(ret4, "[[1, a], 5, 0.1]")
+ val ret5 = cast(Literal.create((Seq(1, 2, 3), "a", 0.1)), StringType)
+ checkEvaluation(ret5, "[[1, 2, 3], a, 0.1]")
+ val ret6 = cast(Literal.create((1, Map(1 -> "a", 2 -> "b", 3 -> "c"))), StringType)
+ checkEvaluation(ret6, "[1, [1 -> a, 2 -> b, 3 -> c]]")
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index a4198f826cedb..676ba3956ddc8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, GetExternalRowField, ValidateExternalType}
+import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -380,4 +380,60 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
s"Incorrect Evaluation: expressions: $exprAnd, actual: $actualAnd, expected: $expectedAnd")
}
}
+
+ test("SPARK-22696: CreateExternalRow should not use global variables") {
+ val ctx = new CodegenContext
+ val schema = new StructType().add("a", IntegerType).add("b", StringType)
+ CreateExternalRow(Seq(Literal(1), Literal("x")), schema).genCode(ctx)
+ assert(ctx.inlinedMutableStates.isEmpty)
+ }
+
+ test("SPARK-22696: InitializeJavaBean should not use global variables") {
+ val ctx = new CodegenContext
+ InitializeJavaBean(Literal.fromObject(new java.util.LinkedList[Int]),
+ Map("add" -> Literal(1))).genCode(ctx)
+ assert(ctx.inlinedMutableStates.isEmpty)
+ }
+
+ test("SPARK-22716: addReferenceObj should not add mutable states") {
+ val ctx = new CodegenContext
+ val foo = new Object()
+ ctx.addReferenceObj("foo", foo)
+ assert(ctx.inlinedMutableStates.isEmpty)
+ }
+
+ test("SPARK-18016: define mutable states by using an array") {
+ val ctx1 = new CodegenContext
+ for (i <- 1 to CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) {
+ ctx1.addMutableState(ctx1.JAVA_INT, "i", v => s"$v = $i;")
+ }
+ assert(ctx1.inlinedMutableStates.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD)
+ // When the number of primitive type mutable states is over the threshold, others are
+ // allocated into an array
+ assert(ctx1.arrayCompactedMutableStates.get(ctx1.JAVA_INT).get.arrayNames.size == 1)
+ assert(ctx1.mutableStateInitCode.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10)
+
+ val ctx2 = new CodegenContext
+ for (i <- 1 to CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + 10) {
+ ctx2.addMutableState("InternalRow[]", "r", v => s"$v = new InternalRow[$i];")
+ }
+ // When the number of non-primitive type mutable states is over the threshold, others are
+ // allocated into a new array
+ assert(ctx2.inlinedMutableStates.isEmpty)
+ assert(ctx2.arrayCompactedMutableStates.get("InternalRow[]").get.arrayNames.size == 2)
+ assert(ctx2.arrayCompactedMutableStates("InternalRow[]").getCurrentIndex == 10)
+ assert(ctx2.mutableStateInitCode.size == CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + 10)
+ }
+
+ test("SPARK-22750: addImmutableStateIfNotExists") {
+ val ctx = new CodegenContext
+ val mutableState1 = "field1"
+ val mutableState2 = "field2"
+ ctx.addImmutableStateIfNotExists("int", mutableState1)
+ ctx.addImmutableStateIfNotExists("int", mutableState1)
+ ctx.addImmutableStateIfNotExists("String", mutableState2)
+ ctx.addImmutableStateIfNotExists("int", mutableState1)
+ ctx.addImmutableStateIfNotExists("String", mutableState2)
+ assert(ctx.inlinedMutableStates.length == 2)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index b0eaad1c80f89..84190f0bd5f7d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -299,4 +300,10 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
new StringToMap(Literal("a=1_b=2_c=3"), Literal("_"), NonFoldableLiteral("="))
.checkInputDataTypes().isFailure)
}
+
+ test("SPARK-22693: CreateNamedStruct should not use global variables") {
+ val ctx = new CodegenContext
+ CreateNamedStruct(Seq("a", "x", "b", 2.0)).genCode(ctx)
+ assert(ctx.inlinedMutableStates.isEmpty)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
index 3e11c3d2d4fe3..a099119732e25 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types._
class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -145,4 +146,10 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
IndexedSeq((Literal(12) === Literal(1), Literal(42)),
(Literal(12) === Literal(42), Literal(1))))
}
+
+ test("SPARK-22705: case when should use less global variables") {
+ val ctx = new CodegenContext()
+ CaseWhen(Seq((Literal.create(false, BooleanType), Literal(1))), Literal(-1)).genCode(ctx)
+ assert(ctx.inlinedMutableStates.size == 1)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
index 89d99f9678cda..786266a2c13c0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
@@ -22,6 +22,7 @@ import java.text.SimpleDateFormat
import java.util.{Calendar, Locale, TimeZone}
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT
@@ -527,7 +528,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, StringType)), null)
}
- test("function trunc") {
+ test("TruncDate") {
def testTrunc(input: Date, fmt: String, expected: Date): Unit = {
checkEvaluation(TruncDate(Literal.create(input, DateType), Literal.create(fmt, StringType)),
expected)
@@ -543,11 +544,82 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
testTrunc(date, fmt, Date.valueOf("2015-07-01"))
}
testTrunc(date, "DD", null)
+ testTrunc(date, "SECOND", null)
+ testTrunc(date, "HOUR", null)
testTrunc(date, null, null)
testTrunc(null, "MON", null)
testTrunc(null, null, null)
}
+ test("TruncTimestamp") {
+ def testTrunc(input: Timestamp, fmt: String, expected: Timestamp): Unit = {
+ checkEvaluation(
+ TruncTimestamp(Literal.create(fmt, StringType), Literal.create(input, TimestampType)),
+ expected)
+ checkEvaluation(
+ TruncTimestamp(
+ NonFoldableLiteral.create(fmt, StringType), Literal.create(input, TimestampType)),
+ expected)
+ }
+
+ withDefaultTimeZone(TimeZoneGMT) {
+ val inputDate = Timestamp.valueOf("2015-07-22 05:30:06")
+
+ Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach { fmt =>
+ testTrunc(
+ inputDate, fmt,
+ Timestamp.valueOf("2015-01-01 00:00:00"))
+ }
+
+ Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt =>
+ testTrunc(
+ inputDate, fmt,
+ Timestamp.valueOf("2015-07-01 00:00:00"))
+ }
+
+ Seq("DAY", "day", "DD", "dd").foreach { fmt =>
+ testTrunc(
+ inputDate, fmt,
+ Timestamp.valueOf("2015-07-22 00:00:00"))
+ }
+
+ Seq("HOUR", "hour").foreach { fmt =>
+ testTrunc(
+ inputDate, fmt,
+ Timestamp.valueOf("2015-07-22 05:00:00"))
+ }
+
+ Seq("MINUTE", "minute").foreach { fmt =>
+ testTrunc(
+ inputDate, fmt,
+ Timestamp.valueOf("2015-07-22 05:30:00"))
+ }
+
+ Seq("SECOND", "second").foreach { fmt =>
+ testTrunc(
+ inputDate, fmt,
+ Timestamp.valueOf("2015-07-22 05:30:06"))
+ }
+
+ Seq("WEEK", "week").foreach { fmt =>
+ testTrunc(
+ inputDate, fmt,
+ Timestamp.valueOf("2015-07-20 00:00:00"))
+ }
+
+ Seq("QUARTER", "quarter").foreach { fmt =>
+ testTrunc(
+ inputDate, fmt,
+ Timestamp.valueOf("2015-07-01 00:00:00"))
+ }
+
+ testTrunc(inputDate, "INVALID", null)
+ testTrunc(inputDate, null, null)
+ testTrunc(null, "MON", null)
+ testTrunc(null, null, null)
+ }
+ }
+
test("from_unixtime") {
val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US)
val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS"
@@ -720,6 +792,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test(null, "UTC", null)
test("2015-07-24 00:00:00", null, null)
test(null, null, null)
+ // Test escaping of timezone
+ GenerateUnsafeProjection.generate(
+ ToUTCTimestamp(Literal(Timestamp.valueOf("2015-07-24 00:00:00")), Literal("\"quote")) :: Nil)
}
test("from_utc_timestamp") {
@@ -740,5 +815,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test(null, "UTC", null)
test("2015-07-24 00:00:00", null, null)
test(null, null, null)
+ // Test escaping of timezone
+ GenerateUnsafeProjection.generate(FromUTCTimestamp(Literal(0), Literal("\"quote")) :: Nil)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
index 112a4a09728ae..4281c89ac475d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
@@ -27,6 +27,7 @@ import org.scalatest.exceptions.TestFailedException
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{RandomDataGenerator, Row}
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
@@ -620,23 +621,30 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("SPARK-18207: Compute hash for a lot of expressions") {
+ def checkResult(schema: StructType, input: InternalRow): Unit = {
+ val exprs = schema.fields.zipWithIndex.map { case (f, i) =>
+ BoundReference(i, f.dataType, true)
+ }
+ val murmur3HashExpr = Murmur3Hash(exprs, 42)
+ val murmur3HashPlan = GenerateMutableProjection.generate(Seq(murmur3HashExpr))
+ val murmursHashEval = Murmur3Hash(exprs, 42).eval(input)
+ assert(murmur3HashPlan(input).getInt(0) == murmursHashEval)
+
+ val hiveHashExpr = HiveHash(exprs)
+ val hiveHashPlan = GenerateMutableProjection.generate(Seq(hiveHashExpr))
+ val hiveHashEval = HiveHash(exprs).eval(input)
+ assert(hiveHashPlan(input).getInt(0) == hiveHashEval)
+ }
+
val N = 1000
val wideRow = new GenericInternalRow(
Seq.tabulate(N)(i => UTF8String.fromString(i.toString)).toArray[Any])
- val schema = StructType((1 to N).map(i => StructField("", StringType)))
-
- val exprs = schema.fields.zipWithIndex.map { case (f, i) =>
- BoundReference(i, f.dataType, true)
- }
- val murmur3HashExpr = Murmur3Hash(exprs, 42)
- val murmur3HashPlan = GenerateMutableProjection.generate(Seq(murmur3HashExpr))
- val murmursHashEval = Murmur3Hash(exprs, 42).eval(wideRow)
- assert(murmur3HashPlan(wideRow).getInt(0) == murmursHashEval)
+ val schema = StructType((1 to N).map(i => StructField(i.toString, StringType)))
+ checkResult(schema, wideRow)
- val hiveHashExpr = HiveHash(exprs)
- val hiveHashPlan = GenerateMutableProjection.generate(Seq(hiveHashExpr))
- val hiveHashEval = HiveHash(exprs).eval(wideRow)
- assert(hiveHashPlan(wideRow).getInt(0) == hiveHashEval)
+ val nestedRow = InternalRow(wideRow)
+ val nestedSchema = new StructType().add("nested", schema)
+ checkResult(nestedSchema, nestedRow)
}
test("SPARK-22284: Compute hash for nested structs") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala
index 4fe7b436982b1..facc863081303 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala
@@ -43,5 +43,4 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Length(Uuid()), 36)
assert(evaluate(Uuid()) !== evaluate(Uuid()))
}
-
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
index 40ef7770da33f..cc6c15cb2c909 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.types._
@@ -155,6 +156,12 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Coalesce(inputs), "x_1")
}
+ test("SPARK-22705: Coalesce should use less global variables") {
+ val ctx = new CodegenContext()
+ Coalesce(Seq(Literal("a"), Literal("b"))).genCode(ctx)
+ assert(ctx.inlinedMutableStates.size == 1)
+ }
+
test("AtLeastNNonNulls should not throw 64kb exception") {
val inputs = (1 to 4000).map(x => Literal(s"x_$x"))
checkEvaluation(AtLeastNNonNulls(1, inputs), true)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index 865092a659f26..8a8f8e10225fa 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -25,6 +25,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.RandomDataGenerator
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.types._
@@ -245,6 +246,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(In(Literal(1.0D), sets), true)
}
+ test("SPARK-22705: In should use less global variables") {
+ val ctx = new CodegenContext()
+ In(Literal(1.0D), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx)
+ assert(ctx.inlinedMutableStates.isEmpty)
+ }
+
test("INSET") {
val hS = HashSet[Any]() + 1 + 2
val nS = HashSet[Any]() + 1 + 2 + null
@@ -291,26 +298,26 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
private val udt = new ExamplePointUDT
private val smallValues =
- Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1),
+ Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), Date.valueOf("2000-01-01"),
new Timestamp(1), "a", 1f, 1d, 0f, 0d, false, Array(1L, 2L))
.map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")),
Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))),
Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt))
private val largeValues =
- Seq(2.toByte, 2.toShort, 2, 2L, Decimal(2), Array(2.toByte), new Date(2000, 1, 2),
+ Seq(2.toByte, 2.toShort, 2, 2L, Decimal(2), Array(2.toByte), Date.valueOf("2000-01-02"),
new Timestamp(2), "b", 2f, 2d, Float.NaN, Double.NaN, true, Array(2L, 1L))
.map(Literal(_)) ++ Seq(Literal.create(MyStruct(2L, "b")),
Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 2))),
Literal.create(ArrayData.toArrayData(Array(1.0, 3.0)), udt))
private val equalValues1 =
- Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1),
+ Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), Date.valueOf("2000-01-01"),
new Timestamp(1), "a", 1f, 1d, Float.NaN, Double.NaN, true, Array(1L, 2L))
.map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")),
Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))),
Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt))
private val equalValues2 =
- Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1),
+ Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), Date.valueOf("2000-01-01"),
new Timestamp(1), "a", 1f, 1d, Float.NaN, Double.NaN, true, Array(1L, 2L))
.map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")),
Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))),
@@ -429,4 +436,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
val infinity = Literal(Double.PositiveInfinity)
checkEvaluation(EqualTo(infinity, infinity), true)
}
+
+ test("SPARK-22693: InSet should not use global variables") {
+ val ctx = new CodegenContext
+ InSet(Literal(1), Set(1, 2, 3, 4)).genCode(ctx)
+ assert(ctx.inlinedMutableStates.isEmpty)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala
index 1ce150e091981..2a0a42c65b086 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala
@@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{IntegerType, StringType}
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
+import org.apache.spark.sql.types.StringType
/**
* Unit tests for regular expression (regexp) related SQL expressions.
@@ -178,6 +179,15 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(nonNullExpr, "num-num", row1)
}
+ test("SPARK-22570: RegExpReplace should not create a lot of global variables") {
+ val ctx = new CodegenContext
+ RegExpReplace(Literal("100"), Literal("(\\d+)"), Literal("num")).genCode(ctx)
+ // four global variables (lastRegex, pattern, lastReplacement, and lastReplacementInUTF8)
+ // are always required, which are allocated in type-based global array
+ assert(ctx.inlinedMutableStates.length == 0)
+ assert(ctx.mutableStateInitCode.length == 4)
+ }
+
test("RegexExtract") {
val row1 = create_row("100-200", "(\\d+)-(\\d+)", 1)
val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
index 13bd363c8b692..10e3ffd0dff97 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.Locale
import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types.{IntegerType, StringType}
class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -47,4 +48,9 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(e2.getMessage.contains("Failed to execute user defined function"))
}
+ test("SPARK-22695: ScalaUDF should not use global variables") {
+ val ctx = new CodegenContext
+ ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil).genCode(ctx)
+ assert(ctx.inlinedMutableStates.isEmpty)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 54cde77176e27..97ddbeba2c5ca 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -51,6 +51,18 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Concat(strs.map(Literal.create(_, StringType))), strs.mkString, EmptyRow)
}
+ test("SPARK-22771 Check Concat.checkInputDataTypes results") {
+ assert(Concat(Seq.empty[Expression]).checkInputDataTypes().isSuccess)
+ assert(Concat(Literal.create("a") :: Literal.create("b") :: Nil)
+ .checkInputDataTypes().isSuccess)
+ assert(Concat(Literal.create("a".getBytes) :: Literal.create("b".getBytes) :: Nil)
+ .checkInputDataTypes().isSuccess)
+ assert(Concat(Literal.create(1) :: Literal.create(2) :: Nil)
+ .checkInputDataTypes().isFailure)
+ assert(Concat(Literal.create("a") :: Literal.create("b".getBytes) :: Nil)
+ .checkInputDataTypes().isFailure)
+ }
+
test("concat_ws") {
def testConcatWs(expected: String, sep: String, inputs: Any*): Unit = {
val inputExprs = inputs.map {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
index 10479630f3f99..30e3bc9fb5779 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.sketch.CountMinSketch
/**
- * Unit test suite for the count-min sketch SQL aggregate funciton [[CountMinSketchAgg]].
+ * Unit test suite for the count-min sketch SQL aggregate function [[CountMinSketchAgg]].
*/
class CountMinSketchAggSuite extends SparkFunSuite {
private val childExpression = BoundReference(0, IntegerType, nullable = true)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
index f203f25ad10d4..75c6beeb32150 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
@@ -22,8 +22,10 @@ import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.RandomDataGenerator
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
-import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.{JoinedRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
/**
* Test suite for [[GenerateUnsafeRowJoiner]].
@@ -45,6 +47,32 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {
testConcat(64, 64, fixed)
}
+ test("rows with all empty strings") {
+ val schema = StructType(Seq(
+ StructField("f1", StringType), StructField("f2", StringType)))
+ val row: UnsafeRow = UnsafeProjection.create(schema).apply(
+ InternalRow(UTF8String.EMPTY_UTF8, UTF8String.EMPTY_UTF8))
+ testConcat(schema, row, schema, row)
+ }
+
+ test("rows with all empty int arrays") {
+ val schema = StructType(Seq(
+ StructField("f1", ArrayType(IntegerType)), StructField("f2", ArrayType(IntegerType))))
+ val emptyIntArray =
+ ExpressionEncoder[Array[Int]]().resolveAndBind().toRow(Array.emptyIntArray).getArray(0)
+ val row: UnsafeRow = UnsafeProjection.create(schema).apply(
+ InternalRow(emptyIntArray, emptyIntArray))
+ testConcat(schema, row, schema, row)
+ }
+
+ test("alternating empty and non-empty strings") {
+ val schema = StructType(Seq(
+ StructField("f1", StringType), StructField("f2", StringType)))
+ val row: UnsafeRow = UnsafeProjection.create(schema).apply(
+ InternalRow(UTF8String.EMPTY_UTF8, UTF8String.fromString("foo")))
+ testConcat(schema, row, schema, row)
+ }
+
test("randomized fix width types") {
for (i <- 0 until 20) {
testConcatOnce(Random.nextInt(100), Random.nextInt(100), fixed)
@@ -94,27 +122,84 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {
val extRow2 = RandomDataGenerator.forType(schema2, nullable = false).get.apply()
val row1 = converter1.apply(internalConverter1.apply(extRow1).asInstanceOf[InternalRow])
val row2 = converter2.apply(internalConverter2.apply(extRow2).asInstanceOf[InternalRow])
+ testConcat(schema1, row1, schema2, row2)
+ }
+
+ private def testConcat(
+ schema1: StructType,
+ row1: UnsafeRow,
+ schema2: StructType,
+ row2: UnsafeRow) {
// Run the joiner.
val mergedSchema = StructType(schema1 ++ schema2)
val concater = GenerateUnsafeRowJoiner.create(schema1, schema2)
- val output = concater.join(row1, row2)
+ val output: UnsafeRow = concater.join(row1, row2)
+
+ // We'll also compare to an UnsafeRow produced with JoinedRow + UnsafeProjection. This ensures
+ // that unused space in the row (e.g. leftover bits in the null-tracking bitmap) is written
+ // correctly.
+ val expectedOutput: UnsafeRow = {
+ val joinedRowProjection = UnsafeProjection.create(mergedSchema)
+ val joined = new JoinedRow()
+ joinedRowProjection.apply(joined.apply(row1, row2))
+ }
// Test everything equals ...
for (i <- mergedSchema.indices) {
+ val dataType = mergedSchema(i).dataType
if (i < schema1.size) {
assert(output.isNullAt(i) === row1.isNullAt(i))
if (!output.isNullAt(i)) {
- assert(output.get(i, mergedSchema(i).dataType) === row1.get(i, mergedSchema(i).dataType))
+ assert(output.get(i, dataType) === row1.get(i, dataType))
+ assert(output.get(i, dataType) === expectedOutput.get(i, dataType))
}
} else {
assert(output.isNullAt(i) === row2.isNullAt(i - schema1.size))
if (!output.isNullAt(i)) {
- assert(output.get(i, mergedSchema(i).dataType) ===
- row2.get(i - schema1.size, mergedSchema(i).dataType))
+ assert(output.get(i, dataType) === row2.get(i - schema1.size, dataType))
+ assert(output.get(i, dataType) === expectedOutput.get(i, dataType))
}
}
}
+
+
+ assert(
+ expectedOutput.getSizeInBytes == output.getSizeInBytes,
+ "output isn't same size in bytes as slow path")
+
+ // Compare the UnsafeRows byte-by-byte so that we can print more useful debug information in
+ // case this assertion fails:
+ val actualBytes = output.getBaseObject.asInstanceOf[Array[Byte]]
+ .take(output.getSizeInBytes)
+ val expectedBytes = expectedOutput.getBaseObject.asInstanceOf[Array[Byte]]
+ .take(expectedOutput.getSizeInBytes)
+
+ val bitsetWidth = UnsafeRow.calculateBitSetWidthInBytes(expectedOutput.numFields())
+ val actualBitset = actualBytes.take(bitsetWidth)
+ val expectedBitset = expectedBytes.take(bitsetWidth)
+ assert(actualBitset === expectedBitset, "bitsets were not equal")
+
+ val fixedLengthSize = expectedOutput.numFields() * 8
+ val actualFixedLength = actualBytes.slice(bitsetWidth, bitsetWidth + fixedLengthSize)
+ val expectedFixedLength = expectedBytes.slice(bitsetWidth, bitsetWidth + fixedLengthSize)
+ if (actualFixedLength !== expectedFixedLength) {
+ actualFixedLength.grouped(8)
+ .zip(expectedFixedLength.grouped(8))
+ .zip(mergedSchema.fields.toIterator)
+ .foreach {
+ case ((actual, expected), field) =>
+ assert(actual === expected, s"Fixed length sections are not equal for field $field")
+ }
+ fail("Fixed length sections were not equal")
+ }
+
+ val variableLengthStart = bitsetWidth + fixedLengthSize
+ val actualVariableLength = actualBytes.drop(variableLengthStart)
+ val expectedVariableLength = expectedBytes.drop(variableLengthStart)
+ assert(actualVariableLength === expectedVariableLength, "fixed length sections were not equal")
+
+ assert(output.hashCode() == expectedOutput.hashCode(), "hash codes were not equal")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
index 0cd0d8859145f..2c45b3b0c73d1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
@@ -208,4 +208,42 @@ class GeneratedProjectionSuite extends SparkFunSuite {
unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("b"))))
assert(row.getStruct(0, 1).getString(0).toString == "a")
}
+
+ test("SPARK-22699: GenerateSafeProjection should not use global variables for struct") {
+ val safeProj = GenerateSafeProjection.generate(
+ Seq(BoundReference(0, new StructType().add("i", IntegerType), true)))
+ val globalVariables = safeProj.getClass.getDeclaredFields
+ // We need always 3 variables:
+ // - one is a reference to this
+ // - one is the references object
+ // - one is the mutableRow
+ assert(globalVariables.length == 3)
+ }
+
+ test("SPARK-18016: generated projections on wider table requiring state compaction") {
+ val N = 6000
+ val wideRow1 = new GenericInternalRow(new Array[Any](N))
+ val schema1 = StructType((1 to N).map(i => StructField("", IntegerType)))
+ val wideRow2 = new GenericInternalRow(
+ Array.tabulate[Any](N)(i => UTF8String.fromString(i.toString)))
+ val schema2 = StructType((1 to N).map(i => StructField("", StringType)))
+ val joined = new JoinedRow(wideRow1, wideRow2)
+ val joinedSchema = StructType(schema1 ++ schema2)
+ val nested = new JoinedRow(InternalRow(joined, joined), joined)
+ val nestedSchema = StructType(
+ Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema)
+
+ val safeProj = FromUnsafeProjection(nestedSchema)
+ val result = safeProj(nested)
+
+ // test generated MutableProjection
+ val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) =>
+ BoundReference(i, f.dataType, true)
+ }
+ val mutableProj = GenerateMutableProjection.generate(exprs)
+ val row1 = mutableProj(result)
+ assert(result === row1)
+ val row2 = mutableProj(result)
+ assert(result === row2)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index 77e4eff26c69b..3f41f4b144096 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -38,54 +38,64 @@ class ColumnPruningSuite extends PlanTest {
CollapseProject) :: Nil
}
- test("Column pruning for Generate when Generate.join = false") {
- val input = LocalRelation('a.int, 'b.array(StringType))
+ test("Column pruning for Generate when Generate.unrequiredChildIndex = child.output") {
+ val input = LocalRelation('a.int, 'b.int, 'c.array(StringType))
- val query = input.generate(Explode('b), join = false).analyze
+ val query =
+ input
+ .generate(Explode('c), outputNames = "explode" :: Nil)
+ .select('c, 'explode)
+ .analyze
val optimized = Optimize.execute(query)
- val correctAnswer = input.select('b).generate(Explode('b), join = false).analyze
+ val correctAnswer =
+ input
+ .select('c)
+ .generate(Explode('c), outputNames = "explode" :: Nil)
+ .analyze
comparePlans(optimized, correctAnswer)
}
- test("Column pruning for Generate when Generate.join = true") {
- val input = LocalRelation('a.int, 'b.int, 'c.array(StringType))
+ test("Fill Generate.unrequiredChildIndex if possible") {
+ val input = LocalRelation('b.array(StringType))
val query =
input
- .generate(Explode('c), join = true, outputNames = "explode" :: Nil)
- .select('a, 'explode)
+ .generate(Explode('b), outputNames = "explode" :: Nil)
+ .select(('explode + 1).as("result"))
.analyze
val optimized = Optimize.execute(query)
val correctAnswer =
input
- .select('a, 'c)
- .generate(Explode('c), join = true, outputNames = "explode" :: Nil)
- .select('a, 'explode)
+ .generate(Explode('b), unrequiredChildIndex = input.output.zipWithIndex.map(_._2),
+ outputNames = "explode" :: Nil)
+ .select(('explode + 1).as("result"))
.analyze
comparePlans(optimized, correctAnswer)
}
- test("Turn Generate.join to false if possible") {
- val input = LocalRelation('b.array(StringType))
+ test("Another fill Generate.unrequiredChildIndex if possible") {
+ val input = LocalRelation('a.int, 'b.int, 'c1.string, 'c2.string)
val query =
input
- .generate(Explode('b), join = true, outputNames = "explode" :: Nil)
- .select(('explode + 1).as("result"))
+ .generate(Explode(CreateArray(Seq('c1, 'c2))), outputNames = "explode" :: Nil)
+ .select('a, 'c1, 'explode)
.analyze
val optimized = Optimize.execute(query)
val correctAnswer =
input
- .generate(Explode('b), join = false, outputNames = "explode" :: Nil)
- .select(('explode + 1).as("result"))
+ .select('a, 'c1, 'c2)
+ .generate(Explode(CreateArray(Seq('c1, 'c2))),
+ unrequiredChildIndex = Seq(2),
+ outputNames = "explode" :: Nil)
.analyze
comparePlans(optimized, correctAnswer)
@@ -246,7 +256,7 @@ class ColumnPruningSuite extends PlanTest {
x.select('a)
.sortBy(SortOrder('a, Ascending)).analyze
- comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer))
+ comparePlans(optimized, correctAnswer)
// push down invalid
val originalQuery1 = {
@@ -261,7 +271,7 @@ class ColumnPruningSuite extends PlanTest {
.sortBy(SortOrder('a, Ascending))
.select('b).analyze
- comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1))
+ comparePlans(optimized1, correctAnswer1)
}
test("Column pruning on Window with useless aggregate functions") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala
index 412e199dfaae3..441c15340a778 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala
@@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
-import org.apache.spark.sql.types.StringType
class CombineConcatsSuite extends PlanTest {
@@ -37,8 +36,10 @@ class CombineConcatsSuite extends PlanTest {
comparePlans(actual, correctAnswer)
}
+ def str(s: String): Literal = Literal(s)
+ def binary(s: String): Literal = Literal(s.getBytes)
+
test("combine nested Concat exprs") {
- def str(s: String): Literal = Literal(s, StringType)
assertEquivalent(
Concat(
Concat(str("a") :: str("b") :: Nil) ::
@@ -72,4 +73,13 @@ class CombineConcatsSuite extends PlanTest {
Nil),
Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
}
+
+ test("combine string and binary exprs") {
+ assertEquivalent(
+ Concat(
+ Concat(str("a") :: str("b") :: Nil) ::
+ Concat(binary("c") :: binary("d") :: Nil) ::
+ Nil),
+ Concat(str("a") :: str("b") :: binary("c") :: binary("d") :: Nil))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index de0e7c7ee49ac..82a10254d846d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -94,19 +94,15 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
- test("combine redundant deterministic filters") {
+ test("do not combine non-deterministic filters even if they are identical") {
val originalQuery =
testRelation
.where(Rand(0) > 0.1 && 'a === 1)
- .where(Rand(0) > 0.1 && 'a === 1)
+ .where(Rand(0) > 0.1 && 'a === 1).analyze
- val optimized = Optimize.execute(originalQuery.analyze)
- val correctAnswer =
- testRelation
- .where(Rand(0) > 0.1 && 'a === 1 && Rand(0) > 0.1)
- .analyze
+ val optimized = Optimize.execute(originalQuery)
- comparePlans(optimized, correctAnswer)
+ comparePlans(optimized, originalQuery)
}
test("SPARK-16164: Filter pushdown should keep the ordering in the logical plan") {
@@ -508,7 +504,7 @@ class FilterPushdownSuite extends PlanTest {
}
val optimized = Optimize.execute(originalQuery.analyze)
- comparePlans(analysis.EliminateSubqueryAliases(originalQuery.analyze), optimized)
+ comparePlans(originalQuery.analyze, optimized)
}
test("joins: conjunctive predicates") {
@@ -527,7 +523,7 @@ class FilterPushdownSuite extends PlanTest {
left.join(right, condition = Some("x.b".attr === "y.b".attr))
.analyze
- comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer))
+ comparePlans(optimized, correctAnswer)
}
test("joins: conjunctive predicates #2") {
@@ -546,7 +542,7 @@ class FilterPushdownSuite extends PlanTest {
left.join(right, condition = Some("x.b".attr === "y.b".attr))
.analyze
- comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer))
+ comparePlans(optimized, correctAnswer)
}
test("joins: conjunctive predicates #3") {
@@ -570,7 +566,7 @@ class FilterPushdownSuite extends PlanTest {
condition = Some("z.a".attr === "x.b".attr))
.analyze
- comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer))
+ comparePlans(optimized, correctAnswer)
}
test("joins: push down where clause into left anti join") {
@@ -585,7 +581,7 @@ class FilterPushdownSuite extends PlanTest {
x.where("x.a".attr > 10)
.join(y, LeftAnti, Some("x.b".attr === "y.b".attr))
.analyze
- comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer))
+ comparePlans(optimized, correctAnswer)
}
test("joins: only push down join conditions to the right of a left anti join") {
@@ -602,7 +598,7 @@ class FilterPushdownSuite extends PlanTest {
LeftAnti,
Some("x.b".attr === "y.b".attr && "x.a".attr > 10))
.analyze
- comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer))
+ comparePlans(optimized, correctAnswer)
}
test("joins: only push down join conditions to the right of an existence join") {
@@ -620,7 +616,7 @@ class FilterPushdownSuite extends PlanTest {
ExistenceJoin(fillerVal),
Some("x.a".attr > 1))
.analyze
- comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer))
+ comparePlans(optimized, correctAnswer)
}
val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType))
@@ -628,14 +624,14 @@ class FilterPushdownSuite extends PlanTest {
test("generate: predicate referenced no generated column") {
val originalQuery = {
testRelationWithArrayType
- .generate(Explode('c_arr), true, false, Some("arr"))
+ .generate(Explode('c_arr), alias = Some("arr"))
.where(('b >= 5) && ('a > 6))
}
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = {
testRelationWithArrayType
.where(('b >= 5) && ('a > 6))
- .generate(Explode('c_arr), true, false, Some("arr")).analyze
+ .generate(Explode('c_arr), alias = Some("arr")).analyze
}
comparePlans(optimized, correctAnswer)
@@ -644,14 +640,14 @@ class FilterPushdownSuite extends PlanTest {
test("generate: non-deterministic predicate referenced no generated column") {
val originalQuery = {
testRelationWithArrayType
- .generate(Explode('c_arr), true, false, Some("arr"))
+ .generate(Explode('c_arr), alias = Some("arr"))
.where(('b >= 5) && ('a + Rand(10).as("rnd") > 6) && ('col > 6))
}
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = {
testRelationWithArrayType
.where('b >= 5)
- .generate(Explode('c_arr), true, false, Some("arr"))
+ .generate(Explode('c_arr), alias = Some("arr"))
.where('a + Rand(10).as("rnd") > 6 && 'col > 6)
.analyze
}
@@ -663,14 +659,14 @@ class FilterPushdownSuite extends PlanTest {
val generator = Explode('c_arr)
val originalQuery = {
testRelationWithArrayType
- .generate(generator, true, false, Some("arr"))
+ .generate(generator, alias = Some("arr"))
.where(('b >= 5) && ('c > 6))
}
val optimized = Optimize.execute(originalQuery.analyze)
val referenceResult = {
testRelationWithArrayType
.where('b >= 5)
- .generate(generator, true, false, Some("arr"))
+ .generate(generator, alias = Some("arr"))
.where('c > 6).analyze
}
@@ -691,7 +687,7 @@ class FilterPushdownSuite extends PlanTest {
test("generate: all conjuncts referenced generated column") {
val originalQuery = {
testRelationWithArrayType
- .generate(Explode('c_arr), true, false, Some("arr"))
+ .generate(Explode('c_arr), alias = Some("arr"))
.where(('col > 6) || ('b > 5)).analyze
}
val optimized = Optimize.execute(originalQuery)
@@ -813,6 +809,19 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+ test("aggregate: don't push filters if the aggregate has no grouping expressions") {
+ val originalQuery = LocalRelation.apply(testRelation.output, Seq.empty)
+ .select('a, 'b)
+ .groupBy()(count(1))
+ .where(false)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = originalQuery.analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
test("broadcast hint") {
val originalQuery = ResolvedHint(testRelation)
.where('a === 2L && 'b + Rand(10).as("rnd") === 3)
@@ -835,9 +844,9 @@ class FilterPushdownSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = Union(Seq(
- testRelation.where('a === 2L),
- testRelation2.where('d === 2L)))
- .where('b + Rand(10).as("rnd") === 3 && 'c > 5L)
+ testRelation.where('a === 2L && 'c > 5L),
+ testRelation2.where('d === 2L && 'f > 5L)))
+ .where('b + Rand(10).as("rnd") === 3)
.analyze
comparePlans(optimized, correctAnswer)
@@ -1138,12 +1147,13 @@ class FilterPushdownSuite extends PlanTest {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
- // Verify that all conditions preceding the first non-deterministic condition are pushed down
+ // Verify that all conditions except the watermark touching condition are pushed down
// by the optimizer and others are not.
val originalQuery = x.join(y, condition = Some("x.a".attr === 5 && "y.a".attr === 5 &&
"x.a".attr === Rand(10) && "y.b".attr === 5))
- val correctAnswer = x.where("x.a".attr === 5).join(y.where("y.a".attr === 5),
- condition = Some("x.a".attr === Rand(10) && "y.b".attr === 5))
+ val correctAnswer =
+ x.where("x.a".attr === 5).join(y.where("y.a".attr === 5 && "y.b".attr === 5),
+ condition = Some("x.a".attr === Rand(10)))
// CheckAnalysis will ensure nondeterministic expressions not appear in join condition.
// TODO support nondeterministic expressions in join condition.
@@ -1151,16 +1161,16 @@ class FilterPushdownSuite extends PlanTest {
checkAnalysis = false)
}
- test("watermark pushdown: no pushdown on watermark attribute") {
+ test("watermark pushdown: no pushdown on watermark attribute #1") {
val interval = new CalendarInterval(2, 2000L)
- // Verify that all conditions preceding the first watermark touching condition are pushed down
+ // Verify that all conditions except the watermark touching condition are pushed down
// by the optimizer and others are not.
val originalQuery = EventTimeWatermark('b, interval, testRelation)
.where('a === 5 && 'b === 10 && 'c === 5)
val correctAnswer = EventTimeWatermark(
- 'b, interval, testRelation.where('a === 5))
- .where('b === 10 && 'c === 5)
+ 'b, interval, testRelation.where('a === 5 && 'c === 5))
+ .where('b === 10)
comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze,
checkAnalysis = false)
@@ -1169,7 +1179,7 @@ class FilterPushdownSuite extends PlanTest {
test("watermark pushdown: no pushdown for nondeterministic filter") {
val interval = new CalendarInterval(2, 2000L)
- // Verify that all conditions preceding the first watermark touching condition are pushed down
+ // Verify that all conditions except the watermark touching condition are pushed down
// by the optimizer and others are not.
val originalQuery = EventTimeWatermark('c, interval, testRelation)
.where('a === 5 && 'b === Rand(10) && 'c === 5)
@@ -1184,7 +1194,7 @@ class FilterPushdownSuite extends PlanTest {
test("watermark pushdown: full pushdown") {
val interval = new CalendarInterval(2, 2000L)
- // Verify that all conditions preceding the first watermark touching condition are pushed down
+ // Verify that all conditions except the watermark touching condition are pushed down
// by the optimizer and others are not.
val originalQuery = EventTimeWatermark('c, interval, testRelation)
.where('a === 5 && 'b === 10)
@@ -1195,15 +1205,15 @@ class FilterPushdownSuite extends PlanTest {
checkAnalysis = false)
}
- test("watermark pushdown: empty pushdown") {
+ test("watermark pushdown: no pushdown on watermark attribute #2") {
val interval = new CalendarInterval(2, 2000L)
- // Verify that all conditions preceding the first watermark touching condition are pushed down
- // by the optimizer and others are not.
val originalQuery = EventTimeWatermark('a, interval, testRelation)
.where('a === 5 && 'b === 10)
+ val correctAnswer = EventTimeWatermark(
+ 'a, interval, testRelation.where('b === 10)).where('a === 5)
- comparePlans(Optimize.execute(originalQuery.analyze), originalQuery.analyze,
+ comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze,
checkAnalysis = false)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala
index dccb32f0379a8..c28844642aed0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala
@@ -147,8 +147,8 @@ class FoldablePropagationSuite extends PlanTest {
test("Propagate in expand") {
val c1 = Literal(1).as('a)
val c2 = Literal(2).as('b)
- val a1 = c1.toAttribute.withNullability(true)
- val a2 = c2.toAttribute.withNullability(true)
+ val a1 = c1.toAttribute.newInstance().withNullability(true)
+ val a2 = c2.toAttribute.newInstance().withNullability(true)
val expand = Expand(
Seq(Seq(Literal(null), 'b), Seq('a, Literal(null))),
Seq(a1, a2),
@@ -161,4 +161,23 @@ class FoldablePropagationSuite extends PlanTest {
val correctAnswer = correctExpand.where(a1.isNotNull).select(a1, a2).analyze
comparePlans(optimized, correctAnswer)
}
+
+ test("Propagate above outer join") {
+ val left = LocalRelation('a.int).select('a, Literal(1).as('b))
+ val right = LocalRelation('c.int).select('c, Literal(1).as('d))
+
+ val join = left.join(
+ right,
+ joinType = LeftOuter,
+ condition = Some('a === 'c && 'b === 'd))
+ val query = join.select(('b + 3).as('res)).analyze
+ val optimized = Optimize.execute(query)
+
+ val correctAnswer = left.join(
+ right,
+ joinType = LeftOuter,
+ condition = Some('a === 'c && Literal(1) === Literal(1)))
+ .select((Literal(1) + 3).as('res)).analyze
+ comparePlans(optimized, correctAnswer)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala
index 5580f8604ec72..a0708bf7eee9a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala
@@ -236,4 +236,17 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
comparePlans(optimized, originalQuery)
}
}
+
+ test("constraints should be inferred from aliased literals") {
+ val originalLeft = testRelation.subquery('left).as("left")
+ val optimizedLeft = testRelation.subquery('left).where(IsNotNull('a) && 'a === 2).as("left")
+
+ val right = Project(Seq(Literal(2).as("two")), testRelation.subquery('right)).as("right")
+ val condition = Some("left.a".attr === "right.two".attr)
+
+ val original = originalLeft.join(right, Inner, condition)
+ val correct = optimizedLeft.join(right, Inner, condition)
+
+ comparePlans(Optimize.execute(original.analyze), correct.analyze)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
index 97733a754ccc2..ccd9d8dd4d213 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.optimizer
-import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
@@ -118,7 +117,7 @@ class JoinOptimizationSuite extends PlanTest {
queryAnswers foreach { queryAnswerPair =>
val optimized = Optimize.execute(queryAnswerPair._1.analyze)
- comparePlans(optimized, analysis.EliminateSubqueryAliases(queryAnswerPair._2.analyze))
+ comparePlans(optimized, queryAnswerPair._2.analyze)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
index d7acd139225cd..478118ed709f7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
@@ -169,7 +169,7 @@ class OptimizeInSuite extends PlanTest {
val optimizedPlan = OptimizeIn(plan)
optimizedPlan match {
case Filter(cond, _)
- if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getSet().size == 3 =>
+ if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].set.size == 3 =>
// pass
case _ => fail("Unexpected result for OptimizedIn")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
index bc1c48b99c295..3964508e3a55e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
@@ -21,8 +21,9 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.StructType
@@ -78,17 +79,18 @@ class PropagateEmptyRelationSuite extends PlanTest {
(true, false, Inner, Some(LocalRelation('a.int, 'b.int))),
(true, false, Cross, Some(LocalRelation('a.int, 'b.int))),
- (true, false, LeftOuter, None),
+ (true, false, LeftOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)),
(true, false, RightOuter, Some(LocalRelation('a.int, 'b.int))),
- (true, false, FullOuter, None),
- (true, false, LeftAnti, None),
- (true, false, LeftSemi, None),
+ (true, false, FullOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)),
+ (true, false, LeftAnti, Some(testRelation1)),
+ (true, false, LeftSemi, Some(LocalRelation('a.int))),
(false, true, Inner, Some(LocalRelation('a.int, 'b.int))),
(false, true, Cross, Some(LocalRelation('a.int, 'b.int))),
(false, true, LeftOuter, Some(LocalRelation('a.int, 'b.int))),
- (false, true, RightOuter, None),
- (false, true, FullOuter, None),
+ (false, true, RightOuter,
+ Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)),
+ (false, true, FullOuter, Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)),
(false, true, LeftAnti, Some(LocalRelation('a.int))),
(false, true, LeftSemi, Some(LocalRelation('a.int))),
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
index 0fa1aaeb9e164..e9701ffd2c54b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Alias, Not}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, Not}
import org.apache.spark.sql.catalyst.expressions.aggregate.First
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical._
@@ -198,6 +198,14 @@ class ReplaceOperatorSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+ test("add one grouping key if necessary when replace Deduplicate with Aggregate") {
+ val input = LocalRelation()
+ val query = Deduplicate(Seq.empty, input) // dropDuplicates()
+ val optimized = Optimize.execute(query.analyze)
+ val correctAnswer = Aggregate(Seq(Literal(1)), input.output, input)
+ comparePlans(optimized, correctAnswer)
+ }
+
test("don't replace streaming Deduplicate") {
val input = LocalRelation(Seq('a.int, 'b.int), isStreaming = true)
val attrA = input.output(0)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala
new file mode 100644
index 0000000000000..6b3739c372c3a
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.ListQuery
+import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+
+class RewriteSubquerySuite extends PlanTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Column Pruning", FixedPoint(100), ColumnPruning) ::
+ Batch("Rewrite Subquery", FixedPoint(1),
+ RewritePredicateSubquery,
+ ColumnPruning,
+ CollapseProject,
+ RemoveRedundantProject) :: Nil
+ }
+
+ test("Column pruning after rewriting predicate subquery") {
+ val relation = LocalRelation('a.int, 'b.int)
+ val relInSubquery = LocalRelation('x.int, 'y.int, 'z.int)
+
+ val query = relation.where('a.in(ListQuery(relInSubquery.select('x)))).select('a)
+
+ val optimized = Optimize.execute(query.analyze)
+ val correctAnswer = relation
+ .select('a)
+ .join(relInSubquery.select('x), LeftSemi, Some('a === 'x))
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
index 3634accf1ec21..0d11958876ce9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -164,6 +165,12 @@ class ComplexTypesSuite extends PlanTest{
comparePlans(Optimizer execute query, expected)
}
+ test("SPARK-22570: CreateArray should not create a lot of global variables") {
+ val ctx = new CodegenContext
+ CreateArray(Seq(Literal(1))).genCode(ctx)
+ assert(ctx.inlinedMutableStates.length == 0)
+ }
+
test("simplify map ops") {
val rel = relation
.select(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index d34a83c42c67e..812bfdd7bb885 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -276,7 +276,7 @@ class PlanParserSuite extends AnalysisTest {
assertEqual(
"select * from t lateral view explode(x) expl as x",
table("t")
- .generate(explode, join = true, outer = false, Some("expl"), Seq("x"))
+ .generate(explode, alias = Some("expl"), outputNames = Seq("x"))
.select(star()))
// Multiple lateral views
@@ -286,12 +286,12 @@ class PlanParserSuite extends AnalysisTest {
|lateral view explode(x) expl
|lateral view outer json_tuple(x, y) jtup q, z""".stripMargin,
table("t")
- .generate(explode, join = true, outer = false, Some("expl"), Seq.empty)
- .generate(jsonTuple, join = true, outer = true, Some("jtup"), Seq("q", "z"))
+ .generate(explode, alias = Some("expl"))
+ .generate(jsonTuple, outer = true, alias = Some("jtup"), outputNames = Seq("q", "z"))
.select(star()))
// Multi-Insert lateral views.
- val from = table("t1").generate(explode, join = true, outer = false, Some("expl"), Seq("x"))
+ val from = table("t1").generate(explode, alias = Some("expl"), outputNames = Seq("x"))
assertEqual(
"""from t1
|lateral view explode(x) expl as x
@@ -303,7 +303,7 @@ class PlanParserSuite extends AnalysisTest {
|where s < 10
""".stripMargin,
Union(from
- .generate(jsonTuple, join = true, outer = false, Some("jtup"), Seq("q", "z"))
+ .generate(jsonTuple, alias = Some("jtup"), outputNames = Seq("q", "z"))
.select(star())
.insertInto("t2"),
from.where('s < 10).select(star()).insertInto("t3")))
@@ -312,10 +312,8 @@ class PlanParserSuite extends AnalysisTest {
val expected = table("t")
.generate(
UnresolvedGenerator(FunctionIdentifier("posexplode"), Seq('x)),
- join = true,
- outer = false,
- Some("posexpl"),
- Seq("x", "y"))
+ alias = Some("posexpl"),
+ outputNames = Seq("x", "y"))
.select(star())
assertEqual(
"select * from t lateral view posexplode(x) posexpl as x, y",
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
index cdf912df7c76a..14041747fd20e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
@@ -23,8 +23,8 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType
/**
- * This suite is used to test [[LogicalPlan]]'s `resolveOperators` and make sure it can correctly
- * skips sub-trees that have already been marked as analyzed.
+ * This suite is used to test [[LogicalPlan]]'s `transformUp/transformDown` plus analysis barrier
+ * and make sure it can correctly skip sub-trees that have already been analyzed.
*/
class LogicalPlanSuite extends SparkFunSuite {
private var invocationCount = 0
@@ -36,39 +36,53 @@ class LogicalPlanSuite extends SparkFunSuite {
private val testRelation = LocalRelation()
- test("resolveOperator runs on operators") {
+ test("transformUp runs on operators") {
invocationCount = 0
val plan = Project(Nil, testRelation)
- plan resolveOperators function
+ plan transformUp function
assert(invocationCount === 1)
+
+ invocationCount = 0
+ plan transformDown function
+ assert(invocationCount === 1)
}
- test("resolveOperator runs on operators recursively") {
+ test("transformUp runs on operators recursively") {
invocationCount = 0
val plan = Project(Nil, Project(Nil, testRelation))
- plan resolveOperators function
+ plan transformUp function
assert(invocationCount === 2)
+
+ invocationCount = 0
+ plan transformDown function
+ assert(invocationCount === 2)
}
- test("resolveOperator skips all ready resolved plans") {
+ test("transformUp skips all ready resolved plans wrapped in analysis barrier") {
invocationCount = 0
- val plan = Project(Nil, Project(Nil, testRelation))
- plan.foreach(_.setAnalyzed())
- plan resolveOperators function
+ val plan = AnalysisBarrier(Project(Nil, Project(Nil, testRelation)))
+ plan transformUp function
assert(invocationCount === 0)
+
+ invocationCount = 0
+ plan transformDown function
+ assert(invocationCount === 0)
}
- test("resolveOperator skips partially resolved plans") {
+ test("transformUp skips partially resolved plans wrapped in analysis barrier") {
invocationCount = 0
- val plan1 = Project(Nil, testRelation)
+ val plan1 = AnalysisBarrier(Project(Nil, testRelation))
val plan2 = Project(Nil, plan1)
- plan1.foreach(_.setAnalyzed())
- plan2 resolveOperators function
+ plan2 transformUp function
assert(invocationCount === 1)
+
+ invocationCount = 0
+ plan2 transformDown function
+ assert(invocationCount === 1)
}
test("isStreaming") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
index 455037e6c9952..2b1fe987a7960 100755
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
@@ -22,7 +22,7 @@ import java.sql.Date
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.LeftOuter
-import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Join, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
@@ -33,7 +33,7 @@ import org.apache.spark.sql.types._
*/
class FilterEstimationSuite extends StatsEstimationTestBase {
- // Suppose our test table has 10 rows and 6 columns.
+ // Suppose our test table has 10 rows and 10 columns.
// column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
// Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4
val attrInt = AttributeReference("cint", IntegerType)()
@@ -91,6 +91,26 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val colStatInt4 = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4)
+ // column cintHgm has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 with histogram.
+ // Note that cintHgm has an even distribution with histogram information built.
+ // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4
+ val attrIntHgm = AttributeReference("cintHgm", IntegerType)()
+ val hgmInt = Histogram(2.0, Array(HistogramBin(1.0, 2.0, 2),
+ HistogramBin(2.0, 4.0, 2), HistogramBin(4.0, 6.0, 2),
+ HistogramBin(6.0, 8.0, 2), HistogramBin(8.0, 10.0, 2)))
+ val colStatIntHgm = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))
+
+ // column cintSkewHgm has values: 1, 4, 4, 5, 5, 5, 5, 6, 6, 10 with histogram.
+ // Note that cintSkewHgm has a skewed distribution with histogram information built.
+ // distinctCount:5, min:1, max:10, nullCount:0, avgLen:4, maxLen:4
+ val attrIntSkewHgm = AttributeReference("cintSkewHgm", IntegerType)()
+ val hgmIntSkew = Histogram(2.0, Array(HistogramBin(1.0, 4.0, 2),
+ HistogramBin(4.0, 5.0, 2), HistogramBin(5.0, 5.0, 1),
+ HistogramBin(5.0, 6.0, 2), HistogramBin(6.0, 10.0, 2)))
+ val colStatIntSkewHgm = ColumnStat(distinctCount = 5, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))
+
val attributeMap = AttributeMap(Seq(
attrInt -> colStatInt,
attrBool -> colStatBool,
@@ -100,7 +120,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
attrString -> colStatString,
attrInt2 -> colStatInt2,
attrInt3 -> colStatInt3,
- attrInt4 -> colStatInt4
+ attrInt4 -> colStatInt4,
+ attrIntHgm -> colStatIntHgm,
+ attrIntSkewHgm -> colStatIntSkewHgm
))
test("true") {
@@ -359,7 +381,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
test("cbool > false") {
validateEstimatedStats(
Filter(GreaterThan(attrBool, Literal(false)), childStatsTestPlan(Seq(attrBool), 10L)),
- Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true),
+ Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(false), max = Some(true),
nullCount = 0, avgLen = 1, maxLen = 1)),
expectedRowCount = 5)
}
@@ -578,6 +600,193 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
expectedRowCount = 5)
}
+ // The following test cases have histogram information collected for the test column with
+ // an even distribution
+ test("Not(cintHgm < 3 AND null)") {
+ val condition = Not(And(LessThan(attrIntHgm, Literal(3)), Literal(null, IntegerType)))
+ validateEstimatedStats(
+ Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)),
+ Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = 7)),
+ expectedRowCount = 7)
+ }
+
+ test("cintHgm = 5") {
+ validateEstimatedStats(
+ Filter(EqualTo(attrIntHgm, Literal(5)), childStatsTestPlan(Seq(attrIntHgm), 10L)),
+ Seq(attrIntHgm -> ColumnStat(distinctCount = 1, min = Some(5), max = Some(5),
+ nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))),
+ expectedRowCount = 1)
+ }
+
+ test("cintHgm = 0") {
+ // This is an out-of-range case since 0 is outside the range [min, max]
+ validateEstimatedStats(
+ Filter(EqualTo(attrIntHgm, Literal(0)), childStatsTestPlan(Seq(attrIntHgm), 10L)),
+ Nil,
+ expectedRowCount = 0)
+ }
+
+ test("cintHgm < 3") {
+ validateEstimatedStats(
+ Filter(LessThan(attrIntHgm, Literal(3)), childStatsTestPlan(Seq(attrIntHgm), 10L)),
+ Seq(attrIntHgm -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
+ nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))),
+ expectedRowCount = 3)
+ }
+
+ test("cintHgm < 0") {
+ // This is a corner case since literal 0 is smaller than min.
+ validateEstimatedStats(
+ Filter(LessThan(attrIntHgm, Literal(0)), childStatsTestPlan(Seq(attrIntHgm), 10L)),
+ Nil,
+ expectedRowCount = 0)
+ }
+
+ test("cintHgm <= 3") {
+ validateEstimatedStats(
+ Filter(LessThanOrEqual(attrIntHgm, Literal(3)), childStatsTestPlan(Seq(attrIntHgm), 10L)),
+ Seq(attrIntHgm -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
+ nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))),
+ expectedRowCount = 3)
+ }
+
+ test("cintHgm > 6") {
+ validateEstimatedStats(
+ Filter(GreaterThan(attrIntHgm, Literal(6)), childStatsTestPlan(Seq(attrIntHgm), 10L)),
+ Seq(attrIntHgm -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))),
+ expectedRowCount = 4)
+ }
+
+ test("cintHgm > 10") {
+ // This is a corner case since max value is 10.
+ validateEstimatedStats(
+ Filter(GreaterThan(attrIntHgm, Literal(10)), childStatsTestPlan(Seq(attrIntHgm), 10L)),
+ Nil,
+ expectedRowCount = 0)
+ }
+
+ test("cintHgm >= 6") {
+ validateEstimatedStats(
+ Filter(GreaterThanOrEqual(attrIntHgm, Literal(6)), childStatsTestPlan(Seq(attrIntHgm), 10L)),
+ Seq(attrIntHgm -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))),
+ expectedRowCount = 5)
+ }
+
+ test("cintHgm > 3 AND cintHgm <= 6") {
+ val condition = And(GreaterThan(attrIntHgm,
+ Literal(3)), LessThanOrEqual(attrIntHgm, Literal(6)))
+ validateEstimatedStats(
+ Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)),
+ Seq(attrIntHgm -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6),
+ nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))),
+ expectedRowCount = 4)
+ }
+
+ test("cintHgm = 3 OR cintHgm = 6") {
+ val condition = Or(EqualTo(attrIntHgm, Literal(3)), EqualTo(attrIntHgm, Literal(6)))
+ validateEstimatedStats(
+ Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)),
+ Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = 3)),
+ expectedRowCount = 3)
+ }
+
+ // The following test cases have histogram information collected for the test column with
+ // a skewed distribution.
+ test("Not(cintSkewHgm < 3 AND null)") {
+ val condition = Not(And(LessThan(attrIntSkewHgm, Literal(3)), Literal(null, IntegerType)))
+ validateEstimatedStats(
+ Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)),
+ Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = 5)),
+ expectedRowCount = 9)
+ }
+
+ test("cintSkewHgm = 5") {
+ validateEstimatedStats(
+ Filter(EqualTo(attrIntSkewHgm, Literal(5)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)),
+ Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(5), max = Some(5),
+ nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))),
+ expectedRowCount = 4)
+ }
+
+ test("cintSkewHgm = 0") {
+ // This is an out-of-range case since 0 is outside the range [min, max]
+ validateEstimatedStats(
+ Filter(EqualTo(attrIntSkewHgm, Literal(0)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)),
+ Nil,
+ expectedRowCount = 0)
+ }
+
+ test("cintSkewHgm < 3") {
+ validateEstimatedStats(
+ Filter(LessThan(attrIntSkewHgm, Literal(3)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)),
+ Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(3),
+ nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))),
+ expectedRowCount = 2)
+ }
+
+ test("cintSkewHgm < 0") {
+ // This is a corner case since literal 0 is smaller than min.
+ validateEstimatedStats(
+ Filter(LessThan(attrIntSkewHgm, Literal(0)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)),
+ Nil,
+ expectedRowCount = 0)
+ }
+
+ test("cintSkewHgm <= 3") {
+ validateEstimatedStats(
+ Filter(LessThanOrEqual(attrIntSkewHgm, Literal(3)),
+ childStatsTestPlan(Seq(attrIntSkewHgm), 10L)),
+ Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(3),
+ nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))),
+ expectedRowCount = 2)
+ }
+
+ test("cintSkewHgm > 6") {
+ validateEstimatedStats(
+ Filter(GreaterThan(attrIntSkewHgm, Literal(6)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)),
+ Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(6), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))),
+ expectedRowCount = 2)
+ }
+
+ test("cintSkewHgm > 10") {
+ // This is a corner case since max value is 10.
+ validateEstimatedStats(
+ Filter(GreaterThan(attrIntSkewHgm, Literal(10)),
+ childStatsTestPlan(Seq(attrIntSkewHgm), 10L)),
+ Nil,
+ expectedRowCount = 0)
+ }
+
+ test("cintSkewHgm >= 6") {
+ validateEstimatedStats(
+ Filter(GreaterThanOrEqual(attrIntSkewHgm, Literal(6)),
+ childStatsTestPlan(Seq(attrIntSkewHgm), 10L)),
+ Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 2, min = Some(6), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))),
+ expectedRowCount = 3)
+ }
+
+ test("cintSkewHgm > 3 AND cintSkewHgm <= 6") {
+ val condition = And(GreaterThan(attrIntSkewHgm,
+ Literal(3)), LessThanOrEqual(attrIntSkewHgm, Literal(6)))
+ validateEstimatedStats(
+ Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)),
+ Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6),
+ nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))),
+ expectedRowCount = 8)
+ }
+
+ test("cintSkewHgm = 3 OR cintSkewHgm = 6") {
+ val condition = Or(EqualTo(attrIntSkewHgm, Literal(3)), EqualTo(attrIntSkewHgm, Literal(6)))
+ validateEstimatedStats(
+ Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)),
+ Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = 2)),
+ expectedRowCount = 3)
+ }
+
private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = {
StatsTestPlan(
outputList = outList,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
index 097c78eb27fca..26139d85d25fb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
@@ -23,7 +23,7 @@ import scala.collection.mutable
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, AttributeReference, EqualTo}
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Project, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types.{DateType, TimestampType, _}
@@ -67,6 +67,213 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
rowCount = 2,
attributeStats = AttributeMap(Seq("key-1-2", "key-2-3").map(nameToColInfo)))
+ private def estimateByHistogram(
+ leftHistogram: Histogram,
+ rightHistogram: Histogram,
+ expectedMin: Double,
+ expectedMax: Double,
+ expectedNdv: Long,
+ expectedRows: Long): Unit = {
+ val col1 = attr("key1")
+ val col2 = attr("key2")
+ val c1 = generateJoinChild(col1, leftHistogram, expectedMin, expectedMax)
+ val c2 = generateJoinChild(col2, rightHistogram, expectedMin, expectedMax)
+
+ val c1JoinC2 = Join(c1, c2, Inner, Some(EqualTo(col1, col2)))
+ val c2JoinC1 = Join(c2, c1, Inner, Some(EqualTo(col2, col1)))
+ val expectedStatsAfterJoin = Statistics(
+ sizeInBytes = expectedRows * (8 + 2 * 4),
+ rowCount = Some(expectedRows),
+ attributeStats = AttributeMap(Seq(
+ col1 -> c1.stats.attributeStats(col1).copy(
+ distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax)),
+ col2 -> c2.stats.attributeStats(col2).copy(
+ distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax))))
+ )
+
+ // Join order should not affect estimation result.
+ Seq(c1JoinC2, c2JoinC1).foreach { join =>
+ assert(join.stats == expectedStatsAfterJoin)
+ }
+ }
+
+ private def generateJoinChild(
+ col: Attribute,
+ histogram: Histogram,
+ expectedMin: Double,
+ expectedMax: Double): LogicalPlan = {
+ val colStat = inferColumnStat(histogram)
+ StatsTestPlan(
+ outputList = Seq(col),
+ rowCount = (histogram.height * histogram.bins.length).toLong,
+ attributeStats = AttributeMap(Seq(col -> colStat)))
+ }
+
+ /** Column statistics should be consistent with histograms in tests. */
+ private def inferColumnStat(histogram: Histogram): ColumnStat = {
+ var ndv = 0L
+ for (i <- histogram.bins.indices) {
+ val bin = histogram.bins(i)
+ if (i == 0 || bin.hi != histogram.bins(i - 1).hi) {
+ ndv += bin.ndv
+ }
+ }
+ ColumnStat(distinctCount = ndv, min = Some(histogram.bins.head.lo),
+ max = Some(histogram.bins.last.hi), nullCount = 0, avgLen = 4, maxLen = 4,
+ histogram = Some(histogram))
+ }
+
+ test("equi-height histograms: a bin is contained by another one") {
+ val histogram1 = Histogram(height = 300, Array(
+ HistogramBin(lo = 10, hi = 30, ndv = 10), HistogramBin(lo = 30, hi = 60, ndv = 30)))
+ val histogram2 = Histogram(height = 100, Array(
+ HistogramBin(lo = 0, hi = 50, ndv = 50), HistogramBin(lo = 50, hi = 100, ndv = 40)))
+ // test bin trimming
+ val (t0, h0) = trimBin(histogram2.bins(0), height = 100, lowerBound = 10, upperBound = 60)
+ assert(t0 == HistogramBin(lo = 10, hi = 50, ndv = 40) && h0 == 80)
+ val (t1, h1) = trimBin(histogram2.bins(1), height = 100, lowerBound = 10, upperBound = 60)
+ assert(t1 == HistogramBin(lo = 50, hi = 60, ndv = 8) && h1 == 20)
+
+ val expectedRanges = Seq(
+ // histogram1.bins(0) overlaps t0
+ OverlappedRange(10, 30, 10, 40 * 1 / 2, 300, 80 * 1 / 2),
+ // histogram1.bins(1) overlaps t0
+ OverlappedRange(30, 50, 30 * 2 / 3, 40 * 1 / 2, 300 * 2 / 3, 80 * 1 / 2),
+ // histogram1.bins(1) overlaps t1
+ OverlappedRange(50, 60, 30 * 1 / 3, 8, 300 * 1 / 3, 20)
+ )
+ assert(expectedRanges.equals(
+ getOverlappedRanges(histogram1, histogram2, lowerBound = 10, upperBound = 60)))
+
+ estimateByHistogram(
+ leftHistogram = histogram1,
+ rightHistogram = histogram2,
+ expectedMin = 10,
+ expectedMax = 60,
+ expectedNdv = 10 + 20 + 8,
+ expectedRows = 300 * 40 / 20 + 200 * 40 / 20 + 100 * 20 / 10)
+ }
+
+ test("equi-height histograms: a bin has only one value after trimming") {
+ val histogram1 = Histogram(height = 300, Array(
+ HistogramBin(lo = 50, hi = 60, ndv = 10), HistogramBin(lo = 60, hi = 75, ndv = 3)))
+ val histogram2 = Histogram(height = 100, Array(
+ HistogramBin(lo = 0, hi = 50, ndv = 50), HistogramBin(lo = 50, hi = 100, ndv = 40)))
+ // test bin trimming
+ val (t0, h0) = trimBin(histogram2.bins(0), height = 100, lowerBound = 50, upperBound = 75)
+ assert(t0 == HistogramBin(lo = 50, hi = 50, ndv = 1) && h0 == 2)
+ val (t1, h1) = trimBin(histogram2.bins(1), height = 100, lowerBound = 50, upperBound = 75)
+ assert(t1 == HistogramBin(lo = 50, hi = 75, ndv = 20) && h1 == 50)
+
+ val expectedRanges = Seq(
+ // histogram1.bins(0) overlaps t0
+ OverlappedRange(50, 50, 1, 1, 300 / 10, 2),
+ // histogram1.bins(0) overlaps t1
+ OverlappedRange(50, 60, 10, 20 * 10 / 25, 300, 50 * 10 / 25),
+ // histogram1.bins(1) overlaps t1
+ OverlappedRange(60, 75, 3, 20 * 15 / 25, 300, 50 * 15 / 25)
+ )
+ assert(expectedRanges.equals(
+ getOverlappedRanges(histogram1, histogram2, lowerBound = 50, upperBound = 75)))
+
+ estimateByHistogram(
+ leftHistogram = histogram1,
+ rightHistogram = histogram2,
+ expectedMin = 50,
+ expectedMax = 75,
+ expectedNdv = 1 + 8 + 3,
+ expectedRows = 30 * 2 / 1 + 300 * 20 / 10 + 300 * 30 / 12)
+ }
+
+ test("equi-height histograms: skew distribution (some bins have only one value)") {
+ val histogram1 = Histogram(height = 300, Array(
+ HistogramBin(lo = 30, hi = 30, ndv = 1),
+ HistogramBin(lo = 30, hi = 30, ndv = 1),
+ HistogramBin(lo = 30, hi = 60, ndv = 30)))
+ val histogram2 = Histogram(height = 100, Array(
+ HistogramBin(lo = 0, hi = 50, ndv = 50), HistogramBin(lo = 50, hi = 100, ndv = 40)))
+ // test bin trimming
+ val (t0, h0) = trimBin(histogram2.bins(0), height = 100, lowerBound = 30, upperBound = 60)
+ assert(t0 == HistogramBin(lo = 30, hi = 50, ndv = 20) && h0 == 40)
+ val (t1, h1) = trimBin(histogram2.bins(1), height = 100, lowerBound = 30, upperBound = 60)
+ assert(t1 ==HistogramBin(lo = 50, hi = 60, ndv = 8) && h1 == 20)
+
+ val expectedRanges = Seq(
+ OverlappedRange(30, 30, 1, 1, 300, 40 / 20),
+ OverlappedRange(30, 30, 1, 1, 300, 40 / 20),
+ OverlappedRange(30, 50, 30 * 2 / 3, 20, 300 * 2 / 3, 40),
+ OverlappedRange(50, 60, 30 * 1 / 3, 8, 300 * 1 / 3, 20)
+ )
+ assert(expectedRanges.equals(
+ getOverlappedRanges(histogram1, histogram2, lowerBound = 30, upperBound = 60)))
+
+ estimateByHistogram(
+ leftHistogram = histogram1,
+ rightHistogram = histogram2,
+ expectedMin = 30,
+ expectedMax = 60,
+ expectedNdv = 1 + 20 + 8,
+ expectedRows = 300 * 2 / 1 + 300 * 2 / 1 + 200 * 40 / 20 + 100 * 20 / 10)
+ }
+
+ test("equi-height histograms: skew distribution (histograms have different skewed values") {
+ val histogram1 = Histogram(height = 300, Array(
+ HistogramBin(lo = 30, hi = 30, ndv = 1), HistogramBin(lo = 30, hi = 60, ndv = 30)))
+ val histogram2 = Histogram(height = 100, Array(
+ HistogramBin(lo = 0, hi = 50, ndv = 50), HistogramBin(lo = 50, hi = 50, ndv = 1)))
+ // test bin trimming
+ val (t0, h0) = trimBin(histogram1.bins(1), height = 300, lowerBound = 30, upperBound = 50)
+ assert(t0 == HistogramBin(lo = 30, hi = 50, ndv = 20) && h0 == 200)
+ val (t1, h1) = trimBin(histogram2.bins(0), height = 100, lowerBound = 30, upperBound = 50)
+ assert(t1 == HistogramBin(lo = 30, hi = 50, ndv = 20) && h1 == 40)
+
+ val expectedRanges = Seq(
+ OverlappedRange(30, 30, 1, 1, 300, 40 / 20),
+ OverlappedRange(30, 50, 20, 20, 200, 40),
+ OverlappedRange(50, 50, 1, 1, 200 / 20, 100)
+ )
+ assert(expectedRanges.equals(
+ getOverlappedRanges(histogram1, histogram2, lowerBound = 30, upperBound = 50)))
+
+ estimateByHistogram(
+ leftHistogram = histogram1,
+ rightHistogram = histogram2,
+ expectedMin = 30,
+ expectedMax = 50,
+ expectedNdv = 1 + 20,
+ expectedRows = 300 * 2 / 1 + 200 * 40 / 20 + 10 * 100 / 1)
+ }
+
+ test("equi-height histograms: skew distribution (both histograms have the same skewed value") {
+ val histogram1 = Histogram(height = 300, Array(
+ HistogramBin(lo = 30, hi = 30, ndv = 1), HistogramBin(lo = 30, hi = 60, ndv = 30)))
+ val histogram2 = Histogram(height = 150, Array(
+ HistogramBin(lo = 0, hi = 30, ndv = 30), HistogramBin(lo = 30, hi = 30, ndv = 1)))
+ // test bin trimming
+ val (t0, h0) = trimBin(histogram1.bins(1), height = 300, lowerBound = 30, upperBound = 30)
+ assert(t0 == HistogramBin(lo = 30, hi = 30, ndv = 1) && h0 == 10)
+ val (t1, h1) = trimBin(histogram2.bins(0), height = 150, lowerBound = 30, upperBound = 30)
+ assert(t1 == HistogramBin(lo = 30, hi = 30, ndv = 1) && h1 == 5)
+
+ val expectedRanges = Seq(
+ OverlappedRange(30, 30, 1, 1, 300, 5),
+ OverlappedRange(30, 30, 1, 1, 300, 150),
+ OverlappedRange(30, 30, 1, 1, 10, 5),
+ OverlappedRange(30, 30, 1, 1, 10, 150)
+ )
+ assert(expectedRanges.equals(
+ getOverlappedRanges(histogram1, histogram2, lowerBound = 30, upperBound = 30)))
+
+ estimateByHistogram(
+ leftHistogram = histogram1,
+ rightHistogram = histogram2,
+ expectedMin = 30,
+ expectedMax = 30,
+ // only one value: 30
+ expectedNdv = 1,
+ expectedRows = 300 * 5 / 1 + 300 * 150 / 1 + 10 * 5 / 1 + 10 * 150 / 1)
+ }
+
test("cross join") {
// table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5)
// table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
index c8cf16d937352..625ff38943fa3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
@@ -563,6 +563,76 @@ class DateTimeUtilsSuite extends SparkFunSuite {
}
}
+ test("truncTimestamp") {
+ def testTrunc(
+ level: Int,
+ expected: String,
+ inputTS: SQLTimestamp,
+ timezone: TimeZone = DateTimeUtils.defaultTimeZone()): Unit = {
+ val truncated =
+ DateTimeUtils.truncTimestamp(inputTS, level, timezone)
+ val expectedTS =
+ DateTimeUtils.stringToTimestamp(UTF8String.fromString(expected))
+ assert(truncated === expectedTS.get)
+ }
+
+ val defaultInputTS =
+ DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-05T09:32:05.359"))
+ val defaultInputTS1 =
+ DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-31T20:32:05.359"))
+ val defaultInputTS2 =
+ DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-04-01T02:32:05.359"))
+ val defaultInputTS3 =
+ DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-30T02:32:05.359"))
+ val defaultInputTS4 =
+ DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-29T02:32:05.359"))
+
+ testTrunc(DateTimeUtils.TRUNC_TO_YEAR, "2015-01-01T00:00:00", defaultInputTS.get)
+ testTrunc(DateTimeUtils.TRUNC_TO_MONTH, "2015-03-01T00:00:00", defaultInputTS.get)
+ testTrunc(DateTimeUtils.TRUNC_TO_DAY, "2015-03-05T00:00:00", defaultInputTS.get)
+ testTrunc(DateTimeUtils.TRUNC_TO_HOUR, "2015-03-05T09:00:00", defaultInputTS.get)
+ testTrunc(DateTimeUtils.TRUNC_TO_MINUTE, "2015-03-05T09:32:00", defaultInputTS.get)
+ testTrunc(DateTimeUtils.TRUNC_TO_SECOND, "2015-03-05T09:32:05", defaultInputTS.get)
+ testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-02T00:00:00", defaultInputTS.get)
+ testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", defaultInputTS1.get)
+ testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", defaultInputTS2.get)
+ testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", defaultInputTS3.get)
+ testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-23T00:00:00", defaultInputTS4.get)
+ testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", defaultInputTS.get)
+ testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", defaultInputTS1.get)
+ testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-04-01T00:00:00", defaultInputTS2.get)
+
+ for (tz <- DateTimeTestUtils.ALL_TIMEZONES) {
+ DateTimeTestUtils.withDefaultTimeZone(tz) {
+ val inputTS =
+ DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-05T09:32:05.359"))
+ val inputTS1 =
+ DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-31T20:32:05.359"))
+ val inputTS2 =
+ DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-04-01T02:32:05.359"))
+ val inputTS3 =
+ DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-30T02:32:05.359"))
+ val inputTS4 =
+ DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-29T02:32:05.359"))
+
+ testTrunc(DateTimeUtils.TRUNC_TO_YEAR, "2015-01-01T00:00:00", inputTS.get, tz)
+ testTrunc(DateTimeUtils.TRUNC_TO_MONTH, "2015-03-01T00:00:00", inputTS.get, tz)
+ testTrunc(DateTimeUtils.TRUNC_TO_DAY, "2015-03-05T00:00:00", inputTS.get, tz)
+ testTrunc(DateTimeUtils.TRUNC_TO_HOUR, "2015-03-05T09:00:00", inputTS.get, tz)
+ testTrunc(DateTimeUtils.TRUNC_TO_MINUTE, "2015-03-05T09:32:00", inputTS.get, tz)
+ testTrunc(DateTimeUtils.TRUNC_TO_SECOND, "2015-03-05T09:32:05", inputTS.get, tz)
+ testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-02T00:00:00", inputTS.get, tz)
+ testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS1.get, tz)
+ testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS2.get, tz)
+ testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS3.get, tz)
+ testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-23T00:00:00", inputTS4.get, tz)
+ testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", inputTS.get, tz)
+ testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", inputTS1.get, tz)
+ testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-04-01T00:00:00", inputTS2.get, tz)
+ }
+ }
+ }
+
test("daysToMillis and millisToDays") {
val c = Calendar.getInstance(TimeZonePST)
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 4db3fea008ee9..ef41837f89d68 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.3.0-SNAPSHOT
+ 2.4.0-SNAPSHOT
../../pom.xml
@@ -38,7 +38,7 @@
com.univocity
univocity-parsers
- 2.5.4
+ 2.5.9
jar
@@ -195,7 +195,7 @@
org.scalatest
scalatest-maven-plugin
- -ea -Xmx4g -Xss4m -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m
+ -ea -Xmx4g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java
new file mode 100644
index 0000000000000..b6e792274da11
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc;
+
+import java.math.BigDecimal;
+
+import org.apache.orc.storage.ql.exec.vector.*;
+
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.sql.types.TimestampType;
+import org.apache.spark.unsafe.types.UTF8String;
+
+/**
+ * A column vector class wrapping Hive's ColumnVector. Because Spark ColumnarBatch only accepts
+ * Spark's vectorized.ColumnVector, this column vector is used to adapt Hive ColumnVector with
+ * Spark ColumnarVector.
+ */
+public class OrcColumnVector extends org.apache.spark.sql.vectorized.ColumnVector {
+ private ColumnVector baseData;
+ private LongColumnVector longData;
+ private DoubleColumnVector doubleData;
+ private BytesColumnVector bytesData;
+ private DecimalColumnVector decimalData;
+ private TimestampColumnVector timestampData;
+ private final boolean isTimestamp;
+
+ private int batchSize;
+
+ OrcColumnVector(DataType type, ColumnVector vector) {
+ super(type);
+
+ if (type instanceof TimestampType) {
+ isTimestamp = true;
+ } else {
+ isTimestamp = false;
+ }
+
+ baseData = vector;
+ if (vector instanceof LongColumnVector) {
+ longData = (LongColumnVector) vector;
+ } else if (vector instanceof DoubleColumnVector) {
+ doubleData = (DoubleColumnVector) vector;
+ } else if (vector instanceof BytesColumnVector) {
+ bytesData = (BytesColumnVector) vector;
+ } else if (vector instanceof DecimalColumnVector) {
+ decimalData = (DecimalColumnVector) vector;
+ } else if (vector instanceof TimestampColumnVector) {
+ timestampData = (TimestampColumnVector) vector;
+ } else {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ public void setBatchSize(int batchSize) {
+ this.batchSize = batchSize;
+ }
+
+ @Override
+ public void close() {
+
+ }
+
+ @Override
+ public int numNulls() {
+ if (baseData.isRepeating) {
+ if (baseData.isNull[0]) {
+ return batchSize;
+ } else {
+ return 0;
+ }
+ } else if (baseData.noNulls) {
+ return 0;
+ } else {
+ int count = 0;
+ for (int i = 0; i < batchSize; i++) {
+ if (baseData.isNull[i]) count++;
+ }
+ return count;
+ }
+ }
+
+ /* A helper method to get the row index in a column. */
+ private int getRowIndex(int rowId) {
+ return baseData.isRepeating ? 0 : rowId;
+ }
+
+ @Override
+ public boolean isNullAt(int rowId) {
+ return baseData.isNull[getRowIndex(rowId)];
+ }
+
+ @Override
+ public boolean getBoolean(int rowId) {
+ return longData.vector[getRowIndex(rowId)] == 1;
+ }
+
+ @Override
+ public boolean[] getBooleans(int rowId, int count) {
+ boolean[] res = new boolean[count];
+ for (int i = 0; i < count; i++) {
+ res[i] = getBoolean(rowId + i);
+ }
+ return res;
+ }
+
+ @Override
+ public byte getByte(int rowId) {
+ return (byte) longData.vector[getRowIndex(rowId)];
+ }
+
+ @Override
+ public byte[] getBytes(int rowId, int count) {
+ byte[] res = new byte[count];
+ for (int i = 0; i < count; i++) {
+ res[i] = getByte(rowId + i);
+ }
+ return res;
+ }
+
+ @Override
+ public short getShort(int rowId) {
+ return (short) longData.vector[getRowIndex(rowId)];
+ }
+
+ @Override
+ public short[] getShorts(int rowId, int count) {
+ short[] res = new short[count];
+ for (int i = 0; i < count; i++) {
+ res[i] = getShort(rowId + i);
+ }
+ return res;
+ }
+
+ @Override
+ public int getInt(int rowId) {
+ return (int) longData.vector[getRowIndex(rowId)];
+ }
+
+ @Override
+ public int[] getInts(int rowId, int count) {
+ int[] res = new int[count];
+ for (int i = 0; i < count; i++) {
+ res[i] = getInt(rowId + i);
+ }
+ return res;
+ }
+
+ @Override
+ public long getLong(int rowId) {
+ int index = getRowIndex(rowId);
+ if (isTimestamp) {
+ return timestampData.time[index] * 1000 + timestampData.nanos[index] / 1000;
+ } else {
+ return longData.vector[index];
+ }
+ }
+
+ @Override
+ public long[] getLongs(int rowId, int count) {
+ long[] res = new long[count];
+ for (int i = 0; i < count; i++) {
+ res[i] = getLong(rowId + i);
+ }
+ return res;
+ }
+
+ @Override
+ public float getFloat(int rowId) {
+ return (float) doubleData.vector[getRowIndex(rowId)];
+ }
+
+ @Override
+ public float[] getFloats(int rowId, int count) {
+ float[] res = new float[count];
+ for (int i = 0; i < count; i++) {
+ res[i] = getFloat(rowId + i);
+ }
+ return res;
+ }
+
+ @Override
+ public double getDouble(int rowId) {
+ return doubleData.vector[getRowIndex(rowId)];
+ }
+
+ @Override
+ public double[] getDoubles(int rowId, int count) {
+ double[] res = new double[count];
+ for (int i = 0; i < count; i++) {
+ res[i] = getDouble(rowId + i);
+ }
+ return res;
+ }
+
+ @Override
+ public int getArrayLength(int rowId) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int getArrayOffset(int rowId) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Decimal getDecimal(int rowId, int precision, int scale) {
+ BigDecimal data = decimalData.vector[getRowIndex(rowId)].getHiveDecimal().bigDecimalValue();
+ return Decimal.apply(data, precision, scale);
+ }
+
+ @Override
+ public UTF8String getUTF8String(int rowId) {
+ int index = getRowIndex(rowId);
+ BytesColumnVector col = bytesData;
+ return UTF8String.fromBytes(col.vector[index], col.start[index], col.length[index]);
+ }
+
+ @Override
+ public byte[] getBinary(int rowId) {
+ int index = getRowIndex(rowId);
+ byte[] binary = new byte[bytesData.length[index]];
+ System.arraycopy(bytesData.vector[index], bytesData.start[index], binary, 0, binary.length);
+ return binary;
+ }
+
+ @Override
+ public org.apache.spark.sql.vectorized.ColumnVector arrayData() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public org.apache.spark.sql.vectorized.ColumnVector getChildColumn(int ordinal) {
+ throw new UnsupportedOperationException();
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java
new file mode 100644
index 0000000000000..36fdf2bdf84d2
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java
@@ -0,0 +1,571 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc;
+
+import java.io.IOException;
+import java.util.stream.IntStream;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.mapreduce.InputSplit;
+import org.apache.hadoop.mapreduce.RecordReader;
+import org.apache.hadoop.mapreduce.TaskAttemptContext;
+import org.apache.hadoop.mapreduce.lib.input.FileSplit;
+import org.apache.orc.OrcConf;
+import org.apache.orc.OrcFile;
+import org.apache.orc.Reader;
+import org.apache.orc.TypeDescription;
+import org.apache.orc.mapred.OrcInputFormat;
+import org.apache.orc.storage.common.type.HiveDecimal;
+import org.apache.orc.storage.ql.exec.vector.*;
+import org.apache.orc.storage.serde2.io.HiveDecimalWritable;
+
+import org.apache.spark.memory.MemoryMode;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils;
+import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector;
+import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
+import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
+import org.apache.spark.sql.types.*;
+import org.apache.spark.sql.vectorized.ColumnarBatch;
+
+
+/**
+ * To support vectorization in WholeStageCodeGen, this reader returns ColumnarBatch.
+ * After creating, `initialize` and `initBatch` should be called sequentially.
+ */
+public class OrcColumnarBatchReader extends RecordReader {
+
+ /**
+ * The default size of batch. We use this value for ORC reader to make it consistent with Spark's
+ * columnar batch, because their default batch sizes are different like the following:
+ *
+ * - ORC's VectorizedRowBatch.DEFAULT_SIZE = 1024
+ * - Spark's ColumnarBatch.DEFAULT_BATCH_SIZE = 4 * 1024
+ */
+ private static final int DEFAULT_SIZE = 4 * 1024;
+
+ // ORC File Reader
+ private Reader reader;
+
+ // Vectorized ORC Row Batch
+ private VectorizedRowBatch batch;
+
+ /**
+ * The column IDs of the physical ORC file schema which are required by this reader.
+ * -1 means this required column doesn't exist in the ORC file.
+ */
+ private int[] requestedColIds;
+
+ // Record reader from ORC row batch.
+ private org.apache.orc.RecordReader recordReader;
+
+ private StructField[] requiredFields;
+
+ // The result columnar batch for vectorized execution by whole-stage codegen.
+ private ColumnarBatch columnarBatch;
+
+ // Writable column vectors of the result columnar batch.
+ private WritableColumnVector[] columnVectors;
+
+ // The wrapped ORC column vectors. It should be null if `copyToSpark` is true.
+ private org.apache.spark.sql.vectorized.ColumnVector[] orcVectorWrappers;
+
+ // The memory mode of the columnarBatch
+ private final MemoryMode MEMORY_MODE;
+
+ // Whether or not to copy the ORC columnar batch to Spark columnar batch.
+ private final boolean copyToSpark;
+
+ public OrcColumnarBatchReader(boolean useOffHeap, boolean copyToSpark) {
+ MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP;
+ this.copyToSpark = copyToSpark;
+ }
+
+
+ @Override
+ public Void getCurrentKey() throws IOException, InterruptedException {
+ return null;
+ }
+
+ @Override
+ public ColumnarBatch getCurrentValue() throws IOException, InterruptedException {
+ return columnarBatch;
+ }
+
+ @Override
+ public float getProgress() throws IOException, InterruptedException {
+ return recordReader.getProgress();
+ }
+
+ @Override
+ public boolean nextKeyValue() throws IOException, InterruptedException {
+ return nextBatch();
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (columnarBatch != null) {
+ columnarBatch.close();
+ columnarBatch = null;
+ }
+ if (recordReader != null) {
+ recordReader.close();
+ recordReader = null;
+ }
+ }
+
+ /**
+ * Initialize ORC file reader and batch record reader.
+ * Please note that `initBatch` is needed to be called after this.
+ */
+ @Override
+ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext)
+ throws IOException, InterruptedException {
+ FileSplit fileSplit = (FileSplit)inputSplit;
+ Configuration conf = taskAttemptContext.getConfiguration();
+ reader = OrcFile.createReader(
+ fileSplit.getPath(),
+ OrcFile.readerOptions(conf)
+ .maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf))
+ .filesystem(fileSplit.getPath().getFileSystem(conf)));
+
+ Reader.Options options =
+ OrcInputFormat.buildOptions(conf, reader, fileSplit.getStart(), fileSplit.getLength());
+ recordReader = reader.rows(options);
+ }
+
+ /**
+ * Initialize columnar batch by setting required schema and partition information.
+ * With this information, this creates ColumnarBatch with the full schema.
+ */
+ public void initBatch(
+ TypeDescription orcSchema,
+ int[] requestedColIds,
+ StructField[] requiredFields,
+ StructType partitionSchema,
+ InternalRow partitionValues) {
+ batch = orcSchema.createRowBatch(DEFAULT_SIZE);
+ assert(!batch.selectedInUse); // `selectedInUse` should be initialized with `false`.
+
+ this.requiredFields = requiredFields;
+ this.requestedColIds = requestedColIds;
+ assert(requiredFields.length == requestedColIds.length);
+
+ StructType resultSchema = new StructType(requiredFields);
+ for (StructField f : partitionSchema.fields()) {
+ resultSchema = resultSchema.add(f);
+ }
+
+ int capacity = DEFAULT_SIZE;
+
+ if (copyToSpark) {
+ if (MEMORY_MODE == MemoryMode.OFF_HEAP) {
+ columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema);
+ } else {
+ columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema);
+ }
+
+ // Initialize the missing columns once.
+ for (int i = 0; i < requiredFields.length; i++) {
+ if (requestedColIds[i] == -1) {
+ columnVectors[i].putNulls(0, capacity);
+ columnVectors[i].setIsConstant();
+ }
+ }
+
+ if (partitionValues.numFields() > 0) {
+ int partitionIdx = requiredFields.length;
+ for (int i = 0; i < partitionValues.numFields(); i++) {
+ ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i);
+ columnVectors[i + partitionIdx].setIsConstant();
+ }
+ }
+
+ columnarBatch = new ColumnarBatch(resultSchema, columnVectors, capacity);
+ } else {
+ // Just wrap the ORC column vector instead of copying it to Spark column vector.
+ orcVectorWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()];
+
+ for (int i = 0; i < requiredFields.length; i++) {
+ DataType dt = requiredFields[i].dataType();
+ int colId = requestedColIds[i];
+ // Initialize the missing columns once.
+ if (colId == -1) {
+ OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt);
+ missingCol.putNulls(0, capacity);
+ missingCol.setIsConstant();
+ orcVectorWrappers[i] = missingCol;
+ } else {
+ orcVectorWrappers[i] = new OrcColumnVector(dt, batch.cols[colId]);
+ }
+ }
+
+ if (partitionValues.numFields() > 0) {
+ int partitionIdx = requiredFields.length;
+ for (int i = 0; i < partitionValues.numFields(); i++) {
+ DataType dt = partitionSchema.fields()[i].dataType();
+ OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, dt);
+ ColumnVectorUtils.populate(partitionCol, partitionValues, i);
+ partitionCol.setIsConstant();
+ orcVectorWrappers[partitionIdx + i] = partitionCol;
+ }
+ }
+
+ columnarBatch = new ColumnarBatch(resultSchema, orcVectorWrappers, capacity);
+ }
+ }
+
+ /**
+ * Return true if there exists more data in the next batch. If exists, prepare the next batch
+ * by copying from ORC VectorizedRowBatch columns to Spark ColumnarBatch columns.
+ */
+ private boolean nextBatch() throws IOException {
+ recordReader.nextBatch(batch);
+ int batchSize = batch.size;
+ if (batchSize == 0) {
+ return false;
+ }
+ columnarBatch.setNumRows(batchSize);
+
+ if (!copyToSpark) {
+ for (int i = 0; i < requiredFields.length; i++) {
+ if (requestedColIds[i] != -1) {
+ ((OrcColumnVector) orcVectorWrappers[i]).setBatchSize(batchSize);
+ }
+ }
+ return true;
+ }
+
+ for (WritableColumnVector vector : columnVectors) {
+ vector.reset();
+ }
+
+ for (int i = 0; i < requiredFields.length; i++) {
+ StructField field = requiredFields[i];
+ WritableColumnVector toColumn = columnVectors[i];
+
+ if (requestedColIds[i] >= 0) {
+ ColumnVector fromColumn = batch.cols[requestedColIds[i]];
+
+ if (fromColumn.isRepeating) {
+ putRepeatingValues(batchSize, field, fromColumn, toColumn);
+ } else if (fromColumn.noNulls) {
+ putNonNullValues(batchSize, field, fromColumn, toColumn);
+ } else {
+ putValues(batchSize, field, fromColumn, toColumn);
+ }
+ }
+ }
+ return true;
+ }
+
+ private void putRepeatingValues(
+ int batchSize,
+ StructField field,
+ ColumnVector fromColumn,
+ WritableColumnVector toColumn) {
+ if (fromColumn.isNull[0]) {
+ toColumn.putNulls(0, batchSize);
+ } else {
+ DataType type = field.dataType();
+ if (type instanceof BooleanType) {
+ toColumn.putBooleans(0, batchSize, ((LongColumnVector)fromColumn).vector[0] == 1);
+ } else if (type instanceof ByteType) {
+ toColumn.putBytes(0, batchSize, (byte)((LongColumnVector)fromColumn).vector[0]);
+ } else if (type instanceof ShortType) {
+ toColumn.putShorts(0, batchSize, (short)((LongColumnVector)fromColumn).vector[0]);
+ } else if (type instanceof IntegerType || type instanceof DateType) {
+ toColumn.putInts(0, batchSize, (int)((LongColumnVector)fromColumn).vector[0]);
+ } else if (type instanceof LongType) {
+ toColumn.putLongs(0, batchSize, ((LongColumnVector)fromColumn).vector[0]);
+ } else if (type instanceof TimestampType) {
+ toColumn.putLongs(0, batchSize,
+ fromTimestampColumnVector((TimestampColumnVector)fromColumn, 0));
+ } else if (type instanceof FloatType) {
+ toColumn.putFloats(0, batchSize, (float)((DoubleColumnVector)fromColumn).vector[0]);
+ } else if (type instanceof DoubleType) {
+ toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector[0]);
+ } else if (type instanceof StringType || type instanceof BinaryType) {
+ BytesColumnVector data = (BytesColumnVector)fromColumn;
+ WritableColumnVector arrayData = toColumn.getChildColumn(0);
+ int size = data.vector[0].length;
+ arrayData.reserve(size);
+ arrayData.putBytes(0, size, data.vector[0], 0);
+ for (int index = 0; index < batchSize; index++) {
+ toColumn.putArray(index, 0, size);
+ }
+ } else if (type instanceof DecimalType) {
+ DecimalType decimalType = (DecimalType)type;
+ putDecimalWritables(
+ toColumn,
+ batchSize,
+ decimalType.precision(),
+ decimalType.scale(),
+ ((DecimalColumnVector)fromColumn).vector[0]);
+ } else {
+ throw new UnsupportedOperationException("Unsupported Data Type: " + type);
+ }
+ }
+ }
+
+ private void putNonNullValues(
+ int batchSize,
+ StructField field,
+ ColumnVector fromColumn,
+ WritableColumnVector toColumn) {
+ DataType type = field.dataType();
+ if (type instanceof BooleanType) {
+ long[] data = ((LongColumnVector)fromColumn).vector;
+ for (int index = 0; index < batchSize; index++) {
+ toColumn.putBoolean(index, data[index] == 1);
+ }
+ } else if (type instanceof ByteType) {
+ long[] data = ((LongColumnVector)fromColumn).vector;
+ for (int index = 0; index < batchSize; index++) {
+ toColumn.putByte(index, (byte)data[index]);
+ }
+ } else if (type instanceof ShortType) {
+ long[] data = ((LongColumnVector)fromColumn).vector;
+ for (int index = 0; index < batchSize; index++) {
+ toColumn.putShort(index, (short)data[index]);
+ }
+ } else if (type instanceof IntegerType || type instanceof DateType) {
+ long[] data = ((LongColumnVector)fromColumn).vector;
+ for (int index = 0; index < batchSize; index++) {
+ toColumn.putInt(index, (int)data[index]);
+ }
+ } else if (type instanceof LongType) {
+ toColumn.putLongs(0, batchSize, ((LongColumnVector)fromColumn).vector, 0);
+ } else if (type instanceof TimestampType) {
+ TimestampColumnVector data = ((TimestampColumnVector)fromColumn);
+ for (int index = 0; index < batchSize; index++) {
+ toColumn.putLong(index, fromTimestampColumnVector(data, index));
+ }
+ } else if (type instanceof FloatType) {
+ double[] data = ((DoubleColumnVector)fromColumn).vector;
+ for (int index = 0; index < batchSize; index++) {
+ toColumn.putFloat(index, (float)data[index]);
+ }
+ } else if (type instanceof DoubleType) {
+ toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector, 0);
+ } else if (type instanceof StringType || type instanceof BinaryType) {
+ BytesColumnVector data = ((BytesColumnVector)fromColumn);
+ WritableColumnVector arrayData = toColumn.getChildColumn(0);
+ int totalNumBytes = IntStream.of(data.length).sum();
+ arrayData.reserve(totalNumBytes);
+ for (int index = 0, pos = 0; index < batchSize; pos += data.length[index], index++) {
+ arrayData.putBytes(pos, data.length[index], data.vector[index], data.start[index]);
+ toColumn.putArray(index, pos, data.length[index]);
+ }
+ } else if (type instanceof DecimalType) {
+ DecimalType decimalType = (DecimalType)type;
+ DecimalColumnVector data = ((DecimalColumnVector)fromColumn);
+ if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) {
+ WritableColumnVector arrayData = toColumn.getChildColumn(0);
+ arrayData.reserve(batchSize * 16);
+ }
+ for (int index = 0; index < batchSize; index++) {
+ putDecimalWritable(
+ toColumn,
+ index,
+ decimalType.precision(),
+ decimalType.scale(),
+ data.vector[index]);
+ }
+ } else {
+ throw new UnsupportedOperationException("Unsupported Data Type: " + type);
+ }
+ }
+
+ private void putValues(
+ int batchSize,
+ StructField field,
+ ColumnVector fromColumn,
+ WritableColumnVector toColumn) {
+ DataType type = field.dataType();
+ if (type instanceof BooleanType) {
+ long[] vector = ((LongColumnVector)fromColumn).vector;
+ for (int index = 0; index < batchSize; index++) {
+ if (fromColumn.isNull[index]) {
+ toColumn.putNull(index);
+ } else {
+ toColumn.putBoolean(index, vector[index] == 1);
+ }
+ }
+ } else if (type instanceof ByteType) {
+ long[] vector = ((LongColumnVector)fromColumn).vector;
+ for (int index = 0; index < batchSize; index++) {
+ if (fromColumn.isNull[index]) {
+ toColumn.putNull(index);
+ } else {
+ toColumn.putByte(index, (byte)vector[index]);
+ }
+ }
+ } else if (type instanceof ShortType) {
+ long[] vector = ((LongColumnVector)fromColumn).vector;
+ for (int index = 0; index < batchSize; index++) {
+ if (fromColumn.isNull[index]) {
+ toColumn.putNull(index);
+ } else {
+ toColumn.putShort(index, (short)vector[index]);
+ }
+ }
+ } else if (type instanceof IntegerType || type instanceof DateType) {
+ long[] vector = ((LongColumnVector)fromColumn).vector;
+ for (int index = 0; index < batchSize; index++) {
+ if (fromColumn.isNull[index]) {
+ toColumn.putNull(index);
+ } else {
+ toColumn.putInt(index, (int)vector[index]);
+ }
+ }
+ } else if (type instanceof LongType) {
+ long[] vector = ((LongColumnVector)fromColumn).vector;
+ for (int index = 0; index < batchSize; index++) {
+ if (fromColumn.isNull[index]) {
+ toColumn.putNull(index);
+ } else {
+ toColumn.putLong(index, vector[index]);
+ }
+ }
+ } else if (type instanceof TimestampType) {
+ TimestampColumnVector vector = ((TimestampColumnVector)fromColumn);
+ for (int index = 0; index < batchSize; index++) {
+ if (fromColumn.isNull[index]) {
+ toColumn.putNull(index);
+ } else {
+ toColumn.putLong(index, fromTimestampColumnVector(vector, index));
+ }
+ }
+ } else if (type instanceof FloatType) {
+ double[] vector = ((DoubleColumnVector)fromColumn).vector;
+ for (int index = 0; index < batchSize; index++) {
+ if (fromColumn.isNull[index]) {
+ toColumn.putNull(index);
+ } else {
+ toColumn.putFloat(index, (float)vector[index]);
+ }
+ }
+ } else if (type instanceof DoubleType) {
+ double[] vector = ((DoubleColumnVector)fromColumn).vector;
+ for (int index = 0; index < batchSize; index++) {
+ if (fromColumn.isNull[index]) {
+ toColumn.putNull(index);
+ } else {
+ toColumn.putDouble(index, vector[index]);
+ }
+ }
+ } else if (type instanceof StringType || type instanceof BinaryType) {
+ BytesColumnVector vector = (BytesColumnVector)fromColumn;
+ WritableColumnVector arrayData = toColumn.getChildColumn(0);
+ int totalNumBytes = IntStream.of(vector.length).sum();
+ arrayData.reserve(totalNumBytes);
+ for (int index = 0, pos = 0; index < batchSize; pos += vector.length[index], index++) {
+ if (fromColumn.isNull[index]) {
+ toColumn.putNull(index);
+ } else {
+ arrayData.putBytes(pos, vector.length[index], vector.vector[index], vector.start[index]);
+ toColumn.putArray(index, pos, vector.length[index]);
+ }
+ }
+ } else if (type instanceof DecimalType) {
+ DecimalType decimalType = (DecimalType)type;
+ HiveDecimalWritable[] vector = ((DecimalColumnVector)fromColumn).vector;
+ if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) {
+ WritableColumnVector arrayData = toColumn.getChildColumn(0);
+ arrayData.reserve(batchSize * 16);
+ }
+ for (int index = 0; index < batchSize; index++) {
+ if (fromColumn.isNull[index]) {
+ toColumn.putNull(index);
+ } else {
+ putDecimalWritable(
+ toColumn,
+ index,
+ decimalType.precision(),
+ decimalType.scale(),
+ vector[index]);
+ }
+ }
+ } else {
+ throw new UnsupportedOperationException("Unsupported Data Type: " + type);
+ }
+ }
+
+ /**
+ * Returns the number of micros since epoch from an element of TimestampColumnVector.
+ */
+ private static long fromTimestampColumnVector(TimestampColumnVector vector, int index) {
+ return vector.time[index] * 1000L + vector.nanos[index] / 1000L;
+ }
+
+ /**
+ * Put a `HiveDecimalWritable` to a `WritableColumnVector`.
+ */
+ private static void putDecimalWritable(
+ WritableColumnVector toColumn,
+ int index,
+ int precision,
+ int scale,
+ HiveDecimalWritable decimalWritable) {
+ HiveDecimal decimal = decimalWritable.getHiveDecimal();
+ Decimal value =
+ Decimal.apply(decimal.bigDecimalValue(), decimal.precision(), decimal.scale());
+ value.changePrecision(precision, scale);
+
+ if (precision <= Decimal.MAX_INT_DIGITS()) {
+ toColumn.putInt(index, (int) value.toUnscaledLong());
+ } else if (precision <= Decimal.MAX_LONG_DIGITS()) {
+ toColumn.putLong(index, value.toUnscaledLong());
+ } else {
+ byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray();
+ WritableColumnVector arrayData = toColumn.getChildColumn(0);
+ arrayData.putBytes(index * 16, bytes.length, bytes, 0);
+ toColumn.putArray(index, index * 16, bytes.length);
+ }
+ }
+
+ /**
+ * Put `HiveDecimalWritable`s to a `WritableColumnVector`.
+ */
+ private static void putDecimalWritables(
+ WritableColumnVector toColumn,
+ int size,
+ int precision,
+ int scale,
+ HiveDecimalWritable decimalWritable) {
+ HiveDecimal decimal = decimalWritable.getHiveDecimal();
+ Decimal value =
+ Decimal.apply(decimal.bigDecimalValue(), decimal.precision(), decimal.scale());
+ value.changePrecision(precision, scale);
+
+ if (precision <= Decimal.MAX_INT_DIGITS()) {
+ toColumn.putInts(0, size, (int) value.toUnscaledLong());
+ } else if (precision <= Decimal.MAX_LONG_DIGITS()) {
+ toColumn.putLongs(0, size, value.toUnscaledLong());
+ } else {
+ byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray();
+ WritableColumnVector arrayData = toColumn.getChildColumn(0);
+ arrayData.reserve(bytes.length);
+ arrayData.putBytes(0, bytes.length, bytes, 0);
+ for (int index = 0; index < size; index++) {
+ toColumn.putArray(index, 0, bytes.length);
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
index 0f9903360e0d8..e511d3da7e160 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
@@ -22,6 +22,7 @@
import static org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.createRLEIterator;
import java.io.IOException;
+import java.util.TimeZone;
import org.apache.parquet.bytes.BytesUtils;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.column.Dictionary;
@@ -36,7 +37,6 @@
import org.apache.parquet.schema.OriginalType;
import org.apache.parquet.schema.PrimitiveType;
import org.apache.spark.sql.catalyst.util.DateTimeUtils;
-import org.apache.spark.sql.execution.vectorized.ColumnVector;
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.DecimalType;
@@ -95,13 +95,18 @@ public class VectorizedColumnReader {
private final PageReader pageReader;
private final ColumnDescriptor descriptor;
private final OriginalType originalType;
+ // The timezone conversion to apply to int96 timestamps. Null if no conversion.
+ private final TimeZone convertTz;
+ private static final TimeZone UTC = DateTimeUtils.TimeZoneUTC();
public VectorizedColumnReader(
ColumnDescriptor descriptor,
OriginalType originalType,
- PageReader pageReader) throws IOException {
+ PageReader pageReader,
+ TimeZone convertTz) throws IOException {
this.descriptor = descriptor;
this.pageReader = pageReader;
+ this.convertTz = convertTz;
this.originalType = originalType;
this.maxDefLevel = descriptor.getMaxDefinitionLevel();
@@ -224,6 +229,10 @@ void readBatch(int total, WritableColumnVector column) throws IOException {
}
}
+ private boolean shouldConvertTimestamps() {
+ return convertTz != null && !convertTz.equals(UTC);
+ }
+
/**
* Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`.
*/
@@ -231,7 +240,7 @@ private void decodeDictionaryIds(
int rowId,
int num,
WritableColumnVector column,
- ColumnVector dictionaryIds) {
+ WritableColumnVector dictionaryIds) {
switch (descriptor.getType()) {
case INT32:
if (column.dataType() == DataTypes.IntegerType ||
@@ -296,11 +305,21 @@ private void decodeDictionaryIds(
break;
case INT96:
if (column.dataType() == DataTypes.TimestampType) {
- for (int i = rowId; i < rowId + num; ++i) {
- // TODO: Convert dictionary of Binaries to dictionary of Longs
- if (!column.isNullAt(i)) {
- Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i));
- column.putLong(i, ParquetRowConverter.binaryToSQLTimestamp(v));
+ if (!shouldConvertTimestamps()) {
+ for (int i = rowId; i < rowId + num; ++i) {
+ if (!column.isNullAt(i)) {
+ Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i));
+ column.putLong(i, ParquetRowConverter.binaryToSQLTimestamp(v));
+ }
+ }
+ } else {
+ for (int i = rowId; i < rowId + num; ++i) {
+ if (!column.isNullAt(i)) {
+ Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i));
+ long rawTime = ParquetRowConverter.binaryToSQLTimestamp(v);
+ long adjTime = DateTimeUtils.convertTz(rawTime, convertTz, UTC);
+ column.putLong(i, adjTime);
+ }
}
}
} else {
@@ -427,16 +446,29 @@ private void readBinaryBatch(int rowId, int num, WritableColumnVector column) {
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
- if (column.isArray()) {
+ if (column.dataType() == DataTypes.StringType || column.dataType() == DataTypes.BinaryType) {
defColumn.readBinarys(num, column, rowId, maxDefLevel, data);
} else if (column.dataType() == DataTypes.TimestampType) {
- for (int i = 0; i < num; i++) {
- if (defColumn.readInteger() == maxDefLevel) {
- column.putLong(rowId + i,
- // Read 12 bytes for INT96
- ParquetRowConverter.binaryToSQLTimestamp(data.readBinary(12)));
- } else {
- column.putNull(rowId + i);
+ if (!shouldConvertTimestamps()) {
+ for (int i = 0; i < num; i++) {
+ if (defColumn.readInteger() == maxDefLevel) {
+ // Read 12 bytes for INT96
+ long rawTime = ParquetRowConverter.binaryToSQLTimestamp(data.readBinary(12));
+ column.putLong(rowId + i, rawTime);
+ } else {
+ column.putNull(rowId + i);
+ }
+ }
+ } else {
+ for (int i = 0; i < num; i++) {
+ if (defColumn.readInteger() == maxDefLevel) {
+ // Read 12 bytes for INT96
+ long rawTime = ParquetRowConverter.binaryToSQLTimestamp(data.readBinary(12));
+ long adjTime = DateTimeUtils.convertTz(rawTime, convertTz, UTC);
+ column.putLong(rowId + i, adjTime);
+ } else {
+ column.putNull(rowId + i);
+ }
}
}
} else {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
index 669d71e60779d..cd745b1f0e4e3 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
@@ -20,6 +20,7 @@
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
+import java.util.TimeZone;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
@@ -30,10 +31,10 @@
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils;
-import org.apache.spark.sql.execution.vectorized.ColumnarBatch;
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector;
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
+import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
@@ -77,6 +78,12 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa
*/
private boolean[] missingColumns;
+ /**
+ * The timezone that timestamp INT96 values should be converted to. Null if no conversion. Here to
+ * workaround incompatibilities between different engines when writing timestamp values.
+ */
+ private TimeZone convertTz = null;
+
/**
* columnBatch object that is used for batch decoding. This is created on first use and triggers
* batched decoding. It is not valid to interleave calls to the batched interface with the row
@@ -105,10 +112,15 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa
*/
private final MemoryMode MEMORY_MODE;
- public VectorizedParquetRecordReader(boolean useOffHeap) {
+ public VectorizedParquetRecordReader(TimeZone convertTz, boolean useOffHeap) {
+ this.convertTz = convertTz;
MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP;
}
+ public VectorizedParquetRecordReader(boolean useOffHeap) {
+ this(null, useOffHeap);
+ }
+
/**
* Implementation of RecordReader API.
*/
@@ -236,7 +248,10 @@ public void enableReturningBatches() {
* Advances to the next batch of rows. Returns false if there are no more.
*/
public boolean nextBatch() throws IOException {
- columnarBatch.reset();
+ for (WritableColumnVector vector : columnVectors) {
+ vector.reset();
+ }
+ columnarBatch.setNumRows(0);
if (rowsReturned >= totalRowCount) return false;
checkEndOfRowGroup();
@@ -291,8 +306,8 @@ private void checkEndOfRowGroup() throws IOException {
columnReaders = new VectorizedColumnReader[columns.size()];
for (int i = 0; i < columns.size(); ++i) {
if (missingColumns[i]) continue;
- columnReaders[i] = new VectorizedColumnReader(
- columns.get(i), types.get(i).getOriginalType(), pages.getPageReader(columns.get(i)));
+ columnReaders[i] = new VectorizedColumnReader(columns.get(i), types.get(i).getOriginalType(),
+ pages.getPageReader(columns.get(i)), convertTz);
}
totalCountLoadedSoFar += pages.getRowCount();
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java
index 9467435435d1f..24260b05194a7 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java
@@ -41,7 +41,7 @@
public class AggregateHashMap {
private OnHeapColumnVector[] columnVectors;
- private ColumnarBatch batch;
+ private MutableColumnarRow aggBufferRow;
private int[] buckets;
private int numBuckets;
private int numRows = 0;
@@ -63,7 +63,7 @@ public AggregateHashMap(StructType schema, int capacity, double loadFactor, int
this.maxSteps = maxSteps;
numBuckets = (int) (capacity / loadFactor);
columnVectors = OnHeapColumnVector.allocateColumns(capacity, schema);
- batch = new ColumnarBatch(schema, columnVectors, capacity);
+ aggBufferRow = new MutableColumnarRow(columnVectors);
buckets = new int[numBuckets];
Arrays.fill(buckets, -1);
}
@@ -72,14 +72,15 @@ public AggregateHashMap(StructType schema) {
this(schema, DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_MAX_STEPS);
}
- public ColumnarRow findOrInsert(long key) {
+ public MutableColumnarRow findOrInsert(long key) {
int idx = find(key);
if (idx != -1 && buckets[idx] == -1) {
columnVectors[0].putLong(numRows, key);
columnVectors[1].putLong(numRows, 0);
buckets[idx] = numRows++;
}
- return batch.getRow(buckets[idx]);
+ aggBufferRow.rowId = buckets[idx];
+ return aggBufferRow;
}
@VisibleForTesting
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
index b4b5f0a265934..b5cbe8e2839ba 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
@@ -28,6 +28,8 @@
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.DateTimeUtils;
import org.apache.spark.sql.types.*;
+import org.apache.spark.sql.vectorized.ColumnarArray;
+import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
@@ -98,21 +100,13 @@ public static void populate(WritableColumnVector col, InternalRow row, int field
* For example, an array of IntegerType will return an int[].
* Throws exceptions for unhandled schemas.
*/
- public static Object toPrimitiveJavaArray(ColumnarArray array) {
- DataType dt = array.data.dataType();
- if (dt instanceof IntegerType) {
- int[] result = new int[array.length];
- ColumnVector data = array.data;
- for (int i = 0; i < result.length; i++) {
- if (data.isNullAt(array.offset + i)) {
- throw new RuntimeException("Cannot handle NULL values.");
- }
- result[i] = data.getInt(array.offset + i);
+ public static int[] toJavaIntArray(ColumnarArray array) {
+ for (int i = 0; i < array.numElements(); i++) {
+ if (array.isNullAt(i)) {
+ throw new RuntimeException("Cannot handle NULL values.");
}
- return result;
- } else {
- throw new UnsupportedOperationException();
}
+ return array.toIntArray();
}
private static void appendValue(WritableColumnVector dst, DataType t, Object o) {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
similarity index 79%
rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java
rename to sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
index 98a907322713b..70057a9def6c0 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
@@ -14,6 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package org.apache.spark.sql.execution.vectorized;
import java.math.BigDecimal;
@@ -22,37 +23,38 @@
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.sql.types.*;
+import org.apache.spark.sql.vectorized.ColumnarArray;
+import org.apache.spark.sql.vectorized.ColumnarBatch;
+import org.apache.spark.sql.vectorized.ColumnarRow;
+import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
/**
- * Row abstraction in {@link ColumnVector}. The instance of this class is intended
- * to be reused, callers should copy the data out if it needs to be stored.
+ * A mutable version of {@link ColumnarRow}, which is used in the vectorized hash map for hash
+ * aggregate, and {@link ColumnarBatch} to save object creation.
+ *
+ * Note that this class intentionally has a lot of duplicated code with {@link ColumnarRow}, to
+ * avoid java polymorphism overhead by keeping {@link ColumnarRow} and this class final classes.
*/
-public final class ColumnarRow extends InternalRow {
- protected int rowId;
+public final class MutableColumnarRow extends InternalRow {
+ public int rowId;
private final ColumnVector[] columns;
private final WritableColumnVector[] writableColumns;
- // Ctor used if this is a struct.
- ColumnarRow(ColumnVector[] columns) {
+ public MutableColumnarRow(ColumnVector[] columns) {
this.columns = columns;
- this.writableColumns = new WritableColumnVector[this.columns.length];
- for (int i = 0; i < this.columns.length; i++) {
- if (this.columns[i] instanceof WritableColumnVector) {
- this.writableColumns[i] = (WritableColumnVector) this.columns[i];
- }
- }
+ this.writableColumns = null;
}
- public ColumnVector[] columns() { return columns; }
+ public MutableColumnarRow(WritableColumnVector[] writableColumns) {
+ this.columns = writableColumns;
+ this.writableColumns = writableColumns;
+ }
@Override
public int numFields() { return columns.length; }
- /**
- * Revisit this. This is expensive. This is currently only used in test paths.
- */
@Override
public InternalRow copy() {
GenericInternalRow row = new GenericInternalRow(columns.length);
@@ -224,8 +226,8 @@ public void update(int ordinal, Object value) {
setDouble(ordinal, (double) value);
} else if (dt instanceof DecimalType) {
DecimalType t = (DecimalType) dt;
- setDecimal(ordinal, Decimal.apply((BigDecimal) value, t.precision(), t.scale()),
- t.precision());
+ Decimal d = Decimal.apply((BigDecimal) value, t.precision(), t.scale());
+ setDecimal(ordinal, d, t.precision());
} else {
throw new UnsupportedOperationException("Datatype not supported " + dt);
}
@@ -234,68 +236,54 @@ public void update(int ordinal, Object value) {
@Override
public void setNullAt(int ordinal) {
- getWritableColumn(ordinal).putNull(rowId);
+ writableColumns[ordinal].putNull(rowId);
}
@Override
public void setBoolean(int ordinal, boolean value) {
- WritableColumnVector column = getWritableColumn(ordinal);
- column.putNotNull(rowId);
- column.putBoolean(rowId, value);
+ writableColumns[ordinal].putNotNull(rowId);
+ writableColumns[ordinal].putBoolean(rowId, value);
}
@Override
public void setByte(int ordinal, byte value) {
- WritableColumnVector column = getWritableColumn(ordinal);
- column.putNotNull(rowId);
- column.putByte(rowId, value);
+ writableColumns[ordinal].putNotNull(rowId);
+ writableColumns[ordinal].putByte(rowId, value);
}
@Override
public void setShort(int ordinal, short value) {
- WritableColumnVector column = getWritableColumn(ordinal);
- column.putNotNull(rowId);
- column.putShort(rowId, value);
+ writableColumns[ordinal].putNotNull(rowId);
+ writableColumns[ordinal].putShort(rowId, value);
}
@Override
public void setInt(int ordinal, int value) {
- WritableColumnVector column = getWritableColumn(ordinal);
- column.putNotNull(rowId);
- column.putInt(rowId, value);
+ writableColumns[ordinal].putNotNull(rowId);
+ writableColumns[ordinal].putInt(rowId, value);
}
@Override
public void setLong(int ordinal, long value) {
- WritableColumnVector column = getWritableColumn(ordinal);
- column.putNotNull(rowId);
- column.putLong(rowId, value);
+ writableColumns[ordinal].putNotNull(rowId);
+ writableColumns[ordinal].putLong(rowId, value);
}
@Override
public void setFloat(int ordinal, float value) {
- WritableColumnVector column = getWritableColumn(ordinal);
- column.putNotNull(rowId);
- column.putFloat(rowId, value);
+ writableColumns[ordinal].putNotNull(rowId);
+ writableColumns[ordinal].putFloat(rowId, value);
}
@Override
public void setDouble(int ordinal, double value) {
- WritableColumnVector column = getWritableColumn(ordinal);
- column.putNotNull(rowId);
- column.putDouble(rowId, value);
+ writableColumns[ordinal].putNotNull(rowId);
+ writableColumns[ordinal].putDouble(rowId, value);
}
@Override
public void setDecimal(int ordinal, Decimal value, int precision) {
- WritableColumnVector column = getWritableColumn(ordinal);
- column.putNotNull(rowId);
- column.putDecimal(rowId, value, precision);
- }
-
- private WritableColumnVector getWritableColumn(int ordinal) {
- WritableColumnVector column = writableColumns[ordinal];
- assert (!column.isConstant);
- return column;
+ writableColumns[ordinal].putNotNull(rowId);
+ writableColumns[ordinal].putDecimal(rowId, value, precision);
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
index 1cbaf08569334..1c45b846790b6 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
@@ -110,7 +110,6 @@ public void putNotNull(int rowId) {
public void putNull(int rowId) {
Platform.putByte(null, nulls + rowId, (byte) 1);
++numNulls;
- anyNullsSet = true;
}
@Override
@@ -119,13 +118,12 @@ public void putNulls(int rowId, int count) {
for (int i = 0; i < count; ++i, ++offset) {
Platform.putByte(null, offset, (byte) 1);
}
- anyNullsSet = true;
numNulls += count;
}
@Override
public void putNotNulls(int rowId, int count) {
- if (!anyNullsSet) return;
+ if (numNulls == 0) return;
long offset = nulls + rowId;
for (int i = 0; i < count; ++i, ++offset) {
Platform.putByte(null, offset, (byte) 0);
@@ -532,7 +530,7 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) {
@Override
protected void reserveInternal(int newCapacity) {
int oldCapacity = (nulls == 0L) ? 0 : capacity;
- if (this.resultArray != null) {
+ if (isArray()) {
this.lengthData =
Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4);
this.offsetData =
@@ -547,7 +545,7 @@ protected void reserveInternal(int newCapacity) {
} else if (type instanceof LongType || type instanceof DoubleType ||
DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) {
this.data = Platform.reallocateMemory(data, oldCapacity * 8, newCapacity * 8);
- } else if (resultStruct != null) {
+ } else if (childColumns != null) {
// Nothing to store.
} else {
throw new RuntimeException("Unhandled " + type);
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
index 85d72295ab9b8..1d538fe4181b7 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
@@ -107,7 +107,6 @@ public void putNotNull(int rowId) {
public void putNull(int rowId) {
nulls[rowId] = (byte)1;
++numNulls;
- anyNullsSet = true;
}
@Override
@@ -115,13 +114,12 @@ public void putNulls(int rowId, int count) {
for (int i = 0; i < count; ++i) {
nulls[rowId + i] = (byte)1;
}
- anyNullsSet = true;
numNulls += count;
}
@Override
public void putNotNulls(int rowId, int count) {
- if (!anyNullsSet) return;
+ if (numNulls == 0) return;
for (int i = 0; i < count; ++i) {
nulls[rowId + i] = (byte)0;
}
@@ -505,7 +503,7 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) {
// Spilt this function out since it is the slow path.
@Override
protected void reserveInternal(int newCapacity) {
- if (this.resultArray != null || DecimalType.isByteArrayDecimalType(type)) {
+ if (isArray()) {
int[] newLengths = new int[newCapacity];
int[] newOffsets = new int[newCapacity];
if (this.arrayLengths != null) {
@@ -558,7 +556,7 @@ protected void reserveInternal(int newCapacity) {
if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, capacity);
doubleData = newData;
}
- } else if (resultStruct != null) {
+ } else if (childColumns != null) {
// Nothing to store.
} else {
throw new RuntimeException("Unhandled " + type);
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
index e7653f0c00b9a..d2ae32b06f83b 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
@@ -23,6 +23,7 @@
import org.apache.spark.sql.internal.SQLConf;
import org.apache.spark.sql.types.*;
+import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.types.UTF8String;
@@ -36,8 +37,10 @@
* elements. This means that the put() APIs do not check as in common cases (i.e. flat schemas),
* the lengths are known up front.
*
- * A ColumnVector should be considered immutable once originally created. In other words, it is not
- * valid to call put APIs after reads until reset() is called.
+ * A WritableColumnVector should be considered immutable once originally created. In other words,
+ * it is not valid to call put APIs after reads until reset() is called.
+ *
+ * WritableColumnVector are intended to be reused.
*/
public abstract class WritableColumnVector extends ColumnVector {
@@ -52,12 +55,11 @@ public void reset() {
((WritableColumnVector) c).reset();
}
}
- numNulls = 0;
elementsAppended = 0;
- if (anyNullsSet) {
+ if (numNulls > 0) {
putNotNulls(0, capacity);
- anyNullsSet = false;
}
+ numNulls = 0;
}
@Override
@@ -74,8 +76,6 @@ public void close() {
dictionaryIds = null;
}
dictionary = null;
- resultStruct = null;
- resultArray = null;
}
public void reserve(int requiredCapacity) {
@@ -104,8 +104,57 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) {
@Override
public int numNulls() { return numNulls; }
- @Override
- public boolean anyNullsSet() { return anyNullsSet; }
+ /**
+ * Returns the dictionary Id for rowId.
+ *
+ * This should only be called when this `WritableColumnVector` represents dictionaryIds.
+ * We have this separate method for dictionaryIds as per SPARK-16928.
+ */
+ public abstract int getDictId(int rowId);
+
+ /**
+ * The Dictionary for this column.
+ *
+ * If it's not null, will be used to decode the value in getXXX().
+ */
+ protected Dictionary dictionary;
+
+ /**
+ * Reusable column for ids of dictionary.
+ */
+ protected WritableColumnVector dictionaryIds;
+
+ /**
+ * Returns true if this column has a dictionary.
+ */
+ public boolean hasDictionary() { return this.dictionary != null; }
+
+ /**
+ * Returns the underlying integer column for ids of dictionary.
+ */
+ public WritableColumnVector getDictionaryIds() {
+ return dictionaryIds;
+ }
+
+ /**
+ * Update the dictionary.
+ */
+ public void setDictionary(Dictionary dictionary) {
+ this.dictionary = dictionary;
+ }
+
+ /**
+ * Reserve a integer column for ids of dictionary.
+ */
+ public WritableColumnVector reserveDictionaryIds(int capacity) {
+ if (dictionaryIds == null) {
+ dictionaryIds = reserveNewColumn(capacity, DataTypes.IntegerType);
+ } else {
+ dictionaryIds.reset();
+ dictionaryIds.reserve(capacity);
+ }
+ return dictionaryIds;
+ }
/**
* Ensures that there is enough storage to store capacity elements. That is, the put() APIs
@@ -537,11 +586,11 @@ public final int appendArray(int length) {
public final int appendStruct(boolean isNull) {
if (isNull) {
appendNull();
- for (ColumnVector c: childColumns) {
+ for (WritableColumnVector c: childColumns) {
if (c.type instanceof StructType) {
- ((WritableColumnVector) c).appendStruct(true);
+ c.appendStruct(true);
} else {
- ((WritableColumnVector) c).appendNull();
+ c.appendNull();
}
}
} else {
@@ -588,12 +637,6 @@ public final int appendStruct(boolean isNull) {
*/
protected int numNulls;
- /**
- * True if there is at least one NULL byte set. This is an optimization for the writer, to skip
- * having to clear NULL bits.
- */
- protected boolean anyNullsSet;
-
/**
* True if this column's values are fixed. This means the column values never change, even
* across resets.
@@ -615,41 +658,16 @@ public final int appendStruct(boolean isNull) {
*/
protected WritableColumnVector[] childColumns;
- /**
- * Update the dictionary.
- */
- public void setDictionary(Dictionary dictionary) {
- this.dictionary = dictionary;
- }
-
- /**
- * Reserve a integer column for ids of dictionary.
- */
- public WritableColumnVector reserveDictionaryIds(int capacity) {
- WritableColumnVector dictionaryIds = (WritableColumnVector) this.dictionaryIds;
- if (dictionaryIds == null) {
- dictionaryIds = reserveNewColumn(capacity, DataTypes.IntegerType);
- this.dictionaryIds = dictionaryIds;
- } else {
- dictionaryIds.reset();
- dictionaryIds.reserve(capacity);
- }
- return dictionaryIds;
- }
-
- /**
- * Returns the underlying integer column for ids of dictionary.
- */
- @Override
- public WritableColumnVector getDictionaryIds() {
- return (WritableColumnVector) dictionaryIds;
- }
-
/**
* Reserve a new column.
*/
protected abstract WritableColumnVector reserveNewColumn(int capacity, DataType type);
+ protected boolean isArray() {
+ return type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType ||
+ DecimalType.isByteArrayDecimalType(type);
+ }
+
/**
* Sets up the common state and also handles creating the child columns if this is a nested
* type.
@@ -658,8 +676,7 @@ protected WritableColumnVector(int capacity, DataType type) {
super(type);
this.capacity = capacity;
- if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType
- || DecimalType.isByteArrayDecimalType(type)) {
+ if (isArray()) {
DataType childType;
int childCapacity = capacity;
if (type instanceof ArrayType) {
@@ -670,27 +687,19 @@ protected WritableColumnVector(int capacity, DataType type) {
}
this.childColumns = new WritableColumnVector[1];
this.childColumns[0] = reserveNewColumn(childCapacity, childType);
- this.resultArray = new ColumnarArray(this.childColumns[0]);
- this.resultStruct = null;
} else if (type instanceof StructType) {
StructType st = (StructType)type;
this.childColumns = new WritableColumnVector[st.fields().length];
for (int i = 0; i < childColumns.length; ++i) {
this.childColumns[i] = reserveNewColumn(capacity, st.fields()[i].dataType());
}
- this.resultArray = null;
- this.resultStruct = new ColumnarRow(this.childColumns);
} else if (type instanceof CalendarIntervalType) {
// Two columns. Months as int. Microseconds as Long.
this.childColumns = new WritableColumnVector[2];
this.childColumns[0] = reserveNewColumn(capacity, DataTypes.IntegerType);
this.childColumns[1] = reserveNewColumn(capacity, DataTypes.LongType);
- this.resultArray = null;
- this.resultStruct = new ColumnarRow(this.childColumns);
} else {
this.childColumns = null;
- this.resultArray = null;
- this.resultStruct = null;
}
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java
index 9a89c8193dd6e..ddc2acca693ac 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java
@@ -36,6 +36,10 @@ private String toLowerCase(String key) {
return key.toLowerCase(Locale.ROOT);
}
+ public static DataSourceV2Options empty() {
+ return new DataSourceV2Options(new HashMap<>());
+ }
+
public DataSourceV2Options(Map originalMap) {
keyLowerCasedMap = new HashMap<>(originalMap.size());
for (Map.Entry entry : originalMap.entrySet()) {
@@ -43,10 +47,54 @@ public DataSourceV2Options(Map originalMap) {
}
}
+ public Map asMap() {
+ return new HashMap<>(keyLowerCasedMap);
+ }
+
/**
* Returns the option value to which the specified key is mapped, case-insensitively.
*/
public Optional get(String key) {
return Optional.ofNullable(keyLowerCasedMap.get(toLowerCase(key)));
}
+
+ /**
+ * Returns the boolean value to which the specified key is mapped,
+ * or defaultValue if there is no mapping for the key. The key match is case-insensitive
+ */
+ public boolean getBoolean(String key, boolean defaultValue) {
+ String lcaseKey = toLowerCase(key);
+ return keyLowerCasedMap.containsKey(lcaseKey) ?
+ Boolean.parseBoolean(keyLowerCasedMap.get(lcaseKey)) : defaultValue;
+ }
+
+ /**
+ * Returns the integer value to which the specified key is mapped,
+ * or defaultValue if there is no mapping for the key. The key match is case-insensitive
+ */
+ public int getInt(String key, int defaultValue) {
+ String lcaseKey = toLowerCase(key);
+ return keyLowerCasedMap.containsKey(lcaseKey) ?
+ Integer.parseInt(keyLowerCasedMap.get(lcaseKey)) : defaultValue;
+ }
+
+ /**
+ * Returns the long value to which the specified key is mapped,
+ * or defaultValue if there is no mapping for the key. The key match is case-insensitive
+ */
+ public long getLong(String key, long defaultValue) {
+ String lcaseKey = toLowerCase(key);
+ return keyLowerCasedMap.containsKey(lcaseKey) ?
+ Long.parseLong(keyLowerCasedMap.get(lcaseKey)) : defaultValue;
+ }
+
+ /**
+ * Returns the double value to which the specified key is mapped,
+ * or defaultValue if there is no mapping for the key. The key match is case-insensitive
+ */
+ public double getDouble(String key, double defaultValue) {
+ String lcaseKey = toLowerCase(key);
+ return keyLowerCasedMap.containsKey(lcaseKey) ?
+ Double.parseDouble(keyLowerCasedMap.get(lcaseKey)) : defaultValue;
+ }
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java
new file mode 100644
index 0000000000000..3cb020d2e0836
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2;
+
+import org.apache.spark.annotation.InterfaceStability;
+
+/**
+ * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to
+ * propagate session configs with the specified key-prefix to all data source operations in this
+ * session.
+ */
+@InterfaceStability.Evolving
+public interface SessionConfigSupport {
+
+ /**
+ * Key prefix of the session configs to propagate. Spark will extract all session configs that
+ * starts with `spark.datasource.$keyPrefix`, turn `spark.datasource.$keyPrefix.xxx -> yyy`
+ * into `xxx -> yyy`, and propagate them to all data source operations in this session.
+ */
+ String keyPrefix();
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java
new file mode 100644
index 0000000000000..3136cee1f655f
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.streaming;
+
+import java.util.Optional;
+
+import org.apache.spark.sql.sources.v2.DataSourceV2;
+import org.apache.spark.sql.sources.v2.DataSourceV2Options;
+import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader;
+import org.apache.spark.sql.types.StructType;
+
+/**
+ * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to
+ * provide data reading ability for continuous stream processing.
+ */
+public interface ContinuousReadSupport extends DataSourceV2 {
+ /**
+ * Creates a {@link ContinuousReader} to scan the data from this data source.
+ *
+ * @param schema the user provided schema, or empty() if none was provided
+ * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure
+ * recovery. Readers for the same logical source in the same query
+ * will be given the same checkpointLocation.
+ * @param options the options for the returned data source reader, which is an immutable
+ * case-insensitive string-to-string map.
+ */
+ ContinuousReader createContinuousReader(
+ Optional schema,
+ String checkpointLocation,
+ DataSourceV2Options options);
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java
new file mode 100644
index 0000000000000..dee493cadb71e
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.streaming;
+
+import java.util.Optional;
+
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.execution.streaming.BaseStreamingSink;
+import org.apache.spark.sql.sources.v2.DataSourceV2;
+import org.apache.spark.sql.sources.v2.DataSourceV2Options;
+import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter;
+import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer;
+import org.apache.spark.sql.streaming.OutputMode;
+import org.apache.spark.sql.types.StructType;
+
+/**
+ * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to
+ * provide data writing ability for continuous stream processing.
+ */
+@InterfaceStability.Evolving
+public interface ContinuousWriteSupport extends BaseStreamingSink {
+
+ /**
+ * Creates an optional {@link ContinuousWriter} to save the data to this data source. Data
+ * sources can return None if there is no writing needed to be done.
+ *
+ * @param queryId A unique string for the writing query. It's possible that there are many
+ * writing queries running at the same time, and the returned
+ * {@link DataSourceV2Writer} can use this id to distinguish itself from others.
+ * @param schema the schema of the data to be written.
+ * @param mode the output mode which determines what successive epoch output means to this
+ * sink, please refer to {@link OutputMode} for more details.
+ * @param options the options for the returned data source writer, which is an immutable
+ * case-insensitive string-to-string map.
+ */
+ Optional createContinuousWriter(
+ String queryId,
+ StructType schema,
+ OutputMode mode,
+ DataSourceV2Options options);
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java
new file mode 100644
index 0000000000000..3c87a3db68243
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.streaming;
+
+import java.util.Optional;
+
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.sources.v2.DataSourceV2;
+import org.apache.spark.sql.sources.v2.DataSourceV2Options;
+import org.apache.spark.sql.sources.v2.streaming.reader.MicroBatchReader;
+import org.apache.spark.sql.types.StructType;
+
+/**
+ * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to
+ * provide streaming micro-batch data reading ability.
+ */
+@InterfaceStability.Evolving
+public interface MicroBatchReadSupport extends DataSourceV2 {
+ /**
+ * Creates a {@link MicroBatchReader} to read batches of data from this data source in a
+ * streaming query.
+ *
+ * The execution engine will create a micro-batch reader at the start of a streaming query,
+ * alternate calls to setOffsetRange and createReadTasks for each batch to process, and then
+ * call stop() when the execution is complete. Note that a single query may have multiple
+ * executions due to restart or failure recovery.
+ *
+ * @param schema the user provided schema, or empty() if none was provided
+ * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure
+ * recovery. Readers for the same logical source in the same query
+ * will be given the same checkpointLocation.
+ * @param options the options for the returned data source reader, which is an immutable
+ * case-insensitive string-to-string map.
+ */
+ MicroBatchReader createMicroBatchReader(
+ Optional schema,
+ String checkpointLocation,
+ DataSourceV2Options options);
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java
new file mode 100644
index 0000000000000..53ffa95ae0f4c
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.streaming;
+
+import java.util.Optional;
+
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.execution.streaming.BaseStreamingSink;
+import org.apache.spark.sql.sources.v2.DataSourceV2;
+import org.apache.spark.sql.sources.v2.DataSourceV2Options;
+import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer;
+import org.apache.spark.sql.streaming.OutputMode;
+import org.apache.spark.sql.types.StructType;
+
+/**
+ * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to
+ * provide data writing ability and save the data from a microbatch to the data source.
+ */
+@InterfaceStability.Evolving
+public interface MicroBatchWriteSupport extends BaseStreamingSink {
+
+ /**
+ * Creates an optional {@link DataSourceV2Writer} to save the data to this data source. Data
+ * sources can return None if there is no writing needed to be done.
+ *
+ * @param queryId A unique string for the writing query. It's possible that there are many writing
+ * queries running at the same time, and the returned {@link DataSourceV2Writer}
+ * can use this id to distinguish itself from others.
+ * @param epochId The unique numeric ID of the batch within this writing query. This is an
+ * incrementing counter representing a consistent set of data; the same batch may
+ * be started multiple times in failure recovery scenarios, but it will always
+ * contain the same records.
+ * @param schema the schema of the data to be written.
+ * @param mode the output mode which determines what successive batch output means to this
+ * sink, please refer to {@link OutputMode} for more details.
+ * @param options the options for the returned data source writer, which is an immutable
+ * case-insensitive string-to-string map.
+ */
+ Optional createMicroBatchWriter(
+ String queryId,
+ long epochId,
+ StructType schema,
+ OutputMode mode,
+ DataSourceV2Options options);
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java
new file mode 100644
index 0000000000000..ca9a290e97a02
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.streaming.reader;
+
+import org.apache.spark.sql.sources.v2.reader.DataReader;
+
+/**
+ * A variation on {@link DataReader} for use with streaming in continuous processing mode.
+ */
+public interface ContinuousDataReader extends DataReader {
+ /**
+ * Get the offset of the current record, or the start offset if no records have been read.
+ *
+ * The execution engine will call this method along with get() to keep track of the current
+ * offset. When an epoch ends, the offset of the previous record in each partition will be saved
+ * as a restart checkpoint.
+ */
+ PartitionOffset getOffset();
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java
new file mode 100644
index 0000000000000..f0b205869ed6c
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.streaming.reader;
+
+import org.apache.spark.sql.execution.streaming.BaseStreamingSource;
+import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader;
+
+import java.util.Optional;
+
+/**
+ * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this
+ * interface to allow reading in a continuous processing mode stream.
+ *
+ * Implementations must ensure each read task output is a {@link ContinuousDataReader}.
+ */
+public interface ContinuousReader extends BaseStreamingSource, DataSourceV2Reader {
+ /**
+ * Merge offsets coming from {@link ContinuousDataReader} instances in each partition to
+ * a single global offset.
+ */
+ Offset mergeOffsets(PartitionOffset[] offsets);
+
+ /**
+ * Deserialize a JSON string into an Offset of the implementation-defined offset type.
+ * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader
+ */
+ Offset deserializeOffset(String json);
+
+ /**
+ * Set the desired start offset for read tasks created from this reader. The scan will start
+ * from the first record after the provided offset, or from an implementation-defined inferred
+ * starting point if no offset is provided.
+ */
+ void setOffset(Optional start);
+
+ /**
+ * Return the specified or inferred start offset for this reader.
+ *
+ * @throws IllegalStateException if setOffset has not been called
+ */
+ Offset getStartOffset();
+
+ /**
+ * The execution engine will call this method in every epoch to determine if new read tasks need
+ * to be generated, which may be required if for example the underlying source system has had
+ * partitions added or removed.
+ *
+ * If true, the query will be shut down and restarted with a new reader.
+ */
+ default boolean needsReconfiguration() {
+ return false;
+ }
+
+ /**
+ * Informs the source that Spark has completed processing all data for offsets less than or
+ * equal to `end` and will only request offsets greater than `end` in the future.
+ */
+ void commit(Offset end);
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java
new file mode 100644
index 0000000000000..70ff756806032
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.streaming.reader;
+
+import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader;
+import org.apache.spark.sql.execution.streaming.BaseStreamingSource;
+
+import java.util.Optional;
+
+/**
+ * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this
+ * interface to indicate they allow micro-batch streaming reads.
+ */
+public interface MicroBatchReader extends DataSourceV2Reader, BaseStreamingSource {
+ /**
+ * Set the desired offset range for read tasks created from this reader. Read tasks will
+ * generate only data within (`start`, `end`]; that is, from the first record after `start` to
+ * the record with offset `end`.
+ *
+ * @param start The initial offset to scan from. If not specified, scan from an
+ * implementation-specified start point, such as the earliest available record.
+ * @param end The last offset to include in the scan. If not specified, scan up to an
+ * implementation-defined endpoint, such as the last available offset
+ * or the start offset plus a target batch size.
+ */
+ void setOffsetRange(Optional start, Optional end);
+
+ /**
+ * Returns the specified (if explicitly set through setOffsetRange) or inferred start offset
+ * for this reader.
+ *
+ * @throws IllegalStateException if setOffsetRange has not been called
+ */
+ Offset getStartOffset();
+
+ /**
+ * Return the specified (if explicitly set through setOffsetRange) or inferred end offset
+ * for this reader.
+ *
+ * @throws IllegalStateException if setOffsetRange has not been called
+ */
+ Offset getEndOffset();
+
+ /**
+ * Deserialize a JSON string into an Offset of the implementation-defined offset type.
+ * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader
+ */
+ Offset deserializeOffset(String json);
+
+ /**
+ * Informs the source that Spark has completed processing all data for offsets less than or
+ * equal to `end` and will only request offsets greater than `end` in the future.
+ */
+ void commit(Offset end);
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java
new file mode 100644
index 0000000000000..60b87f2ac0756
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.streaming.reader;
+
+/**
+ * An abstract representation of progress through a [[MicroBatchReader]] or [[ContinuousReader]].
+ * During execution, Offsets provided by the data source implementation will be logged and used as
+ * restart checkpoints. Sources should provide an Offset implementation which they can use to
+ * reconstruct the stream position where the offset was taken.
+ */
+public abstract class Offset extends org.apache.spark.sql.execution.streaming.Offset {
+ /**
+ * A JSON-serialized representation of an Offset that is
+ * used for saving offsets to the offset log.
+ * Note: We assume that equivalent/equal offsets serialize to
+ * identical JSON strings.
+ *
+ * @return JSON string encoding
+ */
+ public abstract String json();
+
+ /**
+ * Equality based on JSON string representation. We leverage the
+ * JSON representation for normalization between the Offset's
+ * in memory and on disk representations.
+ */
+ @Override
+ public boolean equals(Object obj) {
+ if (obj instanceof org.apache.spark.sql.execution.streaming.Offset) {
+ return this.json()
+ .equals(((org.apache.spark.sql.execution.streaming.Offset) obj).json());
+ } else {
+ return false;
+ }
+ }
+
+ @Override
+ public int hashCode() {
+ return this.json().hashCode();
+ }
+
+ @Override
+ public String toString() {
+ return this.json();
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java
new file mode 100644
index 0000000000000..eca0085c8a8ce
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.streaming.reader;
+
+import java.io.Serializable;
+
+/**
+ * Used for per-partition offsets in continuous processing. ContinuousReader implementations will
+ * provide a method to merge these into a global Offset.
+ *
+ * These offsets must be serializable.
+ */
+public interface PartitionOffset extends Serializable {
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java
new file mode 100644
index 0000000000000..723395bd1e963
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.streaming.writer;
+
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer;
+import org.apache.spark.sql.sources.v2.writer.DataWriter;
+import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage;
+
+/**
+ * A {@link DataSourceV2Writer} for use with continuous stream processing.
+ */
+@InterfaceStability.Evolving
+public interface ContinuousWriter extends DataSourceV2Writer {
+ /**
+ * Commits this writing job for the specified epoch with a list of commit messages. The commit
+ * messages are collected from successful data writers and are produced by
+ * {@link DataWriter#commit()}.
+ *
+ * If this method fails (by throwing an exception), this writing job is considered to have been
+ * failed, and the execution engine will attempt to call {@link #abort(WriterCommitMessage[])}.
+ */
+ void commit(long epochId, WriterCommitMessage[] messages);
+
+ default void commit(WriterCommitMessage[] messages) {
+ throw new UnsupportedOperationException(
+ "Commit without epoch should not be called with ContinuousWriter");
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java
index d31790a285687..33ae9a9e87668 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java
@@ -22,6 +22,7 @@
import scala.concurrent.duration.Duration;
import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger;
import org.apache.spark.sql.execution.streaming.OneTimeTrigger$;
/**
@@ -95,4 +96,57 @@ public static Trigger ProcessingTime(String interval) {
public static Trigger Once() {
return OneTimeTrigger$.MODULE$;
}
+
+ /**
+ * A trigger that continuously processes streaming data, asynchronously checkpointing at
+ * the specified interval.
+ *
+ * @since 2.3.0
+ */
+ public static Trigger Continuous(long intervalMs) {
+ return ContinuousTrigger.apply(intervalMs);
+ }
+
+ /**
+ * A trigger that continuously processes streaming data, asynchronously checkpointing at
+ * the specified interval.
+ *
+ * {{{
+ * import java.util.concurrent.TimeUnit
+ * df.writeStream.trigger(Trigger.Continuous(10, TimeUnit.SECONDS))
+ * }}}
+ *
+ * @since 2.3.0
+ */
+ public static Trigger Continuous(long interval, TimeUnit timeUnit) {
+ return ContinuousTrigger.create(interval, timeUnit);
+ }
+
+ /**
+ * (Scala-friendly)
+ * A trigger that continuously processes streaming data, asynchronously checkpointing at
+ * the specified interval.
+ *
+ * {{{
+ * import scala.concurrent.duration._
+ * df.writeStream.trigger(Trigger.Continuous(10.seconds))
+ * }}}
+ * @since 2.3.0
+ */
+ public static Trigger Continuous(Duration interval) {
+ return ContinuousTrigger.apply(interval);
+ }
+
+ /**
+ * A trigger that continuously processes streaming data, asynchronously checkpointing at
+ * the specified interval.
+ *
+ * {{{
+ * df.writeStream.trigger(Trigger.Continuous("10 seconds"))
+ * }}}
+ * @since 2.3.0
+ */
+ public static Trigger Continuous(String interval) {
+ return ContinuousTrigger.apply(interval);
+ }
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
similarity index 69%
rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java
rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
index 5c502c9d91be4..708333213f3f1 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.execution.vectorized;
+package org.apache.spark.sql.vectorized;
import org.apache.arrow.vector.*;
import org.apache.arrow.vector.complex.*;
@@ -34,11 +34,7 @@ public final class ArrowColumnVector extends ColumnVector {
private ArrowColumnVector[] childColumns;
private void ensureAccessible(int index) {
- int valueCount = accessor.getValueCount();
- if (index < 0 || index >= valueCount) {
- throw new IndexOutOfBoundsException(
- String.format("index: %d, valueCount: %d", index, valueCount));
- }
+ ensureAccessible(index, 1);
}
private void ensureAccessible(int index, int count) {
@@ -54,11 +50,6 @@ public int numNulls() {
return accessor.getNullCount();
}
- @Override
- public boolean anyNullsSet() {
- return numNulls() > 0;
- }
-
@Override
public void close() {
if (childColumns != null) {
@@ -69,20 +60,12 @@ public void close() {
accessor.close();
}
- //
- // APIs dealing with nulls
- //
-
@Override
public boolean isNullAt(int rowId) {
ensureAccessible(rowId);
return accessor.isNullAt(rowId);
}
- //
- // APIs dealing with Booleans
- //
-
@Override
public boolean getBoolean(int rowId) {
ensureAccessible(rowId);
@@ -99,10 +82,6 @@ public boolean[] getBooleans(int rowId, int count) {
return array;
}
- //
- // APIs dealing with Bytes
- //
-
@Override
public byte getByte(int rowId) {
ensureAccessible(rowId);
@@ -119,10 +98,6 @@ public byte[] getBytes(int rowId, int count) {
return array;
}
- //
- // APIs dealing with Shorts
- //
-
@Override
public short getShort(int rowId) {
ensureAccessible(rowId);
@@ -139,10 +114,6 @@ public short[] getShorts(int rowId, int count) {
return array;
}
- //
- // APIs dealing with Ints
- //
-
@Override
public int getInt(int rowId) {
ensureAccessible(rowId);
@@ -159,15 +130,6 @@ public int[] getInts(int rowId, int count) {
return array;
}
- @Override
- public int getDictId(int rowId) {
- throw new UnsupportedOperationException();
- }
-
- //
- // APIs dealing with Longs
- //
-
@Override
public long getLong(int rowId) {
ensureAccessible(rowId);
@@ -184,10 +146,6 @@ public long[] getLongs(int rowId, int count) {
return array;
}
- //
- // APIs dealing with floats
- //
-
@Override
public float getFloat(int rowId) {
ensureAccessible(rowId);
@@ -204,10 +162,6 @@ public float[] getFloats(int rowId, int count) {
return array;
}
- //
- // APIs dealing with doubles
- //
-
@Override
public double getDouble(int rowId) {
ensureAccessible(rowId);
@@ -224,10 +178,6 @@ public double[] getDoubles(int rowId, int count) {
return array;
}
- //
- // APIs dealing with Arrays
- //
-
@Override
public int getArrayLength(int rowId) {
ensureAccessible(rowId);
@@ -240,82 +190,63 @@ public int getArrayOffset(int rowId) {
return accessor.getArrayOffset(rowId);
}
- //
- // APIs dealing with Decimals
- //
-
@Override
public Decimal getDecimal(int rowId, int precision, int scale) {
ensureAccessible(rowId);
return accessor.getDecimal(rowId, precision, scale);
}
- //
- // APIs dealing with UTF8Strings
- //
-
@Override
public UTF8String getUTF8String(int rowId) {
ensureAccessible(rowId);
return accessor.getUTF8String(rowId);
}
- //
- // APIs dealing with Binaries
- //
-
@Override
public byte[] getBinary(int rowId) {
ensureAccessible(rowId);
return accessor.getBinary(rowId);
}
- /**
- * Returns the data for the underlying array.
- */
@Override
public ArrowColumnVector arrayData() { return childColumns[0]; }
- /**
- * Returns the ordinal's child data column.
- */
@Override
public ArrowColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; }
public ArrowColumnVector(ValueVector vector) {
super(ArrowUtils.fromArrowField(vector.getField()));
- if (vector instanceof NullableBitVector) {
- accessor = new BooleanAccessor((NullableBitVector) vector);
- } else if (vector instanceof NullableTinyIntVector) {
- accessor = new ByteAccessor((NullableTinyIntVector) vector);
- } else if (vector instanceof NullableSmallIntVector) {
- accessor = new ShortAccessor((NullableSmallIntVector) vector);
- } else if (vector instanceof NullableIntVector) {
- accessor = new IntAccessor((NullableIntVector) vector);
- } else if (vector instanceof NullableBigIntVector) {
- accessor = new LongAccessor((NullableBigIntVector) vector);
- } else if (vector instanceof NullableFloat4Vector) {
- accessor = new FloatAccessor((NullableFloat4Vector) vector);
- } else if (vector instanceof NullableFloat8Vector) {
- accessor = new DoubleAccessor((NullableFloat8Vector) vector);
- } else if (vector instanceof NullableDecimalVector) {
- accessor = new DecimalAccessor((NullableDecimalVector) vector);
- } else if (vector instanceof NullableVarCharVector) {
- accessor = new StringAccessor((NullableVarCharVector) vector);
- } else if (vector instanceof NullableVarBinaryVector) {
- accessor = new BinaryAccessor((NullableVarBinaryVector) vector);
- } else if (vector instanceof NullableDateDayVector) {
- accessor = new DateAccessor((NullableDateDayVector) vector);
- } else if (vector instanceof NullableTimeStampMicroTZVector) {
- accessor = new TimestampAccessor((NullableTimeStampMicroTZVector) vector);
+ if (vector instanceof BitVector) {
+ accessor = new BooleanAccessor((BitVector) vector);
+ } else if (vector instanceof TinyIntVector) {
+ accessor = new ByteAccessor((TinyIntVector) vector);
+ } else if (vector instanceof SmallIntVector) {
+ accessor = new ShortAccessor((SmallIntVector) vector);
+ } else if (vector instanceof IntVector) {
+ accessor = new IntAccessor((IntVector) vector);
+ } else if (vector instanceof BigIntVector) {
+ accessor = new LongAccessor((BigIntVector) vector);
+ } else if (vector instanceof Float4Vector) {
+ accessor = new FloatAccessor((Float4Vector) vector);
+ } else if (vector instanceof Float8Vector) {
+ accessor = new DoubleAccessor((Float8Vector) vector);
+ } else if (vector instanceof DecimalVector) {
+ accessor = new DecimalAccessor((DecimalVector) vector);
+ } else if (vector instanceof VarCharVector) {
+ accessor = new StringAccessor((VarCharVector) vector);
+ } else if (vector instanceof VarBinaryVector) {
+ accessor = new BinaryAccessor((VarBinaryVector) vector);
+ } else if (vector instanceof DateDayVector) {
+ accessor = new DateAccessor((DateDayVector) vector);
+ } else if (vector instanceof TimeStampMicroTZVector) {
+ accessor = new TimestampAccessor((TimeStampMicroTZVector) vector);
} else if (vector instanceof ListVector) {
ListVector listVector = (ListVector) vector;
accessor = new ArrayAccessor(listVector);
childColumns = new ArrowColumnVector[1];
childColumns[0] = new ArrowColumnVector(listVector.getDataVector());
- resultArray = new ColumnarArray(childColumns[0]);
} else if (vector instanceof MapVector) {
MapVector mapVector = (MapVector) vector;
accessor = new StructAccessor(mapVector);
@@ -324,7 +255,6 @@ public ArrowColumnVector(ValueVector vector) {
for (int i = 0; i < childColumns.length; ++i) {
childColumns[i] = new ArrowColumnVector(mapVector.getVectorById(i));
}
- resultStruct = new ColumnarRow(childColumns);
} else {
throw new UnsupportedOperationException();
}
@@ -333,23 +263,22 @@ public ArrowColumnVector(ValueVector vector) {
private abstract static class ArrowVectorAccessor {
private final ValueVector vector;
- private final ValueVector.Accessor nulls;
ArrowVectorAccessor(ValueVector vector) {
this.vector = vector;
- this.nulls = vector.getAccessor();
}
- final boolean isNullAt(int rowId) {
- return nulls.isNull(rowId);
+ // TODO: should be final after removing ArrayAccessor workaround
+ boolean isNullAt(int rowId) {
+ return vector.isNull(rowId);
}
final int getValueCount() {
- return nulls.getValueCount();
+ return vector.getValueCount();
}
final int getNullCount() {
- return nulls.getNullCount();
+ return vector.getNullCount();
}
final void close() {
@@ -407,11 +336,11 @@ int getArrayOffset(int rowId) {
private static class BooleanAccessor extends ArrowVectorAccessor {
- private final NullableBitVector.Accessor accessor;
+ private final BitVector accessor;
- BooleanAccessor(NullableBitVector vector) {
+ BooleanAccessor(BitVector vector) {
super(vector);
- this.accessor = vector.getAccessor();
+ this.accessor = vector;
}
@Override
@@ -422,11 +351,11 @@ final boolean getBoolean(int rowId) {
private static class ByteAccessor extends ArrowVectorAccessor {
- private final NullableTinyIntVector.Accessor accessor;
+ private final TinyIntVector accessor;
- ByteAccessor(NullableTinyIntVector vector) {
+ ByteAccessor(TinyIntVector vector) {
super(vector);
- this.accessor = vector.getAccessor();
+ this.accessor = vector;
}
@Override
@@ -437,11 +366,11 @@ final byte getByte(int rowId) {
private static class ShortAccessor extends ArrowVectorAccessor {
- private final NullableSmallIntVector.Accessor accessor;
+ private final SmallIntVector accessor;
- ShortAccessor(NullableSmallIntVector vector) {
+ ShortAccessor(SmallIntVector vector) {
super(vector);
- this.accessor = vector.getAccessor();
+ this.accessor = vector;
}
@Override
@@ -452,11 +381,11 @@ final short getShort(int rowId) {
private static class IntAccessor extends ArrowVectorAccessor {
- private final NullableIntVector.Accessor accessor;
+ private final IntVector accessor;
- IntAccessor(NullableIntVector vector) {
+ IntAccessor(IntVector vector) {
super(vector);
- this.accessor = vector.getAccessor();
+ this.accessor = vector;
}
@Override
@@ -467,11 +396,11 @@ final int getInt(int rowId) {
private static class LongAccessor extends ArrowVectorAccessor {
- private final NullableBigIntVector.Accessor accessor;
+ private final BigIntVector accessor;
- LongAccessor(NullableBigIntVector vector) {
+ LongAccessor(BigIntVector vector) {
super(vector);
- this.accessor = vector.getAccessor();
+ this.accessor = vector;
}
@Override
@@ -482,11 +411,11 @@ final long getLong(int rowId) {
private static class FloatAccessor extends ArrowVectorAccessor {
- private final NullableFloat4Vector.Accessor accessor;
+ private final Float4Vector accessor;
- FloatAccessor(NullableFloat4Vector vector) {
+ FloatAccessor(Float4Vector vector) {
super(vector);
- this.accessor = vector.getAccessor();
+ this.accessor = vector;
}
@Override
@@ -497,11 +426,11 @@ final float getFloat(int rowId) {
private static class DoubleAccessor extends ArrowVectorAccessor {
- private final NullableFloat8Vector.Accessor accessor;
+ private final Float8Vector accessor;
- DoubleAccessor(NullableFloat8Vector vector) {
+ DoubleAccessor(Float8Vector vector) {
super(vector);
- this.accessor = vector.getAccessor();
+ this.accessor = vector;
}
@Override
@@ -512,11 +441,11 @@ final double getDouble(int rowId) {
private static class DecimalAccessor extends ArrowVectorAccessor {
- private final NullableDecimalVector.Accessor accessor;
+ private final DecimalVector accessor;
- DecimalAccessor(NullableDecimalVector vector) {
+ DecimalAccessor(DecimalVector vector) {
super(vector);
- this.accessor = vector.getAccessor();
+ this.accessor = vector;
}
@Override
@@ -528,12 +457,12 @@ final Decimal getDecimal(int rowId, int precision, int scale) {
private static class StringAccessor extends ArrowVectorAccessor {
- private final NullableVarCharVector.Accessor accessor;
+ private final VarCharVector accessor;
private final NullableVarCharHolder stringResult = new NullableVarCharHolder();
- StringAccessor(NullableVarCharVector vector) {
+ StringAccessor(VarCharVector vector) {
super(vector);
- this.accessor = vector.getAccessor();
+ this.accessor = vector;
}
@Override
@@ -551,11 +480,11 @@ final UTF8String getUTF8String(int rowId) {
private static class BinaryAccessor extends ArrowVectorAccessor {
- private final NullableVarBinaryVector.Accessor accessor;
+ private final VarBinaryVector accessor;
- BinaryAccessor(NullableVarBinaryVector vector) {
+ BinaryAccessor(VarBinaryVector vector) {
super(vector);
- this.accessor = vector.getAccessor();
+ this.accessor = vector;
}
@Override
@@ -566,11 +495,11 @@ final byte[] getBinary(int rowId) {
private static class DateAccessor extends ArrowVectorAccessor {
- private final NullableDateDayVector.Accessor accessor;
+ private final DateDayVector accessor;
- DateAccessor(NullableDateDayVector vector) {
+ DateAccessor(DateDayVector vector) {
super(vector);
- this.accessor = vector.getAccessor();
+ this.accessor = vector;
}
@Override
@@ -581,11 +510,11 @@ final int getInt(int rowId) {
private static class TimestampAccessor extends ArrowVectorAccessor {
- private final NullableTimeStampMicroTZVector.Accessor accessor;
+ private final TimeStampMicroTZVector accessor;
- TimestampAccessor(NullableTimeStampMicroTZVector vector) {
+ TimestampAccessor(TimeStampMicroTZVector vector) {
super(vector);
- this.accessor = vector.getAccessor();
+ this.accessor = vector;
}
@Override
@@ -596,21 +525,31 @@ final long getLong(int rowId) {
private static class ArrayAccessor extends ArrowVectorAccessor {
- private final UInt4Vector.Accessor accessor;
+ private final ListVector accessor;
ArrayAccessor(ListVector vector) {
super(vector);
- this.accessor = vector.getOffsetVector().getAccessor();
+ this.accessor = vector;
+ }
+
+ @Override
+ final boolean isNullAt(int rowId) {
+ // TODO: Workaround if vector has all non-null values, see ARROW-1948
+ if (accessor.getValueCount() > 0 && accessor.getValidityBuffer().capacity() == 0) {
+ return false;
+ } else {
+ return super.isNullAt(rowId);
+ }
}
@Override
final int getArrayLength(int rowId) {
- return accessor.get(rowId + 1) - accessor.get(rowId);
+ return accessor.getInnerValueCountAt(rowId);
}
@Override
final int getArrayOffset(int rowId) {
- return accessor.get(rowId);
+ return accessor.getOffsetBuffer().getInt(rowId * accessor.OFFSET_WIDTH);
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java
similarity index 59%
rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java
index 940457f2e3363..d1196e1299fee 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql.execution.vectorized;
+package org.apache.spark.sql.vectorized;
import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.sql.types.DataType;
@@ -22,32 +22,36 @@
import org.apache.spark.unsafe.types.UTF8String;
/**
- * This class represents a column of values and provides the main APIs to access the data
- * values. It supports all the types and contains get APIs as well as their batched versions.
- * The batched versions are preferable whenever possible.
- *
- * To handle nested schemas, ColumnVector has two types: Arrays and Structs. In both cases these
- * columns have child columns. All of the data is stored in the child columns and the parent column
- * contains nullability, and in the case of Arrays, the lengths and offsets into the child column.
- * Lengths and offsets are encoded identically to INTs.
- * Maps are just a special case of a two field struct.
+ * An interface representing in-memory columnar data in Spark. This interface defines the main APIs
+ * to access the data, as well as their batched versions. The batched versions are considered to be
+ * faster and preferable whenever possible.
*
* Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values
- * in the current RowBatch.
+ * in this ColumnVector.
+ *
+ * ColumnVector supports all the data types including nested types. To handle nested types,
+ * ColumnVector can have children and is a tree structure. For struct type, it stores the actual
+ * data of each field in the corresponding child ColumnVector, and only stores null information in
+ * the parent ColumnVector. For array type, it stores the actual array elements in the child
+ * ColumnVector, and stores null information, array offsets and lengths in the parent ColumnVector.
*
- * A ColumnVector should be considered immutable once originally created.
+ * ColumnVector is expected to be reused during the entire data loading process, to avoid allocating
+ * memory again and again.
*
- * ColumnVectors are intended to be reused.
+ * ColumnVector is meant to maximize CPU efficiency but not to minimize storage footprint.
+ * Implementations should prefer computing efficiency over storage efficiency when design the
+ * format. Since it is expected to reuse the ColumnVector instance while loading data, the storage
+ * footprint is negligible.
*/
public abstract class ColumnVector implements AutoCloseable {
+
/**
- * Returns the data type of this column.
+ * Returns the data type of this column vector.
*/
public final DataType dataType() { return type; }
/**
* Cleans up memory for this column. The column is not usable after this.
- * TODO: this should probably have ref-counted semantics.
*/
public abstract void close();
@@ -56,12 +60,6 @@ public abstract class ColumnVector implements AutoCloseable {
*/
public abstract int numNulls();
- /**
- * Returns true if any of the nulls indicator are set for this column. This can be used
- * as an optimization to prevent setting nulls.
- */
- public abstract boolean anyNullsSet();
-
/**
* Returns whether the value at rowId is NULL.
*/
@@ -107,13 +105,6 @@ public abstract class ColumnVector implements AutoCloseable {
*/
public abstract int[] getInts(int rowId, int count);
- /**
- * Returns the dictionary Id for rowId.
- * This should only be called when the ColumnVector is dictionaryIds.
- * We have this separate method for dictionaryIds as per SPARK-16928.
- */
- public abstract int getDictId(int rowId);
-
/**
* Returns the value for rowId.
*/
@@ -145,43 +136,39 @@ public abstract class ColumnVector implements AutoCloseable {
public abstract double[] getDoubles(int rowId, int count);
/**
- * Returns the length of the array at rowid.
+ * Returns the length of the array for rowId.
*/
public abstract int getArrayLength(int rowId);
/**
- * Returns the offset of the array at rowid.
+ * Returns the offset of the array for rowId.
*/
public abstract int getArrayOffset(int rowId);
/**
- * Returns a utility object to get structs.
+ * Returns the struct for rowId.
*/
- public ColumnarRow getStruct(int rowId) {
- resultStruct.rowId = rowId;
- return resultStruct;
+ public final ColumnarRow getStruct(int rowId) {
+ return new ColumnarRow(this, rowId);
}
/**
- * Returns a utility object to get structs.
- * provided to keep API compatibility with InternalRow for code generation
+ * A special version of {@link #getStruct(int)}, which is only used as an adapter for Spark
+ * codegen framework, the second parameter is totally ignored.
*/
- public ColumnarRow getStruct(int rowId, int size) {
- resultStruct.rowId = rowId;
- return resultStruct;
+ public final ColumnarRow getStruct(int rowId, int size) {
+ return getStruct(rowId);
}
/**
- * Returns the array at rowid.
+ * Returns the array for rowId.
*/
public final ColumnarArray getArray(int rowId) {
- resultArray.length = getArrayLength(rowId);
- resultArray.offset = getArrayOffset(rowId);
- return resultArray;
+ return new ColumnarArray(arrayData(), getArrayOffset(rowId), getArrayLength(rowId));
}
/**
- * Returns the value for rowId.
+ * Returns the map for rowId.
*/
public MapData getMap(int ordinal) {
throw new UnsupportedOperationException();
@@ -213,50 +200,11 @@ public MapData getMap(int ordinal) {
*/
public abstract ColumnVector getChildColumn(int ordinal);
- /**
- * Returns true if this column is an array.
- */
- public final boolean isArray() { return resultArray != null; }
-
/**
* Data type for this column.
*/
protected DataType type;
- /**
- * Reusable Array holder for getArray().
- */
- protected ColumnarArray resultArray;
-
- /**
- * Reusable Struct holder for getStruct().
- */
- protected ColumnarRow resultStruct;
-
- /**
- * The Dictionary for this column.
- *
- * If it's not null, will be used to decode the value in getXXX().
- */
- protected Dictionary dictionary;
-
- /**
- * Reusable column for ids of dictionary.
- */
- protected ColumnVector dictionaryIds;
-
- /**
- * Returns true if this column has a dictionary.
- */
- public boolean hasDictionary() { return this.dictionary != null; }
-
- /**
- * Returns the underlying integer column for ids of dictionary.
- */
- public ColumnVector getDictionaryIds() {
- return dictionaryIds;
- }
-
/**
* Sets up the common state and also handles creating the child columns if this is a nested
* type.
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java
similarity index 94%
rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java
rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java
index b9da641fc66c8..0d89a52e7a4fe 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql.execution.vectorized;
+package org.apache.spark.sql.vectorized;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.MapData;
@@ -23,18 +23,19 @@
import org.apache.spark.unsafe.types.UTF8String;
/**
- * Array abstraction in {@link ColumnVector}. The instance of this class is intended
- * to be reused, callers should copy the data out if it needs to be stored.
+ * Array abstraction in {@link ColumnVector}.
*/
public final class ColumnarArray extends ArrayData {
// The data for this array. This array contains elements from
// data[offset] to data[offset + length).
- public final ColumnVector data;
- public int length;
- public int offset;
+ private final ColumnVector data;
+ private final int offset;
+ private final int length;
- ColumnarArray(ColumnVector data) {
+ public ColumnarArray(ColumnVector data, int offset, int length) {
this.data = data;
+ this.offset = offset;
+ this.length = length;
}
@Override
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java
similarity index 63%
rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java
index 2f5fb360b226f..9ae1c6d9993f0 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java
@@ -14,25 +14,18 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql.execution.vectorized;
+package org.apache.spark.sql.vectorized;
import java.util.*;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.execution.vectorized.MutableColumnarRow;
import org.apache.spark.sql.types.StructType;
/**
- * This class is the in memory representation of rows as they are streamed through operators. It
- * is designed to maximize CPU efficiency and not storage footprint. Since it is expected that
- * each operator allocates one of these objects, the storage footprint on the task is negligible.
- *
- * The layout is a columnar with values encoded in their native format. Each RowBatch contains
- * a horizontal partitioning of the data, split into columns.
- *
- * The ColumnarBatch supports either on heap or offheap modes with (mostly) the identical API.
- *
- * TODO:
- * - There are many TODOs for the existing APIs. They should throw a not implemented exception.
- * - Compaction: The batch and columns should be able to compact based on a selection vector.
+ * This class wraps multiple ColumnVectors as a row-wise table. It provides a row view of this
+ * batch so that Spark can access the data row by row. Instance of it is meant to be reused during
+ * the entire data loading process.
*/
public final class ColumnarBatch {
public static final int DEFAULT_BATCH_SIZE = 4 * 1024;
@@ -40,10 +33,10 @@ public final class ColumnarBatch {
private final StructType schema;
private final int capacity;
private int numRows;
- final ColumnVector[] columns;
+ private final ColumnVector[] columns;
- // Staging row returned from getRow.
- final ColumnarRow row;
+ // Staging row returned from `getRow`.
+ private final MutableColumnarRow row;
/**
* Called to close all the columns in this batch. It is not valid to access the data after
@@ -56,12 +49,12 @@ public void close() {
}
/**
- * Returns an iterator over the rows in this batch. This skips rows that are filtered out.
+ * Returns an iterator over the rows in this batch.
*/
- public Iterator rowIterator() {
+ public Iterator rowIterator() {
final int maxRows = numRows;
- final ColumnarRow row = new ColumnarRow(columns);
- return new Iterator() {
+ final MutableColumnarRow row = new MutableColumnarRow(columns);
+ return new Iterator() {
int rowId = 0;
@Override
@@ -70,7 +63,7 @@ public boolean hasNext() {
}
@Override
- public ColumnarRow next() {
+ public InternalRow next() {
if (rowId >= maxRows) {
throw new NoSuchElementException();
}
@@ -86,19 +79,7 @@ public void remove() {
}
/**
- * Resets the batch for writing.
- */
- public void reset() {
- for (int i = 0; i < numCols(); ++i) {
- if (columns[i] instanceof WritableColumnVector) {
- ((WritableColumnVector) columns[i]).reset();
- }
- }
- this.numRows = 0;
- }
-
- /**
- * Sets the number of rows that are valid.
+ * Sets the number of rows in this batch.
*/
public void setNumRows(int numRows) {
assert(numRows <= this.capacity);
@@ -133,9 +114,8 @@ public void setNumRows(int numRows) {
/**
* Returns the row in this batch at `rowId`. Returned row is reused across calls.
*/
- public ColumnarRow getRow(int rowId) {
- assert(rowId >= 0);
- assert(rowId < numRows);
+ public InternalRow getRow(int rowId) {
+ assert(rowId >= 0 && rowId < numRows);
row.rowId = rowId;
return row;
}
@@ -144,6 +124,6 @@ public ColumnarBatch(StructType schema, ColumnVector[] columns, int capacity) {
this.schema = schema;
this.columns = columns;
this.capacity = capacity;
- this.row = new ColumnarRow(columns);
+ this.row = new MutableColumnarRow(columns);
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java
new file mode 100644
index 0000000000000..3c6656dec77cd
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.vectorized;
+
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
+import org.apache.spark.sql.catalyst.util.MapData;
+import org.apache.spark.sql.types.*;
+import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.UTF8String;
+
+/**
+ * Row abstraction in {@link ColumnVector}.
+ */
+public final class ColumnarRow extends InternalRow {
+ // The data for this row.
+ // E.g. the value of 3rd int field is `data.getChildColumn(3).getInt(rowId)`.
+ private final ColumnVector data;
+ private final int rowId;
+ private final int numFields;
+
+ public ColumnarRow(ColumnVector data, int rowId) {
+ assert (data.dataType() instanceof StructType);
+ this.data = data;
+ this.rowId = rowId;
+ this.numFields = ((StructType) data.dataType()).size();
+ }
+
+ @Override
+ public int numFields() { return numFields; }
+
+ /**
+ * Revisit this. This is expensive. This is currently only used in test paths.
+ */
+ @Override
+ public InternalRow copy() {
+ GenericInternalRow row = new GenericInternalRow(numFields);
+ for (int i = 0; i < numFields(); i++) {
+ if (isNullAt(i)) {
+ row.setNullAt(i);
+ } else {
+ DataType dt = data.getChildColumn(i).dataType();
+ if (dt instanceof BooleanType) {
+ row.setBoolean(i, getBoolean(i));
+ } else if (dt instanceof ByteType) {
+ row.setByte(i, getByte(i));
+ } else if (dt instanceof ShortType) {
+ row.setShort(i, getShort(i));
+ } else if (dt instanceof IntegerType) {
+ row.setInt(i, getInt(i));
+ } else if (dt instanceof LongType) {
+ row.setLong(i, getLong(i));
+ } else if (dt instanceof FloatType) {
+ row.setFloat(i, getFloat(i));
+ } else if (dt instanceof DoubleType) {
+ row.setDouble(i, getDouble(i));
+ } else if (dt instanceof StringType) {
+ row.update(i, getUTF8String(i).copy());
+ } else if (dt instanceof BinaryType) {
+ row.update(i, getBinary(i));
+ } else if (dt instanceof DecimalType) {
+ DecimalType t = (DecimalType)dt;
+ row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision());
+ } else if (dt instanceof DateType) {
+ row.setInt(i, getInt(i));
+ } else if (dt instanceof TimestampType) {
+ row.setLong(i, getLong(i));
+ } else {
+ throw new RuntimeException("Not implemented. " + dt);
+ }
+ }
+ }
+ return row;
+ }
+
+ @Override
+ public boolean anyNull() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean isNullAt(int ordinal) { return data.getChildColumn(ordinal).isNullAt(rowId); }
+
+ @Override
+ public boolean getBoolean(int ordinal) { return data.getChildColumn(ordinal).getBoolean(rowId); }
+
+ @Override
+ public byte getByte(int ordinal) { return data.getChildColumn(ordinal).getByte(rowId); }
+
+ @Override
+ public short getShort(int ordinal) { return data.getChildColumn(ordinal).getShort(rowId); }
+
+ @Override
+ public int getInt(int ordinal) { return data.getChildColumn(ordinal).getInt(rowId); }
+
+ @Override
+ public long getLong(int ordinal) { return data.getChildColumn(ordinal).getLong(rowId); }
+
+ @Override
+ public float getFloat(int ordinal) { return data.getChildColumn(ordinal).getFloat(rowId); }
+
+ @Override
+ public double getDouble(int ordinal) { return data.getChildColumn(ordinal).getDouble(rowId); }
+
+ @Override
+ public Decimal getDecimal(int ordinal, int precision, int scale) {
+ if (data.getChildColumn(ordinal).isNullAt(rowId)) return null;
+ return data.getChildColumn(ordinal).getDecimal(rowId, precision, scale);
+ }
+
+ @Override
+ public UTF8String getUTF8String(int ordinal) {
+ if (data.getChildColumn(ordinal).isNullAt(rowId)) return null;
+ return data.getChildColumn(ordinal).getUTF8String(rowId);
+ }
+
+ @Override
+ public byte[] getBinary(int ordinal) {
+ if (data.getChildColumn(ordinal).isNullAt(rowId)) return null;
+ return data.getChildColumn(ordinal).getBinary(rowId);
+ }
+
+ @Override
+ public CalendarInterval getInterval(int ordinal) {
+ if (data.getChildColumn(ordinal).isNullAt(rowId)) return null;
+ final int months = data.getChildColumn(ordinal).getChildColumn(0).getInt(rowId);
+ final long microseconds = data.getChildColumn(ordinal).getChildColumn(1).getLong(rowId);
+ return new CalendarInterval(months, microseconds);
+ }
+
+ @Override
+ public ColumnarRow getStruct(int ordinal, int numFields) {
+ if (data.getChildColumn(ordinal).isNullAt(rowId)) return null;
+ return data.getChildColumn(ordinal).getStruct(rowId);
+ }
+
+ @Override
+ public ColumnarArray getArray(int ordinal) {
+ if (data.getChildColumn(ordinal).isNullAt(rowId)) return null;
+ return data.getChildColumn(ordinal).getArray(rowId);
+ }
+
+ @Override
+ public MapData getMap(int ordinal) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Object get(int ordinal, DataType dataType) {
+ if (dataType instanceof BooleanType) {
+ return getBoolean(ordinal);
+ } else if (dataType instanceof ByteType) {
+ return getByte(ordinal);
+ } else if (dataType instanceof ShortType) {
+ return getShort(ordinal);
+ } else if (dataType instanceof IntegerType) {
+ return getInt(ordinal);
+ } else if (dataType instanceof LongType) {
+ return getLong(ordinal);
+ } else if (dataType instanceof FloatType) {
+ return getFloat(ordinal);
+ } else if (dataType instanceof DoubleType) {
+ return getDouble(ordinal);
+ } else if (dataType instanceof StringType) {
+ return getUTF8String(ordinal);
+ } else if (dataType instanceof BinaryType) {
+ return getBinary(ordinal);
+ } else if (dataType instanceof DecimalType) {
+ DecimalType t = (DecimalType) dataType;
+ return getDecimal(ordinal, t.precision(), t.scale());
+ } else if (dataType instanceof DateType) {
+ return getInt(ordinal);
+ } else if (dataType instanceof TimestampType) {
+ return getLong(ordinal);
+ } else if (dataType instanceof ArrayType) {
+ return getArray(ordinal);
+ } else if (dataType instanceof StructType) {
+ return getStruct(ordinal, ((StructType)dataType).fields().length);
+ } else if (dataType instanceof MapType) {
+ return getMap(ordinal);
+ } else {
+ throw new UnsupportedOperationException("Datatype not supported " + dataType);
+ }
+ }
+
+ @Override
+ public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); }
+
+ @Override
+ public void setNullAt(int ordinal) { throw new UnsupportedOperationException(); }
+}
diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
index 0c5f3f22e31e8..0259c774bbf4a 100644
--- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
+++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -1,8 +1,10 @@
org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider
org.apache.spark.sql.execution.datasources.json.JsonFileFormat
+org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
org.apache.spark.sql.execution.datasources.text.TextFileFormat
org.apache.spark.sql.execution.streaming.ConsoleSinkProvider
org.apache.spark.sql.execution.streaming.TextSocketSourceProvider
org.apache.spark.sql.execution.streaming.RateSourceProvider
+org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2
diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppHistoryServerPlugin b/sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppHistoryServerPlugin
new file mode 100644
index 0000000000000..0bba2f88b92a5
--- /dev/null
+++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppHistoryServerPlugin
@@ -0,0 +1 @@
+org.apache.spark.sql.execution.ui.SQLHistoryServerPlugin
diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppStatusPlugin b/sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppStatusPlugin
deleted file mode 100644
index ac6d7f6962f85..0000000000000
--- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppStatusPlugin
+++ /dev/null
@@ -1 +0,0 @@
-org.apache.spark.sql.execution.ui.SQLAppStatusPlugin
diff --git a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css
index 594e747a8d3a5..b13850c301490 100644
--- a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css
+++ b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css
@@ -32,7 +32,7 @@
stroke-width: 1px;
}
-/* Hightlight the SparkPlan node name */
+/* Highlight the SparkPlan node name */
#plan-viz-graph svg text :first-child {
font-weight: bold;
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 17966eecfc051..e8d683a578f35 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -33,7 +33,8 @@ import org.apache.spark.sql.execution.datasources.csv._
import org.apache.spark.sql.execution.datasources.jdbc._
import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
-import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options, ReadSupport, ReadSupportWithSchema}
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils
+import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String
@@ -182,11 +183,15 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
"read files of Hive data source directly.")
}
- val cls = DataSource.lookupDataSource(source)
+ val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf)
if (classOf[DataSourceV2].isAssignableFrom(cls)) {
- val options = new DataSourceV2Options(extraOptions.asJava)
+ val ds = cls.newInstance()
+ val options = new DataSourceV2Options((extraOptions ++
+ DataSourceV2Utils.extractSessionConfigs(
+ ds = ds.asInstanceOf[DataSourceV2],
+ conf = sparkSession.sessionState.conf)).asJava)
- val reader = (cls.newInstance(), userSpecifiedSchema) match {
+ val reader = (ds, userSpecifiedSchema) match {
case (ds: ReadSupportWithSchema, Some(schema)) =>
ds.createReader(schema, options)
@@ -512,17 +517,20 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
*
* You can set the following CSV-specific options to deal with CSV files:
*
- * `sep` (default `,`): sets the single character as a separator for each
+ * `sep` (default `,`): sets a single character as a separator for each
* field and value.
* `encoding` (default `UTF-8`): decodes the CSV files by the given encoding
* type.
- * `quote` (default `"`): sets the single character used for escaping quoted values where
+ * `quote` (default `"`): sets a single character used for escaping quoted values where
* the separator can be part of the value. If you would like to turn off quotations, you need to
* set not `null` but an empty string. This behaviour is different from
* `com.databricks.spark.csv`.
- * `escape` (default `\`): sets the single character used for escaping quotes inside
+ * `escape` (default `\`): sets a single character used for escaping quotes inside
* an already quoted value.
- * `comment` (default empty string): sets the single character used for skipping lines
+ * `charToEscapeQuoteEscaping` (default `escape` or `\0`): sets a single character used for
+ * escaping the escape for the quote character. The default value is escape character when escape
+ * and quote characters are different, `\0` otherwise.
+ * `comment` (default empty string): sets a single character used for skipping lines
* beginning with this character. By default, it is disabled.
* `header` (default `false`): uses the first line as names of columns.
* `inferSchema` (default `false`): infers the input schema automatically from data. It
@@ -646,7 +654,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* Loads text files and returns a `DataFrame` whose schema starts with a string column named
* "value", and followed by partitioned columns if there are any.
*
- * Each line in the text files is a new row in the resulting DataFrame. For example:
+ * You can set the following text-specific option(s) for reading text files:
+ *
+ * `wholetext` ( default `false`): If true, read a file as a single row and not split by "\n".
+ *
+ *
+ * By default, each line in the text files is a new row in the resulting DataFrame.
+ *
+ * Usage example:
* {{{
* // Scala:
* spark.read.text("/path/to/spark/README.md")
@@ -678,7 +693,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* If the directory structure of the text files contains partitioning information, those are
* ignored in the resulting Dataset. To include partitioning information as columns, use `text`.
*
- * Each line in the text files is a new element in the resulting Dataset. For example:
+ * You can set the following textFile-specific option(s) for reading text files:
+ *
+ * `wholetext` ( default `false`): If true, read a file as a single row and not split by "\n".
+ *
+ *
+ * By default, each line in the text files is a new row in the resulting DataFrame. For example:
* {{{
* // Scala:
* spark.read.textFile("/path/to/spark/README.md")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 35abeccfd514a..3304f368e1050 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -26,13 +26,14 @@ import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation}
import org.apache.spark.sql.catalyst.catalog._
-import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{AnalysisBarrier, InsertIntoTable, LogicalPlan}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation}
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils
import org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2
import org.apache.spark.sql.sources.BaseRelation
-import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options, WriteSupport}
+import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.types.StructType
/**
@@ -234,16 +235,20 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
assertNotBucketed("save")
- val cls = DataSource.lookupDataSource(source)
+ val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
if (classOf[DataSourceV2].isAssignableFrom(cls)) {
- cls.newInstance() match {
- case ds: WriteSupport =>
- val options = new DataSourceV2Options(extraOptions.asJava)
+ val ds = cls.newInstance()
+ ds match {
+ case ws: WriteSupport =>
+ val options = new DataSourceV2Options((extraOptions ++
+ DataSourceV2Utils.extractSessionConfigs(
+ ds = ds.asInstanceOf[DataSourceV2],
+ conf = df.sparkSession.sessionState.conf)).asJava)
// Using a timestamp and a random UUID to distinguish different writing jobs. This is good
// enough as there won't be tons of writing jobs created at the same second.
val jobId = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US)
.format(new Date()) + "-" + UUID.randomUUID()
- val writer = ds.createWriter(jobId, df.logicalPlan.schema, mode, options)
+ val writer = ws.createWriter(jobId, df.logicalPlan.schema, mode, options)
if (writer.isPresent) {
runCommand(df.sparkSession, "save") {
WriteToDataSourceV2(writer.get(), df.logicalPlan)
@@ -259,7 +264,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
sparkSession = df.sparkSession,
className = source,
partitionColumns = partitioningColumns.getOrElse(Nil),
- options = extraOptions.toMap).planForWriting(mode, df.logicalPlan)
+ options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan))
}
}
}
@@ -589,13 +594,16 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
*
* You can set the following CSV-specific option(s) for writing CSV files:
*
- * `sep` (default `,`): sets the single character as a separator for each
+ * `sep` (default `,`): sets a single character as a separator for each
* field and value.
- * `quote` (default `"`): sets the single character used for escaping quoted values where
+ * `quote` (default `"`): sets a single character used for escaping quoted values where
* the separator can be part of the value. If an empty string is set, it uses `u0000`
* (null character).
- * `escape` (default `\`): sets the single character used for escaping quotes inside
+ * `escape` (default `\`): sets a single character used for escaping quotes inside
* an already quoted value.
+ * `charToEscapeQuoteEscaping` (default `escape` or `\0`): sets a single character used for
+ * escaping the escape for the quote character. The default value is escape character when escape
+ * and quote characters are different, `\0` otherwise.
* `escapeQuotes` (default `true`): a flag indicating whether values containing
* quotes should always be enclosed in quotes. Default is to escape all values containing
* a quote character.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 1620ab3aa2094..34f0ab5aa6699 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -191,6 +191,9 @@ class Dataset[T] private[sql](
}
}
+ // Wraps analyzed logical plans with an analysis barrier so we won't traverse/resolve it again.
+ @transient private val planWithBarrier = AnalysisBarrier(logicalPlan)
+
/**
* Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the
* passed in encoder to [[ExpressionEncoder]] explicitly, and mark it implicit so that we can use
@@ -234,13 +237,20 @@ class Dataset[T] private[sql](
private[sql] def showString(
_numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = {
val numRows = _numRows.max(0).min(Int.MaxValue - 1)
- val takeResult = toDF().take(numRows + 1)
+ val newDf = toDF()
+ val castCols = newDf.logicalPlan.output.map { col =>
+ // Since binary types in top-level schema fields have a specific format to print,
+ // so we do not cast them to strings here.
+ if (col.dataType == BinaryType) {
+ Column(col)
+ } else {
+ Column(col).cast(StringType)
+ }
+ }
+ val takeResult = newDf.select(castCols: _*).take(numRows + 1)
val hasMoreData = takeResult.length > numRows
val data = takeResult.take(numRows)
- lazy val timeZone =
- DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone)
-
// For array values, replace Seq and Array with square brackets
// For cells that are beyond `truncate` characters, replace it with the
// first `truncate-3` and "..."
@@ -249,12 +259,6 @@ class Dataset[T] private[sql](
val str = cell match {
case null => "null"
case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]")
- case array: Array[_] => array.mkString("[", ", ", "]")
- case seq: Seq[_] => seq.mkString("[", ", ", "]")
- case d: Date =>
- DateTimeUtils.dateToString(DateTimeUtils.fromJavaDate(d))
- case ts: Timestamp =>
- DateTimeUtils.timestampToString(DateTimeUtils.fromJavaTimestamp(ts), timeZone)
case _ => cell.toString
}
if (truncate > 0 && str.length > truncate) {
@@ -398,12 +402,16 @@ class Dataset[T] private[sql](
* If the schema of the Dataset does not match the desired `U` type, you can use `select`
* along with `alias` or `as` to rearrange or rename as required.
*
+ * Note that `as[]` only changes the view of the data that is passed into typed operations,
+ * such as `map()`, and does not eagerly project away any columns that are not present in
+ * the specified class.
+ *
* @group basic
* @since 1.6.0
*/
@Experimental
@InterfaceStability.Evolving
- def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan)
+ def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, planWithBarrier)
/**
* Converts this strongly typed collection of data to generic `DataFrame` with columns renamed.
@@ -524,7 +532,7 @@ class Dataset[T] private[sql](
*/
@Experimental
@InterfaceStability.Evolving
- def checkpoint(): Dataset[T] = checkpoint(eager = true)
+ def checkpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = true)
/**
* Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the
@@ -537,9 +545,52 @@ class Dataset[T] private[sql](
*/
@Experimental
@InterfaceStability.Evolving
- def checkpoint(eager: Boolean): Dataset[T] = {
+ def checkpoint(eager: Boolean): Dataset[T] = checkpoint(eager = eager, reliableCheckpoint = true)
+
+ /**
+ * Eagerly locally checkpoints a Dataset and return the new Dataset. Checkpointing can be
+ * used to truncate the logical plan of this Dataset, which is especially useful in iterative
+ * algorithms where the plan may grow exponentially. Local checkpoints are written to executor
+ * storage and despite potentially faster they are unreliable and may compromise job completion.
+ *
+ * @group basic
+ * @since 2.3.0
+ */
+ @Experimental
+ @InterfaceStability.Evolving
+ def localCheckpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = false)
+
+ /**
+ * Locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used to truncate
+ * the logical plan of this Dataset, which is especially useful in iterative algorithms where the
+ * plan may grow exponentially. Local checkpoints are written to executor storage and despite
+ * potentially faster they are unreliable and may compromise job completion.
+ *
+ * @group basic
+ * @since 2.3.0
+ */
+ @Experimental
+ @InterfaceStability.Evolving
+ def localCheckpoint(eager: Boolean): Dataset[T] = checkpoint(
+ eager = eager,
+ reliableCheckpoint = false
+ )
+
+ /**
+ * Returns a checkpointed version of this Dataset.
+ *
+ * @param eager Whether to checkpoint this dataframe immediately
+ * @param reliableCheckpoint Whether to create a reliable checkpoint saved to files inside the
+ * checkpoint directory. If false creates a local checkpoint using
+ * the caching subsystem
+ */
+ private def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = {
val internalRdd = queryExecution.toRdd.map(_.copy())
- internalRdd.checkpoint()
+ if (reliableCheckpoint) {
+ internalRdd.checkpoint()
+ } else {
+ internalRdd.localCheckpoint()
+ }
if (eager) {
internalRdd.count()
@@ -604,7 +655,7 @@ class Dataset[T] private[sql](
require(parsedDelay.milliseconds >= 0 && parsedDelay.months >= 0,
s"delay threshold ($delayThreshold) should not be negative.")
EliminateEventTimeWatermark(
- EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan))
+ EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, planWithBarrier))
}
/**
@@ -777,7 +828,7 @@ class Dataset[T] private[sql](
* @since 2.0.0
*/
def join(right: Dataset[_]): DataFrame = withPlan {
- Join(logicalPlan, right.logicalPlan, joinType = Inner, None)
+ Join(planWithBarrier, right.planWithBarrier, joinType = Inner, None)
}
/**
@@ -855,7 +906,7 @@ class Dataset[T] private[sql](
// Analyze the self join. The assumption is that the analyzer will disambiguate left vs right
// by creating a new instance for one of the branch.
val joined = sparkSession.sessionState.executePlan(
- Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None))
+ Join(planWithBarrier, right.planWithBarrier, joinType = JoinType(joinType), None))
.analyzed.asInstanceOf[Join]
withPlan {
@@ -916,7 +967,7 @@ class Dataset[T] private[sql](
// Trigger analysis so in the case of self-join, the analyzer will clone the plan.
// After the cloning, left and right side will have distinct expression ids.
val plan = withPlan(
- Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr)))
+ Join(planWithBarrier, right.planWithBarrier, JoinType(joinType), Some(joinExprs.expr)))
.queryExecution.analyzed.asInstanceOf[Join]
// If auto self join alias is disabled, return the plan.
@@ -925,8 +976,8 @@ class Dataset[T] private[sql](
}
// If left/right have no output set intersection, return the plan.
- val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed
- val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed
+ val lanalyzed = withPlan(this.planWithBarrier).queryExecution.analyzed
+ val ranalyzed = withPlan(right.planWithBarrier).queryExecution.analyzed
if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) {
return withPlan(plan)
}
@@ -958,7 +1009,7 @@ class Dataset[T] private[sql](
* @since 2.1.0
*/
def crossJoin(right: Dataset[_]): DataFrame = withPlan {
- Join(logicalPlan, right.logicalPlan, joinType = Cross, None)
+ Join(planWithBarrier, right.planWithBarrier, joinType = Cross, None)
}
/**
@@ -990,8 +1041,8 @@ class Dataset[T] private[sql](
// etc.
val joined = sparkSession.sessionState.executePlan(
Join(
- this.logicalPlan,
- other.logicalPlan,
+ this.planWithBarrier,
+ other.planWithBarrier,
JoinType(joinType),
Some(condition.expr))).analyzed.asInstanceOf[Join]
@@ -1212,7 +1263,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def as(alias: String): Dataset[T] = withTypedPlan {
- SubqueryAlias(alias, logicalPlan)
+ SubqueryAlias(alias, planWithBarrier)
}
/**
@@ -1250,7 +1301,7 @@ class Dataset[T] private[sql](
*/
@scala.annotation.varargs
def select(cols: Column*): DataFrame = withPlan {
- Project(cols.map(_.named), logicalPlan)
+ Project(cols.map(_.named), planWithBarrier)
}
/**
@@ -1305,8 +1356,8 @@ class Dataset[T] private[sql](
@InterfaceStability.Evolving
def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = {
implicit val encoder = c1.encoder
- val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil,
- logicalPlan)
+ val project = Project(c1.withInputType(exprEnc, planWithBarrier.output).named :: Nil,
+ planWithBarrier)
if (encoder.flat) {
new Dataset[U1](sparkSession, project, encoder)
@@ -1324,8 +1375,8 @@ class Dataset[T] private[sql](
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val encoders = columns.map(_.encoder)
val namedColumns =
- columns.map(_.withInputType(exprEnc, logicalPlan.output).named)
- val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan))
+ columns.map(_.withInputType(exprEnc, planWithBarrier.output).named)
+ val execution = new QueryExecution(sparkSession, Project(namedColumns, planWithBarrier))
new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders))
}
@@ -1401,7 +1452,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def filter(condition: Column): Dataset[T] = withTypedPlan {
- Filter(condition.expr, logicalPlan)
+ Filter(condition.expr, planWithBarrier)
}
/**
@@ -1578,7 +1629,7 @@ class Dataset[T] private[sql](
@Experimental
@InterfaceStability.Evolving
def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = {
- val inputPlan = logicalPlan
+ val inputPlan = planWithBarrier
val withGroupingKey = AppendColumns(func, inputPlan)
val executed = sparkSession.sessionState.executePlan(withGroupingKey)
@@ -1724,7 +1775,7 @@ class Dataset[T] private[sql](
* @since 2.0.0
*/
def limit(n: Int): Dataset[T] = withTypedPlan {
- Limit(Literal(n), logicalPlan)
+ Limit(Literal(n), planWithBarrier)
}
/**
@@ -1774,7 +1825,7 @@ class Dataset[T] private[sql](
def union(other: Dataset[T]): Dataset[T] = withSetOperator {
// This breaks caching, but it's usually ok because it addresses a very specific use case:
// using union to union many files or partitions.
- CombineUnions(Union(logicalPlan, other.logicalPlan))
+ CombineUnions(Union(logicalPlan, other.logicalPlan)).mapChildren(AnalysisBarrier)
}
/**
@@ -1833,7 +1884,7 @@ class Dataset[T] private[sql](
// This breaks caching, but it's usually ok because it addresses a very specific use case:
// using union to union many files or partitions.
- CombineUnions(Union(logicalPlan, rightChild))
+ CombineUnions(Union(logicalPlan, rightChild)).mapChildren(AnalysisBarrier)
}
/**
@@ -1847,7 +1898,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def intersect(other: Dataset[T]): Dataset[T] = withSetOperator {
- Intersect(logicalPlan, other.logicalPlan)
+ Intersect(planWithBarrier, other.planWithBarrier)
}
/**
@@ -1861,7 +1912,7 @@ class Dataset[T] private[sql](
* @since 2.0.0
*/
def except(other: Dataset[T]): Dataset[T] = withSetOperator {
- Except(logicalPlan, other.logicalPlan)
+ Except(planWithBarrier, other.planWithBarrier)
}
/**
@@ -1912,7 +1963,7 @@ class Dataset[T] private[sql](
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = {
withTypedPlan {
- Sample(0.0, fraction, withReplacement, seed, logicalPlan)
+ Sample(0.0, fraction, withReplacement, seed, planWithBarrier)
}
}
@@ -1954,15 +2005,15 @@ class Dataset[T] private[sql](
// overlapping splits. To prevent this, we explicitly sort each input partition to make the
// ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out
// from the sort order.
- val sortOrder = logicalPlan.output
+ val sortOrder = planWithBarrier.output
.filter(attr => RowOrdering.isOrderable(attr.dataType))
.map(SortOrder(_, Ascending))
val plan = if (sortOrder.nonEmpty) {
- Sort(sortOrder, global = false, logicalPlan)
+ Sort(sortOrder, global = false, planWithBarrier)
} else {
// SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism
cache()
- logicalPlan
+ planWithBarrier
}
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
@@ -2045,8 +2096,8 @@ class Dataset[T] private[sql](
val generator = UserDefinedGenerator(elementSchema, rowFunction, input.map(_.expr))
withPlan {
- Generate(generator, join = true, outer = false,
- qualifier = None, generatorOutput = Nil, logicalPlan)
+ Generate(generator, unrequiredChildIndex = Nil, outer = false,
+ qualifier = None, generatorOutput = Nil, planWithBarrier)
}
}
@@ -2086,8 +2137,8 @@ class Dataset[T] private[sql](
val generator = UserDefinedGenerator(elementSchema, rowFunction, apply(inputColumn).expr :: Nil)
withPlan {
- Generate(generator, join = true, outer = false,
- qualifier = None, generatorOutput = Nil, logicalPlan)
+ Generate(generator, unrequiredChildIndex = Nil, outer = false,
+ qualifier = None, generatorOutput = Nil, planWithBarrier)
}
}
@@ -2235,7 +2286,7 @@ class Dataset[T] private[sql](
u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u)
case Column(expr: Expression) => expr
}
- val attrs = this.logicalPlan.output
+ val attrs = this.planWithBarrier.output
val colsAfterDrop = attrs.filter { attr =>
attr != expression
}.map(attr => Column(attr))
@@ -2283,7 +2334,7 @@ class Dataset[T] private[sql](
}
cols
}
- Deduplicate(groupCols, logicalPlan)
+ Deduplicate(groupCols, planWithBarrier)
}
/**
@@ -2465,7 +2516,7 @@ class Dataset[T] private[sql](
@Experimental
@InterfaceStability.Evolving
def filter(func: T => Boolean): Dataset[T] = {
- withTypedPlan(TypedFilter(func, logicalPlan))
+ withTypedPlan(TypedFilter(func, planWithBarrier))
}
/**
@@ -2479,7 +2530,7 @@ class Dataset[T] private[sql](
@Experimental
@InterfaceStability.Evolving
def filter(func: FilterFunction[T]): Dataset[T] = {
- withTypedPlan(TypedFilter(func, logicalPlan))
+ withTypedPlan(TypedFilter(func, planWithBarrier))
}
/**
@@ -2493,7 +2544,7 @@ class Dataset[T] private[sql](
@Experimental
@InterfaceStability.Evolving
def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan {
- MapElements[T, U](func, logicalPlan)
+ MapElements[T, U](func, planWithBarrier)
}
/**
@@ -2508,7 +2559,7 @@ class Dataset[T] private[sql](
@InterfaceStability.Evolving
def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
implicit val uEnc = encoder
- withTypedPlan(MapElements[T, U](func, logicalPlan))
+ withTypedPlan(MapElements[T, U](func, planWithBarrier))
}
/**
@@ -2524,7 +2575,7 @@ class Dataset[T] private[sql](
def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
new Dataset[U](
sparkSession,
- MapPartitions[T, U](func, logicalPlan),
+ MapPartitions[T, U](func, planWithBarrier),
implicitly[Encoder[U]])
}
@@ -2555,7 +2606,7 @@ class Dataset[T] private[sql](
val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]]
Dataset.ofRows(
sparkSession,
- MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, logicalPlan))
+ MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, planWithBarrier))
}
/**
@@ -2719,7 +2770,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def repartition(numPartitions: Int): Dataset[T] = withTypedPlan {
- Repartition(numPartitions, shuffle = true, logicalPlan)
+ Repartition(numPartitions, shuffle = true, planWithBarrier)
}
/**
@@ -2732,8 +2783,18 @@ class Dataset[T] private[sql](
* @since 2.0.0
*/
@scala.annotation.varargs
- def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan {
- RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions)
+ def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = {
+ // The underlying `LogicalPlan` operator special-cases all-`SortOrder` arguments.
+ // However, we don't want to complicate the semantics of this API method.
+ // Instead, let's give users a friendly error message, pointing them to the new method.
+ val sortOrders = partitionExprs.filter(_.expr.isInstanceOf[SortOrder])
+ if (sortOrders.nonEmpty) throw new IllegalArgumentException(
+ s"""Invalid partitionExprs specified: $sortOrders
+ |For range partitioning use repartitionByRange(...) instead.
+ """.stripMargin)
+ withTypedPlan {
+ RepartitionByExpression(partitionExprs.map(_.expr), planWithBarrier, numPartitions)
+ }
}
/**
@@ -2747,9 +2808,46 @@ class Dataset[T] private[sql](
* @since 2.0.0
*/
@scala.annotation.varargs
- def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan {
- RepartitionByExpression(
- partitionExprs.map(_.expr), logicalPlan, sparkSession.sessionState.conf.numShufflePartitions)
+ def repartition(partitionExprs: Column*): Dataset[T] = {
+ repartition(sparkSession.sessionState.conf.numShufflePartitions, partitionExprs: _*)
+ }
+
+ /**
+ * Returns a new Dataset partitioned by the given partitioning expressions into
+ * `numPartitions`. The resulting Dataset is range partitioned.
+ *
+ * At least one partition-by expression must be specified.
+ * When no explicit sort order is specified, "ascending nulls first" is assumed.
+ *
+ * @group typedrel
+ * @since 2.3.0
+ */
+ @scala.annotation.varargs
+ def repartitionByRange(numPartitions: Int, partitionExprs: Column*): Dataset[T] = {
+ require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.")
+ val sortOrder: Seq[SortOrder] = partitionExprs.map(_.expr match {
+ case expr: SortOrder => expr
+ case expr: Expression => SortOrder(expr, Ascending)
+ })
+ withTypedPlan {
+ RepartitionByExpression(sortOrder, planWithBarrier, numPartitions)
+ }
+ }
+
+ /**
+ * Returns a new Dataset partitioned by the given partitioning expressions, using
+ * `spark.sql.shuffle.partitions` as number of partitions.
+ * The resulting Dataset is range partitioned.
+ *
+ * At least one partition-by expression must be specified.
+ * When no explicit sort order is specified, "ascending nulls first" is assumed.
+ *
+ * @group typedrel
+ * @since 2.3.0
+ */
+ @scala.annotation.varargs
+ def repartitionByRange(partitionExprs: Column*): Dataset[T] = {
+ repartitionByRange(sparkSession.sessionState.conf.numShufflePartitions, partitionExprs: _*)
}
/**
@@ -2770,7 +2868,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan {
- Repartition(numPartitions, shuffle = false, logicalPlan)
+ Repartition(numPartitions, shuffle = false, planWithBarrier)
}
/**
@@ -2853,7 +2951,7 @@ class Dataset[T] private[sql](
// Represents the `QueryExecution` used to produce the content of the Dataset as an `RDD`.
@transient private lazy val rddQueryExecution: QueryExecution = {
- val deserialized = CatalystSerde.deserialize[T](logicalPlan)
+ val deserialized = CatalystSerde.deserialize[T](planWithBarrier)
sparkSession.sessionState.executePlan(deserialized)
}
@@ -2979,7 +3077,7 @@ class Dataset[T] private[sql](
comment = None,
properties = Map.empty,
originalText = None,
- child = logicalPlan,
+ child = planWithBarrier,
allowExisting = false,
replace = replace,
viewType = viewType)
@@ -3179,7 +3277,7 @@ class Dataset[T] private[sql](
}
}
withTypedPlan {
- Sort(sortOrder, global = global, logicalPlan)
+ Sort(sortOrder, global = global, planWithBarrier)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 69054fa8a9648..11bfaa0a726a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -742,7 +742,10 @@ class SparkSession private(
private[sql] def applySchemaToPythonRDD(
rdd: RDD[Array[Any]],
schema: StructType): DataFrame = {
- val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow])
+ val rowRdd = rdd.mapPartitions { iter =>
+ val fromJava = python.EvaluatePython.makeFromJava(schema)
+ iter.map(r => fromJava(r).asInstanceOf[InternalRow])
+ }
internalCreateDataFrame(rowRdd, schema)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index 3ff476147b8b7..f94baef39dfad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -17,12 +17,13 @@
package org.apache.spark.sql
-import java.lang.reflect.{ParameterizedType, Type}
+import java.lang.reflect.ParameterizedType
import scala.reflect.runtime.universe.TypeTag
import scala.util.Try
import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.internal.Logging
import org.apache.spark.sql.api.java._
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
@@ -41,8 +42,6 @@ import org.apache.spark.util.Utils
* spark.udf
* }}}
*
- * @note The user-defined functions must be deterministic.
- *
* @since 1.3.0
*/
@InterfaceStability.Stable
@@ -58,6 +57,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
| pythonIncludes: ${udf.func.pythonIncludes}
| pythonExec: ${udf.func.pythonExec}
| dataType: ${udf.dataType}
+ | pythonEvalType: ${PythonEvalType.toString(udf.pythonEvalType)}
+ | udfDeterministic: ${udf.udfDeterministic}
""".stripMargin)
functionRegistry.createOrReplaceTempFunction(name, udf.builder)
@@ -109,29 +110,29 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
/* register 0-22 were generated by this script
- (0 to 22).map { x =>
+ (0 to 22).foreach { x =>
val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"})
- val typeTags = (1 to x).map(i => s"A${i}: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _)
+ val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _)
val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor[A$i].dataType :: $s"})
println(s"""
- /**
- * Registers a deterministic Scala closure of ${x} arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
- * @since 1.3.0
- */
- def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = {
- val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
- val inputTypes = Try($inputTypes).toOption
- def builder(e: Seq[Expression]) = if (e.length == $x) {
- ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true)
- } else {
- throw new AnalysisException("Invalid number of arguments for function " + name +
- ". Expected: $x; Found: " + e.length)
- }
- functionRegistry.createOrReplaceTempFunction(name, builder)
- val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name)
- if (nullable) udf else udf.asNonNullable()
- }""")
+ |/**
+ | * Registers a deterministic Scala closure of $x arguments as user-defined function (UDF).
+ | * @tparam RT return type of UDF.
+ | * @since 1.3.0
+ | */
+ |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = {
+ | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
+ | val inputTypes = Try($inputTypes).toOption
+ | def builder(e: Seq[Expression]) = if (e.length == $x) {
+ | ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true)
+ | } else {
+ | throw new AnalysisException("Invalid number of arguments for function " + name +
+ | ". Expected: $x; Found: " + e.length)
+ | }
+ | functionRegistry.createOrReplaceTempFunction(name, builder)
+ | val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name)
+ | if (nullable) udf else udf.asNonNullable()
+ |}""".stripMargin)
}
(0 to 22).foreach { i =>
@@ -143,7 +144,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
val funcCall = if (i == 0) "() => func" else "func"
println(s"""
|/**
- | * Register a user-defined function with ${i} arguments.
+ | * Register a deterministic Java UDF$i instance as user-defined function (UDF).
| * @since $version
| */
|def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = {
@@ -688,7 +689,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}
/**
- * Register a user-defined function with 0 arguments.
+ * Register a deterministic Java UDF0 instance as user-defined function (UDF).
* @since 2.3.0
*/
def register(name: String, f: UDF0[_], returnType: DataType): Unit = {
@@ -703,7 +704,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}
/**
- * Register a user-defined function with 1 arguments.
+ * Register a deterministic Java UDF1 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = {
@@ -718,7 +719,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}
/**
- * Register a user-defined function with 2 arguments.
+ * Register a deterministic Java UDF2 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = {
@@ -733,7 +734,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}
/**
- * Register a user-defined function with 3 arguments.
+ * Register a deterministic Java UDF3 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = {
@@ -748,7 +749,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}
/**
- * Register a user-defined function with 4 arguments.
+ * Register a deterministic Java UDF4 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = {
@@ -763,7 +764,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}
/**
- * Register a user-defined function with 5 arguments.
+ * Register a deterministic Java UDF5 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = {
@@ -778,7 +779,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}
/**
- * Register a user-defined function with 6 arguments.
+ * Register a deterministic Java UDF6 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = {
@@ -793,7 +794,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}
/**
- * Register a user-defined function with 7 arguments.
+ * Register a deterministic Java UDF7 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = {
@@ -808,7 +809,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}
/**
- * Register a user-defined function with 8 arguments.
+ * Register a deterministic Java UDF8 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
@@ -823,7 +824,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}
/**
- * Register a user-defined function with 9 arguments.
+ * Register a deterministic Java UDF9 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
@@ -838,7 +839,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}
/**
- * Register a user-defined function with 10 arguments.
+ * Register a deterministic Java UDF10 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
@@ -853,7 +854,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}
/**
- * Register a user-defined function with 11 arguments.
+ * Register a deterministic Java UDF11 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
@@ -868,7 +869,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}
/**
- * Register a user-defined function with 12 arguments.
+ * Register a deterministic Java UDF12 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
@@ -883,7 +884,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}
/**
- * Register a user-defined function with 13 arguments.
+ * Register a deterministic Java UDF13 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
@@ -898,7 +899,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}
/**
- * Register a user-defined function with 14 arguments.
+ * Register a deterministic Java UDF14 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
@@ -913,7 +914,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}
/**
- * Register a user-defined function with 15 arguments.
+ * Register a deterministic Java UDF15 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
@@ -928,7 +929,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}
/**
- * Register a user-defined function with 16 arguments.
+ * Register a deterministic Java UDF16 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
@@ -943,7 +944,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}
/**
- * Register a user-defined function with 17 arguments.
+ * Register a deterministic Java UDF17 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
@@ -958,7 +959,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}
/**
- * Register a user-defined function with 18 arguments.
+ * Register a deterministic Java UDF18 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
@@ -973,7 +974,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}
/**
- * Register a user-defined function with 19 arguments.
+ * Register a deterministic Java UDF19 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
@@ -988,7 +989,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}
/**
- * Register a user-defined function with 20 arguments.
+ * Register a deterministic Java UDF20 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
@@ -1003,7 +1004,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}
/**
- * Register a user-defined function with 21 arguments.
+ * Register a deterministic Java UDF21 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
@@ -1018,7 +1019,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}
/**
- * Register a user-defined function with 22 arguments.
+ * Register a deterministic Java UDF22 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index 9fd76865b5a95..c95212aaae00c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -26,7 +26,7 @@ import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.storage.StorageLevel
@@ -94,14 +94,13 @@ class CacheManager extends Logging {
logWarning("Asked to cache already cached data.")
} else {
val sparkSession = query.sparkSession
- cachedData.add(CachedData(
- planToCache,
- InMemoryRelation(
- sparkSession.sessionState.conf.useCompression,
- sparkSession.sessionState.conf.columnBatchSize,
- storageLevel,
- sparkSession.sessionState.executePlan(planToCache).executedPlan,
- tableName)))
+ val inMemoryRelation = InMemoryRelation(
+ sparkSession.sessionState.conf.useCompression,
+ sparkSession.sessionState.conf.columnBatchSize, storageLevel,
+ sparkSession.sessionState.executePlan(planToCache).executedPlan,
+ tableName,
+ planToCache.stats)
+ cachedData.add(CachedData(planToCache, inMemoryRelation))
}
}
@@ -148,7 +147,8 @@ class CacheManager extends Logging {
batchSize = cd.cachedRepresentation.batchSize,
storageLevel = cd.cachedRepresentation.storageLevel,
child = spark.sessionState.executePlan(cd.plan).executedPlan,
- tableName = cd.cachedRepresentation.tableName)
+ tableName = cd.cachedRepresentation.tableName,
+ statsOfPlanToCache = cd.plan.stats)
needToRecache += cd.copy(cachedRepresentation = newCache)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
index a9bfb634fbdea..5617046e1396e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
@@ -20,13 +20,13 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
/**
* Helper trait for abstracting scan functionality using
- * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]es.
+ * [[ColumnarBatch]]es.
*/
private[sql] trait ColumnarBatchScan extends CodegenSupport {
@@ -68,30 +68,26 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
*/
// TODO: return ColumnarBatch.Rows instead
override protected def doProduce(ctx: CodegenContext): String = {
- val input = ctx.freshName("input")
// PhysicalRDD always just has one input
- ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
+ val input = ctx.addMutableState("scala.collection.Iterator", "input",
+ v => s"$v = inputs[0];")
// metrics
val numOutputRows = metricTerm(ctx, "numOutputRows")
val scanTimeMetric = metricTerm(ctx, "scanTime")
- val scanTimeTotalNs = ctx.freshName("scanTime")
- ctx.addMutableState(ctx.JAVA_LONG, scanTimeTotalNs, s"$scanTimeTotalNs = 0;")
+ val scanTimeTotalNs = ctx.addMutableState(ctx.JAVA_LONG, "scanTime") // init as scanTime = 0
val columnarBatchClz = classOf[ColumnarBatch].getName
- val batch = ctx.freshName("batch")
- ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;")
+ val batch = ctx.addMutableState(columnarBatchClz, "batch")
- val idx = ctx.freshName("batchIdx")
- ctx.addMutableState(ctx.JAVA_INT, idx, s"$idx = 0;")
- val colVars = output.indices.map(i => ctx.freshName("colInstance" + i))
+ val idx = ctx.addMutableState(ctx.JAVA_INT, "batchIdx") // init as batchIdx = 0
val columnVectorClzs = vectorTypes.getOrElse(
- Seq.fill(colVars.size)(classOf[ColumnVector].getName))
- val columnAssigns = colVars.zip(columnVectorClzs).zipWithIndex.map {
- case ((name, columnVectorClz), i) =>
- ctx.addMutableState(columnVectorClz, name, s"$name = null;")
- s"$name = ($columnVectorClz) $batch.column($i);"
- }
+ Seq.fill(output.indices.size)(classOf[ColumnVector].getName))
+ val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map {
+ case (columnVectorClz, i) =>
+ val name = ctx.addMutableState(columnVectorClz, s"colInstance$i")
+ (name, s"$name = ($columnVectorClz) $batch.column($i);")
+ }.unzip
val nextBatch = ctx.freshName("nextBatch")
val nextBatchFuncName = ctx.addNewFunction(nextBatch,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index 1297126b9f47e..bd286abd7b027 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -69,7 +69,7 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport {
* Shorthand for calling redactString() without specifying redacting rules
*/
private def redact(text: String): String = {
- Utils.redact(SparkSession.getActiveSession.map(_.sparkContext.conf).orNull, text)
+ Utils.redact(sqlContext.sessionState.conf.stringRedationPattern, text)
}
}
@@ -110,8 +110,7 @@ case class RowDataSourceScanExec(
override protected def doProduce(ctx: CodegenContext): String = {
val numOutputRows = metricTerm(ctx, "numOutputRows")
// PhysicalRDD always just has one input
- val input = ctx.freshName("input")
- ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
+ val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];")
val exprRows = output.zipWithIndex.map{ case (a, i) =>
BoundReference(i, a.dataType, a.nullable)
}
@@ -353,8 +352,7 @@ case class FileSourceScanExec(
}
val numOutputRows = metricTerm(ctx, "numOutputRows")
// PhysicalRDD always just has one input
- val input = ctx.freshName("input")
- ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
+ val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];")
val row = ctx.freshName("row")
ctx.INPUT_ROW = row
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index e1562befe14f9..0c2c4a1a9100d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -47,8 +47,7 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In
* terminate().
*
* @param generator the generator expression
- * @param join when true, each output row is implicitly joined with the input tuple that produced
- * it.
+ * @param requiredChildOutput required attributes from child's output
* @param outer when true, each input row will be output at least once, even if the output of the
* given `generator` is empty.
* @param generatorOutput the qualified output attributes of the generator of this node, which
@@ -57,19 +56,13 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In
*/
case class GenerateExec(
generator: Generator,
- join: Boolean,
+ requiredChildOutput: Seq[Attribute],
outer: Boolean,
generatorOutput: Seq[Attribute],
child: SparkPlan)
extends UnaryExecNode with CodegenSupport {
- override def output: Seq[Attribute] = {
- if (join) {
- child.output ++ generatorOutput
- } else {
- generatorOutput
- }
- }
+ override def output: Seq[Attribute] = requiredChildOutput ++ generatorOutput
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
@@ -85,11 +78,19 @@ case class GenerateExec(
val numOutputRows = longMetric("numOutputRows")
child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
val generatorNullRow = new GenericInternalRow(generator.elementSchema.length)
- val rows = if (join) {
+ val rows = if (requiredChildOutput.nonEmpty) {
+
+ val pruneChildForResult: InternalRow => InternalRow =
+ if (child.outputSet == AttributeSet(requiredChildOutput)) {
+ identity
+ } else {
+ UnsafeProjection.create(requiredChildOutput, child.output)
+ }
+
val joinedRow = new JoinedRow
iter.flatMap { row =>
- // we should always set the left (child output)
- joinedRow.withLeft(row)
+ // we should always set the left (required child output)
+ joinedRow.withLeft(pruneChildForResult(row))
val outputRows = boundGenerator.eval(row)
if (outer && outputRows.isEmpty) {
joinedRow.withRight(generatorNullRow) :: Nil
@@ -136,7 +137,7 @@ case class GenerateExec(
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
// Add input rows to the values when we are joining
- val values = if (join) {
+ val values = if (requiredChildOutput.nonEmpty) {
input
} else {
Seq.empty
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index f404621399cea..8bfe3eff0c3b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}
-import org.apache.spark.sql.execution.joins.ReorderJoinPredicates
import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _}
import org.apache.spark.util.Utils
@@ -104,7 +103,6 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
protected def preparations: Seq[Rule[SparkPlan]] = Seq(
python.ExtractPythonUDFs,
PlanSubqueries(sparkSession),
- new ReorderJoinPredicates,
EnsureRequirements(sparkSession.sessionState.conf),
CollapseCodegenStages(sparkSession.sessionState.conf),
ReuseExchange(sparkSession.sessionState.conf),
@@ -196,13 +194,13 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
}
}
- def simpleString: String = {
+ def simpleString: String = withRedaction {
s"""== Physical Plan ==
|${stringOrError(executedPlan.treeString(verbose = false))}
""".stripMargin.trim
}
- override def toString: String = {
+ override def toString: String = withRedaction {
def output = Utils.truncatedString(
analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", ")
val analyzedPlan = Seq(
@@ -221,7 +219,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
""".stripMargin.trim
}
- def stringWithStats: String = {
+ def stringWithStats: String = withRedaction {
// trigger to compute stats for logical plans
optimizedPlan.stats
@@ -233,6 +231,13 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
""".stripMargin.trim
}
+ /**
+ * Redact the sensitive information in the given string.
+ */
+ private def withRedaction(message: String): String = {
+ Utils.redact(sparkSession.sessionState.conf.stringRedationPattern, message)
+ }
+
/** A special namespace for commands that can be used to debug query execution. */
// scalastyle:off
object debug {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
index c0e21343ae623..ef1bb1c2a4468 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
@@ -133,20 +133,18 @@ case class SortExec(
override def needStopCheck: Boolean = false
override protected def doProduce(ctx: CodegenContext): String = {
- val needToSort = ctx.freshName("needToSort")
- ctx.addMutableState(ctx.JAVA_BOOLEAN, needToSort, s"$needToSort = true;")
+ val needToSort = ctx.addMutableState(ctx.JAVA_BOOLEAN, "needToSort", v => s"$v = true;")
// Initialize the class member variables. This includes the instance of the Sorter and
// the iterator to return sorted rows.
val thisPlan = ctx.addReferenceObj("plan", this)
- sorterVariable = ctx.freshName("sorter")
- ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, sorterVariable,
- s"$sorterVariable = $thisPlan.createSorter();")
- val metrics = ctx.freshName("metrics")
- ctx.addMutableState(classOf[TaskMetrics].getName, metrics,
- s"$metrics = org.apache.spark.TaskContext.get().taskMetrics();")
- val sortedIterator = ctx.freshName("sortedIter")
- ctx.addMutableState("scala.collection.Iterator", sortedIterator, "")
+ // Inline mutable state since not many Sort operations in a task
+ sorterVariable = ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, "sorter",
+ v => s"$v = $thisPlan.createSorter();", forceInline = true)
+ val metrics = ctx.addMutableState(classOf[TaskMetrics].getName, "metrics",
+ v => s"$v = org.apache.spark.TaskContext.get().taskMetrics();", forceInline = true)
+ val sortedIterator = ctx.addMutableState("scala.collection.Iterator", "sortedIter",
+ forceInline = true)
val addToSorter = ctx.freshName("addToSorter")
val addToSorterFuncName = ctx.addNewFunction(addToSorter,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 657b265260135..398758a3331b4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -23,7 +23,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext
import org.codehaus.commons.compiler.CompileException
-import org.codehaus.janino.JaninoRuntimeException
+import org.codehaus.janino.InternalCompilerException
import org.apache.spark.{broadcast, SparkEnv}
import org.apache.spark.internal.Logging
@@ -94,7 +94,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/** Specifies how data is partitioned across different nodes in the cluster. */
def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH!
- /** Specifies any partition requirements on the input data for this operator. */
+ /**
+ * Specifies the data distribution requirements of all the children for this operator. By default
+ * it's [[UnspecifiedDistribution]] for each child, which means each child can have any
+ * distribution.
+ *
+ * If an operator overwrites this method, and specifies distribution requirements(excluding
+ * [[UnspecifiedDistribution]] and [[BroadcastDistribution]]) for more than one child, Spark
+ * guarantees that the outputs of these children will have same number of partitions, so that the
+ * operator can safely zip partitions of these children's result RDDs. Some operators can leverage
+ * this guarantee to satisfy some interesting requirement, e.g., non-broadcast joins can specify
+ * HashClusteredDistribution(a,b) for its left child, and specify HashClusteredDistribution(c,d)
+ * for its right child, then it's guaranteed that left and right child are co-partitioned by
+ * a,b/c,d, which means tuples of same value are in the partitions of same index, e.g.,
+ * (a=1,b=2) and (c=1,d=2) are both in the second partition of left and right child.
+ */
def requiredChildDistribution: Seq[Distribution] =
Seq.fill(children.size)(UnspecifiedDistribution)
@@ -337,8 +351,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
if (buf.isEmpty) {
numPartsToTry = partsScanned * limitScaleUpFactor
} else {
- // the left side of max is >=1 whenever partsScanned >= 2
- numPartsToTry = Math.max((1.5 * n * partsScanned / buf.size).toInt - partsScanned, 1)
+ val left = n - buf.size
+ // As left > 0, numPartsToTry is always >= 1
+ numPartsToTry = Math.ceil(1.5 * left * partsScanned / buf.size).toInt
numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor)
}
}
@@ -385,7 +400,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
try {
GeneratePredicate.generate(expression, inputSchema)
} catch {
- case _ @ (_: JaninoRuntimeException | _: CompileException) if codeGenFallBack =>
+ case _ @ (_: InternalCompilerException | _: CompileException) if codeGenFallBack =>
genInterpretedPredicate(expression, inputSchema)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index 29b584b55972c..d3cfd2a1ffbf2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -383,16 +383,19 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
* {{{
* CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name
* USING table_provider
- * [OPTIONS table_property_list]
- * [PARTITIONED BY (col_name, col_name, ...)]
- * [CLUSTERED BY (col_name, col_name, ...)
- * [SORTED BY (col_name [ASC|DESC], ...)]
- * INTO num_buckets BUCKETS
- * ]
- * [LOCATION path]
- * [COMMENT table_comment]
- * [TBLPROPERTIES (property_name=property_value, ...)]
+ * create_table_clauses
* [[AS] select_statement];
+ *
+ * create_table_clauses (order insensitive):
+ * [OPTIONS table_property_list]
+ * [PARTITIONED BY (col_name, col_name, ...)]
+ * [CLUSTERED BY (col_name, col_name, ...)
+ * [SORTED BY (col_name [ASC|DESC], ...)]
+ * INTO num_buckets BUCKETS
+ * ]
+ * [LOCATION path]
+ * [COMMENT table_comment]
+ * [TBLPROPERTIES (property_name=property_value, ...)]
* }}}
*/
override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) {
@@ -400,6 +403,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
if (external) {
operationNotAllowed("CREATE EXTERNAL TABLE ... USING", ctx)
}
+
+ checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx)
+ checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx)
+ checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx)
+ checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx)
+ checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx)
+ checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx)
+
val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty)
val provider = ctx.tableProvider.qualifiedName.getText
val schema = Option(ctx.colTypeList()).map(createSchema)
@@ -408,9 +419,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
.map(visitIdentifierList(_).toArray)
.getOrElse(Array.empty[String])
val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty)
- val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec)
+ val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec)
- val location = Option(ctx.locationSpec).map(visitLocationSpec)
+ val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec)
val storage = DataSource.buildStorageFormatFromOptions(options)
if (location.isDefined && storage.locationUri.isDefined) {
@@ -1087,13 +1098,16 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
* {{{
* CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name
* [(col1[:] data_type [COMMENT col_comment], ...)]
- * [COMMENT table_comment]
- * [PARTITIONED BY (col2[:] data_type [COMMENT col_comment], ...)]
- * [ROW FORMAT row_format]
- * [STORED AS file_format]
- * [LOCATION path]
- * [TBLPROPERTIES (property_name=property_value, ...)]
+ * create_table_clauses
* [AS select_statement];
+ *
+ * create_table_clauses (order insensitive):
+ * [COMMENT table_comment]
+ * [PARTITIONED BY (col2[:] data_type [COMMENT col_comment], ...)]
+ * [ROW FORMAT row_format]
+ * [STORED AS file_format]
+ * [LOCATION path]
+ * [TBLPROPERTIES (property_name=property_value, ...)]
* }}}
*/
override def visitCreateHiveTable(ctx: CreateHiveTableContext): LogicalPlan = withOrigin(ctx) {
@@ -1104,15 +1118,23 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
"CREATE TEMPORARY TABLE is not supported yet. " +
"Please use CREATE TEMPORARY VIEW as an alternative.", ctx)
}
- if (ctx.skewSpec != null) {
+ if (ctx.skewSpec.size > 0) {
operationNotAllowed("CREATE TABLE ... SKEWED BY", ctx)
}
+ checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx)
+ checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx)
+ checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx)
+ checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx)
+ checkDuplicateClauses(ctx.createFileFormat, "STORED AS/BY", ctx)
+ checkDuplicateClauses(ctx.rowFormat, "ROW FORMAT", ctx)
+ checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx)
+
val dataCols = Option(ctx.columns).map(visitColTypeList).getOrElse(Nil)
val partitionCols = Option(ctx.partitionColumns).map(visitColTypeList).getOrElse(Nil)
- val properties = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty)
+ val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty)
val selectQuery = Option(ctx.query).map(plan)
- val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec)
+ val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec)
// Note: Hive requires partition columns to be distinct from the schema, so we need
// to include the partition columns here explicitly
@@ -1120,12 +1142,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
// Storage format
val defaultStorage = HiveSerDe.getDefaultStorage(conf)
- validateRowFormatFileFormat(ctx.rowFormat, ctx.createFileFormat, ctx)
- val fileStorage = Option(ctx.createFileFormat).map(visitCreateFileFormat)
+ validateRowFormatFileFormat(ctx.rowFormat.asScala, ctx.createFileFormat.asScala, ctx)
+ val fileStorage = ctx.createFileFormat.asScala.headOption.map(visitCreateFileFormat)
.getOrElse(CatalogStorageFormat.empty)
- val rowStorage = Option(ctx.rowFormat).map(visitRowFormat)
+ val rowStorage = ctx.rowFormat.asScala.headOption.map(visitRowFormat)
.getOrElse(CatalogStorageFormat.empty)
- val location = Option(ctx.locationSpec).map(visitLocationSpec)
+ val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec)
// If we are creating an EXTERNAL table, then the LOCATION field is required
if (external && location.isEmpty) {
operationNotAllowed("CREATE EXTERNAL TABLE must be accompanied by LOCATION", ctx)
@@ -1180,7 +1202,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
ctx)
}
- val hasStorageProperties = (ctx.createFileFormat != null) || (ctx.rowFormat != null)
+ val hasStorageProperties = (ctx.createFileFormat.size != 0) || (ctx.rowFormat.size != 0)
if (conf.convertCTAS && !hasStorageProperties) {
// At here, both rowStorage.serdeProperties and fileStorage.serdeProperties
// are empty Maps.
@@ -1366,6 +1388,15 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
}
}
+ private def validateRowFormatFileFormat(
+ rowFormatCtx: Seq[RowFormatContext],
+ createFileFormatCtx: Seq[CreateFileFormatContext],
+ parentCtx: ParserRuleContext): Unit = {
+ if (rowFormatCtx.size == 1 && createFileFormatCtx.size == 1) {
+ validateRowFormatFileFormat(rowFormatCtx.head, createFileFormatCtx.head, parentCtx)
+ }
+ }
+
/**
* Create or replace a view. This creates a [[CreateViewCommand]] command.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 19b858faba6ea..910294853c318 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -29,10 +29,12 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
-import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
+import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.StreamingQuery
+import org.apache.spark.sql.types.StructType
/**
* Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting
@@ -91,12 +93,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* predicates can be evaluated by matching join keys. If found, Join implementations are chosen
* with the following precedence:
*
- * - Broadcast: if one side of the join has an estimated physical size that is smaller than the
- * user-configurable [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold
- * or if that side has an explicit broadcast hint (e.g. the user applied the
- * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side
- * of the join will be broadcasted and the other side will be streamed, with no shuffling
- * performed. If both sides of the join are eligible to be broadcasted then the
+ * - Broadcast: We prefer to broadcast the join side with an explicit broadcast hint(e.g. the
+ * user applied the [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame).
+ * If both sides have the broadcast hint, we prefer to broadcast the side with a smaller
+ * estimated physical size. If neither one of the sides has the broadcast hint,
+ * we only broadcast the join side if its estimated physical size that is smaller than
+ * the user-configurable [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold.
* - Shuffle hash join: if the average size of a single partition is small enough to build a hash
* table.
* - Sort merge: if the matching join keys are sortable.
@@ -112,9 +114,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* Matches a plan whose output should be small enough to be used in broadcast join.
*/
private def canBroadcast(plan: LogicalPlan): Boolean = {
- plan.stats.hints.broadcast ||
- (plan.stats.sizeInBytes >= 0 &&
- plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold)
+ plan.stats.sizeInBytes >= 0 && plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold
}
/**
@@ -149,19 +149,74 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case _ => false
}
+ private def broadcastSide(
+ canBuildLeft: Boolean,
+ canBuildRight: Boolean,
+ left: LogicalPlan,
+ right: LogicalPlan): BuildSide = {
+
+ def smallerSide =
+ if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft
+
+ if (canBuildRight && canBuildLeft) {
+ // Broadcast smaller side base on its estimated physical size
+ // if both sides have broadcast hint
+ smallerSide
+ } else if (canBuildRight) {
+ BuildRight
+ } else if (canBuildLeft) {
+ BuildLeft
+ } else {
+ // for the last default broadcast nested loop join
+ smallerSide
+ }
+ }
+
+ private def canBroadcastByHints(joinType: JoinType, left: LogicalPlan, right: LogicalPlan)
+ : Boolean = {
+ val buildLeft = canBuildLeft(joinType) && left.stats.hints.broadcast
+ val buildRight = canBuildRight(joinType) && right.stats.hints.broadcast
+ buildLeft || buildRight
+ }
+
+ private def broadcastSideByHints(joinType: JoinType, left: LogicalPlan, right: LogicalPlan)
+ : BuildSide = {
+ val buildLeft = canBuildLeft(joinType) && left.stats.hints.broadcast
+ val buildRight = canBuildRight(joinType) && right.stats.hints.broadcast
+ broadcastSide(buildLeft, buildRight, left, right)
+ }
+
+ private def canBroadcastBySizes(joinType: JoinType, left: LogicalPlan, right: LogicalPlan)
+ : Boolean = {
+ val buildLeft = canBuildLeft(joinType) && canBroadcast(left)
+ val buildRight = canBuildRight(joinType) && canBroadcast(right)
+ buildLeft || buildRight
+ }
+
+ private def broadcastSideBySizes(joinType: JoinType, left: LogicalPlan, right: LogicalPlan)
+ : BuildSide = {
+ val buildLeft = canBuildLeft(joinType) && canBroadcast(left)
+ val buildRight = canBuildRight(joinType) && canBroadcast(right)
+ broadcastSide(buildLeft, buildRight, left, right)
+ }
+
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// --- BroadcastHashJoin --------------------------------------------------------------------
+ // broadcast hints were specified
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
- if canBuildRight(joinType) && canBroadcast(right) =>
+ if canBroadcastByHints(joinType, left, right) =>
+ val buildSide = broadcastSideByHints(joinType, left, right)
Seq(joins.BroadcastHashJoinExec(
- leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right)))
+ leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right)))
+ // broadcast hints were not specified, so need to infer it from size and configuration.
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
- if canBuildLeft(joinType) && canBroadcast(left) =>
+ if canBroadcastBySizes(joinType, left, right) =>
+ val buildSide = broadcastSideBySizes(joinType, left, right)
Seq(joins.BroadcastHashJoinExec(
- leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right)))
+ leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right)))
// --- ShuffledHashJoin ---------------------------------------------------------------------
@@ -190,25 +245,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// Pick BroadcastNestedLoopJoin if one side could be broadcasted
case j @ logical.Join(left, right, joinType, condition)
- if canBuildRight(joinType) && canBroadcast(right) =>
+ if canBroadcastByHints(joinType, left, right) =>
+ val buildSide = broadcastSideByHints(joinType, left, right)
joins.BroadcastNestedLoopJoinExec(
- planLater(left), planLater(right), BuildRight, joinType, condition) :: Nil
+ planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
+
case j @ logical.Join(left, right, joinType, condition)
- if canBuildLeft(joinType) && canBroadcast(left) =>
+ if canBroadcastBySizes(joinType, left, right) =>
+ val buildSide = broadcastSideBySizes(joinType, left, right)
joins.BroadcastNestedLoopJoinExec(
- planLater(left), planLater(right), BuildLeft, joinType, condition) :: Nil
+ planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
// Pick CartesianProduct for InnerJoin
case logical.Join(left, right, _: InnerLike, condition) =>
joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil
case logical.Join(left, right, joinType, condition) =>
- val buildSide =
- if (right.stats.sizeInBytes <= left.stats.sizeInBytes) {
- BuildRight
- } else {
- BuildLeft
- }
+ val buildSide = broadcastSide(
+ left.stats.hints.broadcast, right.stats.hints.broadcast, left, right)
// This join could be very slow or OOM
joins.BroadcastNestedLoopJoinExec(
planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
@@ -339,6 +393,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
StreamingRelationExec(s.sourceName, s.output) :: Nil
case s: StreamingExecutionRelation =>
StreamingRelationExec(s.toString, s.output) :: Nil
+ case s: StreamingRelationV2 =>
+ StreamingRelationExec(s.sourceName, s.output) :: Nil
case _ => Nil
}
}
@@ -364,11 +420,15 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// Can we automate these 'pass through' operations?
object BasicOperators extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil
case r: RunnableCommand => ExecutedCommandExec(r) :: Nil
case MemoryPlan(sink, output) =>
val encoder = RowEncoder(sink.schema)
LocalTableScanExec(output, sink.allData.map(r => encoder.toRow(r).copy())) :: Nil
+ case MemoryPlanV2(sink, output) =>
+ val encoder = RowEncoder(StructType.fromAttributes(output))
+ LocalTableScanExec(output, sink.allData.map(r => encoder.toRow(r).copy())) :: Nil
case logical.Distinct(child) =>
throw new IllegalStateException(
@@ -439,17 +499,16 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.GlobalLimitExec(limit, planLater(child)) :: Nil
case logical.Union(unionChildren) =>
execution.UnionExec(unionChildren.map(planLater)) :: Nil
- case g @ logical.Generate(generator, join, outer, _, _, child) =>
+ case g @ logical.Generate(generator, _, outer, _, _, child) =>
execution.GenerateExec(
- generator, join = join, outer = outer, g.qualifiedGeneratorOutput,
- planLater(child)) :: Nil
+ generator, g.requiredChildOutput, outer,
+ g.qualifiedGeneratorOutput, planLater(child)) :: Nil
case _: logical.OneRowRelation =>
execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil
case r: logical.Range =>
execution.RangeExec(r) :: Nil
- case logical.RepartitionByExpression(expressions, child, numPartitions) =>
- exchange.ShuffleExchangeExec(HashPartitioning(
- expressions, numPartitions), planLater(child)) :: Nil
+ case r: logical.RepartitionByExpression =>
+ exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child)) :: Nil
case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil
case r: LogicalRDD =>
RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 7166b7771e4db..065954559e487 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -282,9 +282,10 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp
}
override def doProduce(ctx: CodegenContext): String = {
- val input = ctx.freshName("input")
// Right now, InputAdapter is only used when there is one input RDD.
- ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
+ // Inline mutable state since an InputAdapter is used once in a task for WholeStageCodegen
+ val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];",
+ forceInline = true)
val row = ctx.freshName("row")
s"""
| while ($input.hasNext() && !stopEarly()) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index dc8aecf185a96..ce3c68810f3b6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.vectorized.MutableColumnarRow
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
import org.apache.spark.unsafe.KVIterator
import org.apache.spark.util.Utils
@@ -177,8 +178,7 @@ case class HashAggregateExec(
private var bufVars: Seq[ExprCode] = _
private def doProduceWithoutKeys(ctx: CodegenContext): String = {
- val initAgg = ctx.freshName("initAgg")
- ctx.addMutableState(ctx.JAVA_BOOLEAN, initAgg, s"$initAgg = false;")
+ val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg")
// The generated function doesn't have input row in the code context.
ctx.INPUT_ROW = null
@@ -186,10 +186,8 @@ case class HashAggregateExec(
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
val initExpr = functions.flatMap(f => f.initialValues)
bufVars = initExpr.map { e =>
- val isNull = ctx.freshName("bufIsNull")
- val value = ctx.freshName("bufValue")
- ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull)
- ctx.addMutableState(ctx.javaType(e.dataType), value)
+ val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull")
+ val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue")
// The initial expression should not access any column
val ev = e.genCode(ctx)
val initVars = s"""
@@ -443,6 +441,7 @@ case class HashAggregateExec(
val funcName = ctx.freshName("doAggregateWithKeysOutput")
val keyTerm = ctx.freshName("keyTerm")
val bufferTerm = ctx.freshName("bufferTerm")
+ val numOutput = metricTerm(ctx, "numOutputRows")
val body =
if (modes.contains(Final) || modes.contains(Complete)) {
@@ -519,6 +518,7 @@ case class HashAggregateExec(
s"""
private void $funcName(UnsafeRow $keyTerm, UnsafeRow $bufferTerm)
throws java.io.IOException {
+ $numOutput.add(1);
$body
}
""")
@@ -548,7 +548,7 @@ case class HashAggregateExec(
isSupported && isNotByteArrayDecimalType
}
- private def enableTwoLevelHashMap(ctx: CodegenContext) = {
+ private def enableTwoLevelHashMap(ctx: CodegenContext): Unit = {
if (!checkIfFastHashMapSupported(ctx)) {
if (modes.forall(mode => mode == Partial || mode == PartialMerge) && !Utils.isTesting) {
logInfo("spark.sql.codegen.aggregate.map.twolevel.enabled is set to true, but"
@@ -559,107 +559,96 @@ case class HashAggregateExec(
// This is for testing/benchmarking only.
// We enforce to first level to be a vectorized hashmap, instead of the default row-based one.
- sqlContext.getConf("spark.sql.codegen.aggregate.map.vectorized.enable", null) match {
- case "true" => isVectorizedHashMapEnabled = true
- case null | "" | "false" => None }
+ isVectorizedHashMapEnabled = sqlContext.getConf(
+ "spark.sql.codegen.aggregate.map.vectorized.enable", "false") == "true"
}
}
private def doProduceWithKeys(ctx: CodegenContext): String = {
- val initAgg = ctx.freshName("initAgg")
- ctx.addMutableState(ctx.JAVA_BOOLEAN, initAgg, s"$initAgg = false;")
+ val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg")
if (sqlContext.conf.enableTwoLevelAggMap) {
enableTwoLevelHashMap(ctx)
} else {
sqlContext.getConf("spark.sql.codegen.aggregate.map.vectorized.enable", null) match {
- case "true" => logWarning("Two level hashmap is disabled but vectorized hashmap is " +
- "enabled.")
- case null | "" | "false" => None
+ case "true" =>
+ logWarning("Two level hashmap is disabled but vectorized hashmap is enabled.")
+ case _ =>
}
}
- fastHashMapTerm = ctx.freshName("fastHashMap")
- val fastHashMapClassName = ctx.freshName("FastHashMap")
- val fastHashMapGenerator =
- if (isVectorizedHashMapEnabled) {
- new VectorizedHashMapGenerator(ctx, aggregateExpressions,
- fastHashMapClassName, groupingKeySchema, bufferSchema)
- } else {
- new RowBasedHashMapGenerator(ctx, aggregateExpressions,
- fastHashMapClassName, groupingKeySchema, bufferSchema)
- }
val thisPlan = ctx.addReferenceObj("plan", this)
- // Create a name for iterator from vectorized HashMap
- val iterTermForFastHashMap = ctx.freshName("fastHashMapIter")
- if (isFastHashMapEnabled) {
+ // Create a name for the iterator from the fast hash map.
+ val iterTermForFastHashMap = if (isFastHashMapEnabled) {
+ // Generates the fast hash map class and creates the fash hash map term.
+ val fastHashMapClassName = ctx.freshName("FastHashMap")
if (isVectorizedHashMapEnabled) {
- ctx.addMutableState(fastHashMapClassName, fastHashMapTerm,
- s"$fastHashMapTerm = new $fastHashMapClassName();")
- ctx.addMutableState(
- "java.util.Iterator",
- iterTermForFastHashMap)
+ val generatedMap = new VectorizedHashMapGenerator(ctx, aggregateExpressions,
+ fastHashMapClassName, groupingKeySchema, bufferSchema).generate()
+ ctx.addInnerClass(generatedMap)
+
+ // Inline mutable state since not many aggregation operations in a task
+ fastHashMapTerm = ctx.addMutableState(fastHashMapClassName, "vectorizedHastHashMap",
+ v => s"$v = new $fastHashMapClassName();", forceInline = true)
+ ctx.addMutableState(s"java.util.Iterator", "vectorizedFastHashMapIter",
+ forceInline = true)
} else {
- ctx.addMutableState(fastHashMapClassName, fastHashMapTerm,
- s"$fastHashMapTerm = new $fastHashMapClassName(" +
- s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());")
+ val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions,
+ fastHashMapClassName, groupingKeySchema, bufferSchema).generate()
+ ctx.addInnerClass(generatedMap)
+
+ // Inline mutable state since not many aggregation operations in a task
+ fastHashMapTerm = ctx.addMutableState(fastHashMapClassName, "fastHashMap",
+ v => s"$v = new $fastHashMapClassName(" +
+ s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());",
+ forceInline = true)
ctx.addMutableState(
- "org.apache.spark.unsafe.KVIterator",
- iterTermForFastHashMap)
+ "org.apache.spark.unsafe.KVIterator",
+ "fastHashMapIter", forceInline = true)
}
}
+ // Create a name for the iterator from the regular hash map.
+ // Inline mutable state since not many aggregation operations in a task
+ val iterTerm = ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName,
+ "mapIter", forceInline = true)
// create hashMap
- hashMapTerm = ctx.freshName("hashMap")
val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
- ctx.addMutableState(hashMapClassName, hashMapTerm)
- sorterTerm = ctx.freshName("sorter")
- ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm)
-
- // Create a name for iterator from HashMap
- val iterTerm = ctx.freshName("mapIter")
- ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm)
-
- def generateGenerateCode(): String = {
- if (isFastHashMapEnabled) {
- if (isVectorizedHashMapEnabled) {
- s"""
- | ${fastHashMapGenerator.asInstanceOf[VectorizedHashMapGenerator].generate()}
- """.stripMargin
- } else {
- s"""
- | ${fastHashMapGenerator.asInstanceOf[RowBasedHashMapGenerator].generate()}
- """.stripMargin
- }
- } else ""
- }
- ctx.addInnerClass(generateGenerateCode())
+ hashMapTerm = ctx.addMutableState(hashMapClassName, "hashMap",
+ v => s"$v = $thisPlan.createHashMap();", forceInline = true)
+ sorterTerm = ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, "sorter",
+ forceInline = true)
val doAgg = ctx.freshName("doAggregateWithKeys")
val peakMemory = metricTerm(ctx, "peakMemory")
val spillSize = metricTerm(ctx, "spillSize")
val avgHashProbe = metricTerm(ctx, "avgHashProbe")
- val doAggFuncName = ctx.addNewFunction(doAgg,
- s"""
- private void $doAgg() throws java.io.IOException {
- $hashMapTerm = $thisPlan.createHashMap();
- ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
- ${if (isFastHashMapEnabled) {
- s"$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();"} else ""}
+ val finishRegularHashMap = s"$iterTerm = $thisPlan.finishAggregate(" +
+ s"$hashMapTerm, $sorterTerm, $peakMemory, $spillSize, $avgHashProbe);"
+ val finishHashMap = if (isFastHashMapEnabled) {
+ s"""
+ |$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();
+ |$finishRegularHashMap
+ """.stripMargin
+ } else {
+ finishRegularHashMap
+ }
- $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm, $peakMemory, $spillSize,
- $avgHashProbe);
- }
- """)
+ val doAggFuncName = ctx.addNewFunction(doAgg,
+ s"""
+ |private void $doAgg() throws java.io.IOException {
+ | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
+ | $finishHashMap
+ |}
+ """.stripMargin)
// generate code for output
val keyTerm = ctx.freshName("aggKey")
val bufferTerm = ctx.freshName("aggBuffer")
val outputFunc = generateResultFunction(ctx)
- val numOutput = metricTerm(ctx, "numOutputRows")
- def outputFromGeneratedMap: String = {
+ def outputFromFastHashMap: String = {
if (isFastHashMapEnabled) {
if (isVectorizedHashMapEnabled) {
outputFromVectorizedMap
@@ -671,48 +660,55 @@ case class HashAggregateExec(
def outputFromRowBasedMap: String = {
s"""
- while ($iterTermForFastHashMap.next()) {
- $numOutput.add(1);
- UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey();
- UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue();
- $outputFunc($keyTerm, $bufferTerm);
-
- if (shouldStop()) return;
- }
- $fastHashMapTerm.close();
- """
+ |while ($iterTermForFastHashMap.next()) {
+ | UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey();
+ | UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue();
+ | $outputFunc($keyTerm, $bufferTerm);
+ |
+ | if (shouldStop()) return;
+ |}
+ |$fastHashMapTerm.close();
+ """.stripMargin
}
- // Iterate over the aggregate rows and convert them from ColumnarRow to UnsafeRow
+ // Iterate over the aggregate rows and convert them from InternalRow to UnsafeRow
def outputFromVectorizedMap: String = {
- val row = ctx.freshName("fastHashMapRow")
- ctx.currentVars = null
- ctx.INPUT_ROW = row
- val generateKeyRow = GenerateUnsafeProjection.createCode(ctx,
- groupingKeySchema.toAttributes.zipWithIndex
+ val row = ctx.freshName("fastHashMapRow")
+ ctx.currentVars = null
+ ctx.INPUT_ROW = row
+ val generateKeyRow = GenerateUnsafeProjection.createCode(ctx,
+ groupingKeySchema.toAttributes.zipWithIndex
.map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) }
- )
- val generateBufferRow = GenerateUnsafeProjection.createCode(ctx,
- bufferSchema.toAttributes.zipWithIndex
- .map { case (attr, i) =>
- BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable) })
- s"""
- | while ($iterTermForFastHashMap.hasNext()) {
- | $numOutput.add(1);
- | org.apache.spark.sql.execution.vectorized.ColumnarRow $row =
- | (org.apache.spark.sql.execution.vectorized.ColumnarRow)
- | $iterTermForFastHashMap.next();
- | ${generateKeyRow.code}
- | ${generateBufferRow.code}
- | $outputFunc(${generateKeyRow.value}, ${generateBufferRow.value});
- |
- | if (shouldStop()) return;
- | }
- |
- | $fastHashMapTerm.close();
- """.stripMargin
+ )
+ val generateBufferRow = GenerateUnsafeProjection.createCode(ctx,
+ bufferSchema.toAttributes.zipWithIndex.map { case (attr, i) =>
+ BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable)
+ })
+ s"""
+ |while ($iterTermForFastHashMap.hasNext()) {
+ | InternalRow $row = (InternalRow) $iterTermForFastHashMap.next();
+ | ${generateKeyRow.code}
+ | ${generateBufferRow.code}
+ | $outputFunc(${generateKeyRow.value}, ${generateBufferRow.value});
+ |
+ | if (shouldStop()) return;
+ |}
+ |
+ |$fastHashMapTerm.close();
+ """.stripMargin
}
+ def outputFromRegularHashMap: String = {
+ s"""
+ |while ($iterTerm.next()) {
+ | UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
+ | UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
+ | $outputFunc($keyTerm, $bufferTerm);
+ |
+ | if (shouldStop()) return;
+ |}
+ """.stripMargin
+ }
val aggTime = metricTerm(ctx, "aggTime")
val beforeAgg = ctx.freshName("beforeAgg")
@@ -725,16 +721,8 @@ case class HashAggregateExec(
}
// output the result
- ${outputFromGeneratedMap}
-
- while ($iterTerm.next()) {
- $numOutput.add(1);
- UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
- UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
- $outputFunc($keyTerm, $bufferTerm);
-
- if (shouldStop()) return;
- }
+ $outputFromFastHashMap
+ $outputFromRegularHashMap
$iterTerm.close();
if ($sorterTerm == null) {
@@ -744,13 +732,11 @@ case class HashAggregateExec(
}
private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
-
// create grouping key
- ctx.currentVars = input
val unsafeRowKeyCode = GenerateUnsafeProjection.createCode(
ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
val fastRowKeys = ctx.generateExpressions(
- groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
+ groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
val unsafeRowKeys = unsafeRowKeyCode.value
val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer")
val fastRowBuffer = ctx.freshName("fastAggBuffer")
@@ -767,102 +753,76 @@ case class HashAggregateExec(
// generate hash code for key
val hashExpr = Murmur3Hash(groupingExpressions, 42)
- ctx.currentVars = input
val hashEval = BindReferences.bindReference(hashExpr, child.output).genCode(ctx)
- val inputAttr = aggregateBufferAttributes ++ child.output
- ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input
-
val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter,
incCounter) = if (testFallbackStartsAt.isDefined) {
- val countTerm = ctx.freshName("fallbackCounter")
- ctx.addMutableState(ctx.JAVA_INT, countTerm, s"$countTerm = 0;")
+ val countTerm = ctx.addMutableState(ctx.JAVA_INT, "fallbackCounter")
(s"$countTerm < ${testFallbackStartsAt.get._1}",
s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;")
} else {
("true", "true", "", "")
}
- // We first generate code to probe and update the fast hash map. If the probe is
- // successful the corresponding fast row buffer will hold the mutable row
- val findOrInsertFastHashMap: Option[String] = {
+ val findOrInsertRegularHashMap: String =
+ s"""
+ |// generate grouping key
+ |${unsafeRowKeyCode.code.trim}
+ |${hashEval.code.trim}
+ |if ($checkFallbackForBytesToBytesMap) {
+ | // try to get the buffer from hash map
+ | $unsafeRowBuffer =
+ | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, ${hashEval.value});
+ |}
+ |// Can't allocate buffer from the hash map. Spill the map and fallback to sort-based
+ |// aggregation after processing all input rows.
+ |if ($unsafeRowBuffer == null) {
+ | if ($sorterTerm == null) {
+ | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
+ | } else {
+ | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
+ | }
+ | $resetCounter
+ | // the hash map had be spilled, it should have enough memory now,
+ | // try to allocate buffer again.
+ | $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow(
+ | $unsafeRowKeys, ${hashEval.value});
+ | if ($unsafeRowBuffer == null) {
+ | // failed to allocate the first page
+ | throw new OutOfMemoryError("No enough memory for aggregation");
+ | }
+ |}
+ """.stripMargin
+
+ val findOrInsertHashMap: String = {
if (isFastHashMapEnabled) {
- Option(
- s"""
- |
- |if ($checkFallbackForGeneratedHashMap) {
- | ${fastRowKeys.map(_.code).mkString("\n")}
- | if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) {
- | $fastRowBuffer = $fastHashMapTerm.findOrInsert(
- | ${fastRowKeys.map(_.value).mkString(", ")});
- | }
- |}
- """.stripMargin)
+ // If fast hash map is on, we first generate code to probe and update the fast hash map.
+ // If the probe is successful the corresponding fast row buffer will hold the mutable row.
+ s"""
+ |if ($checkFallbackForGeneratedHashMap) {
+ | ${fastRowKeys.map(_.code).mkString("\n")}
+ | if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) {
+ | $fastRowBuffer = $fastHashMapTerm.findOrInsert(
+ | ${fastRowKeys.map(_.value).mkString(", ")});
+ | }
+ |}
+ |// Cannot find the key in fast hash map, try regular hash map.
+ |if ($fastRowBuffer == null) {
+ | $findOrInsertRegularHashMap
+ |}
+ """.stripMargin
} else {
- None
+ findOrInsertRegularHashMap
}
}
+ val inputAttr = aggregateBufferAttributes ++ child.output
+ // Here we set `currentVars(0)` to `currentVars(numBufferSlots)` to null, so that when
+ // generating code for buffer columns, we use `INPUT_ROW`(will be the buffer row), while
+ // generating input columns, we use `currentVars`.
+ ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input
- def updateRowInFastHashMap(isVectorized: Boolean): Option[String] = {
- ctx.INPUT_ROW = fastRowBuffer
- val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
- val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
- val effectiveCodes = subExprs.codes.mkString("\n")
- val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
- boundUpdateExpr.map(_.genCode(ctx))
- }
- val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) =>
- val dt = updateExpr(i).dataType
- ctx.updateColumn(fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorized)
- }
- Option(
- s"""
- |// common sub-expressions
- |$effectiveCodes
- |// evaluate aggregate function
- |${evaluateVariables(fastRowEvals)}
- |// update fast row
- |${updateFastRow.mkString("\n").trim}
- |
- """.stripMargin)
- }
-
- // Next, we generate code to probe and update the unsafe row hash map.
- val findOrInsertInUnsafeRowMap: String = {
- s"""
- | if ($fastRowBuffer == null) {
- | // generate grouping key
- | ${unsafeRowKeyCode.code.trim}
- | ${hashEval.code.trim}
- | if ($checkFallbackForBytesToBytesMap) {
- | // try to get the buffer from hash map
- | $unsafeRowBuffer =
- | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, ${hashEval.value});
- | }
- | // Can't allocate buffer from the hash map. Spill the map and fallback to sort-based
- | // aggregation after processing all input rows.
- | if ($unsafeRowBuffer == null) {
- | if ($sorterTerm == null) {
- | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
- | } else {
- | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
- | }
- | $resetCounter
- | // the hash map had be spilled, it should have enough memory now,
- | // try to allocate buffer again.
- | $unsafeRowBuffer =
- | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, ${hashEval.value});
- | if ($unsafeRowBuffer == null) {
- | // failed to allocate the first page
- | throw new OutOfMemoryError("No enough memory for aggregation");
- | }
- | }
- | }
- """.stripMargin
- }
-
- val updateRowInUnsafeRowMap: String = {
+ val updateRowInRegularHashMap: String = {
ctx.INPUT_ROW = unsafeRowBuffer
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
@@ -881,45 +841,69 @@ case class HashAggregateExec(
|${evaluateVariables(unsafeRowBufferEvals)}
|// update unsafe row buffer
|${updateUnsafeRowBuffer.mkString("\n").trim}
- """.stripMargin
+ """.stripMargin
}
+ val updateRowInHashMap: String = {
+ if (isFastHashMapEnabled) {
+ ctx.INPUT_ROW = fastRowBuffer
+ val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
+ val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
+ val effectiveCodes = subExprs.codes.mkString("\n")
+ val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
+ boundUpdateExpr.map(_.genCode(ctx))
+ }
+ val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) =>
+ val dt = updateExpr(i).dataType
+ ctx.updateColumn(
+ fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorizedHashMapEnabled)
+ }
+
+ // If fast hash map is on, we first generate code to update row in fast hash map, if the
+ // previous loop up hit fast hash map. Otherwise, update row in regular hash map.
+ s"""
+ |if ($fastRowBuffer != null) {
+ | // common sub-expressions
+ | $effectiveCodes
+ | // evaluate aggregate function
+ | ${evaluateVariables(fastRowEvals)}
+ | // update fast row
+ | ${updateFastRow.mkString("\n").trim}
+ |} else {
+ | $updateRowInRegularHashMap
+ |}
+ """.stripMargin
+ } else {
+ updateRowInRegularHashMap
+ }
+ }
+
+ val declareRowBuffer: String = if (isFastHashMapEnabled) {
+ val fastRowType = if (isVectorizedHashMapEnabled) {
+ classOf[MutableColumnarRow].getName
+ } else {
+ "UnsafeRow"
+ }
+ s"""
+ |UnsafeRow $unsafeRowBuffer = null;
+ |$fastRowType $fastRowBuffer = null;
+ """.stripMargin
+ } else {
+ s"UnsafeRow $unsafeRowBuffer = null;"
+ }
// We try to do hash map based in-memory aggregation first. If there is not enough memory (the
// hash map will return null for new key), we spill the hash map to disk to free memory, then
// continue to do in-memory aggregation and spilling until all the rows had been processed.
// Finally, sort the spilled aggregate buffers by key, and merge them together for same key.
s"""
- UnsafeRow $unsafeRowBuffer = null;
- ${
- if (isVectorizedHashMapEnabled) {
- s"""
- | org.apache.spark.sql.execution.vectorized.ColumnarRow $fastRowBuffer = null;
- """.stripMargin
- } else {
- s"""
- | UnsafeRow $fastRowBuffer = null;
- """.stripMargin
- }
- }
-
- ${findOrInsertFastHashMap.getOrElse("")}
+ $declareRowBuffer
- $findOrInsertInUnsafeRowMap
+ $findOrInsertHashMap
$incCounter
- if ($fastRowBuffer != null) {
- // update fast row
- ${
- if (isFastHashMapEnabled) {
- updateRowInFastHashMap(isVectorizedHashMapEnabled).getOrElse("")
- } else ""
- }
- } else {
- // update unsafe row
- $updateRowInUnsafeRowMap
- }
+ $updateRowInHashMap
"""
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala
index 85b4529501ea8..1c613b19c4ab1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala
@@ -46,10 +46,8 @@ abstract class HashMapGenerator(
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
val initExpr = functions.flatMap(f => f.initialValues)
initExpr.map { e =>
- val isNull = ctx.freshName("bufIsNull")
- val value = ctx.freshName("bufValue")
- ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull)
- ctx.addMutableState(ctx.javaType(e.dataType), value)
+ val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull")
+ val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue")
val ev = e.genCode(ctx)
val initVars =
s"""
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
index 3718424931b40..fd25707dd4ca6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
@@ -47,7 +47,7 @@ class RowBasedHashMapGenerator(
val generatedKeySchema: String =
s"new org.apache.spark.sql.types.StructType()" +
groupingKeySchema.map { key =>
- val keyName = ctx.addReferenceMinorObj(key.name)
+ val keyName = ctx.addReferenceObj("keyName", key.name)
key.dataType match {
case d: DecimalType =>
s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType(
@@ -60,7 +60,7 @@ class RowBasedHashMapGenerator(
val generatedValueSchema: String =
s"new org.apache.spark.sql.types.StructType()" +
bufferSchema.map { key =>
- val keyName = ctx.addReferenceMinorObj(key.name)
+ val keyName = ctx.addReferenceObj("keyName", key.name)
key.dataType match {
case d: DecimalType =>
s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 756eeb642e2d0..9dc334c1ead3c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
+import org.apache.spark.memory.SparkOutOfMemoryError
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -205,7 +206,7 @@ class TungstenAggregationIterator(
buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
if (buffer == null) {
// failed to allocate the first page
- throw new OutOfMemoryError("No enough memory for aggregation")
+ throw new SparkOutOfMemoryError("No enough memory for aggregation")
}
}
processRow(buffer, newInput)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
index fd783d905b776..0cf9b53ce1d5d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
@@ -17,9 +17,12 @@
package org.apache.spark.sql.execution.aggregate
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
+import org.apache.spark.sql.execution.vectorized.{MutableColumnarRow, OnHeapColumnVector}
import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.ColumnarBatch
/**
* This is a helper class to generate an append-only vectorized hash map that can act as a 'cache'
@@ -52,7 +55,7 @@ class VectorizedHashMapGenerator(
val generatedSchema: String =
s"new org.apache.spark.sql.types.StructType()" +
(groupingKeySchema ++ bufferSchema).map { key =>
- val keyName = ctx.addReferenceMinorObj(key.name)
+ val keyName = ctx.addReferenceObj("keyName", key.name)
key.dataType match {
case d: DecimalType =>
s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType(
@@ -65,7 +68,7 @@ class VectorizedHashMapGenerator(
val generatedAggBufferSchema: String =
s"new org.apache.spark.sql.types.StructType()" +
bufferSchema.map { key =>
- val keyName = ctx.addReferenceMinorObj(key.name)
+ val keyName = ctx.addReferenceObj("keyName", key.name)
key.dataType match {
case d: DecimalType =>
s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType(
@@ -76,10 +79,9 @@ class VectorizedHashMapGenerator(
}.mkString("\n").concat(";")
s"""
- | private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[] batchVectors;
- | private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[] bufferVectors;
- | private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch;
- | private org.apache.spark.sql.execution.vectorized.ColumnarBatch aggregateBufferBatch;
+ | private ${classOf[OnHeapColumnVector].getName}[] vectors;
+ | private ${classOf[ColumnarBatch].getName} batch;
+ | private ${classOf[MutableColumnarRow].getName} aggBufferRow;
| private int[] buckets;
| private int capacity = 1 << 16;
| private double loadFactor = 0.5;
@@ -91,19 +93,16 @@ class VectorizedHashMapGenerator(
| $generatedAggBufferSchema
|
| public $generatedClassName() {
- | batchVectors = org.apache.spark.sql.execution.vectorized
- | .OnHeapColumnVector.allocateColumns(capacity, schema);
- | batch = new org.apache.spark.sql.execution.vectorized.ColumnarBatch(
- | schema, batchVectors, capacity);
+ | vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, schema);
+ | batch = new ${classOf[ColumnarBatch].getName}(schema, vectors, capacity);
|
- | bufferVectors = new org.apache.spark.sql.execution.vectorized
- | .OnHeapColumnVector[aggregateBufferSchema.fields().length];
+ | // Generates a projection to return the aggregate buffer only.
+ | ${classOf[OnHeapColumnVector].getName}[] aggBufferVectors =
+ | new ${classOf[OnHeapColumnVector].getName}[aggregateBufferSchema.fields().length];
| for (int i = 0; i < aggregateBufferSchema.fields().length; i++) {
- | bufferVectors[i] = batchVectors[i + ${groupingKeys.length}];
+ | aggBufferVectors[i] = vectors[i + ${groupingKeys.length}];
| }
- | // TODO: Possibly generate this projection in HashAggregate directly
- | aggregateBufferBatch = new org.apache.spark.sql.execution.vectorized.ColumnarBatch(
- | aggregateBufferSchema, bufferVectors, capacity);
+ | aggBufferRow = new ${classOf[MutableColumnarRow].getName}(aggBufferVectors);
|
| buckets = new int[numBuckets];
| java.util.Arrays.fill(buckets, -1);
@@ -114,13 +113,13 @@ class VectorizedHashMapGenerator(
/**
* Generates a method that returns true if the group-by keys exist at a given index in the
- * associated [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we
- * have 2 long group-by keys, the generated function would be of the form:
+ * associated [[org.apache.spark.sql.execution.vectorized.OnHeapColumnVector]]. For instance,
+ * if we have 2 long group-by keys, the generated function would be of the form:
*
* {{{
* private boolean equals(int idx, long agg_key, long agg_key1) {
- * return batchVectors[0].getLong(buckets[idx]) == agg_key &&
- * batchVectors[1].getLong(buckets[idx]) == agg_key1;
+ * return vectors[0].getLong(buckets[idx]) == agg_key &&
+ * vectors[1].getLong(buckets[idx]) == agg_key1;
* }
* }}}
*/
@@ -128,7 +127,7 @@ class VectorizedHashMapGenerator(
def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = {
groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
- s"""(${ctx.genEqual(key.dataType, ctx.getValue(s"batchVectors[$ordinal]", "buckets[idx]",
+ s"""(${ctx.genEqual(key.dataType, ctx.getValue(s"vectors[$ordinal]", "buckets[idx]",
key.dataType), key.name)})"""
}.mkString(" && ")
}
@@ -141,29 +140,35 @@ class VectorizedHashMapGenerator(
}
/**
- * Generates a method that returns a mutable
- * [[org.apache.spark.sql.execution.vectorized.ColumnarRow]] which keeps track of the
+ * Generates a method that returns a
+ * [[org.apache.spark.sql.execution.vectorized.MutableColumnarRow]] which keeps track of the
* aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the
* generated method adds the corresponding row in the associated
- * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we
+ * [[org.apache.spark.sql.execution.vectorized.OnHeapColumnVector]]. For instance, if we
* have 2 long group-by keys, the generated function would be of the form:
*
* {{{
- * public org.apache.spark.sql.execution.vectorized.ColumnarRow findOrInsert(
- * long agg_key, long agg_key1) {
+ * public MutableColumnarRow findOrInsert(long agg_key, long agg_key1) {
* long h = hash(agg_key, agg_key1);
* int step = 0;
* int idx = (int) h & (numBuckets - 1);
* while (step < maxSteps) {
* // Return bucket index if it's either an empty slot or already contains the key
* if (buckets[idx] == -1) {
- * batchVectors[0].putLong(numRows, agg_key);
- * batchVectors[1].putLong(numRows, agg_key1);
- * batchVectors[2].putLong(numRows, 0);
- * buckets[idx] = numRows++;
- * return batch.getRow(buckets[idx]);
+ * if (numRows < capacity) {
+ * vectors[0].putLong(numRows, agg_key);
+ * vectors[1].putLong(numRows, agg_key1);
+ * vectors[2].putLong(numRows, 0);
+ * buckets[idx] = numRows++;
+ * aggBufferRow.rowId = numRows;
+ * return aggBufferRow;
+ * } else {
+ * // No more space
+ * return null;
+ * }
* } else if (equals(idx, agg_key, agg_key1)) {
- * return batch.getRow(buckets[idx]);
+ * aggBufferRow.rowId = buckets[idx];
+ * return aggBufferRow;
* }
* idx = (idx + 1) & (numBuckets - 1);
* step++;
@@ -177,20 +182,19 @@ class VectorizedHashMapGenerator(
def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = {
groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
- ctx.setValue(s"batchVectors[$ordinal]", "numRows", key.dataType, key.name)
+ ctx.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name)
}
}
def genCodeToSetAggBuffers(bufferValues: Seq[Buffer]): Seq[String] = {
bufferValues.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
- ctx.updateColumn(s"batchVectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType,
+ ctx.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType,
buffVars(ordinal), nullable = true)
}
}
s"""
- |public org.apache.spark.sql.execution.vectorized.ColumnarRow findOrInsert(${
- groupingKeySignature}) {
+ |public ${classOf[MutableColumnarRow].getName} findOrInsert($groupingKeySignature) {
| long h = hash(${groupingKeys.map(_.name).mkString(", ")});
| int step = 0;
| int idx = (int) h & (numBuckets - 1);
@@ -208,15 +212,15 @@ class VectorizedHashMapGenerator(
| ${genCodeToSetAggBuffers(bufferValues).mkString("\n")}
|
| buckets[idx] = numRows++;
- | batch.setNumRows(numRows);
- | aggregateBufferBatch.setNumRows(numRows);
- | return aggregateBufferBatch.getRow(buckets[idx]);
+ | aggBufferRow.rowId = buckets[idx];
+ | return aggBufferRow;
| } else {
| // No more space
| return null;
| }
| } else if (equals(idx, ${groupingKeys.map(_.name).mkString(", ")})) {
- | return aggregateBufferBatch.getRow(buckets[idx]);
+ | aggBufferRow.rowId = buckets[idx];
+ | return aggBufferRow;
| }
| idx = (idx + 1) & (numBuckets - 1);
| step++;
@@ -229,8 +233,8 @@ class VectorizedHashMapGenerator(
protected def generateRowIterator(): String = {
s"""
- |public java.util.Iterator
- | rowIterator() {
+ |public java.util.Iterator<${classOf[InternalRow].getName}> rowIterator() {
+ | batch.setNumRows(numRows);
| return batch.rowIterator();
|}
""".stripMargin
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index 3cafb344ef553..bcd1aa0890ba3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -24,16 +24,16 @@ import scala.collection.JavaConverters._
import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector._
-import org.apache.arrow.vector.file._
-import org.apache.arrow.vector.schema.ArrowRecordBatch
+import org.apache.arrow.vector.ipc.{ArrowFileReader, ArrowFileWriter}
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel
import org.apache.spark.TaskContext
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
import org.apache.spark.util.Utils
@@ -86,13 +86,9 @@ private[sql] object ArrowConverters {
val root = VectorSchemaRoot.create(arrowSchema, allocator)
val arrowWriter = ArrowWriter.create(root)
- var closed = false
-
context.addTaskCompletionListener { _ =>
- if (!closed) {
- root.close()
- allocator.close()
- }
+ root.close()
+ allocator.close()
}
new Iterator[ArrowPayload] {
@@ -100,7 +96,6 @@ private[sql] object ArrowConverters {
override def hasNext: Boolean = rowIter.hasNext || {
root.close()
allocator.close()
- closed = true
false
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
index e4af4f65da127..22b63513548fe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -46,17 +46,19 @@ object ArrowWriter {
private def createFieldWriter(vector: ValueVector): ArrowFieldWriter = {
val field = vector.getField()
(ArrowUtils.fromArrowField(field), vector) match {
- case (BooleanType, vector: NullableBitVector) => new BooleanWriter(vector)
- case (ByteType, vector: NullableTinyIntVector) => new ByteWriter(vector)
- case (ShortType, vector: NullableSmallIntVector) => new ShortWriter(vector)
- case (IntegerType, vector: NullableIntVector) => new IntegerWriter(vector)
- case (LongType, vector: NullableBigIntVector) => new LongWriter(vector)
- case (FloatType, vector: NullableFloat4Vector) => new FloatWriter(vector)
- case (DoubleType, vector: NullableFloat8Vector) => new DoubleWriter(vector)
- case (StringType, vector: NullableVarCharVector) => new StringWriter(vector)
- case (BinaryType, vector: NullableVarBinaryVector) => new BinaryWriter(vector)
- case (DateType, vector: NullableDateDayVector) => new DateWriter(vector)
- case (TimestampType, vector: NullableTimeStampMicroTZVector) => new TimestampWriter(vector)
+ case (BooleanType, vector: BitVector) => new BooleanWriter(vector)
+ case (ByteType, vector: TinyIntVector) => new ByteWriter(vector)
+ case (ShortType, vector: SmallIntVector) => new ShortWriter(vector)
+ case (IntegerType, vector: IntVector) => new IntegerWriter(vector)
+ case (LongType, vector: BigIntVector) => new LongWriter(vector)
+ case (FloatType, vector: Float4Vector) => new FloatWriter(vector)
+ case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector)
+ case (DecimalType.Fixed(precision, scale), vector: DecimalVector) =>
+ new DecimalWriter(vector, precision, scale)
+ case (StringType, vector: VarCharVector) => new StringWriter(vector)
+ case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector)
+ case (DateType, vector: DateDayVector) => new DateWriter(vector)
+ case (TimestampType, vector: TimeStampMicroTZVector) => new TimestampWriter(vector)
case (ArrayType(_, _), vector: ListVector) =>
val elementVector = createFieldWriter(vector.getDataVector())
new ArrayWriter(vector, elementVector)
@@ -103,7 +105,6 @@ class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) {
private[arrow] abstract class ArrowFieldWriter {
def valueVector: ValueVector
- def valueMutator: ValueVector.Mutator
def name: String = valueVector.getField().getName()
def dataType: DataType = ArrowUtils.fromArrowField(valueVector.getField())
@@ -124,161 +125,163 @@ private[arrow] abstract class ArrowFieldWriter {
}
def finish(): Unit = {
- valueMutator.setValueCount(count)
+ valueVector.setValueCount(count)
}
def reset(): Unit = {
- valueMutator.reset()
+ // TODO: reset() should be in a common interface
+ valueVector match {
+ case fixedWidthVector: BaseFixedWidthVector => fixedWidthVector.reset()
+ case variableWidthVector: BaseVariableWidthVector => variableWidthVector.reset()
+ case _ =>
+ }
count = 0
}
}
-private[arrow] class BooleanWriter(val valueVector: NullableBitVector) extends ArrowFieldWriter {
-
- override def valueMutator: NullableBitVector#Mutator = valueVector.getMutator()
+private[arrow] class BooleanWriter(val valueVector: BitVector) extends ArrowFieldWriter {
override def setNull(): Unit = {
- valueMutator.setNull(count)
+ valueVector.setNull(count)
}
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
- valueMutator.setSafe(count, if (input.getBoolean(ordinal)) 1 else 0)
+ valueVector.setSafe(count, if (input.getBoolean(ordinal)) 1 else 0)
}
}
-private[arrow] class ByteWriter(val valueVector: NullableTinyIntVector) extends ArrowFieldWriter {
-
- override def valueMutator: NullableTinyIntVector#Mutator = valueVector.getMutator()
+private[arrow] class ByteWriter(val valueVector: TinyIntVector) extends ArrowFieldWriter {
override def setNull(): Unit = {
- valueMutator.setNull(count)
+ valueVector.setNull(count)
}
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
- valueMutator.setSafe(count, input.getByte(ordinal))
+ valueVector.setSafe(count, input.getByte(ordinal))
}
}
-private[arrow] class ShortWriter(val valueVector: NullableSmallIntVector) extends ArrowFieldWriter {
-
- override def valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator()
+private[arrow] class ShortWriter(val valueVector: SmallIntVector) extends ArrowFieldWriter {
override def setNull(): Unit = {
- valueMutator.setNull(count)
+ valueVector.setNull(count)
}
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
- valueMutator.setSafe(count, input.getShort(ordinal))
+ valueVector.setSafe(count, input.getShort(ordinal))
}
}
-private[arrow] class IntegerWriter(val valueVector: NullableIntVector) extends ArrowFieldWriter {
-
- override def valueMutator: NullableIntVector#Mutator = valueVector.getMutator()
+private[arrow] class IntegerWriter(val valueVector: IntVector) extends ArrowFieldWriter {
override def setNull(): Unit = {
- valueMutator.setNull(count)
+ valueVector.setNull(count)
}
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
- valueMutator.setSafe(count, input.getInt(ordinal))
+ valueVector.setSafe(count, input.getInt(ordinal))
}
}
-private[arrow] class LongWriter(val valueVector: NullableBigIntVector) extends ArrowFieldWriter {
-
- override def valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator()
+private[arrow] class LongWriter(val valueVector: BigIntVector) extends ArrowFieldWriter {
override def setNull(): Unit = {
- valueMutator.setNull(count)
+ valueVector.setNull(count)
}
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
- valueMutator.setSafe(count, input.getLong(ordinal))
+ valueVector.setSafe(count, input.getLong(ordinal))
}
}
-private[arrow] class FloatWriter(val valueVector: NullableFloat4Vector) extends ArrowFieldWriter {
-
- override def valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator()
+private[arrow] class FloatWriter(val valueVector: Float4Vector) extends ArrowFieldWriter {
override def setNull(): Unit = {
- valueMutator.setNull(count)
+ valueVector.setNull(count)
}
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
- valueMutator.setSafe(count, input.getFloat(ordinal))
+ valueVector.setSafe(count, input.getFloat(ordinal))
}
}
-private[arrow] class DoubleWriter(val valueVector: NullableFloat8Vector) extends ArrowFieldWriter {
-
- override def valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator()
+private[arrow] class DoubleWriter(val valueVector: Float8Vector) extends ArrowFieldWriter {
override def setNull(): Unit = {
- valueMutator.setNull(count)
+ valueVector.setNull(count)
}
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
- valueMutator.setSafe(count, input.getDouble(ordinal))
+ valueVector.setSafe(count, input.getDouble(ordinal))
}
}
-private[arrow] class StringWriter(val valueVector: NullableVarCharVector) extends ArrowFieldWriter {
+private[arrow] class DecimalWriter(
+ val valueVector: DecimalVector,
+ precision: Int,
+ scale: Int) extends ArrowFieldWriter {
- override def valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator()
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ val decimal = input.getDecimal(ordinal, precision, scale)
+ if (decimal.changePrecision(precision, scale)) {
+ valueVector.setSafe(count, decimal.toJavaBigDecimal)
+ } else {
+ setNull()
+ }
+ }
+}
+
+private[arrow] class StringWriter(val valueVector: VarCharVector) extends ArrowFieldWriter {
override def setNull(): Unit = {
- valueMutator.setNull(count)
+ valueVector.setNull(count)
}
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
val utf8 = input.getUTF8String(ordinal)
val utf8ByteBuffer = utf8.getByteBuffer
// todo: for off-heap UTF8String, how to pass in to arrow without copy?
- valueMutator.setSafe(count, utf8ByteBuffer, utf8ByteBuffer.position(), utf8.numBytes())
+ valueVector.setSafe(count, utf8ByteBuffer, utf8ByteBuffer.position(), utf8.numBytes())
}
}
private[arrow] class BinaryWriter(
- val valueVector: NullableVarBinaryVector) extends ArrowFieldWriter {
-
- override def valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator()
+ val valueVector: VarBinaryVector) extends ArrowFieldWriter {
override def setNull(): Unit = {
- valueMutator.setNull(count)
+ valueVector.setNull(count)
}
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
val bytes = input.getBinary(ordinal)
- valueMutator.setSafe(count, bytes, 0, bytes.length)
+ valueVector.setSafe(count, bytes, 0, bytes.length)
}
}
-private[arrow] class DateWriter(val valueVector: NullableDateDayVector) extends ArrowFieldWriter {
-
- override def valueMutator: NullableDateDayVector#Mutator = valueVector.getMutator()
+private[arrow] class DateWriter(val valueVector: DateDayVector) extends ArrowFieldWriter {
override def setNull(): Unit = {
- valueMutator.setNull(count)
+ valueVector.setNull(count)
}
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
- valueMutator.setSafe(count, input.getInt(ordinal))
+ valueVector.setSafe(count, input.getInt(ordinal))
}
}
private[arrow] class TimestampWriter(
- val valueVector: NullableTimeStampMicroTZVector) extends ArrowFieldWriter {
-
- override def valueMutator: NullableTimeStampMicroTZVector#Mutator = valueVector.getMutator()
+ val valueVector: TimeStampMicroTZVector) extends ArrowFieldWriter {
override def setNull(): Unit = {
- valueMutator.setNull(count)
+ valueVector.setNull(count)
}
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
- valueMutator.setSafe(count, input.getLong(ordinal))
+ valueVector.setSafe(count, input.getLong(ordinal))
}
}
@@ -286,20 +289,18 @@ private[arrow] class ArrayWriter(
val valueVector: ListVector,
val elementWriter: ArrowFieldWriter) extends ArrowFieldWriter {
- override def valueMutator: ListVector#Mutator = valueVector.getMutator()
-
override def setNull(): Unit = {
}
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
val array = input.getArray(ordinal)
var i = 0
- valueMutator.startNewValue(count)
+ valueVector.startNewValue(count)
while (i < array.numElements()) {
elementWriter.write(array, i)
i += 1
}
- valueMutator.endValue(count, array.numElements())
+ valueVector.endValue(count, array.numElements())
}
override def finish(): Unit = {
@@ -317,8 +318,6 @@ private[arrow] class StructWriter(
val valueVector: NullableMapVector,
children: Array[ArrowFieldWriter]) extends ArrowFieldWriter {
- override def valueMutator: NullableMapVector#Mutator = valueVector.getMutator()
-
override def setNull(): Unit = {
var i = 0
while (i < children.length) {
@@ -326,7 +325,7 @@ private[arrow] class StructWriter(
children(i).count += 1
i += 1
}
- valueMutator.setNull(count)
+ valueVector.setNull(count)
}
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
@@ -336,7 +335,7 @@ private[arrow] class StructWriter(
children(i).write(struct, i)
i += 1
}
- valueMutator.setIndexDefined(count)
+ valueVector.setIndexDefined(count)
}
override def finish(): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index c9a15147e30d0..a15a8d11aa2a0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -279,29 +279,30 @@ case class SampleExec(
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val numOutput = metricTerm(ctx, "numOutputRows")
- val sampler = ctx.freshName("sampler")
if (withReplacement) {
val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName
val initSampler = ctx.freshName("initSampler")
- val initSamplerFuncName = ctx.addNewFunction(initSampler,
- s"""
- | private void $initSampler() {
- | $sampler = new $samplerClass($upperBound - $lowerBound, false);
- | java.util.Random random = new java.util.Random(${seed}L);
- | long randomSeed = random.nextLong();
- | int loopCount = 0;
- | while (loopCount < partitionIndex) {
- | randomSeed = random.nextLong();
- | loopCount += 1;
- | }
- | $sampler.setSeed(randomSeed);
- | }
- """.stripMargin.trim)
-
- ctx.addMutableState(s"$samplerClass", sampler,
- s"$initSamplerFuncName();")
+ // Inline mutable state since not many Sample operations in a task
+ val sampler = ctx.addMutableState(s"$samplerClass", "sampleReplace",
+ v => {
+ val initSamplerFuncName = ctx.addNewFunction(initSampler,
+ s"""
+ | private void $initSampler() {
+ | $v = new $samplerClass($upperBound - $lowerBound, false);
+ | java.util.Random random = new java.util.Random(${seed}L);
+ | long randomSeed = random.nextLong();
+ | int loopCount = 0;
+ | while (loopCount < partitionIndex) {
+ | randomSeed = random.nextLong();
+ | loopCount += 1;
+ | }
+ | $v.setSeed(randomSeed);
+ | }
+ """.stripMargin.trim)
+ s"$initSamplerFuncName();"
+ }, forceInline = true)
val samplingCount = ctx.freshName("samplingCount")
s"""
@@ -313,10 +314,10 @@ case class SampleExec(
""".stripMargin.trim
} else {
val samplerClass = classOf[BernoulliCellSampler[UnsafeRow]].getName
- ctx.addMutableState(s"$samplerClass", sampler,
- s"""
- | $sampler = new $samplerClass($lowerBound, $upperBound, false);
- | $sampler.setSeed(${seed}L + partitionIndex);
+ val sampler = ctx.addMutableState(s"$samplerClass", "sampler",
+ v => s"""
+ | $v = new $samplerClass($lowerBound, $upperBound, false);
+ | $v.setSeed(${seed}L + partitionIndex);
""".stripMargin.trim)
s"""
@@ -363,20 +364,18 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
protected override def doProduce(ctx: CodegenContext): String = {
val numOutput = metricTerm(ctx, "numOutputRows")
- val initTerm = ctx.freshName("initRange")
- ctx.addMutableState(ctx.JAVA_BOOLEAN, initTerm, s"$initTerm = false;")
- val number = ctx.freshName("number")
- ctx.addMutableState(ctx.JAVA_LONG, number, s"$number = 0L;")
+ val initTerm = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initRange")
+ val number = ctx.addMutableState(ctx.JAVA_LONG, "number")
val value = ctx.freshName("value")
val ev = ExprCode("", "false", value)
val BigInt = classOf[java.math.BigInteger].getName
- val taskContext = ctx.freshName("taskContext")
- ctx.addMutableState("TaskContext", taskContext, s"$taskContext = TaskContext.get();")
- val inputMetrics = ctx.freshName("inputMetrics")
- ctx.addMutableState("InputMetrics", inputMetrics,
- s"$inputMetrics = $taskContext.taskMetrics().inputMetrics();")
+ // Inline mutable state since not many Range operations in a task
+ val taskContext = ctx.addMutableState("TaskContext", "taskContext",
+ v => s"$v = TaskContext.get();", forceInline = true)
+ val inputMetrics = ctx.addMutableState("InputMetrics", "inputMetrics",
+ v => s"$v = $taskContext.taskMetrics().inputMetrics();", forceInline = true)
// In order to periodically update the metrics without inflicting performance penalty, this
// operator produces elements in batches. After a batch is complete, the metrics are updated
@@ -386,12 +385,10 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
// the metrics.
// Once number == batchEnd, it's time to progress to the next batch.
- val batchEnd = ctx.freshName("batchEnd")
- ctx.addMutableState(ctx.JAVA_LONG, batchEnd, s"$batchEnd = 0;")
+ val batchEnd = ctx.addMutableState(ctx.JAVA_LONG, "batchEnd")
// How many values should still be generated by this range operator.
- val numElementsTodo = ctx.freshName("numElementsTodo")
- ctx.addMutableState(ctx.JAVA_LONG, numElementsTodo, s"$numElementsTodo = 0L;")
+ val numElementsTodo = ctx.addMutableState(ctx.JAVA_LONG, "numElementsTodo")
// How many values should be generated in the next batch.
val nextBatchTodo = ctx.freshName("nextBatchTodo")
@@ -440,10 +437,6 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
| }
""".stripMargin)
- val input = ctx.freshName("input")
- // Right now, Range is only used when there is one upstream.
- ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
-
val localIdx = ctx.freshName("localIdx")
val localEnd = ctx.freshName("localEnd")
val range = ctx.freshName("range")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
index ff5dd707f0b38..4f28eeb725cbb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
@@ -70,7 +70,6 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
val ctx = newCodeGenContext()
val numFields = columnTypes.size
val (initializeAccessors, extractors) = columnTypes.zipWithIndex.map { case (dt, index) =>
- val accessorName = ctx.freshName("accessor")
val accessorCls = dt match {
case NullType => classOf[NullColumnAccessor].getName
case BooleanType => classOf[BooleanColumnAccessor].getName
@@ -89,7 +88,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
case array: ArrayType => classOf[ArrayColumnAccessor].getName
case t: MapType => classOf[MapColumnAccessor].getName
}
- ctx.addMutableState(accessorCls, accessorName)
+ val accessorName = ctx.addMutableState(accessorCls, "accessor")
val createCode = dt match {
case t if ctx.isPrimitiveType(dt) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index a1c62a729900e..51928d914841e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -37,8 +37,10 @@ object InMemoryRelation {
batchSize: Int,
storageLevel: StorageLevel,
child: SparkPlan,
- tableName: Option[String]): InMemoryRelation =
- new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)()
+ tableName: Option[String],
+ statsOfPlanToCache: Statistics): InMemoryRelation =
+ new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)(
+ statsOfPlanToCache = statsOfPlanToCache)
}
@@ -60,7 +62,8 @@ case class InMemoryRelation(
@transient child: SparkPlan,
tableName: Option[String])(
@transient var _cachedColumnBuffers: RDD[CachedBatch] = null,
- val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator)
+ val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator,
+ statsOfPlanToCache: Statistics = null)
extends logical.LeafNode with MultiInstanceRelation {
override protected def innerChildren: Seq[SparkPlan] = Seq(child)
@@ -71,9 +74,8 @@ case class InMemoryRelation(
override def computeStats(): Statistics = {
if (batchStats.value == 0L) {
- // Underlying columnar RDD hasn't been materialized, no useful statistics information
- // available, return the default statistics.
- Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes)
+ // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache
+ statsOfPlanToCache
} else {
Statistics(sizeInBytes = batchStats.value.longValue)
}
@@ -142,7 +144,7 @@ case class InMemoryRelation(
def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
InMemoryRelation(
newOutput, useCompression, batchSize, storageLevel, child, tableName)(
- _cachedColumnBuffers, batchStats)
+ _cachedColumnBuffers, batchStats, statsOfPlanToCache)
}
override def newInstance(): this.type = {
@@ -154,11 +156,12 @@ case class InMemoryRelation(
child,
tableName)(
_cachedColumnBuffers,
- batchStats).asInstanceOf[this.type]
+ batchStats,
+ statsOfPlanToCache).asInstanceOf[this.type]
}
def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers
override protected def otherCopyArgs: Seq[AnyRef] =
- Seq(_cachedColumnBuffers, batchStats)
+ Seq(_cachedColumnBuffers, batchStats, statsOfPlanToCache)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index 3e73393b12850..933b9753faa61 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partition
import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec}
import org.apache.spark.sql.execution.vectorized._
import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
case class InMemoryTableScanExec(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala
index 2f09757aa341c..341ade1a5c613 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala
@@ -35,7 +35,7 @@ private[columnar] trait NullableColumnAccessor extends ColumnAccessor {
nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1
pos = 0
- underlyingBuffer.position(underlyingBuffer.position + 4 + nullCount * 4)
+ underlyingBuffer.position(underlyingBuffer.position() + 4 + nullCount * 4)
super.initialize()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala
index bf00ad997c76e..79dcf3a6105ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala
@@ -112,7 +112,7 @@ private[columnar] case object PassThrough extends CompressionScheme {
var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else capacity
var pos = 0
var seenNulls = 0
- var bufferPos = buffer.position
+ var bufferPos = buffer.position()
while (pos < capacity) {
if (pos != nextNullIndex) {
val len = nextNullIndex - pos
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
index e3bb4d357b395..1122522ccb4cb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
@@ -143,7 +143,12 @@ case class AnalyzeColumnCommand(
val percentilesRow = new QueryExecution(sparkSession, Aggregate(Nil, namedExprs, relation))
.executedPlan.executeTake(1).head
attrsToGenHistogram.zipWithIndex.foreach { case (attr, i) =>
- attributePercentiles += attr -> percentilesRow.getArray(i)
+ val percentiles = percentilesRow.getArray(i)
+ // When there is no non-null value, `percentiles` is null. In such case, there is no
+ // need to generate histogram.
+ if (percentiles != null) {
+ attributePercentiles += attr -> percentiles
+ }
}
}
AttributeMap(attributePercentiles.toSeq)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala
index 1a0d67fc71fbc..c27048626c8eb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala
@@ -116,8 +116,8 @@ object CommandUtils extends Logging {
oldStats: Option[CatalogStatistics],
newTotalSize: BigInt,
newRowCount: Option[BigInt]): Option[CatalogStatistics] = {
- val oldTotalSize = oldStats.map(_.sizeInBytes.toLong).getOrElse(-1L)
- val oldRowCount = oldStats.flatMap(_.rowCount.map(_.toLong)).getOrElse(-1L)
+ val oldTotalSize = oldStats.map(_.sizeInBytes).getOrElse(BigInt(-1))
+ val oldRowCount = oldStats.flatMap(_.rowCount).getOrElse(BigInt(-1))
var newStats: Option[CatalogStatistics] = None
if (newTotalSize >= 0 && newTotalSize != oldTotalSize) {
newStats = Some(CatalogStatistics(sizeInBytes = newTotalSize))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala
index 2cf06982e25f6..e56f8105fc9a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala
@@ -20,30 +20,32 @@ package org.apache.spark.sql.execution.command
import org.apache.hadoop.conf.Configuration
import org.apache.spark.SparkContext
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
+import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker
+import org.apache.spark.sql.execution.datasources.FileFormatWriter
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.util.SerializableConfiguration
-
/**
- * A special `RunnableCommand` which writes data out and updates metrics.
+ * A special `Command` which writes data out and updates metrics.
*/
-trait DataWritingCommand extends RunnableCommand {
-
+trait DataWritingCommand extends Command {
/**
* The input query plan that produces the data to be written.
+ * IMPORTANT: the input query plan MUST be analyzed, so that we can carry its output columns
+ * to [[FileFormatWriter]].
*/
def query: LogicalPlan
- // We make the input `query` an inner child instead of a child in order to hide it from the
- // optimizer. This is because optimizer may not preserve the output schema names' case, and we
- // have to keep the original analyzed plan here so that we can pass the corrected schema to the
- // writer. The schema of analyzed plan is what user expects(or specifies), so we should respect
- // it when writing.
- override protected def innerChildren: Seq[LogicalPlan] = query :: Nil
+ override final def children: Seq[LogicalPlan] = query :: Nil
- override lazy val metrics: Map[String, SQLMetric] = {
+ // Output columns of the analyzed input query plan
+ def outputColumns: Seq[Attribute]
+
+ lazy val metrics: Map[String, SQLMetric] = {
val sparkContext = SparkContext.getActive.get
Map(
"numFiles" -> SQLMetrics.createMetric(sparkContext, "number of written files"),
@@ -57,4 +59,6 @@ trait DataWritingCommand extends RunnableCommand {
val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
new BasicWriteJobStatsTracker(serializableHadoopConf, metrics)
}
+
+ def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala
index 9e3519073303c..1dc24b3d221cf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala
@@ -67,8 +67,7 @@ case class InsertIntoDataSourceDirCommand(
val saveMode = if (overwrite) SaveMode.Overwrite else SaveMode.ErrorIfExists
try {
- sparkSession.sessionState.executePlan(dataSource.planForWriting(saveMode, query))
- dataSource.writeAndRead(saveMode, query)
+ sparkSession.sessionState.executePlan(dataSource.planForWriting(saveMode, query)).toRdd
} catch {
case ex: AnalysisException =>
logError(s"Failed to write to directory " + storage.locationUri.toString, ex)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
index e28b5eb2e2a2b..2cc0e38adc2ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
-import org.apache.spark.sql.execution.LeafExecNode
+import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan}
import org.apache.spark.sql.execution.debug._
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata}
@@ -87,6 +87,42 @@ case class ExecutedCommandExec(cmd: RunnableCommand) extends LeafExecNode {
}
}
+/**
+ * A physical operator that executes the run method of a `DataWritingCommand` and
+ * saves the result to prevent multiple executions.
+ *
+ * @param cmd the `DataWritingCommand` this operator will run.
+ * @param child the physical plan child ran by the `DataWritingCommand`.
+ */
+case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan)
+ extends SparkPlan {
+
+ override lazy val metrics: Map[String, SQLMetric] = cmd.metrics
+
+ protected[sql] lazy val sideEffectResult: Seq[InternalRow] = {
+ val converter = CatalystTypeConverters.createToCatalystConverter(schema)
+ val rows = cmd.run(sqlContext.sparkSession, child)
+
+ rows.map(converter(_).asInstanceOf[InternalRow])
+ }
+
+ override def children: Seq[SparkPlan] = child :: Nil
+
+ override def output: Seq[Attribute] = cmd.output
+
+ override def nodeName: String = "Execute " + cmd.nodeName
+
+ override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray
+
+ override def executeToIterator: Iterator[InternalRow] = sideEffectResult.toIterator
+
+ override def executeTake(limit: Int): Array[InternalRow] = sideEffectResult.take(limit).toArray
+
+ protected override def doExecute(): RDD[InternalRow] = {
+ sqlContext.sparkContext.parallelize(sideEffectResult, 1)
+ }
+}
+
/**
* An explain command for users to see how a command will be executed.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index 568567aa8ea88..0142f17ce62e2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -203,14 +203,20 @@ case class DropTableCommand(
case _ =>
}
}
- try {
- sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
- } catch {
- case _: NoSuchTableException if ifExists =>
- case NonFatal(e) => log.warn(e.toString, e)
+
+ if (catalog.isTemporaryTable(tableName) || catalog.tableExists(tableName)) {
+ try {
+ sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
+ } catch {
+ case NonFatal(e) => log.warn(e.toString, e)
+ }
+ catalog.refreshTable(tableName)
+ catalog.dropTable(tableName, ifExists, purge)
+ } else if (ifExists) {
+ // no-op
+ } else {
+ throw new AnalysisException(s"Table or view not found: ${tableName.identifier}")
}
- catalog.refreshTable(tableName)
- catalog.dropTable(tableName, ifExists, purge)
Seq.empty[Row]
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
index c9f6e571ddab3..e400975f19708 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.util.Utils
@@ -190,7 +191,7 @@ case class AlterTableAddColumnsCommand(
colsToAdd: Seq[StructField]) extends RunnableCommand {
override def run(sparkSession: SparkSession): Seq[Row] = {
val catalog = sparkSession.sessionState.catalog
- val catalogTable = verifyAlterTableAddColumn(catalog, table)
+ val catalogTable = verifyAlterTableAddColumn(sparkSession.sessionState.conf, catalog, table)
try {
sparkSession.catalog.uncacheTable(table.quotedString)
@@ -216,6 +217,7 @@ case class AlterTableAddColumnsCommand(
* For datasource table, it currently only supports parquet, json, csv.
*/
private def verifyAlterTableAddColumn(
+ conf: SQLConf,
catalog: SessionCatalog,
table: TableIdentifier): CatalogTable = {
val catalogTable = catalog.getTempViewOrPermanentTableMetadata(table)
@@ -229,7 +231,7 @@ case class AlterTableAddColumnsCommand(
}
if (DDLUtils.isDatasourceTable(catalogTable)) {
- DataSource.lookupDataSource(catalogTable.provider.get).newInstance() match {
+ DataSource.lookupDataSource(catalogTable.provider.get, conf).newInstance() match {
// For datasource table, this command can only support the following File format.
// TextFileFormat only default to one column "value"
// Hive type is already considered as hive serde table, so the logic will not
@@ -340,7 +342,7 @@ case class LoadDataCommand(
uri
} else {
val uri = new URI(path)
- if (uri.getScheme() != null && uri.getAuthority() != null) {
+ val hdfsUri = if (uri.getScheme() != null && uri.getAuthority() != null) {
uri
} else {
// Follow Hive's behavior:
@@ -380,6 +382,13 @@ case class LoadDataCommand(
}
new URI(scheme, authority, absolutePath, uri.getQuery(), uri.getFragment())
}
+ val hadoopConf = sparkSession.sessionState.newHadoopConf()
+ val srcPath = new Path(hdfsUri)
+ val fs = srcPath.getFileSystem(hadoopConf)
+ if (!fs.exists(srcPath)) {
+ throw new AnalysisException(s"LOAD DATA input path does not exist: $path")
+ }
+ hdfsUri
}
if (partition.nonEmpty) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala
index 11af0aaa7b206..9dbbe9946ee99 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala
@@ -22,7 +22,7 @@ import java.io.FileNotFoundException
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.SQLExecution
@@ -44,7 +44,6 @@ case class BasicWriteTaskStats(
/**
* Simple [[WriteTaskStatsTracker]] implementation that produces [[BasicWriteTaskStats]].
- * @param hadoopConf
*/
class BasicWriteTaskStatsTracker(hadoopConf: Configuration)
extends WriteTaskStatsTracker with Logging {
@@ -106,6 +105,13 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration)
override def getFinalStats(): WriteTaskStats = {
statCurrentFile()
+
+ // Reports bytesWritten and recordsWritten to the Spark output metrics.
+ Option(TaskContext.get()).map(_.taskMetrics().outputMetrics).foreach { outputMetrics =>
+ outputMetrics.setBytesWritten(numBytes)
+ outputMetrics.setRecordsWritten(numRows)
+ }
+
if (submittedFiles != numFiles) {
logInfo(s"Expected $submittedFiles files, but only saw $numFiles. " +
"This could be due to the output format not writing empty files, " +
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index b43d282bd434c..25e1210504273 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -36,8 +36,10 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
+import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{CalendarIntervalType, StructType}
@@ -85,7 +87,8 @@ case class DataSource(
case class SourceInfo(name: String, schema: StructType, partitionColumns: Seq[String])
- lazy val providingClass: Class[_] = DataSource.lookupDataSource(className)
+ lazy val providingClass: Class[_] =
+ DataSource.lookupDataSource(className, sparkSession.sessionState.conf)
lazy val sourceInfo: SourceInfo = sourceSchema()
private val caseInsensitiveOptions = CaseInsensitiveMap(options)
private val equality = sparkSession.sessionState.conf.resolver
@@ -453,17 +456,6 @@ case class DataSource(
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
PartitioningUtils.validatePartitionColumn(data.schema, partitionColumns, caseSensitive)
-
- // SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does
- // not need to have the query as child, to avoid to analyze an optimized query,
- // because InsertIntoHadoopFsRelationCommand will be optimized first.
- val partitionAttributes = partitionColumns.map { name =>
- data.output.find(a => equality(a.name, name)).getOrElse {
- throw new AnalysisException(
- s"Unable to resolve $name given [${data.output.map(_.name).mkString(", ")}]")
- }
- }
-
val fileIndex = catalogTable.map(_.identifier).map { tableIdent =>
sparkSession.table(tableIdent).queryExecution.analyzed.collect {
case LogicalRelation(t: HadoopFsRelation, _, _, _) => t.location
@@ -476,14 +468,15 @@ case class DataSource(
outputPath = outputPath,
staticPartitions = Map.empty,
ifPartitionNotExists = false,
- partitionColumns = partitionAttributes,
+ partitionColumns = partitionColumns.map(UnresolvedAttribute.quoted),
bucketSpec = bucketSpec,
fileFormat = format,
options = options,
query = data,
mode = mode,
catalogTable = catalogTable,
- fileIndex = fileIndex)
+ fileIndex = fileIndex,
+ outputColumns = data.output)
}
/**
@@ -537,6 +530,7 @@ object DataSource extends Logging {
val csv = classOf[CSVFileFormat].getCanonicalName
val libsvm = "org.apache.spark.ml.source.libsvm.LibSVMFileFormat"
val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat"
+ val nativeOrc = classOf[OrcFileFormat].getCanonicalName
Map(
"org.apache.spark.sql.jdbc" -> jdbc,
@@ -553,6 +547,8 @@ object DataSource extends Logging {
"org.apache.spark.sql.execution.datasources.parquet.DefaultSource" -> parquet,
"org.apache.spark.sql.hive.orc.DefaultSource" -> orc,
"org.apache.spark.sql.hive.orc" -> orc,
+ "org.apache.spark.sql.execution.datasources.orc.DefaultSource" -> nativeOrc,
+ "org.apache.spark.sql.execution.datasources.orc" -> nativeOrc,
"org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm,
"org.apache.spark.ml.source.libsvm" -> libsvm,
"com.databricks.spark.csv" -> csv
@@ -568,8 +564,16 @@ object DataSource extends Logging {
"org.apache.spark.Logging")
/** Given a provider name, look up the data source class definition. */
- def lookupDataSource(provider: String): Class[_] = {
- val provider1 = backwardCompatibilityMap.getOrElse(provider, provider)
+ def lookupDataSource(provider: String, conf: SQLConf): Class[_] = {
+ val provider1 = backwardCompatibilityMap.getOrElse(provider, provider) match {
+ case name if name.equalsIgnoreCase("orc") &&
+ conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "native" =>
+ classOf[OrcFileFormat].getCanonicalName
+ case name if name.equalsIgnoreCase("orc") &&
+ conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "hive" =>
+ "org.apache.spark.sql.hive.orc.OrcFileFormat"
+ case name => name
+ }
val provider2 = s"$provider1.DefaultSource"
val loader = Utils.getContextOrSparkClassLoader
val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader)
@@ -584,10 +588,11 @@ object DataSource extends Logging {
// Found the data source using fully qualified path
dataSource
case Failure(error) =>
- if (provider1.toLowerCase(Locale.ROOT) == "orc" ||
- provider1.startsWith("org.apache.spark.sql.hive.orc")) {
+ if (provider1.startsWith("org.apache.spark.sql.hive.orc")) {
throw new AnalysisException(
- "The ORC data source must be used with Hive support enabled")
+ "Hive built-in ORC data source must be used with Hive support enabled. " +
+ "Please use the native ORC data source by setting 'spark.sql.orc.impl' to " +
+ "'native'")
} else if (provider1.toLowerCase(Locale.ROOT) == "avro" ||
provider1 == "com.databricks.spark.avro") {
throw new AnalysisException(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 400f2e03165b2..d94c5bbccdd84 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -208,7 +208,8 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast
actualQuery,
mode,
table,
- Some(t.location))
+ Some(t.location),
+ actualQuery.output)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala
index dfe752ce796b1..2fd722e5639db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala
@@ -78,7 +78,7 @@ trait FileFormat {
}
/**
- * Returns whether a file with `path` could be splitted or not.
+ * Returns whether a file with `path` could be split or not.
*/
def isSplitable(
sparkSession: SparkSession,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 1fac01a2c26c6..1d80a69bc5a1d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, _}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
-import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution}
+import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.types.StringType
import org.apache.spark.util.{SerializableConfiguration, Utils}
@@ -56,7 +56,9 @@ object FileFormatWriter extends Logging {
/** Describes how output files should be placed in the filesystem. */
case class OutputSpec(
- outputPath: String, customPartitionLocations: Map[TablePartitionSpec, String])
+ outputPath: String,
+ customPartitionLocations: Map[TablePartitionSpec, String],
+ outputColumns: Seq[Attribute])
/** A shared job description for all the write tasks. */
private class WriteJobDescription(
@@ -101,7 +103,7 @@ object FileFormatWriter extends Logging {
*/
def write(
sparkSession: SparkSession,
- queryExecution: QueryExecution,
+ plan: SparkPlan,
fileFormat: FileFormat,
committer: FileCommitProtocol,
outputSpec: OutputSpec,
@@ -117,11 +119,8 @@ object FileFormatWriter extends Logging {
job.setOutputValueClass(classOf[InternalRow])
FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath))
- // Pick the attributes from analyzed plan, as optimizer may not preserve the output schema
- // names' case.
- val allColumns = queryExecution.analyzed.output
val partitionSet = AttributeSet(partitionColumns)
- val dataColumns = allColumns.filterNot(partitionSet.contains)
+ val dataColumns = outputSpec.outputColumns.filterNot(partitionSet.contains)
val bucketIdExpression = bucketSpec.map { spec =>
val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
@@ -144,7 +143,7 @@ object FileFormatWriter extends Logging {
uuid = UUID.randomUUID().toString,
serializableHadoopConf = new SerializableConfiguration(job.getConfiguration),
outputWriterFactory = outputWriterFactory,
- allColumns = allColumns,
+ allColumns = outputSpec.outputColumns,
dataColumns = dataColumns,
partitionColumns = partitionColumns,
bucketIdExpression = bucketIdExpression,
@@ -160,7 +159,7 @@ object FileFormatWriter extends Logging {
// We should first sort by partition columns, then bucket id, and finally sorting columns.
val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns
// the sort order doesn't matter
- val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child)
+ val actualOrdering = plan.outputOrdering.map(_.child)
val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
false
} else {
@@ -178,17 +177,18 @@ object FileFormatWriter extends Logging {
try {
val rdd = if (orderingMatched) {
- queryExecution.toRdd
+ plan.execute()
} else {
// SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and
// the physical plan may have different attribute ids due to optimizer removing some
// aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch.
val orderingExpr = requiredOrdering
- .map(SortOrder(_, Ascending)).map(BindReferences.bindReference(_, allColumns))
+ .map(SortOrder(_, Ascending))
+ .map(BindReferences.bindReference(_, outputSpec.outputColumns))
SortExec(
orderingExpr,
global = false,
- child = queryExecution.executedPlan).execute()
+ child = plan).execute()
}
val ret = new Array[WriteTaskResult](rdd.partitions.length)
sparkSession.sparkContext.runJob(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
index 8731ee88f87f2..835ce98462477 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
@@ -26,7 +26,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.{InputFileBlockHolder, RDD}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.vectorized.ColumnarBatch
+import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.NextIterator
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileWholeTextReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileWholeTextReader.scala
new file mode 100644
index 0000000000000..c61a89e6e8c3f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileWholeTextReader.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import java.io.Closeable
+import java.net.URI
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.mapreduce._
+import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit
+import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
+
+import org.apache.spark.input.WholeTextFileRecordReader
+
+/**
+ * An adaptor from a [[PartitionedFile]] to an [[Iterator]] of [[Text]], which is all of the lines
+ * in that file.
+ */
+class HadoopFileWholeTextReader(file: PartitionedFile, conf: Configuration)
+ extends Iterator[Text] with Closeable {
+ private val iterator = {
+ val fileSplit = new CombineFileSplit(
+ Array(new Path(new URI(file.filePath))),
+ Array(file.start),
+ Array(file.length),
+ // TODO: Implement Locality
+ Array.empty[String])
+ val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
+ val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+ val reader = new WholeTextFileRecordReader(fileSplit, hadoopAttemptContext, 0)
+ reader.initialize(fileSplit, hadoopAttemptContext)
+ new RecordReaderIterator(reader)
+ }
+
+ override def hasNext: Boolean = iterator.hasNext
+
+ override def next(): Text = iterator.next()
+
+ override def close(): Unit = iterator.close()
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala
index 89d8a85a9cbd2..6b34638529770 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala
@@ -82,7 +82,11 @@ case class HadoopFsRelation(
}
}
- override def sizeInBytes: Long = location.sizeInBytes
+ override def sizeInBytes: Long = {
+ val compressionFactor = sqlContext.conf.fileCompressionFactor
+ (location.sizeInBytes * compressionFactor).toLong
+ }
+
override def inputFiles: Array[String] = location.inputFiles
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
index 675bee85bf61e..dd7ef0d15c140 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
@@ -27,7 +27,9 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogT
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command._
+import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
import org.apache.spark.sql.util.SchemaUtils
/**
@@ -52,11 +54,12 @@ case class InsertIntoHadoopFsRelationCommand(
query: LogicalPlan,
mode: SaveMode,
catalogTable: Option[CatalogTable],
- fileIndex: Option[FileIndex])
+ fileIndex: Option[FileIndex],
+ outputColumns: Seq[Attribute])
extends DataWritingCommand {
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName
- override def run(sparkSession: SparkSession): Seq[Row] = {
+ override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
// Most formats don't do well with duplicate columns, so lets not allow that
SchemaUtils.checkSchemaColumnNameDuplication(
query.schema,
@@ -87,13 +90,19 @@ case class InsertIntoHadoopFsRelationCommand(
}
val pathExists = fs.exists(qualifiedOutputPath)
- // If we are appending data to an existing dir.
- val isAppend = pathExists && (mode == SaveMode.Append)
+
+ val enableDynamicOverwrite =
+ sparkSession.sessionState.conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC
+ // This config only makes sense when we are overwriting a partitioned dataset with dynamic
+ // partition columns.
+ val dynamicPartitionOverwrite = enableDynamicOverwrite && mode == SaveMode.Overwrite &&
+ staticPartitions.size < partitionColumns.length
val committer = FileCommitProtocol.instantiate(
sparkSession.sessionState.conf.fileCommitProtocolClass,
jobId = java.util.UUID.randomUUID().toString,
- outputPath = outputPath.toString)
+ outputPath = outputPath.toString,
+ dynamicPartitionOverwrite = dynamicPartitionOverwrite)
val doInsertion = (mode, pathExists) match {
case (SaveMode.ErrorIfExists, true) =>
@@ -101,6 +110,9 @@ case class InsertIntoHadoopFsRelationCommand(
case (SaveMode.Overwrite, true) =>
if (ifPartitionNotExists && matchingPartitions.nonEmpty) {
false
+ } else if (dynamicPartitionOverwrite) {
+ // For dynamic partition overwrite, do not delete partition directories ahead.
+ true
} else {
deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer)
true
@@ -124,7 +136,9 @@ case class InsertIntoHadoopFsRelationCommand(
catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)),
ifNotExists = true).run(sparkSession)
}
- if (mode == SaveMode.Overwrite) {
+ // For dynamic partition overwrite, we never remove partitions but only update existing
+ // ones.
+ if (mode == SaveMode.Overwrite && !dynamicPartitionOverwrite) {
val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions
if (deletedPartitions.nonEmpty) {
AlterTableDropPartitionCommand(
@@ -139,11 +153,11 @@ case class InsertIntoHadoopFsRelationCommand(
val updatedPartitionPaths =
FileFormatWriter.write(
sparkSession = sparkSession,
- queryExecution = Dataset.ofRows(sparkSession, query).queryExecution,
+ plan = child,
fileFormat = fileFormat,
committer = committer,
outputSpec = FileFormatWriter.OutputSpec(
- qualifiedOutputPath.toString, customPartitionLocations),
+ qualifiedOutputPath.toString, customPartitionLocations, outputColumns),
hadoopConf = hadoopConf,
partitionColumns = partitionColumns,
bucketSpec = bucketSpec,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala
index 40825a1f724b1..39c594a9bc618 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala
@@ -29,11 +29,15 @@ import org.apache.spark.sql.internal.SQLConf
* A variant of [[HadoopMapReduceCommitProtocol]] that allows specifying the actual
* Hadoop output committer using an option specified in SQLConf.
*/
-class SQLHadoopMapReduceCommitProtocol(jobId: String, path: String)
- extends HadoopMapReduceCommitProtocol(jobId, path) with Serializable with Logging {
+class SQLHadoopMapReduceCommitProtocol(
+ jobId: String,
+ path: String,
+ dynamicPartitionOverwrite: Boolean = false)
+ extends HadoopMapReduceCommitProtocol(jobId, path, dynamicPartitionOverwrite)
+ with Serializable with Logging {
override protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = {
- var committer = context.getOutputFormatClass.newInstance().getOutputCommitter(context)
+ var committer = super.setupCommitter(context)
val configuration = context.getConfiguration
val clazz =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
index b64d71bb4eef2..a585cbed2551b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -150,7 +150,7 @@ private[csv] object CSVInferSchema {
if ((allCatch opt options.timestampFormat.parse(field)).isDefined) {
TimestampType
} else if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) {
- // We keep this for backwords competibility.
+ // We keep this for backwards compatibility.
TimestampType
} else {
tryParseBoolean(field, options)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index a13a5a34b4a84..c16790630ce17 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -89,6 +89,14 @@ class CSVOptions(
val quote = getChar("quote", '\"')
val escape = getChar("escape", '\\')
+ val charToEscapeQuoteEscaping = parameters.get("charToEscapeQuoteEscaping") match {
+ case None => None
+ case Some(null) => None
+ case Some(value) if value.length == 0 => None
+ case Some(value) if value.length == 1 => Some(value.charAt(0))
+ case _ =>
+ throw new RuntimeException("charToEscapeQuoteEscaping cannot be more than one character")
+ }
val comment = getChar("comment", '\u0000')
val headerFlag = getBool("header")
@@ -148,6 +156,7 @@ class CSVOptions(
format.setDelimiter(delimiter)
format.setQuote(quote)
format.setQuoteEscape(escape)
+ charToEscapeQuoteEscaping.foreach(format.setCharToEscapeQuoteEscaping)
format.setComment(comment)
writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite)
writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite)
@@ -165,6 +174,7 @@ class CSVOptions(
format.setDelimiter(delimiter)
format.setQuote(quote)
format.setQuoteEscape(escape)
+ charToEscapeQuoteEscaping.foreach(format.setCharToEscapeQuoteEscaping)
format.setComment(comment)
settings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceInRead)
settings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceInRead)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
index 37e7bb0a59bb6..cc506e51bd0c6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
@@ -68,7 +68,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
case SaveMode.Overwrite =>
if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) {
// In this case, we should truncate table and then load.
- truncateTable(conn, options.table)
+ truncateTable(conn, options)
val tableSchema = JdbcUtils.getSchemaOption(conn, options)
saveTable(df, tableSchema, isCaseSensitive, options)
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index 75c94fc486493..e6dc2fda4eb1b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -96,12 +96,13 @@ object JdbcUtils extends Logging {
}
/**
- * Truncates a table from the JDBC database.
+ * Truncates a table from the JDBC database without side effects.
*/
- def truncateTable(conn: Connection, table: String): Unit = {
+ def truncateTable(conn: Connection, options: JDBCOptions): Unit = {
+ val dialect = JdbcDialects.get(options.url)
val statement = conn.createStatement
try {
- statement.executeUpdate(s"TRUNCATE TABLE $table")
+ statement.executeUpdate(dialect.getTruncateQuery(options.table))
} finally {
statement.close()
}
@@ -226,10 +227,10 @@ object JdbcUtils extends Logging {
case java.sql.Types.STRUCT => StringType
case java.sql.Types.TIME => TimestampType
case java.sql.Types.TIME_WITH_TIMEZONE
- => TimestampType
+ => null
case java.sql.Types.TIMESTAMP => TimestampType
case java.sql.Types.TIMESTAMP_WITH_TIMEZONE
- => TimestampType
+ => null
case java.sql.Types.TINYINT => IntegerType
case java.sql.Types.VARBINARY => BinaryType
case java.sql.Types.VARCHAR => StringType
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala
new file mode 100644
index 0000000000000..4ecc54bd2fd96
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc
+
+import org.apache.hadoop.io._
+import org.apache.orc.mapred.{OrcList, OrcMap, OrcStruct, OrcTimestamp}
+import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable}
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData}
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A deserializer to deserialize ORC structs to Spark rows.
+ */
+class OrcDeserializer(
+ dataSchema: StructType,
+ requiredSchema: StructType,
+ requestedColIds: Array[Int]) {
+
+ private val resultRow = new SpecificInternalRow(requiredSchema.map(_.dataType))
+
+ private val fieldWriters: Array[WritableComparable[_] => Unit] = {
+ requiredSchema.zipWithIndex
+ // The value of missing columns are always null, do not need writers.
+ .filterNot { case (_, index) => requestedColIds(index) == -1 }
+ .map { case (f, index) =>
+ val writer = newWriter(f.dataType, new RowUpdater(resultRow))
+ (value: WritableComparable[_]) => writer(index, value)
+ }.toArray
+ }
+
+ private val validColIds = requestedColIds.filterNot(_ == -1)
+
+ def deserialize(orcStruct: OrcStruct): InternalRow = {
+ var i = 0
+ while (i < validColIds.length) {
+ val value = orcStruct.getFieldValue(validColIds(i))
+ if (value == null) {
+ resultRow.setNullAt(i)
+ } else {
+ fieldWriters(i)(value)
+ }
+ i += 1
+ }
+ resultRow
+ }
+
+ /**
+ * Creates a writer to write ORC values to Catalyst data structure at the given ordinal.
+ */
+ private def newWriter(
+ dataType: DataType, updater: CatalystDataUpdater): (Int, WritableComparable[_]) => Unit =
+ dataType match {
+ case NullType => (ordinal, _) =>
+ updater.setNullAt(ordinal)
+
+ case BooleanType => (ordinal, value) =>
+ updater.setBoolean(ordinal, value.asInstanceOf[BooleanWritable].get)
+
+ case ByteType => (ordinal, value) =>
+ updater.setByte(ordinal, value.asInstanceOf[ByteWritable].get)
+
+ case ShortType => (ordinal, value) =>
+ updater.setShort(ordinal, value.asInstanceOf[ShortWritable].get)
+
+ case IntegerType => (ordinal, value) =>
+ updater.setInt(ordinal, value.asInstanceOf[IntWritable].get)
+
+ case LongType => (ordinal, value) =>
+ updater.setLong(ordinal, value.asInstanceOf[LongWritable].get)
+
+ case FloatType => (ordinal, value) =>
+ updater.setFloat(ordinal, value.asInstanceOf[FloatWritable].get)
+
+ case DoubleType => (ordinal, value) =>
+ updater.setDouble(ordinal, value.asInstanceOf[DoubleWritable].get)
+
+ case StringType => (ordinal, value) =>
+ updater.set(ordinal, UTF8String.fromBytes(value.asInstanceOf[Text].copyBytes))
+
+ case BinaryType => (ordinal, value) =>
+ val binary = value.asInstanceOf[BytesWritable]
+ val bytes = new Array[Byte](binary.getLength)
+ System.arraycopy(binary.getBytes, 0, bytes, 0, binary.getLength)
+ updater.set(ordinal, bytes)
+
+ case DateType => (ordinal, value) =>
+ updater.setInt(ordinal, DateTimeUtils.fromJavaDate(value.asInstanceOf[DateWritable].get))
+
+ case TimestampType => (ordinal, value) =>
+ updater.setLong(ordinal, DateTimeUtils.fromJavaTimestamp(value.asInstanceOf[OrcTimestamp]))
+
+ case DecimalType.Fixed(precision, scale) => (ordinal, value) =>
+ val decimal = value.asInstanceOf[HiveDecimalWritable].getHiveDecimal()
+ val v = Decimal(decimal.bigDecimalValue, decimal.precision(), decimal.scale())
+ v.changePrecision(precision, scale)
+ updater.set(ordinal, v)
+
+ case st: StructType => (ordinal, value) =>
+ val result = new SpecificInternalRow(st)
+ val fieldUpdater = new RowUpdater(result)
+ val fieldConverters = st.map(_.dataType).map { dt =>
+ newWriter(dt, fieldUpdater)
+ }.toArray
+ val orcStruct = value.asInstanceOf[OrcStruct]
+
+ var i = 0
+ while (i < st.length) {
+ val value = orcStruct.getFieldValue(i)
+ if (value == null) {
+ result.setNullAt(i)
+ } else {
+ fieldConverters(i)(i, value)
+ }
+ i += 1
+ }
+
+ updater.set(ordinal, result)
+
+ case ArrayType(elementType, _) => (ordinal, value) =>
+ val orcArray = value.asInstanceOf[OrcList[WritableComparable[_]]]
+ val length = orcArray.size()
+ val result = createArrayData(elementType, length)
+ val elementUpdater = new ArrayDataUpdater(result)
+ val elementConverter = newWriter(elementType, elementUpdater)
+
+ var i = 0
+ while (i < length) {
+ val value = orcArray.get(i)
+ if (value == null) {
+ result.setNullAt(i)
+ } else {
+ elementConverter(i, value)
+ }
+ i += 1
+ }
+
+ updater.set(ordinal, result)
+
+ case MapType(keyType, valueType, _) => (ordinal, value) =>
+ val orcMap = value.asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]]
+ val length = orcMap.size()
+ val keyArray = createArrayData(keyType, length)
+ val keyUpdater = new ArrayDataUpdater(keyArray)
+ val keyConverter = newWriter(keyType, keyUpdater)
+ val valueArray = createArrayData(valueType, length)
+ val valueUpdater = new ArrayDataUpdater(valueArray)
+ val valueConverter = newWriter(valueType, valueUpdater)
+
+ var i = 0
+ val it = orcMap.entrySet().iterator()
+ while (it.hasNext) {
+ val entry = it.next()
+ keyConverter(i, entry.getKey)
+ val value = entry.getValue
+ if (value == null) {
+ valueArray.setNullAt(i)
+ } else {
+ valueConverter(i, value)
+ }
+ i += 1
+ }
+
+ updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray))
+
+ case udt: UserDefinedType[_] => newWriter(udt.sqlType, updater)
+
+ case _ =>
+ throw new UnsupportedOperationException(s"$dataType is not supported yet.")
+ }
+
+ private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match {
+ case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length))
+ case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length))
+ case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length))
+ case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length))
+ case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length))
+ case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length))
+ case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length))
+ case _ => new GenericArrayData(new Array[Any](length))
+ }
+
+ /**
+ * A base interface for updating values inside catalyst data structure like `InternalRow` and
+ * `ArrayData`.
+ */
+ sealed trait CatalystDataUpdater {
+ def set(ordinal: Int, value: Any): Unit
+
+ def setNullAt(ordinal: Int): Unit = set(ordinal, null)
+ def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value)
+ def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value)
+ def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value)
+ def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value)
+ def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value)
+ def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value)
+ def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value)
+ }
+
+ final class RowUpdater(row: InternalRow) extends CatalystDataUpdater {
+ override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal)
+ override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value)
+
+ override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value)
+ override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value)
+ override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value)
+ override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value)
+ override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value)
+ override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value)
+ override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value)
+ }
+
+ final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater {
+ override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal)
+ override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value)
+
+ override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value)
+ override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value)
+ override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value)
+ override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value)
+ override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value)
+ override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value)
+ override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
index 215740e90fe84..2dd314d165348 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
@@ -17,10 +17,31 @@
package org.apache.spark.sql.execution.datasources.orc
-import org.apache.orc.TypeDescription
+import java.io._
+import java.net.URI
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.mapreduce._
+import org.apache.hadoop.mapreduce.lib.input.FileSplit
+import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
+import org.apache.orc._
+import org.apache.orc.OrcConf.{COMPRESS, MAPRED_OUTPUT_SCHEMA}
+import org.apache.orc.mapred.OrcStruct
+import org.apache.orc.mapreduce._
+
+import org.apache.spark.TaskContext
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types._
+import org.apache.spark.util.SerializableConfiguration
private[sql] object OrcFileFormat {
private def checkFieldName(name: String): Unit = {
@@ -39,3 +60,163 @@ private[sql] object OrcFileFormat {
names.foreach(checkFieldName)
}
}
+
+/**
+ * New ORC File Format based on Apache ORC.
+ */
+class OrcFileFormat
+ extends FileFormat
+ with DataSourceRegister
+ with Serializable {
+
+ override def shortName(): String = "orc"
+
+ override def toString: String = "ORC"
+
+ override def hashCode(): Int = getClass.hashCode()
+
+ override def equals(other: Any): Boolean = other.isInstanceOf[OrcFileFormat]
+
+ override def inferSchema(
+ sparkSession: SparkSession,
+ options: Map[String, String],
+ files: Seq[FileStatus]): Option[StructType] = {
+ OrcUtils.readSchema(sparkSession, files)
+ }
+
+ override def prepareWrite(
+ sparkSession: SparkSession,
+ job: Job,
+ options: Map[String, String],
+ dataSchema: StructType): OutputWriterFactory = {
+ val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf)
+
+ val conf = job.getConfiguration
+
+ conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, dataSchema.catalogString)
+
+ conf.set(COMPRESS.getAttribute, orcOptions.compressionCodec)
+
+ conf.asInstanceOf[JobConf]
+ .setOutputFormat(classOf[org.apache.orc.mapred.OrcOutputFormat[OrcStruct]])
+
+ new OutputWriterFactory {
+ override def newInstance(
+ path: String,
+ dataSchema: StructType,
+ context: TaskAttemptContext): OutputWriter = {
+ new OrcOutputWriter(path, dataSchema, context)
+ }
+
+ override def getFileExtension(context: TaskAttemptContext): String = {
+ val compressionExtension: String = {
+ val name = context.getConfiguration.get(COMPRESS.getAttribute)
+ OrcUtils.extensionsForCompressionCodecNames.getOrElse(name, "")
+ }
+
+ compressionExtension + ".orc"
+ }
+ }
+ }
+
+ override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = {
+ val conf = sparkSession.sessionState.conf
+ conf.orcVectorizedReaderEnabled && conf.wholeStageEnabled &&
+ schema.length <= conf.wholeStageMaxNumFields &&
+ schema.forall(_.dataType.isInstanceOf[AtomicType])
+ }
+
+ override def isSplitable(
+ sparkSession: SparkSession,
+ options: Map[String, String],
+ path: Path): Boolean = {
+ true
+ }
+
+ override def buildReaderWithPartitionValues(
+ sparkSession: SparkSession,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ requiredSchema: StructType,
+ filters: Seq[Filter],
+ options: Map[String, String],
+ hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
+ if (sparkSession.sessionState.conf.orcFilterPushDown) {
+ OrcFilters.createFilter(dataSchema, filters).foreach { f =>
+ OrcInputFormat.setSearchArgument(hadoopConf, f, dataSchema.fieldNames)
+ }
+ }
+
+ val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields)
+ val sqlConf = sparkSession.sessionState.conf
+ val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled
+ val enableVectorizedReader = supportBatch(sparkSession, resultSchema)
+ val copyToSpark = sparkSession.sessionState.conf.getConf(SQLConf.ORC_COPY_BATCH_TO_SPARK)
+
+ val broadcastedConf =
+ sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
+ val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
+
+ (file: PartitionedFile) => {
+ val conf = broadcastedConf.value.value
+
+ val filePath = new Path(new URI(file.filePath))
+
+ val fs = filePath.getFileSystem(conf)
+ val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
+ val reader = OrcFile.createReader(filePath, readerOptions)
+
+ val requestedColIdsOrEmptyFile = OrcUtils.requestedColumnIds(
+ isCaseSensitive, dataSchema, requiredSchema, reader, conf)
+
+ if (requestedColIdsOrEmptyFile.isEmpty) {
+ Iterator.empty
+ } else {
+ val requestedColIds = requestedColIdsOrEmptyFile.get
+ assert(requestedColIds.length == requiredSchema.length,
+ "[BUG] requested column IDs do not match required schema")
+ val taskConf = new Configuration(conf)
+ taskConf.set(OrcConf.INCLUDE_COLUMNS.getAttribute,
+ requestedColIds.filter(_ != -1).sorted.mkString(","))
+
+ val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty)
+ val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
+ val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId)
+
+ val taskContext = Option(TaskContext.get())
+ if (enableVectorizedReader) {
+ val batchReader = new OrcColumnarBatchReader(
+ enableOffHeapColumnVector && taskContext.isDefined, copyToSpark)
+ batchReader.initialize(fileSplit, taskAttemptContext)
+ batchReader.initBatch(
+ reader.getSchema,
+ requestedColIds,
+ requiredSchema.fields,
+ partitionSchema,
+ file.partitionValues)
+
+ val iter = new RecordReaderIterator(batchReader)
+ Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close()))
+ iter.asInstanceOf[Iterator[InternalRow]]
+ } else {
+ val orcRecordReader = new OrcInputFormat[OrcStruct]
+ .createRecordReader(fileSplit, taskAttemptContext)
+ val iter = new RecordReaderIterator[OrcStruct](orcRecordReader)
+ Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close()))
+
+ val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes
+ val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema)
+ val deserializer = new OrcDeserializer(dataSchema, requiredSchema, requestedColIds)
+
+ if (partitionSchema.length == 0) {
+ iter.map(value => unsafeProjection(deserializer.deserialize(value)))
+ } else {
+ val joinedRow = new JoinedRow()
+ iter.map(value =>
+ unsafeProjection(joinedRow(deserializer.deserialize(value), file.partitionValues)))
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala
new file mode 100644
index 0000000000000..4f44ae4fa1d71
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc
+
+import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument, SearchArgumentFactory}
+import org.apache.orc.storage.ql.io.sarg.SearchArgument.Builder
+import org.apache.orc.storage.serde2.io.HiveDecimalWritable
+
+import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.types._
+
+/**
+ * Helper object for building ORC `SearchArgument`s, which are used for ORC predicate push-down.
+ *
+ * Due to limitation of ORC `SearchArgument` builder, we had to end up with a pretty weird double-
+ * checking pattern when converting `And`/`Or`/`Not` filters.
+ *
+ * An ORC `SearchArgument` must be built in one pass using a single builder. For example, you can't
+ * build `a = 1` and `b = 2` first, and then combine them into `a = 1 AND b = 2`. This is quite
+ * different from the cases in Spark SQL or Parquet, where complex filters can be easily built using
+ * existing simpler ones.
+ *
+ * The annoying part is that, `SearchArgument` builder methods like `startAnd()`, `startOr()`, and
+ * `startNot()` mutate internal state of the builder instance. This forces us to translate all
+ * convertible filters with a single builder instance. However, before actually converting a filter,
+ * we've no idea whether it can be recognized by ORC or not. Thus, when an inconvertible filter is
+ * found, we may already end up with a builder whose internal state is inconsistent.
+ *
+ * For example, to convert an `And` filter with builder `b`, we call `b.startAnd()` first, and then
+ * try to convert its children. Say we convert `left` child successfully, but find that `right`
+ * child is inconvertible. Alas, `b.startAnd()` call can't be rolled back, and `b` is inconsistent
+ * now.
+ *
+ * The workaround employed here is that, for `And`/`Or`/`Not`, we first try to convert their
+ * children with brand new builders, and only do the actual conversion with the right builder
+ * instance when the children are proven to be convertible.
+ *
+ * P.S.: Hive seems to use `SearchArgument` together with `ExprNodeGenericFuncDesc` only. Usage of
+ * builder methods mentioned above can only be found in test code, where all tested filters are
+ * known to be convertible.
+ */
+private[orc] object OrcFilters {
+
+ /**
+ * Create ORC filter as a SearchArgument instance.
+ */
+ def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = {
+ val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap
+
+ // First, tries to convert each filter individually to see whether it's convertible, and then
+ // collect all convertible ones to build the final `SearchArgument`.
+ val convertibleFilters = for {
+ filter <- filters
+ _ <- buildSearchArgument(dataTypeMap, filter, SearchArgumentFactory.newBuilder())
+ } yield filter
+
+ for {
+ // Combines all convertible filters using `And` to produce a single conjunction
+ conjunction <- convertibleFilters.reduceOption(org.apache.spark.sql.sources.And)
+ // Then tries to build a single ORC `SearchArgument` for the conjunction predicate
+ builder <- buildSearchArgument(dataTypeMap, conjunction, SearchArgumentFactory.newBuilder())
+ } yield builder.build()
+ }
+
+ /**
+ * Return true if this is a searchable type in ORC.
+ * Both CharType and VarcharType are cleaned at AstBuilder.
+ */
+ private def isSearchableType(dataType: DataType) = dataType match {
+ case BinaryType => false
+ case _: AtomicType => true
+ case _ => false
+ }
+
+ /**
+ * Get PredicateLeafType which is corresponding to the given DataType.
+ */
+ private def getPredicateLeafType(dataType: DataType) = dataType match {
+ case BooleanType => PredicateLeaf.Type.BOOLEAN
+ case ByteType | ShortType | IntegerType | LongType => PredicateLeaf.Type.LONG
+ case FloatType | DoubleType => PredicateLeaf.Type.FLOAT
+ case StringType => PredicateLeaf.Type.STRING
+ case DateType => PredicateLeaf.Type.DATE
+ case TimestampType => PredicateLeaf.Type.TIMESTAMP
+ case _: DecimalType => PredicateLeaf.Type.DECIMAL
+ case _ => throw new UnsupportedOperationException(s"DataType: $dataType")
+ }
+
+ /**
+ * Cast literal values for filters.
+ *
+ * We need to cast to long because ORC raises exceptions
+ * at 'checkLiteralType' of SearchArgumentImpl.java.
+ */
+ private def castLiteralValue(value: Any, dataType: DataType): Any = dataType match {
+ case ByteType | ShortType | IntegerType | LongType =>
+ value.asInstanceOf[Number].longValue
+ case FloatType | DoubleType =>
+ value.asInstanceOf[Number].doubleValue()
+ case _: DecimalType =>
+ val decimal = value.asInstanceOf[java.math.BigDecimal]
+ val decimalWritable = new HiveDecimalWritable(decimal.longValue)
+ decimalWritable.mutateEnforcePrecisionScale(decimal.precision, decimal.scale)
+ decimalWritable
+ case _ => value
+ }
+
+ /**
+ * Build a SearchArgument and return the builder so far.
+ */
+ private def buildSearchArgument(
+ dataTypeMap: Map[String, DataType],
+ expression: Filter,
+ builder: Builder): Option[Builder] = {
+ def newBuilder = SearchArgumentFactory.newBuilder()
+
+ def getType(attribute: String): PredicateLeaf.Type =
+ getPredicateLeafType(dataTypeMap(attribute))
+
+ import org.apache.spark.sql.sources._
+
+ expression match {
+ case And(left, right) =>
+ // At here, it is not safe to just convert one side if we do not understand the
+ // other side. Here is an example used to explain the reason.
+ // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to
+ // convert b in ('1'). If we only convert a = 2, we will end up with a filter
+ // NOT(a = 2), which will generate wrong results.
+ // Pushing one side of AND down is only safe to do at the top level.
+ // You can see ParquetRelation's initializeLocalJobFunc method as an example.
+ for {
+ _ <- buildSearchArgument(dataTypeMap, left, newBuilder)
+ _ <- buildSearchArgument(dataTypeMap, right, newBuilder)
+ lhs <- buildSearchArgument(dataTypeMap, left, builder.startAnd())
+ rhs <- buildSearchArgument(dataTypeMap, right, lhs)
+ } yield rhs.end()
+
+ case Or(left, right) =>
+ for {
+ _ <- buildSearchArgument(dataTypeMap, left, newBuilder)
+ _ <- buildSearchArgument(dataTypeMap, right, newBuilder)
+ lhs <- buildSearchArgument(dataTypeMap, left, builder.startOr())
+ rhs <- buildSearchArgument(dataTypeMap, right, lhs)
+ } yield rhs.end()
+
+ case Not(child) =>
+ for {
+ _ <- buildSearchArgument(dataTypeMap, child, newBuilder)
+ negate <- buildSearchArgument(dataTypeMap, child, builder.startNot())
+ } yield negate.end()
+
+ // NOTE: For all case branches dealing with leaf predicates below, the additional `startAnd()`
+ // call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be
+ // wrapped by a "parent" predicate (`And`, `Or`, or `Not`).
+
+ case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startAnd().equals(attribute, getType(attribute), castedValue).end())
+
+ case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startAnd().nullSafeEquals(attribute, getType(attribute), castedValue).end())
+
+ case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startAnd().lessThan(attribute, getType(attribute), castedValue).end())
+
+ case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startAnd().lessThanEquals(attribute, getType(attribute), castedValue).end())
+
+ case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startNot().lessThanEquals(attribute, getType(attribute), castedValue).end())
+
+ case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startNot().lessThan(attribute, getType(attribute), castedValue).end())
+
+ case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) =>
+ Some(builder.startAnd().isNull(attribute, getType(attribute)).end())
+
+ case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) =>
+ Some(builder.startNot().isNull(attribute, getType(attribute)).end())
+
+ case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) =>
+ val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(attribute)))
+ Some(builder.startAnd().in(attribute, getType(attribute),
+ castedValues.map(_.asInstanceOf[AnyRef]): _*).end())
+
+ case _ => None
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala
new file mode 100644
index 0000000000000..84755bfa301f0
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.NullWritable
+import org.apache.hadoop.mapreduce.TaskAttemptContext
+import org.apache.orc.mapred.OrcStruct
+import org.apache.orc.mapreduce.OrcOutputFormat
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.datasources.OutputWriter
+import org.apache.spark.sql.types._
+
+private[orc] class OrcOutputWriter(
+ path: String,
+ dataSchema: StructType,
+ context: TaskAttemptContext)
+ extends OutputWriter {
+
+ private[this] val serializer = new OrcSerializer(dataSchema)
+
+ private val recordWriter = {
+ new OrcOutputFormat[OrcStruct]() {
+ override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
+ new Path(path)
+ }
+ }.getRecordWriter(context)
+ }
+
+ override def write(row: InternalRow): Unit = {
+ recordWriter.write(NullWritable.get(), serializer.serialize(row))
+ }
+
+ override def close(): Unit = {
+ recordWriter.close(context)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala
new file mode 100644
index 0000000000000..899af0750cadf
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc
+
+import org.apache.hadoop.io._
+import org.apache.orc.TypeDescription
+import org.apache.orc.mapred.{OrcList, OrcMap, OrcStruct, OrcTimestamp}
+import org.apache.orc.storage.common.`type`.HiveDecimal
+import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable}
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.types._
+
+/**
+ * A serializer to serialize Spark rows to ORC structs.
+ */
+class OrcSerializer(dataSchema: StructType) {
+
+ private val result = createOrcValue(dataSchema).asInstanceOf[OrcStruct]
+ private val converters = dataSchema.map(_.dataType).map(newConverter(_)).toArray
+
+ def serialize(row: InternalRow): OrcStruct = {
+ var i = 0
+ while (i < converters.length) {
+ if (row.isNullAt(i)) {
+ result.setFieldValue(i, null)
+ } else {
+ result.setFieldValue(i, converters(i)(row, i))
+ }
+ i += 1
+ }
+ result
+ }
+
+ private type Converter = (SpecializedGetters, Int) => WritableComparable[_]
+
+ /**
+ * Creates a converter to convert Catalyst data at the given ordinal to ORC values.
+ */
+ private def newConverter(
+ dataType: DataType,
+ reuseObj: Boolean = true): Converter = dataType match {
+ case NullType => (getter, ordinal) => null
+
+ case BooleanType =>
+ if (reuseObj) {
+ val result = new BooleanWritable()
+ (getter, ordinal) =>
+ result.set(getter.getBoolean(ordinal))
+ result
+ } else {
+ (getter, ordinal) => new BooleanWritable(getter.getBoolean(ordinal))
+ }
+
+ case ByteType =>
+ if (reuseObj) {
+ val result = new ByteWritable()
+ (getter, ordinal) =>
+ result.set(getter.getByte(ordinal))
+ result
+ } else {
+ (getter, ordinal) => new ByteWritable(getter.getByte(ordinal))
+ }
+
+ case ShortType =>
+ if (reuseObj) {
+ val result = new ShortWritable()
+ (getter, ordinal) =>
+ result.set(getter.getShort(ordinal))
+ result
+ } else {
+ (getter, ordinal) => new ShortWritable(getter.getShort(ordinal))
+ }
+
+ case IntegerType =>
+ if (reuseObj) {
+ val result = new IntWritable()
+ (getter, ordinal) =>
+ result.set(getter.getInt(ordinal))
+ result
+ } else {
+ (getter, ordinal) => new IntWritable(getter.getInt(ordinal))
+ }
+
+
+ case LongType =>
+ if (reuseObj) {
+ val result = new LongWritable()
+ (getter, ordinal) =>
+ result.set(getter.getLong(ordinal))
+ result
+ } else {
+ (getter, ordinal) => new LongWritable(getter.getLong(ordinal))
+ }
+
+ case FloatType =>
+ if (reuseObj) {
+ val result = new FloatWritable()
+ (getter, ordinal) =>
+ result.set(getter.getFloat(ordinal))
+ result
+ } else {
+ (getter, ordinal) => new FloatWritable(getter.getFloat(ordinal))
+ }
+
+ case DoubleType =>
+ if (reuseObj) {
+ val result = new DoubleWritable()
+ (getter, ordinal) =>
+ result.set(getter.getDouble(ordinal))
+ result
+ } else {
+ (getter, ordinal) => new DoubleWritable(getter.getDouble(ordinal))
+ }
+
+
+ // Don't reuse the result object for string and binary as it would cause extra data copy.
+ case StringType => (getter, ordinal) =>
+ new Text(getter.getUTF8String(ordinal).getBytes)
+
+ case BinaryType => (getter, ordinal) =>
+ new BytesWritable(getter.getBinary(ordinal))
+
+ case DateType =>
+ if (reuseObj) {
+ val result = new DateWritable()
+ (getter, ordinal) =>
+ result.set(getter.getInt(ordinal))
+ result
+ } else {
+ (getter, ordinal) => new DateWritable(getter.getInt(ordinal))
+ }
+
+ // The following cases are already expensive, reusing object or not doesn't matter.
+
+ case TimestampType => (getter, ordinal) =>
+ val ts = DateTimeUtils.toJavaTimestamp(getter.getLong(ordinal))
+ val result = new OrcTimestamp(ts.getTime)
+ result.setNanos(ts.getNanos)
+ result
+
+ case DecimalType.Fixed(precision, scale) => (getter, ordinal) =>
+ val d = getter.getDecimal(ordinal, precision, scale)
+ new HiveDecimalWritable(HiveDecimal.create(d.toJavaBigDecimal))
+
+ case st: StructType => (getter, ordinal) =>
+ val result = createOrcValue(st).asInstanceOf[OrcStruct]
+ val fieldConverters = st.map(_.dataType).map(newConverter(_))
+ val numFields = st.length
+ val struct = getter.getStruct(ordinal, numFields)
+ var i = 0
+ while (i < numFields) {
+ if (struct.isNullAt(i)) {
+ result.setFieldValue(i, null)
+ } else {
+ result.setFieldValue(i, fieldConverters(i)(struct, i))
+ }
+ i += 1
+ }
+ result
+
+ case ArrayType(elementType, _) => (getter, ordinal) =>
+ val result = createOrcValue(dataType).asInstanceOf[OrcList[WritableComparable[_]]]
+ // Need to put all converted values to a list, can't reuse object.
+ val elementConverter = newConverter(elementType, reuseObj = false)
+ val array = getter.getArray(ordinal)
+ var i = 0
+ while (i < array.numElements()) {
+ if (array.isNullAt(i)) {
+ result.add(null)
+ } else {
+ result.add(elementConverter(array, i))
+ }
+ i += 1
+ }
+ result
+
+ case MapType(keyType, valueType, _) => (getter, ordinal) =>
+ val result = createOrcValue(dataType)
+ .asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]]
+ // Need to put all converted values to a list, can't reuse object.
+ val keyConverter = newConverter(keyType, reuseObj = false)
+ val valueConverter = newConverter(valueType, reuseObj = false)
+ val map = getter.getMap(ordinal)
+ val keyArray = map.keyArray()
+ val valueArray = map.valueArray()
+ var i = 0
+ while (i < map.numElements()) {
+ val key = keyConverter(keyArray, i)
+ if (valueArray.isNullAt(i)) {
+ result.put(key, null)
+ } else {
+ result.put(key, valueConverter(valueArray, i))
+ }
+ i += 1
+ }
+ result
+
+ case udt: UserDefinedType[_] => newConverter(udt.sqlType)
+
+ case _ =>
+ throw new UnsupportedOperationException(s"$dataType is not supported yet.")
+ }
+
+ /**
+ * Return a Orc value object for the given Spark schema.
+ */
+ private def createOrcValue(dataType: DataType) = {
+ OrcStruct.createValue(TypeDescription.fromString(dataType.catalogString))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
new file mode 100644
index 0000000000000..460194ba61c8b
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc
+
+import scala.collection.JavaConverters._
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.orc.{OrcFile, Reader, TypeDescription}
+
+import org.apache.spark.SparkException
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, caseSensitiveResolution}
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.types._
+
+object OrcUtils extends Logging {
+
+ // The extensions for ORC compression codecs
+ val extensionsForCompressionCodecNames = Map(
+ "NONE" -> "",
+ "SNAPPY" -> ".snappy",
+ "ZLIB" -> ".zlib",
+ "LZO" -> ".lzo")
+
+ def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = {
+ val origPath = new Path(pathStr)
+ val fs = origPath.getFileSystem(conf)
+ val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath)
+ .filterNot(_.isDirectory)
+ .map(_.getPath)
+ .filterNot(_.getName.startsWith("_"))
+ .filterNot(_.getName.startsWith("."))
+ paths
+ }
+
+ def readSchema(file: Path, conf: Configuration, ignoreCorruptFiles: Boolean)
+ : Option[TypeDescription] = {
+ val fs = file.getFileSystem(conf)
+ val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
+ try {
+ val reader = OrcFile.createReader(file, readerOptions)
+ val schema = reader.getSchema
+ if (schema.getFieldNames.size == 0) {
+ None
+ } else {
+ Some(schema)
+ }
+ } catch {
+ case e: org.apache.orc.FileFormatException =>
+ if (ignoreCorruptFiles) {
+ logWarning(s"Skipped the footer in the corrupted file: $file", e)
+ None
+ } else {
+ throw new SparkException(s"Could not read footer for file: $file", e)
+ }
+ }
+ }
+
+ def readSchema(sparkSession: SparkSession, files: Seq[FileStatus])
+ : Option[StructType] = {
+ val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles
+ val conf = sparkSession.sessionState.newHadoopConf()
+ // TODO: We need to support merge schema. Please see SPARK-11412.
+ files.map(_.getPath).flatMap(readSchema(_, conf, ignoreCorruptFiles)).headOption.map { schema =>
+ logDebug(s"Reading schema from file $files, got Hive schema string: $schema")
+ CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType]
+ }
+ }
+
+ /**
+ * Returns the requested column ids from the given ORC file. Column id can be -1, which means the
+ * requested column doesn't exist in the ORC file. Returns None if the given ORC file is empty.
+ */
+ def requestedColumnIds(
+ isCaseSensitive: Boolean,
+ dataSchema: StructType,
+ requiredSchema: StructType,
+ reader: Reader,
+ conf: Configuration): Option[Array[Int]] = {
+ val orcFieldNames = reader.getSchema.getFieldNames.asScala
+ if (orcFieldNames.isEmpty) {
+ // SPARK-8501: Some old empty ORC files always have an empty schema stored in their footer.
+ None
+ } else {
+ if (orcFieldNames.forall(_.startsWith("_col"))) {
+ // This is a ORC file written by Hive, no field names in the physical schema, assume the
+ // physical schema maps to the data scheme by index.
+ assert(orcFieldNames.length <= dataSchema.length, "The given data schema " +
+ s"${dataSchema.simpleString} has less fields than the actual ORC physical schema, " +
+ "no idea which columns were dropped, fail to read.")
+ Some(requiredSchema.fieldNames.map { name =>
+ val index = dataSchema.fieldIndex(name)
+ if (index < orcFieldNames.length) {
+ index
+ } else {
+ -1
+ }
+ })
+ } else {
+ val resolver = if (isCaseSensitive) caseSensitiveResolution else caseInsensitiveResolution
+ Some(requiredSchema.fieldNames.map { name => orcFieldNames.indexWhere(resolver(_, name)) })
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
index b1d49af3de465..8e2bc1033c948 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
@@ -49,6 +49,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector}
import org.apache.spark.sql.internal.SQLConf
@@ -377,6 +378,9 @@ class ParquetFileFormat
hadoopConf.set(
ParquetWriteSupport.SPARK_ROW_SCHEMA,
requiredSchema.json)
+ hadoopConf.set(
+ SQLConf.SESSION_LOCAL_TIMEZONE.key,
+ sparkSession.sessionState.conf.sessionLocalTimeZone)
ParquetWriteSupport.setSchema(requiredSchema, hadoopConf)
@@ -424,6 +428,8 @@ class ParquetFileFormat
resultSchema.forall(_.dataType.isInstanceOf[AtomicType])
val enableRecordFilter: Boolean =
sparkSession.sessionState.conf.parquetRecordFilterEnabled
+ val timestampConversion: Boolean =
+ sparkSession.sessionState.conf.isParquetINT96TimestampConversion
// Whole stage codegen (PhysicalRDD) is able to deal with batches directly
val returningBatch = supportBatch(sparkSession, resultSchema)
@@ -442,6 +448,22 @@ class ParquetFileFormat
fileSplit.getLocations,
null)
+ val sharedConf = broadcastedHadoopConf.value.value
+ // PARQUET_INT96_TIMESTAMP_CONVERSION says to apply timezone conversions to int96 timestamps'
+ // *only* if the file was created by something other than "parquet-mr", so check the actual
+ // writer here for this file. We have to do this per-file, as each file in the table may
+ // have different writers.
+ def isCreatedByParquetMr(): Boolean = {
+ val footer = ParquetFileReader.readFooter(sharedConf, fileSplit.getPath, SKIP_ROW_GROUPS)
+ footer.getFileMetaData().getCreatedBy().startsWith("parquet-mr")
+ }
+ val convertTz =
+ if (timestampConversion && !isCreatedByParquetMr()) {
+ Some(DateTimeUtils.getTimeZone(sharedConf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key)))
+ } else {
+ None
+ }
+
val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
val hadoopAttemptContext =
new TaskAttemptContextImpl(broadcastedHadoopConf.value.value, attemptId)
@@ -453,8 +475,8 @@ class ParquetFileFormat
}
val taskContext = Option(TaskContext.get())
val parquetReader = if (enableVectorizedReader) {
- val vectorizedReader =
- new VectorizedParquetRecordReader(enableOffHeapColumnVector && taskContext.isDefined)
+ val vectorizedReader = new VectorizedParquetRecordReader(
+ convertTz.orNull, enableOffHeapColumnVector && taskContext.isDefined)
vectorizedReader.initialize(split, hadoopAttemptContext)
logDebug(s"Appending $partitionSchema ${file.partitionValues}")
vectorizedReader.initBatch(partitionSchema, file.partitionValues)
@@ -467,9 +489,9 @@ class ParquetFileFormat
// ParquetRecordReader returns UnsafeRow
val reader = if (pushed.isDefined && enableRecordFilter) {
val parquetFilter = FilterCompat.get(pushed.get, null)
- new ParquetRecordReader[UnsafeRow](new ParquetReadSupport, parquetFilter)
+ new ParquetRecordReader[UnsafeRow](new ParquetReadSupport(convertTz), parquetFilter)
} else {
- new ParquetRecordReader[UnsafeRow](new ParquetReadSupport)
+ new ParquetRecordReader[UnsafeRow](new ParquetReadSupport(convertTz))
}
reader.initialize(split, hadoopAttemptContext)
reader
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala
index 772d4565de548..ef67ea7d17cea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet
import java.util.Locale
+import org.apache.parquet.hadoop.ParquetOutputFormat
import org.apache.parquet.hadoop.metadata.CompressionCodecName
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
@@ -42,8 +43,15 @@ private[parquet] class ParquetOptions(
* Acceptable values are defined in [[shortParquetCompressionCodecNames]].
*/
val compressionCodecClassName: String = {
- val codecName = parameters.getOrElse("compression",
- sqlConf.parquetCompressionCodec).toLowerCase(Locale.ROOT)
+ // `compression`, `parquet.compression`(i.e., ParquetOutputFormat.COMPRESSION), and
+ // `spark.sql.parquet.compression.codec`
+ // are in order of precedence from highest to lowest.
+ val parquetCompressionConf = parameters.get(ParquetOutputFormat.COMPRESSION)
+ val codecName = parameters
+ .get("compression")
+ .orElse(parquetCompressionConf)
+ .getOrElse(sqlConf.parquetCompressionCodec)
+ .toLowerCase(Locale.ROOT)
if (!shortParquetCompressionCodecNames.contains(codecName)) {
val availableCodecs =
shortParquetCompressionCodecNames.keys.map(_.toLowerCase(Locale.ROOT))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala
index 2854cb1bc0c25..40ce5d5e0564e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.datasources.parquet
-import java.util.{Map => JMap}
+import java.util.{Map => JMap, TimeZone}
import scala.collection.JavaConverters._
@@ -48,9 +48,17 @@ import org.apache.spark.sql.types._
* Due to this reason, we no longer rely on [[ReadContext]] to pass requested schema from [[init()]]
* to [[prepareForRead()]], but use a private `var` for simplicity.
*/
-private[parquet] class ParquetReadSupport extends ReadSupport[UnsafeRow] with Logging {
+private[parquet] class ParquetReadSupport(val convertTz: Option[TimeZone])
+ extends ReadSupport[UnsafeRow] with Logging {
private var catalystRequestedSchema: StructType = _
+ def this() {
+ // We need a zero-arg constructor for SpecificParquetRecordReaderBase. But that is only
+ // used in the vectorized reader, where we get the convertTz value directly, and the value here
+ // is ignored.
+ this(None)
+ }
+
/**
* Called on executor side before [[prepareForRead()]] and instantiating actual Parquet record
* readers. Responsible for figuring out Parquet requested schema used for column pruning.
@@ -95,7 +103,8 @@ private[parquet] class ParquetReadSupport extends ReadSupport[UnsafeRow] with Lo
new ParquetRecordMaterializer(
parquetRequestedSchema,
ParquetReadSupport.expandUDT(catalystRequestedSchema),
- new ParquetToSparkSchemaConverter(conf))
+ new ParquetToSparkSchemaConverter(conf),
+ convertTz)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala
index 793755e9aaeb5..b2459dd0e8bba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.datasources.parquet
+import java.util.TimeZone
+
import org.apache.parquet.io.api.{GroupConverter, RecordMaterializer}
import org.apache.parquet.schema.MessageType
@@ -33,11 +35,12 @@ import org.apache.spark.sql.types.StructType
private[parquet] class ParquetRecordMaterializer(
parquetSchema: MessageType,
catalystSchema: StructType,
- schemaConverter: ParquetToSparkSchemaConverter)
+ schemaConverter: ParquetToSparkSchemaConverter,
+ convertTz: Option[TimeZone])
extends RecordMaterializer[UnsafeRow] {
private val rootConverter =
- new ParquetRowConverter(schemaConverter, parquetSchema, catalystSchema, NoopUpdater)
+ new ParquetRowConverter(schemaConverter, parquetSchema, catalystSchema, convertTz, NoopUpdater)
override def getCurrentRecord: UnsafeRow = rootConverter.currentRecord
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
index 10f6c3b4f15e3..1199725941842 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet
import java.math.{BigDecimal, BigInteger}
import java.nio.ByteOrder
+import java.util.TimeZone
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
@@ -117,12 +118,14 @@ private[parquet] class ParquetPrimitiveConverter(val updater: ParentContainerUpd
* @param parquetType Parquet schema of Parquet records
* @param catalystType Spark SQL schema that corresponds to the Parquet record type. User-defined
* types should have been expanded.
+ * @param convertTz the optional time zone to convert to for int96 data
* @param updater An updater which propagates converted field values to the parent container
*/
private[parquet] class ParquetRowConverter(
schemaConverter: ParquetToSparkSchemaConverter,
parquetType: GroupType,
catalystType: StructType,
+ convertTz: Option[TimeZone],
updater: ParentContainerUpdater)
extends ParquetGroupConverter(updater) with Logging {
@@ -151,6 +154,8 @@ private[parquet] class ParquetRowConverter(
|${catalystType.prettyJson}
""".stripMargin)
+ private val UTC = DateTimeUtils.TimeZoneUTC
+
/**
* Updater used together with field converters within a [[ParquetRowConverter]]. It propagates
* converted filed values to the `ordinal`-th cell in `currentRow`.
@@ -279,7 +284,9 @@ private[parquet] class ParquetRowConverter(
val buf = value.toByteBuffer.order(ByteOrder.LITTLE_ENDIAN)
val timeOfDayNanos = buf.getLong
val julianDay = buf.getInt
- updater.setLong(DateTimeUtils.fromJulianDay(julianDay, timeOfDayNanos))
+ val rawTime = DateTimeUtils.fromJulianDay(julianDay, timeOfDayNanos)
+ val adjTime = convertTz.map(DateTimeUtils.convertTz(rawTime, _, UTC)).getOrElse(rawTime)
+ updater.setLong(adjTime)
}
}
@@ -309,7 +316,7 @@ private[parquet] class ParquetRowConverter(
case t: StructType =>
new ParquetRowConverter(
- schemaConverter, parquetType.asGroupType(), t, new ParentContainerUpdater {
+ schemaConverter, parquetType.asGroupType(), t, convertTz, new ParentContainerUpdater {
override def set(value: Any): Unit = updater.set(value.asInstanceOf[InternalRow].copy())
})
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index 60c430bcfece2..f64e079539c4f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -39,7 +39,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] {
sparkSession.sessionState.conf.runSQLonFile && u.tableIdentifier.database.isDefined
}
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case u: UnresolvedRelation if maybeSQLFile(u) =>
try {
val dataSource = DataSource(
@@ -108,8 +108,9 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi
}
// Check if the specified data source match the data source of the existing table.
- val existingProvider = DataSource.lookupDataSource(existingTable.provider.get)
- val specifiedProvider = DataSource.lookupDataSource(tableDesc.provider.get)
+ val conf = sparkSession.sessionState.conf
+ val existingProvider = DataSource.lookupDataSource(existingTable.provider.get, conf)
+ val specifiedProvider = DataSource.lookupDataSource(tableDesc.provider.get, conf)
// TODO: Check that options from the resolved relation match the relation that we are
// inserting into (i.e. using the same compression).
if (existingProvider != specifiedProvider) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
index d0690445d7672..c661e9bd3b94c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
@@ -17,12 +17,16 @@
package org.apache.spark.sql.execution.datasources.text
+import java.io.Closeable
+
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.io.Text
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
import org.apache.spark.TaskContext
-import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
@@ -53,6 +57,14 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister {
}
}
+ override def isSplitable(
+ sparkSession: SparkSession,
+ options: Map[String, String],
+ path: Path): Boolean = {
+ val textOptions = new TextOptions(options)
+ super.isSplitable(sparkSession, options, path) && !textOptions.wholeText
+ }
+
override def inferSchema(
sparkSession: SparkSession,
options: Map[String, String],
@@ -97,14 +109,26 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister {
assert(
requiredSchema.length <= 1,
"Text data source only produces a single data column named \"value\".")
-
+ val textOptions = new TextOptions(options)
val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
+ readToUnsafeMem(broadcastedHadoopConf, requiredSchema, textOptions.wholeText)
+ }
+
+ private def readToUnsafeMem(
+ conf: Broadcast[SerializableConfiguration],
+ requiredSchema: StructType,
+ wholeTextMode: Boolean): (PartitionedFile) => Iterator[UnsafeRow] = {
+
(file: PartitionedFile) => {
- val reader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value)
+ val confValue = conf.value.value
+ val reader = if (!wholeTextMode) {
+ new HadoopFileLinesReader(file, confValue)
+ } else {
+ new HadoopFileWholeTextReader(file, confValue)
+ }
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => reader.close()))
-
if (requiredSchema.isEmpty) {
val emptyUnsafeRow = new UnsafeRow(0)
reader.map(_ => emptyUnsafeRow)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala
index 49bd7382f9cf3..2a661561ab51e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala
@@ -33,8 +33,15 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti
* Compression codec to use.
*/
val compressionCodec = parameters.get(COMPRESSION).map(CompressionCodecs.getCodecClassName)
+
+ /**
+ * wholetext - If true, read a file as a single row and not split by "\n".
+ */
+ val wholeText = parameters.getOrElse(WHOLETEXT, "false").toBoolean
+
}
private[text] object TextOptions {
val COMPRESSION = "compression"
+ val WHOLETEXT = "wholetext"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
index 7eb99a645001a..cba20dd902007 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
@@ -35,6 +35,16 @@ case class DataSourceV2Relation(
}
}
+/**
+ * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical
+ * to the non-streaming relation.
+ */
+class StreamingDataSourceV2Relation(
+ fullOutput: Seq[AttributeReference],
+ reader: DataSourceV2Reader) extends DataSourceV2Relation(fullOutput, reader) {
+ override def isStreaming: Boolean = true
+}
+
object DataSourceV2Relation {
def apply(reader: DataSourceV2Reader): DataSourceV2Relation = {
new DataSourceV2Relation(reader.readSchema().toAttributes, reader)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
index 3f243dc44e043..49c506bc560cf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
@@ -26,7 +26,10 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.LeafExecNode
import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.execution.streaming.StreamExecution
+import org.apache.spark.sql.execution.streaming.continuous.{ContinuousDataSourceRDD, ContinuousExecution, EpochCoordinatorRef, SetReaderPartitions}
import org.apache.spark.sql.sources.v2.reader._
+import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader
import org.apache.spark.sql.types.StructType
/**
@@ -52,10 +55,20 @@ case class DataSourceV2ScanExec(
}.asJava
}
- val inputRDD = new DataSourceRDD(sparkContext, readTasks)
- .asInstanceOf[RDD[InternalRow]]
+ val inputRDD = reader match {
+ case _: ContinuousReader =>
+ EpochCoordinatorRef.get(
+ sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env)
+ .askSync[Unit](SetReaderPartitions(readTasks.size()))
+
+ new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks)
+
+ case _ =>
+ new DataSourceRDD(sparkContext, readTasks)
+ }
+
val numOutputRows = longMetric("numOutputRows")
- inputRDD.map { r =>
+ inputRDD.asInstanceOf[RDD[InternalRow]].map { r =>
numOutputRows += 1
r
}
@@ -73,7 +86,7 @@ class RowToUnsafeRowReadTask(rowReadTask: ReadTask[Row], schema: StructType)
}
}
-class RowToUnsafeDataReader(rowReader: DataReader[Row], encoder: ExpressionEncoder[Row])
+class RowToUnsafeDataReader(val rowReader: DataReader[Row], encoder: ExpressionEncoder[Row])
extends DataReader[UnsafeRow] {
override def next: Boolean = rowReader.next
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala
new file mode 100644
index 0000000000000..5267f5f1580c3
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import java.util.regex.Pattern
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.sources.v2.{DataSourceV2, SessionConfigSupport}
+
+private[sql] object DataSourceV2Utils extends Logging {
+
+ /**
+ * Helper method that extracts and transforms session configs into k/v pairs, the k/v pairs will
+ * be used to create data source options.
+ * Only extract when `ds` implements [[SessionConfigSupport]], in this case we may fetch the
+ * specified key-prefix from `ds`, and extract session configs with config keys that start with
+ * `spark.datasource.$keyPrefix`. A session config `spark.datasource.$keyPrefix.xxx -> yyy` will
+ * be transformed into `xxx -> yyy`.
+ *
+ * @param ds a [[DataSourceV2]] object
+ * @param conf the session conf
+ * @return an immutable map that contains all the extracted and transformed k/v pairs.
+ */
+ def extractSessionConfigs(ds: DataSourceV2, conf: SQLConf): Map[String, String] = ds match {
+ case cs: SessionConfigSupport =>
+ val keyPrefix = cs.keyPrefix()
+ require(keyPrefix != null, "The data source config key prefix can't be null.")
+
+ val pattern = Pattern.compile(s"^spark\\.datasource\\.$keyPrefix\\.(.+)")
+
+ conf.getAllConfs.flatMap { case (key, value) =>
+ val m = pattern.matcher(key)
+ if (m.matches() && m.groupCount() > 0) {
+ Seq((m.group(1), value))
+ } else {
+ Seq.empty
+ }
+ }
+
+ case _ => Map.empty
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
index 0c1708131ae46..df034adf1e7d6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
@@ -40,12 +40,8 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel
// top-down, then we can simplify the logic here and only collect target operators.
val filterPushed = plan transformUp {
case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) =>
- // Non-deterministic expressions are stateful and we must keep the input sequence unchanged
- // to avoid changing the result. This means, we can't evaluate the filter conditions that
- // are after the first non-deterministic condition ahead. Here we only try to push down
- // deterministic conditions that are before the first non-deterministic condition.
- val (candidates, containingNonDeterministic) =
- splitConjunctivePredicates(condition).span(_.deterministic)
+ val (candidates, nonDeterministic) =
+ splitConjunctivePredicates(condition).partition(_.deterministic)
val stayUpFilters: Seq[Expression] = reader match {
case r: SupportsPushDownCatalystFilters =>
@@ -74,7 +70,7 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel
case _ => candidates
}
- val filterCondition = (stayUpFilters ++ containingNonDeterministic).reduceLeftOption(And)
+ val filterCondition = (stayUpFilters ++ nonDeterministic).reduceLeftOption(And)
val withFilter = filterCondition.map(Filter(_, r)).getOrElse(r)
if (withFilter.output == fields) {
withFilter
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
index b72d15ed15aed..f0bdf84bb7a84 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.datasources.v2
-import org.apache.spark.{SparkException, TaskContext}
+import org.apache.spark.{SparkEnv, SparkException, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
@@ -26,6 +26,9 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.streaming.StreamExecution
+import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions}
+import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter
import org.apache.spark.sql.sources.v2.writer._
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
@@ -58,10 +61,22 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan)
s"The input RDD has ${messages.length} partitions.")
try {
+ val runTask = writer match {
+ case w: ContinuousWriter =>
+ EpochCoordinatorRef.get(
+ sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env)
+ .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions))
+
+ (context: TaskContext, iter: Iterator[InternalRow]) =>
+ DataWritingSparkTask.runContinuous(writeTask, context, iter)
+ case _ =>
+ (context: TaskContext, iter: Iterator[InternalRow]) =>
+ DataWritingSparkTask.run(writeTask, context, iter)
+ }
+
sparkContext.runJob(
rdd,
- (context: TaskContext, iter: Iterator[InternalRow]) =>
- DataWritingSparkTask.run(writeTask, context, iter),
+ runTask,
rdd.partitions.indices,
(index, message: WriterCommitMessage) => messages(index) = message
)
@@ -70,6 +85,8 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan)
writer.commit(messages)
logInfo(s"Data source writer $writer committed.")
} catch {
+ case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] =>
+ // Interruption is how continuous queries are ended, so accept and ignore the exception.
case cause: Throwable =>
logError(s"Data source writer $writer is aborting.")
try {
@@ -109,6 +126,44 @@ object DataWritingSparkTask extends Logging {
logError(s"Writer for partition ${context.partitionId()} aborted.")
})
}
+
+ def runContinuous(
+ writeTask: DataWriterFactory[InternalRow],
+ context: TaskContext,
+ iter: Iterator[InternalRow]): WriterCommitMessage = {
+ val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber())
+ val epochCoordinator = EpochCoordinatorRef.get(
+ context.getLocalProperty(ContinuousExecution.RUN_ID_KEY),
+ SparkEnv.get)
+ val currentMsg: WriterCommitMessage = null
+ var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
+
+ do {
+ // write the data and commit this writer.
+ Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
+ try {
+ iter.foreach(dataWriter.write)
+ logInfo(s"Writer for partition ${context.partitionId()} is committing.")
+ val msg = dataWriter.commit()
+ logInfo(s"Writer for partition ${context.partitionId()} committed.")
+ epochCoordinator.send(
+ CommitPartitionEpoch(context.partitionId(), currentEpoch, msg)
+ )
+ currentEpoch += 1
+ } catch {
+ case _: InterruptedException =>
+ // Continuous shutdown always involves an interrupt. Just finish the task.
+ }
+ })(catchBlock = {
+ // If there is an error, abort this writer
+ logError(s"Writer for partition ${context.partitionId()} is aborting.")
+ dataWriter.abort()
+ logError(s"Writer for partition ${context.partitionId()} aborted.")
+ })
+ } while (!context.isInterrupted())
+
+ currentMsg
+ }
}
class InternalRowDataWriterFactory(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index 4e2ca37bc1a59..e3d28388c5470 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -17,10 +17,14 @@
package org.apache.spark.sql.execution.exchange
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec,
+ SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
/**
@@ -42,23 +46,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None
}
- /**
- * Given a required distribution, returns a partitioning that satisfies that distribution.
- * @param requiredDistribution The distribution that is required by the operator
- * @param numPartitions Used when the distribution doesn't require a specific number of partitions
- */
- private def createPartitioning(
- requiredDistribution: Distribution,
- numPartitions: Int): Partitioning = {
- requiredDistribution match {
- case AllTuples => SinglePartition
- case ClusteredDistribution(clustering, desiredPartitions) =>
- HashPartitioning(clustering, desiredPartitions.getOrElse(numPartitions))
- case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions)
- case dist => sys.error(s"Do not know how to satisfy distribution $dist")
- }
- }
-
/**
* Adds [[ExchangeCoordinator]] to [[ShuffleExchangeExec]]s if adaptive query execution is enabled
* and partitioning schemes of these [[ShuffleExchangeExec]]s support [[ExchangeCoordinator]].
@@ -84,8 +71,9 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
// shuffle data when we have more than one children because data generated by
// these children may not be partitioned in the same way.
// Please see the comment in withCoordinator for more details.
- val supportsDistribution =
- requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution])
+ val supportsDistribution = requiredChildDistributions.forall { dist =>
+ dist.isInstanceOf[ClusteredDistribution] || dist.isInstanceOf[HashClusteredDistribution]
+ }
children.length > 1 && supportsDistribution
}
@@ -138,8 +126,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
//
// It will be great to introduce a new Partitioning to represent the post-shuffle
// partitions when one post-shuffle partition includes multiple pre-shuffle partitions.
- val targetPartitioning =
- createPartitioning(distribution, defaultNumPreShufflePartitions)
+ val targetPartitioning = distribution.createPartitioning(defaultNumPreShufflePartitions)
assert(targetPartitioning.isInstanceOf[HashPartitioning])
ShuffleExchangeExec(targetPartitioning, child, Some(coordinator))
}
@@ -158,71 +145,56 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
assert(requiredChildDistributions.length == children.length)
assert(requiredChildOrderings.length == children.length)
- // Ensure that the operator's children satisfy their output distribution requirements:
+ // Ensure that the operator's children satisfy their output distribution requirements.
children = children.zip(requiredChildDistributions).map {
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
child
case (child, BroadcastDistribution(mode)) =>
BroadcastExchangeExec(mode, child)
case (child, distribution) =>
- ShuffleExchangeExec(createPartitioning(distribution, defaultNumPreShufflePartitions), child)
+ val numPartitions = distribution.requiredNumPartitions
+ .getOrElse(defaultNumPreShufflePartitions)
+ ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child)
}
- // If the operator has multiple children and specifies child output distributions (e.g. join),
- // then the children's output partitionings must be compatible:
- def requireCompatiblePartitioning(distribution: Distribution): Boolean = distribution match {
- case UnspecifiedDistribution => false
- case BroadcastDistribution(_) => false
+ // Get the indexes of children which have specified distribution requirements and need to have
+ // same number of partitions.
+ val childrenIndexes = requiredChildDistributions.zipWithIndex.filter {
+ case (UnspecifiedDistribution, _) => false
+ case (_: BroadcastDistribution, _) => false
case _ => true
- }
- if (children.length > 1
- && requiredChildDistributions.exists(requireCompatiblePartitioning)
- && !Partitioning.allCompatible(children.map(_.outputPartitioning))) {
+ }.map(_._2)
+
+ val childrenNumPartitions =
+ childrenIndexes.map(children(_).outputPartitioning.numPartitions).toSet
- // First check if the existing partitions of the children all match. This means they are
- // partitioned by the same partitioning into the same number of partitions. In that case,
- // don't try to make them match `defaultPartitions`, just use the existing partitioning.
- val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max
- val useExistingPartitioning = children.zip(requiredChildDistributions).forall {
- case (child, distribution) =>
- child.outputPartitioning.guarantees(
- createPartitioning(distribution, maxChildrenNumPartitions))
+ if (childrenNumPartitions.size > 1) {
+ // Get the number of partitions which is explicitly required by the distributions.
+ val requiredNumPartitions = {
+ val numPartitionsSet = childrenIndexes.flatMap {
+ index => requiredChildDistributions(index).requiredNumPartitions
+ }.toSet
+ assert(numPartitionsSet.size <= 1,
+ s"$operator have incompatible requirements of the number of partitions for its children")
+ numPartitionsSet.headOption
}
- children = if (useExistingPartitioning) {
- // We do not need to shuffle any child's output.
- children
- } else {
- // We need to shuffle at least one child's output.
- // Now, we will determine the number of partitions that will be used by created
- // partitioning schemes.
- val numPartitions = {
- // Let's see if we need to shuffle all child's outputs when we use
- // maxChildrenNumPartitions.
- val shufflesAllChildren = children.zip(requiredChildDistributions).forall {
- case (child, distribution) =>
- !child.outputPartitioning.guarantees(
- createPartitioning(distribution, maxChildrenNumPartitions))
- }
- // If we need to shuffle all children, we use defaultNumPreShufflePartitions as the
- // number of partitions. Otherwise, we use maxChildrenNumPartitions.
- if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions
- }
+ val targetNumPartitions = requiredNumPartitions.getOrElse(childrenNumPartitions.max)
- children.zip(requiredChildDistributions).map {
- case (child, distribution) =>
- val targetPartitioning = createPartitioning(distribution, numPartitions)
- if (child.outputPartitioning.guarantees(targetPartitioning)) {
- child
- } else {
- child match {
- // If child is an exchange, we replace it with
- // a new one having targetPartitioning.
- case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(targetPartitioning, c)
- case _ => ShuffleExchangeExec(targetPartitioning, child)
- }
+ children = children.zip(requiredChildDistributions).zipWithIndex.map {
+ case ((child, distribution), index) if childrenIndexes.contains(index) =>
+ if (child.outputPartitioning.numPartitions == targetNumPartitions) {
+ child
+ } else {
+ val defaultPartitioning = distribution.createPartitioning(targetNumPartitions)
+ child match {
+ // If child is an exchange, we replace it with a new one having defaultPartitioning.
+ case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(defaultPartitioning, c)
+ case _ => ShuffleExchangeExec(defaultPartitioning, child)
+ }
}
- }
+
+ case ((child, _), _) => child
}
}
@@ -248,13 +220,85 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
operator.withNewChildren(children)
}
+ private def reorder(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ expectedOrderOfKeys: Seq[Expression],
+ currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
+ val leftKeysBuffer = ArrayBuffer[Expression]()
+ val rightKeysBuffer = ArrayBuffer[Expression]()
+
+ expectedOrderOfKeys.foreach(expression => {
+ val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression))
+ leftKeysBuffer.append(leftKeys(index))
+ rightKeysBuffer.append(rightKeys(index))
+ })
+ (leftKeysBuffer, rightKeysBuffer)
+ }
+
+ private def reorderJoinKeys(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ leftPartitioning: Partitioning,
+ rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
+ if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
+ leftPartitioning match {
+ case HashPartitioning(leftExpressions, _)
+ if leftExpressions.length == leftKeys.length &&
+ leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
+ reorder(leftKeys, rightKeys, leftExpressions, leftKeys)
+
+ case _ => rightPartitioning match {
+ case HashPartitioning(rightExpressions, _)
+ if rightExpressions.length == rightKeys.length &&
+ rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
+ reorder(leftKeys, rightKeys, rightExpressions, rightKeys)
+
+ case _ => (leftKeys, rightKeys)
+ }
+ }
+ } else {
+ (leftKeys, rightKeys)
+ }
+ }
+
+ /**
+ * When the physical operators are created for JOIN, the ordering of join keys is based on order
+ * in which the join keys appear in the user query. That might not match with the output
+ * partitioning of the join node's children (thus leading to extra sort / shuffle being
+ * introduced). This rule will change the ordering of the join keys to match with the
+ * partitioning of the join nodes' children.
+ */
+ private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = {
+ plan.transformUp {
+ case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left,
+ right) =>
+ val (reorderedLeftKeys, reorderedRightKeys) =
+ reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
+ BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition,
+ left, right)
+
+ case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) =>
+ val (reorderedLeftKeys, reorderedRightKeys) =
+ reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
+ ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition,
+ left, right)
+
+ case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) =>
+ val (reorderedLeftKeys, reorderedRightKeys) =
+ reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
+ SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right)
+ }
+ }
+
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
- case operator @ ShuffleExchangeExec(partitioning, child, _) =>
- child.children match {
- case ShuffleExchangeExec(childPartitioning, baseChild, _)::Nil =>
- if (childPartitioning.guarantees(partitioning)) child else operator
+ // TODO: remove this after we create a physical operator for `RepartitionByExpression`.
+ case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) =>
+ child.outputPartitioning match {
+ case lower: HashPartitioning if upper.semanticEquals(lower) => child
case _ => operator
}
- case operator: SparkPlan => ensureDistributionAndOrdering(operator)
+ case operator: SparkPlan =>
+ ensureDistributionAndOrdering(reorderJoinPredicates(operator))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index c96ed6ef41016..1918fcc5482db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -134,19 +134,18 @@ case class BroadcastHashJoinExec(
// create a name for HashedRelation
val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
- val relationTerm = ctx.freshName("relation")
val clsName = broadcastRelation.value.getClass.getName
// At the end of the task, we update the avg hash probe.
val avgHashProbe = metricTerm(ctx, "avgHashProbe")
- val addTaskListener = genTaskListener(avgHashProbe, relationTerm)
- ctx.addMutableState(clsName, relationTerm,
- s"""
- | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy();
- | incPeakExecutionMemory($relationTerm.estimatedSize());
- | $addTaskListener
- """.stripMargin)
+ // Inline mutable state since not many join operations in a task
+ val relationTerm = ctx.addMutableState(clsName, "relation",
+ v => s"""
+ | $v = (($clsName) $broadcast.value()).asReadOnlyCopy();
+ | incPeakExecutionMemory($v.estimatedSize());
+ | ${genTaskListener(avgHashProbe, v)}
+ """.stripMargin, forceInline = true)
(broadcastRelation, relationTerm)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index d98cf852a1b48..1465346eb802d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -368,7 +368,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
// The minimum key
private var minKey = Long.MaxValue
- // The maxinum key
+ // The maximum key
private var maxKey = Long.MinValue
// The array to store the key and offset of UnsafeRow in the page.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala
deleted file mode 100644
index 534d8c5689c27..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala
+++ /dev/null
@@ -1,94 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.SparkPlan
-
-/**
- * When the physical operators are created for JOIN, the ordering of join keys is based on order
- * in which the join keys appear in the user query. That might not match with the output
- * partitioning of the join node's children (thus leading to extra sort / shuffle being
- * introduced). This rule will change the ordering of the join keys to match with the
- * partitioning of the join nodes' children.
- */
-class ReorderJoinPredicates extends Rule[SparkPlan] {
- private def reorderJoinKeys(
- leftKeys: Seq[Expression],
- rightKeys: Seq[Expression],
- leftPartitioning: Partitioning,
- rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
-
- def reorder(
- expectedOrderOfKeys: Seq[Expression],
- currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
- val leftKeysBuffer = ArrayBuffer[Expression]()
- val rightKeysBuffer = ArrayBuffer[Expression]()
-
- expectedOrderOfKeys.foreach(expression => {
- val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression))
- leftKeysBuffer.append(leftKeys(index))
- rightKeysBuffer.append(rightKeys(index))
- })
- (leftKeysBuffer, rightKeysBuffer)
- }
-
- if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
- leftPartitioning match {
- case HashPartitioning(leftExpressions, _)
- if leftExpressions.length == leftKeys.length &&
- leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
- reorder(leftExpressions, leftKeys)
-
- case _ => rightPartitioning match {
- case HashPartitioning(rightExpressions, _)
- if rightExpressions.length == rightKeys.length &&
- rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
- reorder(rightExpressions, rightKeys)
-
- case _ => (leftKeys, rightKeys)
- }
- }
- } else {
- (leftKeys, rightKeys)
- }
- }
-
- def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
- case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) =>
- val (reorderedLeftKeys, reorderedRightKeys) =
- reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
- BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition,
- left, right)
-
- case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) =>
- val (reorderedLeftKeys, reorderedRightKeys) =
- reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
- ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition,
- left, right)
-
- case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) =>
- val (reorderedLeftKeys, reorderedRightKeys) =
- reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
- SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right)
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
index 66e8031bb5191..897a4dae39f32 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
@@ -46,7 +46,7 @@ case class ShuffledHashJoinExec(
"avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe"))
override def requiredChildDistribution: Seq[Distribution] =
- ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+ HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
val buildDataSize = longMetric("buildDataSize")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 9c08ec71c1fde..2de2f30eb05d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -78,7 +78,7 @@ case class SortMergeJoinExec(
}
override def requiredChildDistribution: Seq[Distribution] =
- ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+ HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
override def outputOrdering: Seq[SortOrder] = joinType match {
// For inner join, orders of both sides keys should be kept.
@@ -422,10 +422,9 @@ case class SortMergeJoinExec(
*/
private def genScanner(ctx: CodegenContext): (String, String) = {
// Create class member for next row from both sides.
- val leftRow = ctx.freshName("leftRow")
- ctx.addMutableState("InternalRow", leftRow)
- val rightRow = ctx.freshName("rightRow")
- ctx.addMutableState("InternalRow", rightRow, s"$rightRow = null;")
+ // Inline mutable state since not many join operations in a task
+ val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true)
+ val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true)
// Create variables for join keys from both sides.
val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output)
@@ -436,14 +435,14 @@ case class SortMergeJoinExec(
val rightKeyVars = copyKeys(ctx, rightKeyTmpVars)
// A list to hold all matched rows from right side.
- val matches = ctx.freshName("matches")
val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName
val spillThreshold = getSpillThreshold
val inMemoryThreshold = getInMemoryThreshold
- ctx.addMutableState(clsName, matches,
- s"$matches = new $clsName($inMemoryThreshold, $spillThreshold);")
+ // Inline mutable state since not many join operations in a task
+ val matches = ctx.addMutableState(clsName, "matches",
+ v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true)
// Copy the left keys as class members so they could be used in next function call.
val matchedKeyVars = copyKeys(ctx, leftKeyVars)
@@ -507,32 +506,38 @@ case class SortMergeJoinExec(
}
/**
- * Creates variables for left part of result row.
+ * Creates variables and declarations for left part of result row.
*
* In order to defer the access after condition and also only access once in the loop,
* the variables should be declared separately from accessing the columns, we can't use the
* codegen of BoundReference here.
*/
- private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = {
+ private def createLeftVars(ctx: CodegenContext, leftRow: String): (Seq[ExprCode], Seq[String]) = {
ctx.INPUT_ROW = leftRow
left.output.zipWithIndex.map { case (a, i) =>
val value = ctx.freshName("value")
val valueCode = ctx.getValue(leftRow, a.dataType, i.toString)
- // declare it as class member, so we can access the column before or in the loop.
- ctx.addMutableState(ctx.javaType(a.dataType), value)
+ val javaType = ctx.javaType(a.dataType)
+ val defaultValue = ctx.defaultValue(a.dataType)
if (a.nullable) {
val isNull = ctx.freshName("isNull")
- ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull)
val code =
s"""
|$isNull = $leftRow.isNullAt($i);
- |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode);
+ |$value = $isNull ? $defaultValue : ($valueCode);
+ """.stripMargin
+ val leftVarsDecl =
+ s"""
+ |boolean $isNull = false;
+ |$javaType $value = $defaultValue;
""".stripMargin
- ExprCode(code, isNull, value)
+ (ExprCode(code, isNull, value), leftVarsDecl)
} else {
- ExprCode(s"$value = $valueCode;", "false", value)
+ val code = s"$value = $valueCode;"
+ val leftVarsDecl = s"""$javaType $value = $defaultValue;"""
+ (ExprCode(code, "false", value), leftVarsDecl)
}
- }
+ }.unzip
}
/**
@@ -572,15 +577,16 @@ case class SortMergeJoinExec(
override def needCopyResult: Boolean = true
override def doProduce(ctx: CodegenContext): String = {
- val leftInput = ctx.freshName("leftInput")
- ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];")
- val rightInput = ctx.freshName("rightInput")
- ctx.addMutableState("scala.collection.Iterator", rightInput, s"$rightInput = inputs[1];")
+ // Inline mutable state since not many join operations in a task
+ val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput",
+ v => s"$v = inputs[0];", forceInline = true)
+ val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput",
+ v => s"$v = inputs[1];", forceInline = true)
val (leftRow, matches) = genScanner(ctx)
// Create variables for row from both sides.
- val leftVars = createLeftVars(ctx, leftRow)
+ val (leftVars, leftVarDecl) = createLeftVars(ctx, leftRow)
val rightRow = ctx.freshName("rightRow")
val rightVars = createRightVar(ctx, rightRow)
@@ -617,6 +623,7 @@ case class SortMergeJoinExec(
s"""
|while (findNextInnerJoinRows($leftInput, $rightInput)) {
+ | ${leftVarDecl.mkString("\n")}
| ${beforeLoop.trim}
| scala.collection.Iterator $iterator = $matches.generateIterator();
| while ($iterator.hasNext()) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index 443b6898aaf75..8b24cad6cc597 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -74,8 +74,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
}
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
- val stopEarly = ctx.freshName("stopEarly")
- ctx.addMutableState(ctx.JAVA_BOOLEAN, stopEarly, s"$stopEarly = false;")
+ val stopEarly = ctx.addMutableState(ctx.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false
ctx.addNewFunction("stopEarly", s"""
@Override
@@ -83,8 +82,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
return $stopEarly;
}
""", inlineToOuterClass = true)
- val countTerm = ctx.freshName("count")
- ctx.addMutableState(ctx.JAVA_INT, countTerm, s"$countTerm = 0;")
+ val countTerm = ctx.addMutableState(ctx.JAVA_INT, "count") // init as count = 0
s"""
| if ($countTerm < $limit) {
| $countTerm += 1;
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 32d9ea1871c3d..351acb202861b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -462,7 +462,7 @@ case class CoGroupExec(
right: SparkPlan) extends BinaryExecNode with ObjectProducerExec {
override def requiredChildDistribution: Seq[Distribution] =
- ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil
+ HashClusteredDistribution(leftGroup) :: HashClusteredDistribution(rightGroup) :: Nil
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 9a94d771a01b0..dc5ba96e69aec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -24,14 +24,14 @@ import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.JavaConverters._
import org.apache.arrow.vector.VectorSchemaRoot
-import org.apache.arrow.vector.stream.{ArrowStreamReader, ArrowStreamWriter}
+import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter}
import org.apache.spark._
import org.apache.spark.api.python._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow.{ArrowUtils, ArrowWriter}
-import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
import org.apache.spark.util.Utils
/**
@@ -74,13 +74,9 @@ class ArrowPythonRunner(
val root = VectorSchemaRoot.create(arrowSchema, allocator)
val arrowWriter = ArrowWriter.create(root)
- var closed = false
-
context.addTaskCompletionListener { _ =>
- if (!closed) {
- root.close()
- allocator.close()
- }
+ root.close()
+ allocator.close()
}
val writer = new ArrowStreamWriter(root, null, dataOut)
@@ -102,7 +98,6 @@ class ArrowPythonRunner(
writer.end()
root.close()
allocator.close()
- closed = true
}
}
}
@@ -126,18 +121,11 @@ class ArrowPythonRunner(
private var schema: StructType = _
private var vectors: Array[ColumnVector] = _
- private var closed = false
-
context.addTaskCompletionListener { _ =>
- // todo: we need something like `reader.end()`, which release all the resources, but leave
- // the input stream open. `reader.close()` will close the socket and we can't reuse worker.
- // So here we simply not close the reader, which is problematic.
- if (!closed) {
- if (root != null) {
- root.close()
- }
- allocator.close()
+ if (reader != null) {
+ reader.close(false)
}
+ allocator.close()
}
private var batchLoaded = true
@@ -154,9 +142,8 @@ class ArrowPythonRunner(
batch.setNumRows(root.getRowCount)
batch
} else {
- root.close()
+ reader.close(false)
allocator.close()
- closed = true
// Reach end of stream. Call `read()` again to read control data.
read()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index 26ee25f633ea4..f4d83e8dc7c2b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -79,16 +79,19 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
} else {
StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
}
+
+ val fromJava = EvaluatePython.makeFromJava(resultType)
+
outputIterator.flatMap { pickedResult =>
val unpickledBatch = unpickle.loads(pickedResult)
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
}.map { result =>
if (udfs.length == 1) {
// fast path for single UDF
- mutableRow(0) = EvaluatePython.fromJava(result, resultType)
+ mutableRow(0) = fromJava(result)
mutableRow
} else {
- EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
+ fromJava(result).asInstanceOf[InternalRow]
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
index 9bbfa6018ba77..520afad287648 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
@@ -83,82 +83,134 @@ object EvaluatePython {
}
/**
- * Converts `obj` to the type specified by the data type, or returns null if the type of obj is
- * unexpected. Because Python doesn't enforce the type.
+ * Make a converter that converts `obj` to the type specified by the data type, or returns
+ * null if the type of obj is unexpected. Because Python doesn't enforce the type.
*/
- def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
- case (null, _) => null
-
- case (c: Boolean, BooleanType) => c
+ def makeFromJava(dataType: DataType): Any => Any = dataType match {
+ case BooleanType => (obj: Any) => nullSafeConvert(obj) {
+ case b: Boolean => b
+ }
- case (c: Byte, ByteType) => c
- case (c: Short, ByteType) => c.toByte
- case (c: Int, ByteType) => c.toByte
- case (c: Long, ByteType) => c.toByte
+ case ByteType => (obj: Any) => nullSafeConvert(obj) {
+ case c: Byte => c
+ case c: Short => c.toByte
+ case c: Int => c.toByte
+ case c: Long => c.toByte
+ }
- case (c: Byte, ShortType) => c.toShort
- case (c: Short, ShortType) => c
- case (c: Int, ShortType) => c.toShort
- case (c: Long, ShortType) => c.toShort
+ case ShortType => (obj: Any) => nullSafeConvert(obj) {
+ case c: Byte => c.toShort
+ case c: Short => c
+ case c: Int => c.toShort
+ case c: Long => c.toShort
+ }
- case (c: Byte, IntegerType) => c.toInt
- case (c: Short, IntegerType) => c.toInt
- case (c: Int, IntegerType) => c
- case (c: Long, IntegerType) => c.toInt
+ case IntegerType => (obj: Any) => nullSafeConvert(obj) {
+ case c: Byte => c.toInt
+ case c: Short => c.toInt
+ case c: Int => c
+ case c: Long => c.toInt
+ }
- case (c: Byte, LongType) => c.toLong
- case (c: Short, LongType) => c.toLong
- case (c: Int, LongType) => c.toLong
- case (c: Long, LongType) => c
+ case LongType => (obj: Any) => nullSafeConvert(obj) {
+ case c: Byte => c.toLong
+ case c: Short => c.toLong
+ case c: Int => c.toLong
+ case c: Long => c
+ }
- case (c: Float, FloatType) => c
- case (c: Double, FloatType) => c.toFloat
+ case FloatType => (obj: Any) => nullSafeConvert(obj) {
+ case c: Float => c
+ case c: Double => c.toFloat
+ }
- case (c: Float, DoubleType) => c.toDouble
- case (c: Double, DoubleType) => c
+ case DoubleType => (obj: Any) => nullSafeConvert(obj) {
+ case c: Float => c.toDouble
+ case c: Double => c
+ }
- case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, dt.precision, dt.scale)
+ case dt: DecimalType => (obj: Any) => nullSafeConvert(obj) {
+ case c: java.math.BigDecimal => Decimal(c, dt.precision, dt.scale)
+ }
- case (c: Int, DateType) => c
+ case DateType => (obj: Any) => nullSafeConvert(obj) {
+ case c: Int => c
+ }
- case (c: Long, TimestampType) => c
- // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs
- case (c: Int, TimestampType) => c.toLong
+ case TimestampType => (obj: Any) => nullSafeConvert(obj) {
+ case c: Long => c
+ // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs
+ case c: Int => c.toLong
+ }
- case (c, StringType) => UTF8String.fromString(c.toString)
+ case StringType => (obj: Any) => nullSafeConvert(obj) {
+ case _ => UTF8String.fromString(obj.toString)
+ }
- case (c: String, BinaryType) => c.getBytes(StandardCharsets.UTF_8)
- case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c
+ case BinaryType => (obj: Any) => nullSafeConvert(obj) {
+ case c: String => c.getBytes(StandardCharsets.UTF_8)
+ case c if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c
+ }
- case (c: java.util.List[_], ArrayType(elementType, _)) =>
- new GenericArrayData(c.asScala.map { e => fromJava(e, elementType)}.toArray)
+ case ArrayType(elementType, _) =>
+ val elementFromJava = makeFromJava(elementType)
- case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
- new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)))
+ (obj: Any) => nullSafeConvert(obj) {
+ case c: java.util.List[_] =>
+ new GenericArrayData(c.asScala.map { e => elementFromJava(e) }.toArray)
+ case c if c.getClass.isArray =>
+ new GenericArrayData(c.asInstanceOf[Array[_]].map(e => elementFromJava(e)))
+ }
- case (javaMap: java.util.Map[_, _], MapType(keyType, valueType, _)) =>
- ArrayBasedMapData(
- javaMap,
- (key: Any) => fromJava(key, keyType),
- (value: Any) => fromJava(value, valueType))
+ case MapType(keyType, valueType, _) =>
+ val keyFromJava = makeFromJava(keyType)
+ val valueFromJava = makeFromJava(valueType)
+
+ (obj: Any) => nullSafeConvert(obj) {
+ case javaMap: java.util.Map[_, _] =>
+ ArrayBasedMapData(
+ javaMap,
+ (key: Any) => keyFromJava(key),
+ (value: Any) => valueFromJava(value))
+ }
- case (c, StructType(fields)) if c.getClass.isArray =>
- val array = c.asInstanceOf[Array[_]]
- if (array.length != fields.length) {
- throw new IllegalStateException(
- s"Input row doesn't have expected number of values required by the schema. " +
- s"${fields.length} fields are required while ${array.length} values are provided."
- )
+ case StructType(fields) =>
+ val fieldsFromJava = fields.map(f => makeFromJava(f.dataType)).toArray
+
+ (obj: Any) => nullSafeConvert(obj) {
+ case c if c.getClass.isArray =>
+ val array = c.asInstanceOf[Array[_]]
+ if (array.length != fields.length) {
+ throw new IllegalStateException(
+ s"Input row doesn't have expected number of values required by the schema. " +
+ s"${fields.length} fields are required while ${array.length} values are provided."
+ )
+ }
+
+ val row = new GenericInternalRow(fields.length)
+ var i = 0
+ while (i < fields.length) {
+ row(i) = fieldsFromJava(i)(array(i))
+ i += 1
+ }
+ row
}
- new GenericInternalRow(array.zip(fields).map {
- case (e, f) => fromJava(e, f.dataType)
- })
- case (_, udt: UserDefinedType[_]) => fromJava(obj, udt.sqlType)
+ case udt: UserDefinedType[_] => makeFromJava(udt.sqlType)
+
+ case other => (obj: Any) => nullSafeConvert(other)(PartialFunction.empty)
+ }
- // all other unexpected type should be null, or we will have runtime exception
- // TODO(davies): we could improve this by try to cast the object to expected type
- case (c, _) => null
+ private def nullSafeConvert(input: Any)(f: PartialFunction[Any, Any]): Any = {
+ if (input == null) {
+ null
+ } else {
+ f.applyOrElse(input, {
+ // all other unexpected type should be null, or we will have runtime exception
+ // TODO(davies): we could improve this by try to cast the object to expected type
+ _: Any => null
+ })
+ }
}
private val module = "pyspark.sql.types"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index f5a4cbc4793e3..2f53fe788c7d0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -202,12 +202,12 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
private def trySplitFilter(plan: SparkPlan): SparkPlan = {
plan match {
case filter: FilterExec =>
- val (candidates, containingNonDeterministic) =
- splitConjunctivePredicates(filter.condition).span(_.deterministic)
+ val (candidates, nonDeterministic) =
+ splitConjunctivePredicates(filter.condition).partition(_.deterministic)
val (pushDown, rest) = candidates.partition(!hasPythonUDF(_))
if (pushDown.nonEmpty) {
val newChild = FilterExec(pushDown.reduceLeft(And), filter.child)
- FilterExec((rest ++ containingNonDeterministic).reduceLeft(And), newChild)
+ FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild)
} else {
filter
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
index ef27fbc2db7d9..d3f743d9eb61e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
@@ -29,9 +29,12 @@ case class PythonUDF(
func: PythonFunction,
dataType: DataType,
children: Seq[Expression],
- evalType: Int)
+ evalType: Int,
+ udfDeterministic: Boolean)
extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression {
+ override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)
+
override def toString: String = s"$name(${children.mkString(", ")})"
override def nullable: Boolean = true
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index 348e49e473ed3..50dca32cb7861 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -29,10 +29,11 @@ case class UserDefinedPythonFunction(
name: String,
func: PythonFunction,
dataType: DataType,
- pythonEvalType: Int) {
+ pythonEvalType: Int,
+ udfDeterministic: Boolean) {
def builder(e: Seq[Expression]): PythonUDF = {
- PythonUDF(name, func, dataType, e, pythonEvalType)
+ PythonUDF(name, func, dataType, e, pythonEvalType, udfDeterministic)
}
/** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSink.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSink.java
new file mode 100644
index 0000000000000..ac96c2765368f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSink.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming;
+
+/**
+ * The shared interface between V1 and V2 streaming sinks.
+ *
+ * This is a temporary interface for compatibility during migration. It should not be implemented
+ * directly, and will be removed in future versions.
+ */
+public interface BaseStreamingSink {
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java
new file mode 100644
index 0000000000000..c44b8af2552f0
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming;
+
+/**
+ * The shared interface between V1 streaming sources and V2 streaming readers.
+ *
+ * This is a temporary interface for compatibility during migration. It should not be implemented
+ * directly, and will be removed in future versions.
+ */
+public interface BaseStreamingSource {
+ /** Stop this source and free any resources it has allocated. */
+ void stop();
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala
similarity index 93%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala
index 5e24e8fc4e3cc..5b114242558dc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala
@@ -42,10 +42,10 @@ import org.apache.spark.sql.SparkSession
* line 1: version
* line 2: metadata (optional json string)
*/
-class BatchCommitLog(sparkSession: SparkSession, path: String)
+class CommitLog(sparkSession: SparkSession, path: String)
extends HDFSMetadataLog[String](sparkSession, path) {
- import BatchCommitLog._
+ import CommitLog._
def add(batchId: Long): Unit = {
super.add(batchId, EMPTY_JSON)
@@ -53,7 +53,7 @@ class BatchCommitLog(sparkSession: SparkSession, path: String)
override def add(batchId: Long, metadata: String): Boolean = {
throw new UnsupportedOperationException(
- "BatchCommitLog does not take any metadata, use 'add(batchId)' instead")
+ "CommitLog does not take any metadata, use 'add(batchId)' instead")
}
override protected def deserialize(in: InputStream): String = {
@@ -76,7 +76,7 @@ class BatchCommitLog(sparkSession: SparkSession, path: String)
}
}
-object BatchCommitLog {
+object CommitLog {
private val VERSION = 1
private val EMPTY_JSON = "{}"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
index 6bd0696622005..2715fa93d0e98 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
@@ -118,13 +118,14 @@ class FileStreamSink(
throw new RuntimeException(s"Partition column $col not found in schema ${data.schema}")
}
}
+ val qe = data.queryExecution
FileFormatWriter.write(
sparkSession = sparkSession,
- queryExecution = data.queryExecution,
+ plan = qe.executedPlan,
fileFormat = fileFormat,
committer = committer,
- outputSpec = FileFormatWriter.OutputSpec(path, Map.empty),
+ outputSpec = FileFormatWriter.OutputSpec(path, Map.empty, qe.analyzed.output),
hadoopConf = hadoopConf,
partitionColumns = partitionColumns,
bucketSpec = None,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceOffset.scala
index 06d0fe6c18c1e..a2b49d944a688 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceOffset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceOffset.scala
@@ -24,6 +24,7 @@ import org.json4s.jackson.Serialization
/**
* Offset for the [[FileStreamSource]].
+ *
* @param logOffset Position in the [[FileStreamSourceLog]]
*/
case class FileStreamSourceOffset(logOffset: Long) extends Offset {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index 29f38fab3f896..80769d728b8f1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -23,8 +23,10 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Attribut
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution}
import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
+import org.apache.spark.sql.types.IntegerType
import org.apache.spark.util.CompletionIterator
/**
@@ -60,8 +62,27 @@ case class FlatMapGroupsWithStateExec(
import GroupStateImpl._
private val isTimeoutEnabled = timeoutConf != NoTimeout
- val stateManager = new FlatMapGroupsWithState_StateManager(stateEncoder, isTimeoutEnabled)
- val watermarkPresent = child.output.exists {
+ private val timestampTimeoutAttribute =
+ AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)()
+ private val stateAttributes: Seq[Attribute] = {
+ val encSchemaAttribs = stateEncoder.schema.toAttributes
+ if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs
+ }
+ // Get the serializer for the state, taking into account whether we need to save timestamps
+ private val stateSerializer = {
+ val encoderSerializer = stateEncoder.namedExpressions
+ if (isTimeoutEnabled) {
+ encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP)
+ } else {
+ encoderSerializer
+ }
+ }
+ // Get the deserializer for the state. Note that this must be done in the driver, as
+ // resolving and binding of deserializer expressions to the encoded type can be safely done
+ // only in the driver.
+ private val stateDeserializer = stateEncoder.resolveAndBind().deserializer
+
+ private val watermarkPresent = child.output.exists {
case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true
case _ => false
}
@@ -92,11 +113,11 @@ case class FlatMapGroupsWithStateExec(
child.execute().mapPartitionsWithStateStore[InternalRow](
getStateInfo,
groupingAttributes.toStructType,
- stateManager.stateSchema,
+ stateAttributes.toStructType,
indexOrdinal = None,
sqlContext.sessionState,
Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
- val processor = new InputProcessor(store)
+ val updater = new StateStoreUpdater(store)
// If timeout is based on event time, then filter late data based on watermark
val filteredIter = watermarkPredicateForData match {
@@ -111,7 +132,7 @@ case class FlatMapGroupsWithStateExec(
// all the data has been processed. This is to ensure that the timeout information of all
// the keys with data is updated before they are processed for timeouts.
val outputIterator =
- processor.processNewData(filteredIter) ++ processor.processTimedOutState()
+ updater.updateStateForKeysWithData(filteredIter) ++ updater.updateStateForTimedOutKeys()
// Return an iterator of all the rows generated by all the keys, such that when fully
// consumed, all the state updates will be committed by the state store
@@ -126,7 +147,7 @@ case class FlatMapGroupsWithStateExec(
}
/** Helper class to update the state store */
- class InputProcessor(store: StateStore) {
+ class StateStoreUpdater(store: StateStore) {
// Converters for translating input keys, values, output data between rows and Java objects
private val getKeyObj =
@@ -135,6 +156,14 @@ case class FlatMapGroupsWithStateExec(
ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
+ // Converters for translating state between rows and Java objects
+ private val getStateObjFromRow = ObjectOperator.deserializeRowToObject(
+ stateDeserializer, stateAttributes)
+ private val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer)
+
+ // Index of the additional metadata fields in the state row
+ private val timeoutTimestampIndex = stateAttributes.indexOf(timestampTimeoutAttribute)
+
// Metrics
private val numUpdatedStateRows = longMetric("numUpdatedStateRows")
private val numOutputRows = longMetric("numOutputRows")
@@ -143,19 +172,20 @@ case class FlatMapGroupsWithStateExec(
* For every group, get the key, values and corresponding state and call the function,
* and return an iterator of rows
*/
- def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = {
+ def updateStateForKeysWithData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = {
val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output)
groupedIter.flatMap { case (keyRow, valueRowIter) =>
val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
callFunctionAndUpdateState(
- stateManager.getState(store, keyUnsafeRow),
+ keyUnsafeRow,
valueRowIter,
+ store.get(keyUnsafeRow),
hasTimedOut = false)
}
}
/** Find the groups that have timeout set and are timing out right now, and call the function */
- def processTimedOutState(): Iterator[InternalRow] = {
+ def updateStateForTimedOutKeys(): Iterator[InternalRow] = {
if (isTimeoutEnabled) {
val timeoutThreshold = timeoutConf match {
case ProcessingTimeTimeout => batchTimestampMs.get
@@ -164,11 +194,12 @@ case class FlatMapGroupsWithStateExec(
throw new IllegalStateException(
s"Cannot filter timed out keys for $timeoutConf")
}
- val timingOutKeys = stateManager.getAllState(store).filter { state =>
- state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold
+ val timingOutKeys = store.getRange(None, None).filter { rowPair =>
+ val timeoutTimestamp = getTimeoutTimestamp(rowPair.value)
+ timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold
}
- timingOutKeys.flatMap { stateData =>
- callFunctionAndUpdateState(stateData, Iterator.empty, hasTimedOut = true)
+ timingOutKeys.flatMap { rowPair =>
+ callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true)
}
} else Iterator.empty
}
@@ -178,19 +209,22 @@ case class FlatMapGroupsWithStateExec(
* iterator. Note that the store updating is lazy, that is, the store will be updated only
* after the returned iterator is fully consumed.
*
- * @param stateData All the data related to the state to be updated
+ * @param keyRow Row representing the key, cannot be null
* @param valueRowIter Iterator of values as rows, cannot be null, but can be empty
+ * @param prevStateRow Row representing the previous state, can be null
* @param hasTimedOut Whether this function is being called for a key timeout
*/
private def callFunctionAndUpdateState(
- stateData: FlatMapGroupsWithState_StateData,
+ keyRow: UnsafeRow,
valueRowIter: Iterator[InternalRow],
+ prevStateRow: UnsafeRow,
hasTimedOut: Boolean): Iterator[InternalRow] = {
- val keyObj = getKeyObj(stateData.keyRow) // convert key to objects
+ val keyObj = getKeyObj(keyRow) // convert key to objects
val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects
- val groupState = GroupStateImpl.createForStreaming(
- Option(stateData.stateObj),
+ val stateObj = getStateObj(prevStateRow)
+ val keyedState = GroupStateImpl.createForStreaming(
+ Option(stateObj),
batchTimestampMs.getOrElse(NO_TIMESTAMP),
eventTimeWatermark.getOrElse(NO_TIMESTAMP),
timeoutConf,
@@ -198,24 +232,50 @@ case class FlatMapGroupsWithStateExec(
watermarkPresent)
// Call function, get the returned objects and convert them to rows
- val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj =>
+ val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj =>
numOutputRows += 1
getOutputRow(obj)
}
// When the iterator is consumed, then write changes to state
def onIteratorCompletion: Unit = {
- if (groupState.hasRemoved && groupState.getTimeoutTimestamp == NO_TIMESTAMP) {
- stateManager.removeState(store, stateData.keyRow)
+
+ val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp
+ // If the state has not yet been set but timeout has been set, then
+ // we have to generate a row to save the timeout. However, attempting serialize
+ // null using case class encoder throws -
+ // java.lang.NullPointerException: Null value appeared in non-nullable field:
+ // If the schema is inferred from a Scala tuple / case class, or a Java bean, please
+ // try to use scala.Option[_] or other nullable types.
+ if (!keyedState.exists && currentTimeoutTimestamp != NO_TIMESTAMP) {
+ throw new IllegalStateException(
+ "Cannot set timeout when state is not defined, that is, state has not been" +
+ "initialized or has been removed")
+ }
+
+ if (keyedState.hasRemoved) {
+ store.remove(keyRow)
numUpdatedStateRows += 1
+
} else {
- val currentTimeoutTimestamp = groupState.getTimeoutTimestamp
- val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp
- val shouldWriteState = groupState.hasUpdated || groupState.hasRemoved || hasTimeoutChanged
+ val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow)
+ val stateRowToWrite = if (keyedState.hasUpdated) {
+ getStateRow(keyedState.get)
+ } else {
+ prevStateRow
+ }
+
+ val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp
+ val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged
if (shouldWriteState) {
- val updatedStateObj = if (groupState.exists) groupState.get else null
- stateManager.putState(store, stateData.keyRow, updatedStateObj, currentTimeoutTimestamp)
+ if (stateRowToWrite == null) {
+ // This should never happen because checks in GroupStateImpl should avoid cases
+ // where empty state would need to be written
+ throw new IllegalStateException("Attempting to write empty state")
+ }
+ setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp)
+ store.put(keyRow, stateRowToWrite)
numUpdatedStateRows += 1
}
}
@@ -224,5 +284,28 @@ case class FlatMapGroupsWithStateExec(
// Return an iterator of rows such that fully consumed, the updated state value will be saved
CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion)
}
+
+ /** Returns the state as Java object if defined */
+ def getStateObj(stateRow: UnsafeRow): Any = {
+ if (stateRow != null) getStateObjFromRow(stateRow) else null
+ }
+
+ /** Returns the row for an updated state */
+ def getStateRow(obj: Any): UnsafeRow = {
+ assert(obj != null)
+ getStateRowFromObj(obj)
+ }
+
+ /** Returns the timeout timestamp of a state row is set */
+ def getTimeoutTimestamp(stateRow: UnsafeRow): Long = {
+ if (isTimeoutEnabled && stateRow != null) {
+ stateRow.getLong(timeoutTimestampIndex)
+ } else NO_TIMESTAMP
+ }
+
+ /** Set the timestamp in a state row */
+ def setTimeoutTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = {
+ if (isTimeoutEnabled) stateRow.setLong(timeoutTimestampIndex, timeoutTimestamps)
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala
index 43cf0ef1da8ca..6e8154d58d4c6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala
@@ -266,6 +266,20 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path:
}
}
+ /**
+ * Removes all log entries later than thresholdBatchId (exclusive).
+ */
+ def purgeAfter(thresholdBatchId: Long): Unit = {
+ val batchIds = fileManager.list(metadataPath, batchFilesFilter)
+ .map(f => pathToBatchId(f.getPath))
+
+ for (batchId <- batchIds if batchId > thresholdBatchId) {
+ val path = batchIdToPath(batchId)
+ fileManager.delete(path)
+ logTrace(s"Removed metadata log file: $path")
+ }
+ }
+
private def createFileManager(): FileManager = {
val hadoopConf = sparkSession.sessionState.newHadoopConf()
try {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala
index b84e6ce64c611..66b11ecddf233 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala
@@ -17,15 +17,11 @@
package org.apache.spark.sql.execution.streaming
-import java.{util => ju}
-
-import scala.collection.mutable
-
import com.codahale.metrics.{Gauge, MetricRegistry}
import org.apache.spark.internal.Logging
import org.apache.spark.metrics.source.{Source => CodahaleSource}
-import org.apache.spark.util.Clock
+import org.apache.spark.sql.streaming.StreamingQueryProgress
/**
* Serves metrics from a [[org.apache.spark.sql.streaming.StreamingQuery]] to
@@ -39,14 +35,17 @@ class MetricsReporter(
// Metric names should not have . in them, so that all the metrics of a query are identified
// together in Ganglia as a single metric group
- registerGauge("inputRate-total", () => stream.lastProgress.inputRowsPerSecond)
- registerGauge("processingRate-total", () => stream.lastProgress.processedRowsPerSecond)
- registerGauge("latency", () => stream.lastProgress.durationMs.get("triggerExecution").longValue())
-
- private def registerGauge[T](name: String, f: () => T)(implicit num: Numeric[T]): Unit = {
+ registerGauge("inputRate-total", _.inputRowsPerSecond, 0.0)
+ registerGauge("processingRate-total", _.processedRowsPerSecond, 0.0)
+ registerGauge("latency", _.durationMs.get("triggerExecution").longValue(), 0L)
+
+ private def registerGauge[T](
+ name: String,
+ f: StreamingQueryProgress => T,
+ default: T): Unit = {
synchronized {
metricRegistry.register(name, new Gauge[T] {
- override def getValue: T = f()
+ override def getValue: T = Option(stream.lastProgress).map(f).getOrElse(default)
})
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
new file mode 100644
index 0000000000000..42240eeb58d4b
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
@@ -0,0 +1,492 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.Optional
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.{ArrayBuffer, Map => MutableMap}
+
+import org.apache.spark.sql.{Dataset, SparkSession}
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.execution.SQLExecution
+import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2}
+import org.apache.spark.sql.sources.v2.DataSourceV2Options
+import org.apache.spark.sql.sources.v2.streaming.{MicroBatchReadSupport, MicroBatchWriteSupport}
+import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset => OffsetV2}
+import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger}
+import org.apache.spark.util.{Clock, Utils}
+
+class MicroBatchExecution(
+ sparkSession: SparkSession,
+ name: String,
+ checkpointRoot: String,
+ analyzedPlan: LogicalPlan,
+ sink: BaseStreamingSink,
+ trigger: Trigger,
+ triggerClock: Clock,
+ outputMode: OutputMode,
+ extraOptions: Map[String, String],
+ deleteCheckpointOnStop: Boolean)
+ extends StreamExecution(
+ sparkSession, name, checkpointRoot, analyzedPlan, sink,
+ trigger, triggerClock, outputMode, deleteCheckpointOnStop) {
+
+ @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty
+
+ private val triggerExecutor = trigger match {
+ case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock)
+ case OneTimeTrigger => OneTimeExecutor()
+ case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger")
+ }
+
+ override lazy val logicalPlan: LogicalPlan = {
+ assert(queryExecutionThread eq Thread.currentThread,
+ "logicalPlan must be initialized in QueryExecutionThread " +
+ s"but the current thread was ${Thread.currentThread}")
+ var nextSourceId = 0L
+ val toExecutionRelationMap = MutableMap[StreamingRelation, StreamingExecutionRelation]()
+ val v2ToExecutionRelationMap = MutableMap[StreamingRelationV2, StreamingExecutionRelation]()
+ // We transform each distinct streaming relation into a StreamingExecutionRelation, keeping a
+ // map as we go to ensure each identical relation gets the same StreamingExecutionRelation
+ // object. For each microbatch, the StreamingExecutionRelation will be replaced with a logical
+ // plan for the data within that batch.
+ // Note that we have to use the previous `output` as attributes in StreamingExecutionRelation,
+ // since the existing logical plan has already used those attributes. The per-microbatch
+ // transformation is responsible for replacing attributes with their final values.
+ val _logicalPlan = analyzedPlan.transform {
+ case streamingRelation@StreamingRelation(dataSource, _, output) =>
+ toExecutionRelationMap.getOrElseUpdate(streamingRelation, {
+ // Materialize source to avoid creating it in every batch
+ val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
+ val source = dataSource.createSource(metadataPath)
+ nextSourceId += 1
+ StreamingExecutionRelation(source, output)(sparkSession)
+ })
+ case s @ StreamingRelationV2(source: MicroBatchReadSupport, _, options, output, _) =>
+ v2ToExecutionRelationMap.getOrElseUpdate(s, {
+ // Materialize source to avoid creating it in every batch
+ val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
+ val reader = source.createMicroBatchReader(
+ Optional.empty(), // user specified schema
+ metadataPath,
+ new DataSourceV2Options(options.asJava))
+ nextSourceId += 1
+ StreamingExecutionRelation(reader, output)(sparkSession)
+ })
+ case s @ StreamingRelationV2(_, _, _, output, v1Relation) =>
+ v2ToExecutionRelationMap.getOrElseUpdate(s, {
+ // Materialize source to avoid creating it in every batch
+ val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
+ assert(v1Relation.isDefined, "v2 execution didn't match but v1 was unavailable")
+ val source = v1Relation.get.dataSource.createSource(metadataPath)
+ nextSourceId += 1
+ StreamingExecutionRelation(source, output)(sparkSession)
+ })
+ }
+ sources = _logicalPlan.collect { case s: StreamingExecutionRelation => s.source }
+ uniqueSources = sources.distinct
+ _logicalPlan
+ }
+
+ /**
+ * Repeatedly attempts to run batches as data arrives.
+ */
+ protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = {
+ triggerExecutor.execute(() => {
+ startTrigger()
+
+ if (isActive) {
+ reportTimeTaken("triggerExecution") {
+ if (currentBatchId < 0) {
+ // We'll do this initialization only once
+ populateStartOffsets(sparkSessionForStream)
+ sparkSession.sparkContext.setJobDescription(getBatchDescriptionString)
+ logDebug(s"Stream running from $committedOffsets to $availableOffsets")
+ } else {
+ constructNextBatch()
+ }
+ if (dataAvailable) {
+ currentStatus = currentStatus.copy(isDataAvailable = true)
+ updateStatusMessage("Processing new data")
+ runBatch(sparkSessionForStream)
+ }
+ }
+ // Report trigger as finished and construct progress object.
+ finishTrigger(dataAvailable)
+ if (dataAvailable) {
+ // Update committed offsets.
+ commitLog.add(currentBatchId)
+ committedOffsets ++= availableOffsets
+ logDebug(s"batch ${currentBatchId} committed")
+ // We'll increase currentBatchId after we complete processing current batch's data
+ currentBatchId += 1
+ sparkSession.sparkContext.setJobDescription(getBatchDescriptionString)
+ } else {
+ currentStatus = currentStatus.copy(isDataAvailable = false)
+ updateStatusMessage("Waiting for data to arrive")
+ Thread.sleep(pollingDelayMs)
+ }
+ }
+ updateStatusMessage("Waiting for next trigger")
+ isActive
+ })
+ }
+
+ /**
+ * Populate the start offsets to start the execution at the current offsets stored in the sink
+ * (i.e. avoid reprocessing data that we have already processed). This function must be called
+ * before any processing occurs and will populate the following fields:
+ * - currentBatchId
+ * - committedOffsets
+ * - availableOffsets
+ * The basic structure of this method is as follows:
+ *
+ * Identify (from the offset log) the offsets used to run the last batch
+ * IF last batch exists THEN
+ * Set the next batch to be executed as the last recovered batch
+ * Check the commit log to see which batch was committed last
+ * IF the last batch was committed THEN
+ * Call getBatch using the last batch start and end offsets
+ * // ^^^^ above line is needed since some sources assume last batch always re-executes
+ * Setup for a new batch i.e., start = last batch end, and identify new end
+ * DONE
+ * ELSE
+ * Identify a brand new batch
+ * DONE
+ */
+ private def populateStartOffsets(sparkSessionToRunBatches: SparkSession): Unit = {
+ offsetLog.getLatest() match {
+ case Some((latestBatchId, nextOffsets)) =>
+ /* First assume that we are re-executing the latest known batch
+ * in the offset log */
+ currentBatchId = latestBatchId
+ availableOffsets = nextOffsets.toStreamProgress(sources)
+ /* Initialize committed offsets to a committed batch, which at this
+ * is the second latest batch id in the offset log. */
+ if (latestBatchId != 0) {
+ val secondLatestBatchId = offsetLog.get(latestBatchId - 1).getOrElse {
+ throw new IllegalStateException(s"batch ${latestBatchId - 1} doesn't exist")
+ }
+ committedOffsets = secondLatestBatchId.toStreamProgress(sources)
+ }
+
+ // update offset metadata
+ nextOffsets.metadata.foreach { metadata =>
+ OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf)
+ offsetSeqMetadata = OffsetSeqMetadata(
+ metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf)
+ }
+
+ /* identify the current batch id: if commit log indicates we successfully processed the
+ * latest batch id in the offset log, then we can safely move to the next batch
+ * i.e., committedBatchId + 1 */
+ commitLog.getLatest() match {
+ case Some((latestCommittedBatchId, _)) =>
+ if (latestBatchId == latestCommittedBatchId) {
+ /* The last batch was successfully committed, so we can safely process a
+ * new next batch but first:
+ * Make a call to getBatch using the offsets from previous batch.
+ * because certain sources (e.g., KafkaSource) assume on restart the last
+ * batch will be executed before getOffset is called again. */
+ availableOffsets.foreach {
+ case (source: Source, end: Offset) =>
+ if (committedOffsets.get(source).map(_ != end).getOrElse(true)) {
+ val start = committedOffsets.get(source)
+ source.getBatch(start, end)
+ }
+ case nonV1Tuple =>
+ // The V2 API does not have the same edge case requiring getBatch to be called
+ // here, so we do nothing here.
+ }
+ currentBatchId = latestCommittedBatchId + 1
+ committedOffsets ++= availableOffsets
+ // Construct a new batch be recomputing availableOffsets
+ constructNextBatch()
+ } else if (latestCommittedBatchId < latestBatchId - 1) {
+ logWarning(s"Batch completion log latest batch id is " +
+ s"${latestCommittedBatchId}, which is not trailing " +
+ s"batchid $latestBatchId by one")
+ }
+ case None => logInfo("no commit log present")
+ }
+ logDebug(s"Resuming at batch $currentBatchId with committed offsets " +
+ s"$committedOffsets and available offsets $availableOffsets")
+ case None => // We are starting this stream for the first time.
+ logInfo(s"Starting new streaming query.")
+ currentBatchId = 0
+ constructNextBatch()
+ }
+ }
+
+ /**
+ * Returns true if there is any new data available to be processed.
+ */
+ private def dataAvailable: Boolean = {
+ availableOffsets.exists {
+ case (source, available) =>
+ committedOffsets
+ .get(source)
+ .map(committed => committed != available)
+ .getOrElse(true)
+ }
+ }
+
+ /**
+ * Queries all of the sources to see if any new data is available. When there is new data the
+ * batchId counter is incremented and a new log entry is written with the newest offsets.
+ */
+ private def constructNextBatch(): Unit = {
+ // Check to see what new data is available.
+ val hasNewData = {
+ awaitProgressLock.lock()
+ try {
+ // Generate a map from each unique source to the next available offset.
+ val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map {
+ case s: Source =>
+ updateStatusMessage(s"Getting offsets from $s")
+ reportTimeTaken("getOffset") {
+ (s, s.getOffset)
+ }
+ case s: MicroBatchReader =>
+ updateStatusMessage(s"Getting offsets from $s")
+ reportTimeTaken("getOffset") {
+ // Once v1 streaming source execution is gone, we can refactor this away.
+ // For now, we set the range here to get the source to infer the available end offset,
+ // get that offset, and then set the range again when we later execute.
+ s.setOffsetRange(
+ toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))),
+ Optional.empty())
+
+ (s, Some(s.getEndOffset))
+ }
+ }.toMap
+ availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get)
+
+ if (dataAvailable) {
+ true
+ } else {
+ noNewData = true
+ false
+ }
+ } finally {
+ awaitProgressLock.unlock()
+ }
+ }
+ if (hasNewData) {
+ var batchWatermarkMs = offsetSeqMetadata.batchWatermarkMs
+ // Update the eventTime watermarks if we find any in the plan.
+ if (lastExecution != null) {
+ lastExecution.executedPlan.collect {
+ case e: EventTimeWatermarkExec => e
+ }.zipWithIndex.foreach {
+ case (e, index) if e.eventTimeStats.value.count > 0 =>
+ logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}")
+ val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs
+ val prevWatermarkMs = watermarkMsMap.get(index)
+ if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) {
+ watermarkMsMap.put(index, newWatermarkMs)
+ }
+
+ // Populate 0 if we haven't seen any data yet for this watermark node.
+ case (_, index) =>
+ if (!watermarkMsMap.isDefinedAt(index)) {
+ watermarkMsMap.put(index, 0)
+ }
+ }
+
+ // Update the global watermark to the minimum of all watermark nodes.
+ // This is the safest option, because only the global watermark is fault-tolerant. Making
+ // it the minimum of all individual watermarks guarantees it will never advance past where
+ // any individual watermark operator would be if it were in a plan by itself.
+ if(!watermarkMsMap.isEmpty) {
+ val newWatermarkMs = watermarkMsMap.minBy(_._2)._2
+ if (newWatermarkMs > batchWatermarkMs) {
+ logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms")
+ batchWatermarkMs = newWatermarkMs
+ } else {
+ logDebug(
+ s"Event time didn't move: $newWatermarkMs < " +
+ s"$batchWatermarkMs")
+ }
+ }
+ }
+ offsetSeqMetadata = offsetSeqMetadata.copy(
+ batchWatermarkMs = batchWatermarkMs,
+ batchTimestampMs = triggerClock.getTimeMillis()) // Current batch timestamp in milliseconds
+
+ updateStatusMessage("Writing offsets to log")
+ reportTimeTaken("walCommit") {
+ assert(offsetLog.add(
+ currentBatchId,
+ availableOffsets.toOffsetSeq(sources, offsetSeqMetadata)),
+ s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId")
+ logInfo(s"Committed offsets for batch $currentBatchId. " +
+ s"Metadata ${offsetSeqMetadata.toString}")
+
+ // NOTE: The following code is correct because runStream() processes exactly one
+ // batch at a time. If we add pipeline parallelism (multiple batches in flight at
+ // the same time), this cleanup logic will need to change.
+
+ // Now that we've updated the scheduler's persistent checkpoint, it is safe for the
+ // sources to discard data from the previous batch.
+ if (currentBatchId != 0) {
+ val prevBatchOff = offsetLog.get(currentBatchId - 1)
+ if (prevBatchOff.isDefined) {
+ prevBatchOff.get.toStreamProgress(sources).foreach {
+ case (src: Source, off) => src.commit(off)
+ case (reader: MicroBatchReader, off) =>
+ reader.commit(reader.deserializeOffset(off.json))
+ }
+ } else {
+ throw new IllegalStateException(s"batch $currentBatchId doesn't exist")
+ }
+ }
+
+ // It is now safe to discard the metadata beyond the minimum number to retain.
+ // Note that purge is exclusive, i.e. it purges everything before the target ID.
+ if (minLogEntriesToMaintain < currentBatchId) {
+ offsetLog.purge(currentBatchId - minLogEntriesToMaintain)
+ commitLog.purge(currentBatchId - minLogEntriesToMaintain)
+ }
+ }
+ } else {
+ awaitProgressLock.lock()
+ try {
+ // Wake up any threads that are waiting for the stream to progress.
+ awaitProgressLockCondition.signalAll()
+ } finally {
+ awaitProgressLock.unlock()
+ }
+ }
+ }
+
+ /**
+ * Processes any data available between `availableOffsets` and `committedOffsets`.
+ * @param sparkSessionToRunBatch Isolated [[SparkSession]] to run this batch with.
+ */
+ private def runBatch(sparkSessionToRunBatch: SparkSession): Unit = {
+ // Request unprocessed data from all sources.
+ newData = reportTimeTaken("getBatch") {
+ availableOffsets.flatMap {
+ case (source: Source, available)
+ if committedOffsets.get(source).map(_ != available).getOrElse(true) =>
+ val current = committedOffsets.get(source)
+ val batch = source.getBatch(current, available)
+ assert(batch.isStreaming,
+ s"DataFrame returned by getBatch from $source did not have isStreaming=true\n" +
+ s"${batch.queryExecution.logical}")
+ logDebug(s"Retrieving data from $source: $current -> $available")
+ Some(source -> batch.logicalPlan)
+ case (reader: MicroBatchReader, available)
+ if committedOffsets.get(reader).map(_ != available).getOrElse(true) =>
+ val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json))
+ reader.setOffsetRange(
+ toJava(current),
+ Optional.of(available.asInstanceOf[OffsetV2]))
+ logDebug(s"Retrieving data from $reader: $current -> $available")
+ Some(reader ->
+ new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader))
+ case _ => None
+ }
+ }
+
+ // A list of attributes that will need to be updated.
+ val replacements = new ArrayBuffer[(Attribute, Attribute)]
+ // Replace sources in the logical plan with data that has arrived since the last batch.
+ val newBatchesPlan = logicalPlan transform {
+ case StreamingExecutionRelation(source, output) =>
+ newData.get(source).map { dataPlan =>
+ assert(output.size == dataPlan.output.size,
+ s"Invalid batch: ${Utils.truncatedString(output, ",")} != " +
+ s"${Utils.truncatedString(dataPlan.output, ",")}")
+ replacements ++= output.zip(dataPlan.output)
+ dataPlan
+ }.getOrElse {
+ LocalRelation(output, isStreaming = true)
+ }
+ }
+
+ // Rewire the plan to use the new attributes that were returned by the source.
+ val replacementMap = AttributeMap(replacements)
+ val newAttributePlan = newBatchesPlan transformAllExpressions {
+ case a: Attribute if replacementMap.contains(a) =>
+ replacementMap(a).withMetadata(a.metadata)
+ case ct: CurrentTimestamp =>
+ CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs,
+ ct.dataType)
+ case cd: CurrentDate =>
+ CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs,
+ cd.dataType, cd.timeZoneId)
+ }
+
+ val triggerLogicalPlan = sink match {
+ case _: Sink => newAttributePlan
+ case s: MicroBatchWriteSupport =>
+ val writer = s.createMicroBatchWriter(
+ s"$runId",
+ currentBatchId,
+ newAttributePlan.schema,
+ outputMode,
+ new DataSourceV2Options(extraOptions.asJava))
+ assert(writer.isPresent, "microbatch writer must always be present")
+ WriteToDataSourceV2(writer.get, newAttributePlan)
+ case _ => throw new IllegalArgumentException(s"unknown sink type for $sink")
+ }
+
+ reportTimeTaken("queryPlanning") {
+ lastExecution = new IncrementalExecution(
+ sparkSessionToRunBatch,
+ triggerLogicalPlan,
+ outputMode,
+ checkpointFile("state"),
+ runId,
+ currentBatchId,
+ offsetSeqMetadata)
+ lastExecution.executedPlan // Force the lazy generation of execution plan
+ }
+
+ val nextBatch =
+ new Dataset(sparkSessionToRunBatch, lastExecution, RowEncoder(lastExecution.analyzed.schema))
+
+ reportTimeTaken("addBatch") {
+ SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) {
+ sink match {
+ case s: Sink => s.addBatch(currentBatchId, nextBatch)
+ case s: MicroBatchWriteSupport =>
+ // This doesn't accumulate any data - it just forces execution of the microbatch writer.
+ nextBatch.collect()
+ }
+ }
+ }
+
+ awaitProgressLock.lock()
+ try {
+ // Wake up any threads that are waiting for the stream to progress.
+ awaitProgressLockCondition.signalAll()
+ } finally {
+ awaitProgressLock.unlock()
+ }
+ }
+
+ private def toJava(scalaOption: Option[OffsetV2]): Optional[OffsetV2] = {
+ Optional.ofNullable(scalaOption.orNull)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java
new file mode 100644
index 0000000000000..80aa5505db991
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming;
+
+/**
+ * This is an internal, deprecated interface. New source implementations should use the
+ * org.apache.spark.sql.sources.v2.reader.Offset class, which is the one that will be supported
+ * in the long term.
+ *
+ * This class will be removed in a future release.
+ */
+public abstract class Offset {
+ /**
+ * A JSON-serialized representation of an Offset that is
+ * used for saving offsets to the offset log.
+ * Note: We assume that equivalent/equal offsets serialize to
+ * identical JSON strings.
+ *
+ * @return JSON string encoding
+ */
+ public abstract String json();
+
+ /**
+ * Equality based on JSON string representation. We leverage the
+ * JSON representation for normalization between the Offset's
+ * in memory and on disk representations.
+ */
+ @Override
+ public boolean equals(Object obj) {
+ if (obj instanceof Offset) {
+ return this.json().equals(((Offset) obj).json());
+ } else {
+ return false;
+ }
+ }
+
+ @Override
+ public int hashCode() {
+ return this.json().hashCode();
+ }
+
+ @Override
+ public String toString() {
+ return this.json();
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
index 4e0a468b962a2..a1b63a6de3823 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
@@ -38,7 +38,7 @@ case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[OffsetSeqMet
* This method is typically used to associate a serialized offset with actual sources (which
* cannot be serialized).
*/
- def toStreamProgress(sources: Seq[Source]): StreamProgress = {
+ def toStreamProgress(sources: Seq[BaseStreamingSource]): StreamProgress = {
assert(sources.size == offsets.size)
new StreamProgress ++ sources.zip(offsets).collect { case (s, Some(o)) => (s, o) }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
index b1c3a8ab235ab..d1e5be9c12762 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
@@ -42,7 +42,7 @@ import org.apache.spark.util.Clock
trait ProgressReporter extends Logging {
case class ExecutionStats(
- inputRows: Map[Source, Long],
+ inputRows: Map[BaseStreamingSource, Long],
stateOperators: Seq[StateOperatorProgress],
eventTimeStats: Map[String, String])
@@ -53,11 +53,11 @@ trait ProgressReporter extends Logging {
protected def triggerClock: Clock
protected def logicalPlan: LogicalPlan
protected def lastExecution: QueryExecution
- protected def newData: Map[Source, DataFrame]
+ protected def newData: Map[BaseStreamingSource, LogicalPlan]
protected def availableOffsets: StreamProgress
protected def committedOffsets: StreamProgress
- protected def sources: Seq[Source]
- protected def sink: Sink
+ protected def sources: Seq[BaseStreamingSource]
+ protected def sink: BaseStreamingSink
protected def offsetSeqMetadata: OffsetSeqMetadata
protected def currentBatchId: Long
protected def sparkSession: SparkSession
@@ -225,12 +225,12 @@ trait ProgressReporter extends Logging {
//
// 3. For each source, we sum the metrics of the associated execution plan leaves.
//
- val logicalPlanLeafToSource = newData.flatMap { case (source, df) =>
- df.logicalPlan.collectLeaves().map { leaf => leaf -> source }
+ val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) =>
+ logicalPlan.collectLeaves().map { leaf => leaf -> source }
}
val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming
val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves()
- val numInputRows: Map[Source, Long] =
+ val numInputRows: Map[BaseStreamingSource, Long] =
if (allLogicalPlanLeaves.size == allExecPlanLeaves.size) {
val execLeafToSource = allLogicalPlanLeaves.zip(allExecPlanLeaves).flatMap {
case (lp, ep) => logicalPlanLeafToSource.get(lp).map { source => ep -> source }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala
index 077a4778e34a8..66eb0169ac1ec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming
import java.io._
import java.nio.charset.StandardCharsets
+import java.util.Optional
import java.util.concurrent.TimeUnit
import org.apache.commons.io.IOUtils
@@ -28,7 +29,12 @@ import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
+import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader
+import org.apache.spark.sql.execution.streaming.sources.RateStreamMicroBatchReader
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
+import org.apache.spark.sql.sources.v2._
+import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport}
+import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader}
import org.apache.spark.sql.types._
import org.apache.spark.util.{ManualClock, SystemClock}
@@ -46,7 +52,8 @@ import org.apache.spark.util.{ManualClock, SystemClock}
* generated rows. The source will try its best to reach `rowsPerSecond`, but the query may
* be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed.
*/
-class RateSourceProvider extends StreamSourceProvider with DataSourceRegister {
+class RateSourceProvider extends StreamSourceProvider with DataSourceRegister
+ with DataSourceV2 with ContinuousReadSupport {
override def sourceSchema(
sqlContext: SQLContext,
@@ -100,6 +107,14 @@ class RateSourceProvider extends StreamSourceProvider with DataSourceRegister {
params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing
)
}
+
+ override def createContinuousReader(
+ schema: Optional[StructType],
+ checkpointLocation: String,
+ options: DataSourceV2Options): ContinuousReader = {
+ new RateStreamContinuousReader(options)
+ }
+
override def shortName(): String = "rate"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala
new file mode 100644
index 0000000000000..261d69bbd9843
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.json4s.DefaultFormats
+import org.json4s.jackson.Serialization
+
+import org.apache.spark.sql.sources.v2
+
+case class RateStreamOffset(partitionToValueAndRunTimeMs: Map[Int, ValueRunTimeMsPair])
+ extends v2.streaming.reader.Offset {
+ implicit val defaultFormats: DefaultFormats = DefaultFormats
+ override val json = Serialization.write(partitionToValueAndRunTimeMs)
+}
+
+
+case class ValueRunTimeMsPair(value: Long, runTimeMs: Long)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SerializedOffset.scala
similarity index 55%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SerializedOffset.scala
index 4efcee0f8f9d6..129cfed860eb6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SerializedOffset.scala
@@ -17,39 +17,6 @@
package org.apache.spark.sql.execution.streaming
-/**
- * An offset is a monotonically increasing metric used to track progress in the computation of a
- * stream. Since offsets are retrieved from a [[Source]] by a single thread, we know the global
- * ordering of two [[Offset]] instances. We do assume that if two offsets are `equal` then no
- * new data has arrived.
- */
-abstract class Offset {
-
- /**
- * Equality based on JSON string representation. We leverage the
- * JSON representation for normalization between the Offset's
- * in memory and on disk representations.
- */
- override def equals(obj: Any): Boolean = obj match {
- case o: Offset => this.json == o.json
- case _ => false
- }
-
- override def hashCode(): Int = this.json.hashCode
-
- override def toString(): String = this.json.toString
-
- /**
- * A JSON-serialized representation of an Offset that is
- * used for saving offsets to the offset log.
- * Note: We assume that equivalent/equal offsets serialize to
- * identical JSON strings.
- *
- * @return JSON string encoding
- */
- def json: String
-}
-
/**
* Used when loading a JSON serialized offset from external storage.
* We are currently not responsible for converting JSON serialized
@@ -58,3 +25,5 @@ abstract class Offset {
* that accepts a [[SerializedOffset]] for doing the conversion.
*/
case class SerializedOffset(override val json: String) extends Offset
+
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala
index d10cd3044ecdf..34bc085d920c1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.DataFrame
* exactly once semantics a sink must be idempotent in the face of multiple attempts to add the same
* batch.
*/
-trait Sink {
+trait Sink extends BaseStreamingSink {
/**
* Adds a batch of data to this sink. The data for a given `batchId` is deterministic and if
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala
index 311942f6dbd84..dbbd59e06909c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.types.StructType
* monotonically increasing notion of progress that can be represented as an [[Offset]]. Spark
* will regularly query each [[Source]] to see if any more data is available.
*/
-trait Source {
+trait Source extends BaseStreamingSource {
/** Returns the schema of the data from this source */
def schema: StructType
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 406560c260f07..24a8b000df0c1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -22,10 +22,9 @@ import java.nio.channels.ClosedByInterruptException
import java.util.UUID
import java.util.concurrent.{CountDownLatch, ExecutionException, TimeUnit}
import java.util.concurrent.atomic.AtomicReference
-import java.util.concurrent.locks.ReentrantLock
+import java.util.concurrent.locks.{Condition, ReentrantLock}
import scala.collection.mutable.{Map => MutableMap}
-import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
import com.google.common.util.concurrent.UncheckedExecutionException
@@ -33,10 +32,8 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.encoders.RowEncoder
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp}
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
-import org.apache.spark.sql.execution.{QueryExecution, SQLExecution}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.command.StreamingExplainCommand
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming._
@@ -47,6 +44,7 @@ trait State
case object INITIALIZING extends State
case object ACTIVE extends State
case object TERMINATED extends State
+case object RECONFIGURING extends State
/**
* Manages the execution of a streaming Spark SQL query that is occurring in a separate thread.
@@ -57,12 +55,12 @@ case object TERMINATED extends State
* @param deleteCheckpointOnStop whether to delete the checkpoint if the query is stopped without
* errors
*/
-class StreamExecution(
+abstract class StreamExecution(
override val sparkSession: SparkSession,
override val name: String,
private val checkpointRoot: String,
analyzedPlan: LogicalPlan,
- val sink: Sink,
+ val sink: BaseStreamingSink,
val trigger: Trigger,
val triggerClock: Clock,
val outputMode: OutputMode,
@@ -71,16 +69,16 @@ class StreamExecution(
import org.apache.spark.sql.streaming.StreamingQueryListener._
- private val pollingDelayMs = sparkSession.sessionState.conf.streamingPollingDelay
+ protected val pollingDelayMs: Long = sparkSession.sessionState.conf.streamingPollingDelay
- private val minBatchesToRetain = sparkSession.sessionState.conf.minBatchesToRetain
- require(minBatchesToRetain > 0, "minBatchesToRetain has to be positive")
+ protected val minLogEntriesToMaintain: Int = sparkSession.sessionState.conf.minBatchesToRetain
+ require(minLogEntriesToMaintain > 0, "minBatchesToRetain has to be positive")
/**
* A lock used to wait/notify when batches complete. Use a fair lock to avoid thread starvation.
*/
- private val awaitBatchLock = new ReentrantLock(true)
- private val awaitBatchLockCondition = awaitBatchLock.newCondition()
+ protected val awaitProgressLock = new ReentrantLock(true)
+ protected val awaitProgressLockCondition = awaitProgressLock.newCondition()
private val initializationLatch = new CountDownLatch(1)
private val startLatch = new CountDownLatch(1)
@@ -89,9 +87,11 @@ class StreamExecution(
val resolvedCheckpointRoot = {
val checkpointPath = new Path(checkpointRoot)
val fs = checkpointPath.getFileSystem(sparkSession.sessionState.newHadoopConf())
- checkpointPath.makeQualified(fs.getUri(), fs.getWorkingDirectory()).toUri.toString
+ checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toUri.toString
}
+ def logicalPlan: LogicalPlan
+
/**
* Tracks how much data we have processed and committed to the sink or state store from each
* input source.
@@ -148,59 +148,25 @@ class StreamExecution(
* Pretty identified string of printing in logs. Format is
* If name is set "queryName [id = xyz, runId = abc]" else "[id = xyz, runId = abc]"
*/
- private val prettyIdString =
+ protected val prettyIdString =
Option(name).map(_ + " ").getOrElse("") + s"[id = $id, runId = $runId]"
- /**
- * All stream sources present in the query plan. This will be set when generating logical plan.
- */
- @volatile protected var sources: Seq[Source] = Seq.empty
-
/**
* A list of unique sources in the query plan. This will be set when generating logical plan.
*/
- @volatile private var uniqueSources: Seq[Source] = Seq.empty
-
- override lazy val logicalPlan: LogicalPlan = {
- assert(microBatchThread eq Thread.currentThread,
- "logicalPlan must be initialized in StreamExecutionThread " +
- s"but the current thread was ${Thread.currentThread}")
- var nextSourceId = 0L
- val toExecutionRelationMap = MutableMap[StreamingRelation, StreamingExecutionRelation]()
- val _logicalPlan = analyzedPlan.transform {
- case streamingRelation@StreamingRelation(dataSource, _, output) =>
- toExecutionRelationMap.getOrElseUpdate(streamingRelation, {
- // Materialize source to avoid creating it in every batch
- val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
- val source = dataSource.createSource(metadataPath)
- nextSourceId += 1
- // We still need to use the previous `output` instead of `source.schema` as attributes in
- // "df.logicalPlan" has already used attributes of the previous `output`.
- StreamingExecutionRelation(source, output)(sparkSession)
- })
- }
- sources = _logicalPlan.collect { case s: StreamingExecutionRelation => s.source }
- uniqueSources = sources.distinct
- _logicalPlan
- }
-
- private val triggerExecutor = trigger match {
- case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock)
- case OneTimeTrigger => OneTimeExecutor()
- case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger")
- }
+ @volatile protected var uniqueSources: Seq[BaseStreamingSource] = Seq.empty
/** Defines the internal state of execution */
- private val state = new AtomicReference[State](INITIALIZING)
+ protected val state = new AtomicReference[State](INITIALIZING)
@volatile
var lastExecution: IncrementalExecution = _
/** Holds the most recent input data for each source. */
- protected var newData: Map[Source, DataFrame] = _
+ protected var newData: Map[BaseStreamingSource, LogicalPlan] = _
@volatile
- private var streamDeathCause: StreamingQueryException = null
+ protected var streamDeathCause: StreamingQueryException = null
/* Get the call site in the caller thread; will pass this into the micro batch thread */
private val callSite = Utils.getCallSite()
@@ -214,13 +180,13 @@ class StreamExecution(
* [[org.apache.spark.util.UninterruptibleThread]] to workaround KAFKA-1894: interrupting a
* running `KafkaConsumer` may cause endless loop.
*/
- val microBatchThread =
- new StreamExecutionThread(s"stream execution thread for $prettyIdString") {
+ val queryExecutionThread: QueryExecutionThread =
+ new QueryExecutionThread(s"stream execution thread for $prettyIdString") {
override def run(): Unit = {
// To fix call site like "run at :0", we bridge the call site from the caller
// thread to this micro batch thread
sparkSession.sparkContext.setCallSite(callSite)
- runBatches()
+ runStream()
}
}
@@ -237,7 +203,7 @@ class StreamExecution(
* fully processed, and its output was committed to the sink, hence no need to process it again.
* This is used (for instance) during restart, to help identify which batch to run next.
*/
- val batchCommitLog = new BatchCommitLog(sparkSession, checkpointFile("commits"))
+ val commitLog = new CommitLog(sparkSession, checkpointFile("commits"))
/** Whether all fields of the query have been initialized */
private def isInitialized: Boolean = state.get != INITIALIZING
@@ -249,7 +215,7 @@ class StreamExecution(
override def exception: Option[StreamingQueryException] = Option(streamDeathCause)
/** Returns the path of a file with `name` in the checkpoint directory. */
- private def checkpointFile(name: String): String =
+ protected def checkpointFile(name: String): String =
new Path(new Path(resolvedCheckpointRoot), name).toUri.toString
/**
@@ -258,20 +224,25 @@ class StreamExecution(
*/
def start(): Unit = {
logInfo(s"Starting $prettyIdString. Use $resolvedCheckpointRoot to store the query checkpoint.")
- microBatchThread.setDaemon(true)
- microBatchThread.start()
+ queryExecutionThread.setDaemon(true)
+ queryExecutionThread.start()
startLatch.await() // Wait until thread started and QueryStart event has been posted
}
/**
- * Repeatedly attempts to run batches as data arrives.
+ * Run the activated stream until stopped.
+ */
+ protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit
+
+ /**
+ * Activate the stream and then wrap a callout to runActivatedStream, handling start and stop.
*
* Note that this method ensures that [[QueryStartedEvent]] and [[QueryTerminatedEvent]] are
* posted such that listeners are guaranteed to get a start event before a termination.
* Furthermore, this method also ensures that [[QueryStartedEvent]] event is posted before the
* `start()` method returns.
*/
- private def runBatches(): Unit = {
+ private def runStream(): Unit = {
try {
sparkSession.sparkContext.setJobGroup(runId.toString, getBatchDescriptionString,
interruptOnCancel = true)
@@ -294,56 +265,18 @@ class StreamExecution(
logicalPlan
// Isolated spark session to run the batches with.
- val sparkSessionToRunBatches = sparkSession.cloneSession()
+ val sparkSessionForStream = sparkSession.cloneSession()
// Adaptive execution can change num shuffle partitions, disallow
- sparkSessionToRunBatches.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false")
+ sparkSessionForStream.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false")
// Disable cost-based join optimization as we do not want stateful operations to be rearranged
- sparkSessionToRunBatches.conf.set(SQLConf.CBO_ENABLED.key, "false")
+ sparkSessionForStream.conf.set(SQLConf.CBO_ENABLED.key, "false")
offsetSeqMetadata = OffsetSeqMetadata(
- batchWatermarkMs = 0, batchTimestampMs = 0, sparkSessionToRunBatches.conf)
+ batchWatermarkMs = 0, batchTimestampMs = 0, sparkSessionForStream.conf)
if (state.compareAndSet(INITIALIZING, ACTIVE)) {
// Unblock `awaitInitialization`
initializationLatch.countDown()
-
- triggerExecutor.execute(() => {
- startTrigger()
-
- if (isActive) {
- reportTimeTaken("triggerExecution") {
- if (currentBatchId < 0) {
- // We'll do this initialization only once
- populateStartOffsets(sparkSessionToRunBatches)
- sparkSession.sparkContext.setJobDescription(getBatchDescriptionString)
- logDebug(s"Stream running from $committedOffsets to $availableOffsets")
- } else {
- constructNextBatch()
- }
- if (dataAvailable) {
- currentStatus = currentStatus.copy(isDataAvailable = true)
- updateStatusMessage("Processing new data")
- runBatch(sparkSessionToRunBatches)
- }
- }
- // Report trigger as finished and construct progress object.
- finishTrigger(dataAvailable)
- if (dataAvailable) {
- // Update committed offsets.
- batchCommitLog.add(currentBatchId)
- committedOffsets ++= availableOffsets
- logDebug(s"batch ${currentBatchId} committed")
- // We'll increase currentBatchId after we complete processing current batch's data
- currentBatchId += 1
- sparkSession.sparkContext.setJobDescription(getBatchDescriptionString)
- } else {
- currentStatus = currentStatus.copy(isDataAvailable = false)
- updateStatusMessage("Waiting for data to arrive")
- Thread.sleep(pollingDelayMs)
- }
- }
- updateStatusMessage("Waiting for next trigger")
- isActive
- })
+ runActivatedStream(sparkSessionForStream)
updateStatusMessage("Stopped")
} else {
// `stop()` is already called. Let `finally` finish the cleanup.
@@ -372,7 +305,7 @@ class StreamExecution(
if (!NonFatal(e)) {
throw e
}
- } finally microBatchThread.runUninterruptibly {
+ } finally queryExecutionThread.runUninterruptibly {
// The whole `finally` block must run inside `runUninterruptibly` to avoid being interrupted
// when a query is stopped by the user. We need to make sure the following codes finish
// otherwise it may throw `InterruptedException` to `UncaughtExceptionHandler` (SPARK-21248).
@@ -409,12 +342,12 @@ class StreamExecution(
}
}
} finally {
- awaitBatchLock.lock()
+ awaitProgressLock.lock()
try {
// Wake up any threads that are waiting for the stream to progress.
- awaitBatchLockCondition.signalAll()
+ awaitProgressLockCondition.signalAll()
} finally {
- awaitBatchLock.unlock()
+ awaitProgressLock.unlock()
}
terminationLatch.countDown()
}
@@ -447,302 +380,12 @@ class StreamExecution(
}
}
- /**
- * Populate the start offsets to start the execution at the current offsets stored in the sink
- * (i.e. avoid reprocessing data that we have already processed). This function must be called
- * before any processing occurs and will populate the following fields:
- * - currentBatchId
- * - committedOffsets
- * - availableOffsets
- * The basic structure of this method is as follows:
- *
- * Identify (from the offset log) the offsets used to run the last batch
- * IF last batch exists THEN
- * Set the next batch to be executed as the last recovered batch
- * Check the commit log to see which batch was committed last
- * IF the last batch was committed THEN
- * Call getBatch using the last batch start and end offsets
- * // ^^^^ above line is needed since some sources assume last batch always re-executes
- * Setup for a new batch i.e., start = last batch end, and identify new end
- * DONE
- * ELSE
- * Identify a brand new batch
- * DONE
- */
- private def populateStartOffsets(sparkSessionToRunBatches: SparkSession): Unit = {
- offsetLog.getLatest() match {
- case Some((latestBatchId, nextOffsets)) =>
- /* First assume that we are re-executing the latest known batch
- * in the offset log */
- currentBatchId = latestBatchId
- availableOffsets = nextOffsets.toStreamProgress(sources)
- /* Initialize committed offsets to a committed batch, which at this
- * is the second latest batch id in the offset log. */
- if (latestBatchId != 0) {
- val secondLatestBatchId = offsetLog.get(latestBatchId - 1).getOrElse {
- throw new IllegalStateException(s"batch ${latestBatchId - 1} doesn't exist")
- }
- committedOffsets = secondLatestBatchId.toStreamProgress(sources)
- }
-
- // update offset metadata
- nextOffsets.metadata.foreach { metadata =>
- OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf)
- offsetSeqMetadata = OffsetSeqMetadata(
- metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf)
- }
-
- /* identify the current batch id: if commit log indicates we successfully processed the
- * latest batch id in the offset log, then we can safely move to the next batch
- * i.e., committedBatchId + 1 */
- batchCommitLog.getLatest() match {
- case Some((latestCommittedBatchId, _)) =>
- if (latestBatchId == latestCommittedBatchId) {
- /* The last batch was successfully committed, so we can safely process a
- * new next batch but first:
- * Make a call to getBatch using the offsets from previous batch.
- * because certain sources (e.g., KafkaSource) assume on restart the last
- * batch will be executed before getOffset is called again. */
- availableOffsets.foreach { ao: (Source, Offset) =>
- val (source, end) = ao
- if (committedOffsets.get(source).map(_ != end).getOrElse(true)) {
- val start = committedOffsets.get(source)
- source.getBatch(start, end)
- }
- }
- currentBatchId = latestCommittedBatchId + 1
- committedOffsets ++= availableOffsets
- // Construct a new batch be recomputing availableOffsets
- constructNextBatch()
- } else if (latestCommittedBatchId < latestBatchId - 1) {
- logWarning(s"Batch completion log latest batch id is " +
- s"${latestCommittedBatchId}, which is not trailing " +
- s"batchid $latestBatchId by one")
- }
- case None => logInfo("no commit log present")
- }
- logDebug(s"Resuming at batch $currentBatchId with committed offsets " +
- s"$committedOffsets and available offsets $availableOffsets")
- case None => // We are starting this stream for the first time.
- logInfo(s"Starting new streaming query.")
- currentBatchId = 0
- constructNextBatch()
- }
- }
-
- /**
- * Returns true if there is any new data available to be processed.
- */
- private def dataAvailable: Boolean = {
- availableOffsets.exists {
- case (source, available) =>
- committedOffsets
- .get(source)
- .map(committed => committed != available)
- .getOrElse(true)
- }
- }
-
- /**
- * Queries all of the sources to see if any new data is available. When there is new data the
- * batchId counter is incremented and a new log entry is written with the newest offsets.
- */
- private def constructNextBatch(): Unit = {
- // Check to see what new data is available.
- val hasNewData = {
- awaitBatchLock.lock()
- try {
- val latestOffsets: Map[Source, Option[Offset]] = uniqueSources.map { s =>
- updateStatusMessage(s"Getting offsets from $s")
- reportTimeTaken("getOffset") {
- (s, s.getOffset)
- }
- }.toMap
- availableOffsets ++= latestOffsets.filter { case (s, o) => o.nonEmpty }.mapValues(_.get)
-
- if (dataAvailable) {
- true
- } else {
- noNewData = true
- false
- }
- } finally {
- awaitBatchLock.unlock()
- }
- }
- if (hasNewData) {
- var batchWatermarkMs = offsetSeqMetadata.batchWatermarkMs
- // Update the eventTime watermarks if we find any in the plan.
- if (lastExecution != null) {
- lastExecution.executedPlan.collect {
- case e: EventTimeWatermarkExec => e
- }.zipWithIndex.foreach {
- case (e, index) if e.eventTimeStats.value.count > 0 =>
- logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}")
- val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs
- val prevWatermarkMs = watermarkMsMap.get(index)
- if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) {
- watermarkMsMap.put(index, newWatermarkMs)
- }
-
- // Populate 0 if we haven't seen any data yet for this watermark node.
- case (_, index) =>
- if (!watermarkMsMap.isDefinedAt(index)) {
- watermarkMsMap.put(index, 0)
- }
- }
-
- // Update the global watermark to the minimum of all watermark nodes.
- // This is the safest option, because only the global watermark is fault-tolerant. Making
- // it the minimum of all individual watermarks guarantees it will never advance past where
- // any individual watermark operator would be if it were in a plan by itself.
- if(!watermarkMsMap.isEmpty) {
- val newWatermarkMs = watermarkMsMap.minBy(_._2)._2
- if (newWatermarkMs > batchWatermarkMs) {
- logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms")
- batchWatermarkMs = newWatermarkMs
- } else {
- logDebug(
- s"Event time didn't move: $newWatermarkMs < " +
- s"$batchWatermarkMs")
- }
- }
- }
- offsetSeqMetadata = offsetSeqMetadata.copy(
- batchWatermarkMs = batchWatermarkMs,
- batchTimestampMs = triggerClock.getTimeMillis()) // Current batch timestamp in milliseconds
-
- updateStatusMessage("Writing offsets to log")
- reportTimeTaken("walCommit") {
- assert(offsetLog.add(
- currentBatchId,
- availableOffsets.toOffsetSeq(sources, offsetSeqMetadata)),
- s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId")
- logInfo(s"Committed offsets for batch $currentBatchId. " +
- s"Metadata ${offsetSeqMetadata.toString}")
-
- // NOTE: The following code is correct because runBatches() processes exactly one
- // batch at a time. If we add pipeline parallelism (multiple batches in flight at
- // the same time), this cleanup logic will need to change.
-
- // Now that we've updated the scheduler's persistent checkpoint, it is safe for the
- // sources to discard data from the previous batch.
- if (currentBatchId != 0) {
- val prevBatchOff = offsetLog.get(currentBatchId - 1)
- if (prevBatchOff.isDefined) {
- prevBatchOff.get.toStreamProgress(sources).foreach {
- case (src, off) => src.commit(off)
- }
- } else {
- throw new IllegalStateException(s"batch $currentBatchId doesn't exist")
- }
- }
-
- // It is now safe to discard the metadata beyond the minimum number to retain.
- // Note that purge is exclusive, i.e. it purges everything before the target ID.
- if (minBatchesToRetain < currentBatchId) {
- offsetLog.purge(currentBatchId - minBatchesToRetain)
- batchCommitLog.purge(currentBatchId - minBatchesToRetain)
- }
- }
- } else {
- awaitBatchLock.lock()
- try {
- // Wake up any threads that are waiting for the stream to progress.
- awaitBatchLockCondition.signalAll()
- } finally {
- awaitBatchLock.unlock()
- }
- }
- }
-
- /**
- * Processes any data available between `availableOffsets` and `committedOffsets`.
- * @param sparkSessionToRunBatch Isolated [[SparkSession]] to run this batch with.
- */
- private def runBatch(sparkSessionToRunBatch: SparkSession): Unit = {
- // Request unprocessed data from all sources.
- newData = reportTimeTaken("getBatch") {
- availableOffsets.flatMap {
- case (source, available)
- if committedOffsets.get(source).map(_ != available).getOrElse(true) =>
- val current = committedOffsets.get(source)
- val batch = source.getBatch(current, available)
- assert(batch.isStreaming,
- s"DataFrame returned by getBatch from $source did not have isStreaming=true\n" +
- s"${batch.queryExecution.logical}")
- logDebug(s"Retrieving data from $source: $current -> $available")
- Some(source -> batch)
- case _ => None
- }
- }
-
- // A list of attributes that will need to be updated.
- val replacements = new ArrayBuffer[(Attribute, Attribute)]
- // Replace sources in the logical plan with data that has arrived since the last batch.
- val withNewSources = logicalPlan transform {
- case StreamingExecutionRelation(source, output) =>
- newData.get(source).map { data =>
- val newPlan = data.logicalPlan
- assert(output.size == newPlan.output.size,
- s"Invalid batch: ${Utils.truncatedString(output, ",")} != " +
- s"${Utils.truncatedString(newPlan.output, ",")}")
- replacements ++= output.zip(newPlan.output)
- newPlan
- }.getOrElse {
- LocalRelation(output, isStreaming = true)
- }
- }
-
- // Rewire the plan to use the new attributes that were returned by the source.
- val replacementMap = AttributeMap(replacements)
- val triggerLogicalPlan = withNewSources transformAllExpressions {
- case a: Attribute if replacementMap.contains(a) =>
- replacementMap(a).withMetadata(a.metadata)
- case ct: CurrentTimestamp =>
- CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs,
- ct.dataType)
- case cd: CurrentDate =>
- CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs,
- cd.dataType, cd.timeZoneId)
- }
-
- reportTimeTaken("queryPlanning") {
- lastExecution = new IncrementalExecution(
- sparkSessionToRunBatch,
- triggerLogicalPlan,
- outputMode,
- checkpointFile("state"),
- runId,
- currentBatchId,
- offsetSeqMetadata)
- lastExecution.executedPlan // Force the lazy generation of execution plan
- }
-
- val nextBatch =
- new Dataset(sparkSessionToRunBatch, lastExecution, RowEncoder(lastExecution.analyzed.schema))
-
- reportTimeTaken("addBatch") {
- SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) {
- sink.addBatch(currentBatchId, nextBatch)
- }
- }
-
- awaitBatchLock.lock()
- try {
- // Wake up any threads that are waiting for the stream to progress.
- awaitBatchLockCondition.signalAll()
- } finally {
- awaitBatchLock.unlock()
- }
- }
-
override protected def postEvent(event: StreamingQueryListener.Event): Unit = {
sparkSession.streams.postListenerEvent(event)
}
/** Stops all streaming sources safely. */
- private def stopSources(): Unit = {
+ protected def stopSources(): Unit = {
uniqueSources.foreach { source =>
try {
source.stop()
@@ -761,10 +404,10 @@ class StreamExecution(
// Set the state to TERMINATED so that the batching thread knows that it was interrupted
// intentionally
state.set(TERMINATED)
- if (microBatchThread.isAlive) {
+ if (queryExecutionThread.isAlive) {
sparkSession.sparkContext.cancelJobGroup(runId.toString)
- microBatchThread.interrupt()
- microBatchThread.join()
+ queryExecutionThread.interrupt()
+ queryExecutionThread.join()
// microBatchThread may spawn new jobs, so we need to cancel again to prevent a leak
sparkSession.sparkContext.cancelJobGroup(runId.toString)
}
@@ -775,7 +418,7 @@ class StreamExecution(
* Blocks the current thread until processing for data from the given `source` has reached at
* least the given `Offset`. This method is intended for use primarily when writing tests.
*/
- private[sql] def awaitOffset(source: Source, newOffset: Offset): Unit = {
+ private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: Offset): Unit = {
assertAwaitThread()
def notDone = {
val localCommittedOffsets = committedOffsets
@@ -783,21 +426,21 @@ class StreamExecution(
}
while (notDone) {
- awaitBatchLock.lock()
+ awaitProgressLock.lock()
try {
- awaitBatchLockCondition.await(100, TimeUnit.MILLISECONDS)
+ awaitProgressLockCondition.await(100, TimeUnit.MILLISECONDS)
if (streamDeathCause != null) {
throw streamDeathCause
}
} finally {
- awaitBatchLock.unlock()
+ awaitProgressLock.unlock()
}
}
logDebug(s"Unblocked at $newOffset for $source")
}
/** A flag to indicate that a batch has completed with no new data available. */
- @volatile private var noNewData = false
+ @volatile protected var noNewData = false
/**
* Assert that the await APIs should not be called in the stream thread. Otherwise, it may cause
@@ -805,7 +448,7 @@ class StreamExecution(
* the stream thread forever.
*/
private def assertAwaitThread(): Unit = {
- if (microBatchThread eq Thread.currentThread) {
+ if (queryExecutionThread eq Thread.currentThread) {
throw new IllegalStateException(
"Cannot wait for a query state from the same thread that is running the query")
}
@@ -832,11 +475,11 @@ class StreamExecution(
throw streamDeathCause
}
if (!isActive) return
- awaitBatchLock.lock()
+ awaitProgressLock.lock()
try {
noNewData = false
while (true) {
- awaitBatchLockCondition.await(10000, TimeUnit.MILLISECONDS)
+ awaitProgressLockCondition.await(10000, TimeUnit.MILLISECONDS)
if (streamDeathCause != null) {
throw streamDeathCause
}
@@ -845,7 +488,7 @@ class StreamExecution(
}
}
} finally {
- awaitBatchLock.unlock()
+ awaitProgressLock.unlock()
}
}
@@ -899,7 +542,7 @@ class StreamExecution(
|Current Available Offsets: $availableOffsets
|
|Current State: $state
- |Thread State: ${microBatchThread.getState}""".stripMargin
+ |Thread State: ${queryExecutionThread.getState}""".stripMargin
if (includeLogicalPlan) {
debugString + s"\n\nLogical Plan:\n$logicalPlan"
} else {
@@ -907,7 +550,7 @@ class StreamExecution(
}
}
- private def getBatchDescriptionString: String = {
+ protected def getBatchDescriptionString: String = {
val batchDescription = if (currentBatchId < 0) "init" else currentBatchId.toString
Option(name).map(_ + " ").getOrElse("") +
s"id = $id runId = $runId batch = $batchDescription"
@@ -919,7 +562,7 @@ object StreamExecution {
}
/**
- * A special thread to run the stream query. Some codes require to run in the StreamExecutionThread
- * and will use `classOf[StreamExecutionThread]` to check.
+ * A special thread to run the stream query. Some codes require to run in the QueryExecutionThread
+ * and will use `classOf[QueryxecutionThread]` to check.
*/
-abstract class StreamExecutionThread(name: String) extends UninterruptibleThread(name)
+abstract class QueryExecutionThread(name: String) extends UninterruptibleThread(name)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala
index a3f3662e6f4c9..8531070b1bc49 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala
@@ -23,25 +23,28 @@ import scala.collection.{immutable, GenTraversableOnce}
* A helper class that looks like a Map[Source, Offset].
*/
class StreamProgress(
- val baseMap: immutable.Map[Source, Offset] = new immutable.HashMap[Source, Offset])
- extends scala.collection.immutable.Map[Source, Offset] {
+ val baseMap: immutable.Map[BaseStreamingSource, Offset] =
+ new immutable.HashMap[BaseStreamingSource, Offset])
+ extends scala.collection.immutable.Map[BaseStreamingSource, Offset] {
- def toOffsetSeq(source: Seq[Source], metadata: OffsetSeqMetadata): OffsetSeq = {
+ def toOffsetSeq(source: Seq[BaseStreamingSource], metadata: OffsetSeqMetadata): OffsetSeq = {
OffsetSeq(source.map(get), Some(metadata))
}
override def toString: String =
baseMap.map { case (k, v) => s"$k: $v"}.mkString("{", ",", "}")
- override def +[B1 >: Offset](kv: (Source, B1)): Map[Source, B1] = baseMap + kv
+ override def +[B1 >: Offset](kv: (BaseStreamingSource, B1)): Map[BaseStreamingSource, B1] = {
+ baseMap + kv
+ }
- override def get(key: Source): Option[Offset] = baseMap.get(key)
+ override def get(key: BaseStreamingSource): Option[Offset] = baseMap.get(key)
- override def iterator: Iterator[(Source, Offset)] = baseMap.iterator
+ override def iterator: Iterator[(BaseStreamingSource, Offset)] = baseMap.iterator
- override def -(key: Source): Map[Source, Offset] = baseMap - key
+ override def -(key: BaseStreamingSource): Map[BaseStreamingSource, Offset] = baseMap - key
- def ++(updates: GenTraversableOnce[(Source, Offset)]): StreamProgress = {
+ def ++(updates: GenTraversableOnce[(BaseStreamingSource, Offset)]): StreamProgress = {
new StreamProgress(baseMap ++ updates)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala
index 07e39023c8366..7dd491ede9d05 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala
@@ -40,7 +40,7 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus)
import StreamingQueryListener._
- sparkListenerBus.addToSharedQueue(this)
+ sparkListenerBus.addToQueue(this, StreamingQueryListenerBus.STREAM_EVENT_QUERY)
/**
* RunIds of active queries whose events are supposed to be forwarded by this ListenerBus
@@ -130,3 +130,7 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus)
}
}
}
+
+object StreamingQueryListenerBus {
+ val STREAM_EVENT_QUERY = "streams"
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
index 6b82c78ea653d..a0ee683a895d8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
@@ -25,6 +25,8 @@ import org.apache.spark.sql.catalyst.plans.logical.LeafNode
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.execution.LeafExecNode
import org.apache.spark.sql.execution.datasources.DataSource
+import org.apache.spark.sql.sources.v2.DataSourceV2
+import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport
object StreamingRelation {
def apply(dataSource: DataSource): StreamingRelation = {
@@ -59,7 +61,53 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output:
* [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]].
*/
case class StreamingExecutionRelation(
- source: Source,
+ source: BaseStreamingSource,
+ output: Seq[Attribute])(session: SparkSession)
+ extends LeafNode {
+
+ override def isStreaming: Boolean = true
+ override def toString: String = source.toString
+
+ // There's no sensible value here. On the execution path, this relation will be
+ // swapped out with microbatches. But some dataframe operations (in particular explain) do lead
+ // to this node surviving analysis. So we satisfy the LeafNode contract with the session default
+ // value.
+ override def computeStats(): Statistics = Statistics(
+ sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
+ )
+}
+
+// We have to pack in the V1 data source as a shim, for the case when a source implements
+// continuous processing (which is always V2) but only has V1 microbatch support. We don't
+// know at read time whether the query is conntinuous or not, so we need to be able to
+// swap a V1 relation back in.
+/**
+ * Used to link a [[DataSourceV2]] into a streaming
+ * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. This is only used for creating
+ * a streaming [[org.apache.spark.sql.DataFrame]] from [[org.apache.spark.sql.DataFrameReader]],
+ * and should be converted before passing to [[StreamExecution]].
+ */
+case class StreamingRelationV2(
+ dataSource: DataSourceV2,
+ sourceName: String,
+ extraOptions: Map[String, String],
+ output: Seq[Attribute],
+ v1Relation: Option[StreamingRelation])(session: SparkSession)
+ extends LeafNode {
+ override def isStreaming: Boolean = true
+ override def toString: String = sourceName
+
+ override def computeStats(): Statistics = Statistics(
+ sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
+ )
+}
+
+/**
+ * Used to link a [[DataSourceV2]] into a continuous processing execution.
+ */
+case class ContinuousExecutionRelation(
+ source: ContinuousReadSupport,
+ extraOptions: Map[String, String],
output: Seq[Attribute])(session: SparkSession)
extends LeafNode {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala
index 167e991ca62f8..4aba76cad367e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala
@@ -72,8 +72,7 @@ object StreamingSymmetricHashJoinHelper extends Logging {
* left AND right AND joined is equivalent to full.
*
* Note that left and right do not necessarily contain *all* conjuncts which satisfy
- * their condition. Any conjuncts after the first nondeterministic one are treated as
- * nondeterministic for purposes of the split.
+ * their condition.
*
* @param leftSideOnly Deterministic conjuncts which reference only the left side of the join.
* @param rightSideOnly Deterministic conjuncts which reference only the right side of the join.
@@ -111,7 +110,7 @@ object StreamingSymmetricHashJoinHelper extends Logging {
// Span rather than partition, because nondeterministic expressions don't commute
// across AND.
val (deterministicConjuncts, nonDeterministicConjuncts) =
- splitConjunctivePredicates(condition.get).span(_.deterministic)
+ splitConjunctivePredicates(condition.get).partition(_.deterministic)
val (leftConjuncts, nonLeftConjuncts) = deterministicConjuncts.partition { cond =>
cond.references.subsetOf(left.outputSet)
@@ -204,7 +203,7 @@ object StreamingSymmetricHashJoinHelper extends Logging {
/**
* A custom RDD that allows partitions to be "zipped" together, while ensuring the tasks'
* preferred location is based on which executors have the required join state stores already
- * loaded. This is class is a modified verion of [[ZippedPartitionsRDD2]].
+ * loaded. This is class is a modified version of [[ZippedPartitionsRDD2]].
*/
class StateStoreAwareZipPartitionsRDD[A: ClassTag, B: ClassTag, V: ClassTag](
sc: SparkContext,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala
index 271bc4da99c08..19e3e55cb2829 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.{Experimental, InterfaceStability}
import org.apache.spark.sql.streaming.Trigger
/**
- * A [[Trigger]] that process only one batch of data in a streaming query then terminates
+ * A [[Trigger]] that processes only one batch of data in a streaming query then terminates
* the query.
*/
@Experimental
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
new file mode 100644
index 0000000000000..d79e4bd65f563
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous
+
+import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit}
+import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong}
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark._
+import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader}
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.continuous._
+import org.apache.spark.sql.sources.v2.reader._
+import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, PartitionOffset}
+import org.apache.spark.sql.streaming.ProcessingTime
+import org.apache.spark.util.{SystemClock, ThreadUtils}
+
+class ContinuousDataSourceRDD(
+ sc: SparkContext,
+ sqlContext: SQLContext,
+ @transient private val readTasks: java.util.List[ReadTask[UnsafeRow]])
+ extends RDD[UnsafeRow](sc, Nil) {
+
+ private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize
+ private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs
+
+ override protected def getPartitions: Array[Partition] = {
+ readTasks.asScala.zipWithIndex.map {
+ case (readTask, index) => new DataSourceRDDPartition(index, readTask)
+ }.toArray
+ }
+
+ override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
+ val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createDataReader()
+
+ val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY)
+
+ // This queue contains two types of messages:
+ // * (null, null) representing an epoch boundary.
+ // * (row, off) containing a data row and its corresponding PartitionOffset.
+ val queue = new ArrayBlockingQueue[(UnsafeRow, PartitionOffset)](dataQueueSize)
+
+ val epochPollFailed = new AtomicBoolean(false)
+ val epochPollExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor(
+ s"epoch-poll--${runId}--${context.partitionId()}")
+ val epochPollRunnable = new EpochPollRunnable(queue, context, epochPollFailed)
+ epochPollExecutor.scheduleWithFixedDelay(
+ epochPollRunnable, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS)
+
+ // Important sequencing - we must get start offset before the data reader thread begins
+ val startOffset = ContinuousDataSourceRDD.getBaseReader(reader).getOffset
+
+ val dataReaderFailed = new AtomicBoolean(false)
+ val dataReaderThread = new DataReaderThread(reader, queue, context, dataReaderFailed)
+ dataReaderThread.setDaemon(true)
+ dataReaderThread.start()
+
+ context.addTaskCompletionListener(_ => {
+ reader.close()
+ dataReaderThread.interrupt()
+ epochPollExecutor.shutdown()
+ })
+
+ val epochEndpoint = EpochCoordinatorRef.get(runId, SparkEnv.get)
+ new Iterator[UnsafeRow] {
+ private val POLL_TIMEOUT_MS = 1000
+
+ private var currentEntry: (UnsafeRow, PartitionOffset) = _
+ private var currentOffset: PartitionOffset = startOffset
+ private var currentEpoch =
+ context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
+
+ override def hasNext(): Boolean = {
+ while (currentEntry == null) {
+ if (context.isInterrupted() || context.isCompleted()) {
+ currentEntry = (null, null)
+ }
+ if (dataReaderFailed.get()) {
+ throw new SparkException("data read failed", dataReaderThread.failureReason)
+ }
+ if (epochPollFailed.get()) {
+ throw new SparkException("epoch poll failed", epochPollRunnable.failureReason)
+ }
+ currentEntry = queue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS)
+ }
+
+ currentEntry match {
+ // epoch boundary marker
+ case (null, null) =>
+ epochEndpoint.send(ReportPartitionOffset(
+ context.partitionId(),
+ currentEpoch,
+ currentOffset))
+ currentEpoch += 1
+ currentEntry = null
+ false
+ // real row
+ case (_, offset) =>
+ currentOffset = offset
+ true
+ }
+ }
+
+ override def next(): UnsafeRow = {
+ if (currentEntry == null) throw new NoSuchElementException("No current row was set")
+ val r = currentEntry._1
+ currentEntry = null
+ r
+ }
+ }
+ }
+
+ override def getPreferredLocations(split: Partition): Seq[String] = {
+ split.asInstanceOf[DataSourceRDDPartition].readTask.preferredLocations()
+ }
+}
+
+case class EpochPackedPartitionOffset(epoch: Long) extends PartitionOffset
+
+class EpochPollRunnable(
+ queue: BlockingQueue[(UnsafeRow, PartitionOffset)],
+ context: TaskContext,
+ failedFlag: AtomicBoolean)
+ extends Thread with Logging {
+ private[continuous] var failureReason: Throwable = _
+
+ private val epochEndpoint = EpochCoordinatorRef.get(
+ context.getLocalProperty(ContinuousExecution.RUN_ID_KEY), SparkEnv.get)
+ private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
+
+ override def run(): Unit = {
+ try {
+ val newEpoch = epochEndpoint.askSync[Long](GetCurrentEpoch)
+ for (i <- currentEpoch to newEpoch - 1) {
+ queue.put((null, null))
+ logDebug(s"Sent marker to start epoch ${i + 1}")
+ }
+ currentEpoch = newEpoch
+ } catch {
+ case t: Throwable =>
+ failureReason = t
+ failedFlag.set(true)
+ throw t
+ }
+ }
+}
+
+class DataReaderThread(
+ reader: DataReader[UnsafeRow],
+ queue: BlockingQueue[(UnsafeRow, PartitionOffset)],
+ context: TaskContext,
+ failedFlag: AtomicBoolean)
+ extends Thread(
+ s"continuous-reader--${context.partitionId()}--" +
+ s"${context.getLocalProperty(ContinuousExecution.RUN_ID_KEY)}") {
+ private[continuous] var failureReason: Throwable = _
+
+ override def run(): Unit = {
+ val baseReader = ContinuousDataSourceRDD.getBaseReader(reader)
+ try {
+ while (!context.isInterrupted && !context.isCompleted()) {
+ if (!reader.next()) {
+ // Check again, since reader.next() might have blocked through an incoming interrupt.
+ if (!context.isInterrupted && !context.isCompleted()) {
+ throw new IllegalStateException(
+ "Continuous reader reported no elements! Reader should have blocked waiting.")
+ } else {
+ return
+ }
+ }
+
+ queue.put((reader.get().copy(), baseReader.getOffset))
+ }
+ } catch {
+ case _: InterruptedException if context.isInterrupted() =>
+ // Continuous shutdown always involves an interrupt; do nothing and shut down quietly.
+
+ case t: Throwable =>
+ failureReason = t
+ failedFlag.set(true)
+ // Don't rethrow the exception in this thread. It's not needed, and the default Spark
+ // exception handler will kill the executor.
+ }
+ }
+}
+
+object ContinuousDataSourceRDD {
+ private[continuous] def getBaseReader(reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = {
+ reader match {
+ case r: ContinuousDataReader[UnsafeRow] => r
+ case wrapped: RowToUnsafeDataReader =>
+ wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]]
+ case _ =>
+ throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}")
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
new file mode 100644
index 0000000000000..9657b5e26d770
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
@@ -0,0 +1,350 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.{ArrayBuffer, Map => MutableMap}
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.sql.{AnalysisException, SparkSession}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.SQLExecution
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2}
+import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _}
+import org.apache.spark.sql.sources.v2.DataSourceV2Options
+import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport}
+import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, Offset, PartitionOffset}
+import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter
+import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.{Clock, Utils}
+
+class ContinuousExecution(
+ sparkSession: SparkSession,
+ name: String,
+ checkpointRoot: String,
+ analyzedPlan: LogicalPlan,
+ sink: ContinuousWriteSupport,
+ trigger: Trigger,
+ triggerClock: Clock,
+ outputMode: OutputMode,
+ extraOptions: Map[String, String],
+ deleteCheckpointOnStop: Boolean)
+ extends StreamExecution(
+ sparkSession, name, checkpointRoot, analyzedPlan, sink,
+ trigger, triggerClock, outputMode, deleteCheckpointOnStop) {
+
+ @volatile protected var continuousSources: Seq[ContinuousReader] = Seq.empty
+ override protected def sources: Seq[BaseStreamingSource] = continuousSources
+
+ override lazy val logicalPlan: LogicalPlan = {
+ assert(queryExecutionThread eq Thread.currentThread,
+ "logicalPlan must be initialized in StreamExecutionThread " +
+ s"but the current thread was ${Thread.currentThread}")
+ val toExecutionRelationMap = MutableMap[StreamingRelationV2, ContinuousExecutionRelation]()
+ analyzedPlan.transform {
+ case r @ StreamingRelationV2(
+ source: ContinuousReadSupport, _, extraReaderOptions, output, _) =>
+ toExecutionRelationMap.getOrElseUpdate(r, {
+ ContinuousExecutionRelation(source, extraReaderOptions, output)(sparkSession)
+ })
+ case StreamingRelationV2(_, sourceName, _, _, _) =>
+ throw new AnalysisException(
+ s"Data source $sourceName does not support continuous processing.")
+ }
+ }
+
+ private val triggerExecutor = trigger match {
+ case ContinuousTrigger(t) => ProcessingTimeExecutor(ProcessingTime(t), triggerClock)
+ case _ => throw new IllegalStateException(s"Unsupported type of trigger: $trigger")
+ }
+
+ override protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = {
+ do {
+ try {
+ runContinuous(sparkSessionForStream)
+ } catch {
+ case _: InterruptedException if state.get().equals(RECONFIGURING) =>
+ // swallow exception and run again
+ state.set(ACTIVE)
+ }
+ } while (state.get() == ACTIVE)
+ }
+
+ /**
+ * Populate the start offsets to start the execution at the current offsets stored in the sink
+ * (i.e. avoid reprocessing data that we have already processed). This function must be called
+ * before any processing occurs and will populate the following fields:
+ * - currentBatchId
+ * - committedOffsets
+ * The basic structure of this method is as follows:
+ *
+ * Identify (from the commit log) the latest epoch that has committed
+ * IF last epoch exists THEN
+ * Get end offsets for the epoch
+ * Set those offsets as the current commit progress
+ * Set the next epoch ID as the last + 1
+ * Return the end offsets of the last epoch as start for the next one
+ * DONE
+ * ELSE
+ * Start a new query log
+ * DONE
+ */
+ private def getStartOffsets(sparkSessionToRunBatches: SparkSession): OffsetSeq = {
+ // Note that this will need a slight modification for exactly once. If ending offsets were
+ // reported but not committed for any epochs, we must replay exactly to those offsets.
+ // For at least once, we can just ignore those reports and risk duplicates.
+ commitLog.getLatest() match {
+ case Some((latestEpochId, _)) =>
+ val nextOffsets = offsetLog.get(latestEpochId).getOrElse {
+ throw new IllegalStateException(
+ s"Batch $latestEpochId was committed without end epoch offsets!")
+ }
+ committedOffsets = nextOffsets.toStreamProgress(sources)
+
+ // Forcibly align commit and offset logs by slicing off any spurious offset logs from
+ // a previous run. We can't allow commits to an epoch that a previous run reached but
+ // this run has not.
+ offsetLog.purgeAfter(latestEpochId)
+
+ currentBatchId = latestEpochId + 1
+ logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets")
+ nextOffsets
+ case None =>
+ // We are starting this stream for the first time. Offsets are all None.
+ logInfo(s"Starting new streaming query.")
+ currentBatchId = 0
+ OffsetSeq.fill(continuousSources.map(_ => null): _*)
+ }
+ }
+
+ /**
+ * Do a continuous run.
+ * @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with.
+ */
+ private def runContinuous(sparkSessionForQuery: SparkSession): Unit = {
+ // A list of attributes that will need to be updated.
+ val replacements = new ArrayBuffer[(Attribute, Attribute)]
+ // Translate from continuous relation to the underlying data source.
+ var nextSourceId = 0
+ continuousSources = logicalPlan.collect {
+ case ContinuousExecutionRelation(dataSource, extraReaderOptions, output) =>
+ val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
+ nextSourceId += 1
+
+ dataSource.createContinuousReader(
+ java.util.Optional.empty[StructType](),
+ metadataPath,
+ new DataSourceV2Options(extraReaderOptions.asJava))
+ }
+ uniqueSources = continuousSources.distinct
+
+ val offsets = getStartOffsets(sparkSessionForQuery)
+
+ var insertedSourceId = 0
+ val withNewSources = logicalPlan transform {
+ case ContinuousExecutionRelation(_, _, output) =>
+ val reader = continuousSources(insertedSourceId)
+ insertedSourceId += 1
+ val newOutput = reader.readSchema().toAttributes
+
+ assert(output.size == newOutput.size,
+ s"Invalid reader: ${Utils.truncatedString(output, ",")} != " +
+ s"${Utils.truncatedString(newOutput, ",")}")
+ replacements ++= output.zip(newOutput)
+
+ val loggedOffset = offsets.offsets(0)
+ val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json))
+ reader.setOffset(java.util.Optional.ofNullable(realOffset.orNull))
+ new StreamingDataSourceV2Relation(newOutput, reader)
+ }
+
+ // Rewire the plan to use the new attributes that were returned by the source.
+ val replacementMap = AttributeMap(replacements)
+ val triggerLogicalPlan = withNewSources transformAllExpressions {
+ case a: Attribute if replacementMap.contains(a) =>
+ replacementMap(a).withMetadata(a.metadata)
+ case (_: CurrentTimestamp | _: CurrentDate) =>
+ throw new IllegalStateException(
+ "CurrentTimestamp and CurrentDate not yet supported for continuous processing")
+ }
+
+ val writer = sink.createContinuousWriter(
+ s"$runId",
+ triggerLogicalPlan.schema,
+ outputMode,
+ new DataSourceV2Options(extraOptions.asJava))
+ val withSink = WriteToDataSourceV2(writer.get(), triggerLogicalPlan)
+
+ val reader = withSink.collect {
+ case DataSourceV2Relation(_, r: ContinuousReader) => r
+ }.head
+
+ reportTimeTaken("queryPlanning") {
+ lastExecution = new IncrementalExecution(
+ sparkSessionForQuery,
+ withSink,
+ outputMode,
+ checkpointFile("state"),
+ runId,
+ currentBatchId,
+ offsetSeqMetadata)
+ lastExecution.executedPlan // Force the lazy generation of execution plan
+ }
+
+ sparkSession.sparkContext.setLocalProperty(
+ ContinuousExecution.START_EPOCH_KEY, currentBatchId.toString)
+ sparkSession.sparkContext.setLocalProperty(
+ ContinuousExecution.RUN_ID_KEY, runId.toString)
+
+ // Use the parent Spark session for the endpoint since it's where this query ID is registered.
+ val epochEndpoint =
+ EpochCoordinatorRef.create(
+ writer.get(), reader, this, currentBatchId, sparkSession, SparkEnv.get)
+ val epochUpdateThread = new Thread(new Runnable {
+ override def run: Unit = {
+ try {
+ triggerExecutor.execute(() => {
+ startTrigger()
+
+ if (reader.needsReconfiguration()) {
+ state.set(RECONFIGURING)
+ stopSources()
+ if (queryExecutionThread.isAlive) {
+ sparkSession.sparkContext.cancelJobGroup(runId.toString)
+ queryExecutionThread.interrupt()
+ // No need to join - this thread is about to end anyway.
+ }
+ false
+ } else if (isActive) {
+ currentBatchId = epochEndpoint.askSync[Long](IncrementAndGetEpoch)
+ logInfo(s"New epoch $currentBatchId is starting.")
+ true
+ } else {
+ false
+ }
+ })
+ } catch {
+ case _: InterruptedException =>
+ // Cleanly stop the query.
+ return
+ }
+ }
+ }, s"epoch update thread for $prettyIdString")
+
+ try {
+ epochUpdateThread.setDaemon(true)
+ epochUpdateThread.start()
+
+ reportTimeTaken("runContinuous") {
+ SQLExecution.withNewExecutionId(
+ sparkSessionForQuery, lastExecution)(lastExecution.toRdd)
+ }
+ } finally {
+ SparkEnv.get.rpcEnv.stop(epochEndpoint)
+
+ epochUpdateThread.interrupt()
+ epochUpdateThread.join()
+ }
+ }
+
+ /**
+ * Report ending partition offsets for the given reader at the given epoch.
+ */
+ def addOffset(
+ epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = {
+ assert(continuousSources.length == 1, "only one continuous source supported currently")
+
+ if (partitionOffsets.contains(null)) {
+ // If any offset is null, that means the corresponding partition hasn't seen any data yet, so
+ // there's nothing meaningful to add to the offset log.
+ }
+ val globalOffset = reader.mergeOffsets(partitionOffsets.toArray)
+ synchronized {
+ if (queryExecutionThread.isAlive) {
+ offsetLog.add(epoch, OffsetSeq.fill(globalOffset))
+ } else {
+ return
+ }
+ }
+ }
+
+ /**
+ * Mark the specified epoch as committed. All readers must have reported end offsets for the epoch
+ * before this is called.
+ */
+ def commit(epoch: Long): Unit = {
+ assert(continuousSources.length == 1, "only one continuous source supported currently")
+ assert(offsetLog.get(epoch).isDefined, s"offset for epoch $epoch not reported before commit")
+ synchronized {
+ if (queryExecutionThread.isAlive) {
+ commitLog.add(epoch)
+ val offset = offsetLog.get(epoch).get.offsets(0).get
+ committedOffsets ++= Seq(continuousSources(0) -> offset)
+ } else {
+ return
+ }
+ }
+
+ if (minLogEntriesToMaintain < currentBatchId) {
+ offsetLog.purge(currentBatchId - minLogEntriesToMaintain)
+ commitLog.purge(currentBatchId - minLogEntriesToMaintain)
+ }
+
+ awaitProgressLock.lock()
+ try {
+ awaitProgressLockCondition.signalAll()
+ } finally {
+ awaitProgressLock.unlock()
+ }
+ }
+
+ /**
+ * Blocks the current thread until execution has committed at or after the specified epoch.
+ */
+ private[sql] def awaitEpoch(epoch: Long): Unit = {
+ def notDone = {
+ val latestCommit = commitLog.getLatest()
+ latestCommit match {
+ case Some((latestEpoch, _)) =>
+ latestEpoch < epoch
+ case None => true
+ }
+ }
+
+ while (notDone) {
+ awaitProgressLock.lock()
+ try {
+ awaitProgressLockCondition.await(100, TimeUnit.MILLISECONDS)
+ if (streamDeathCause != null) {
+ throw streamDeathCause
+ }
+ } finally {
+ awaitProgressLock.unlock()
+ }
+ }
+ }
+}
+
+object ContinuousExecution {
+ val START_EPOCH_KEY = "__continuous_start_epoch"
+ val RUN_ID_KEY = "__run_id"
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
new file mode 100644
index 0000000000000..b4b21e7d2052f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous
+
+import scala.collection.JavaConverters._
+
+import org.json4s.DefaultFormats
+import org.json4s.jackson.Serialization
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair}
+import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2
+import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options}
+import org.apache.spark.sql.sources.v2.reader._
+import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
+import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType}
+
+case class RateStreamPartitionOffset(
+ partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset
+
+class RateStreamContinuousReader(options: DataSourceV2Options)
+ extends ContinuousReader {
+ implicit val defaultFormats: DefaultFormats = DefaultFormats
+
+ val creationTime = System.currentTimeMillis()
+
+ val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt
+ val rowsPerSecond = options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong
+ val perPartitionRate = rowsPerSecond.toDouble / numPartitions.toDouble
+
+ override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = {
+ assert(offsets.length == numPartitions)
+ val tuples = offsets.map {
+ case RateStreamPartitionOffset(i, currVal, nextRead) =>
+ (i, ValueRunTimeMsPair(currVal, nextRead))
+ }
+ RateStreamOffset(Map(tuples: _*))
+ }
+
+ override def deserializeOffset(json: String): Offset = {
+ RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json))
+ }
+
+ override def readSchema(): StructType = RateSourceProvider.SCHEMA
+
+ private var offset: Offset = _
+
+ override def setOffset(offset: java.util.Optional[Offset]): Unit = {
+ this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime))
+ }
+
+ override def getStartOffset(): Offset = offset
+
+ override def createReadTasks(): java.util.List[ReadTask[Row]] = {
+ val partitionStartMap = offset match {
+ case off: RateStreamOffset => off.partitionToValueAndRunTimeMs
+ case off =>
+ throw new IllegalArgumentException(
+ s"invalid offset type ${off.getClass()} for ContinuousRateSource")
+ }
+ if (partitionStartMap.keySet.size != numPartitions) {
+ throw new IllegalArgumentException(
+ s"The previous run contained ${partitionStartMap.keySet.size} partitions, but" +
+ s" $numPartitions partitions are currently configured. The numPartitions option" +
+ " cannot be changed.")
+ }
+
+ Range(0, numPartitions).map { i =>
+ val start = partitionStartMap(i)
+ // Have each partition advance by numPartitions each row, with starting points staggered
+ // by their partition index.
+ RateStreamContinuousReadTask(
+ start.value,
+ start.runTimeMs,
+ i,
+ numPartitions,
+ perPartitionRate)
+ .asInstanceOf[ReadTask[Row]]
+ }.asJava
+ }
+
+ override def commit(end: Offset): Unit = {}
+ override def stop(): Unit = {}
+
+}
+
+case class RateStreamContinuousReadTask(
+ startValue: Long,
+ startTimeMs: Long,
+ partitionIndex: Int,
+ increment: Long,
+ rowsPerSecond: Double)
+ extends ReadTask[Row] {
+ override def createDataReader(): DataReader[Row] =
+ new RateStreamContinuousDataReader(
+ startValue, startTimeMs, partitionIndex, increment, rowsPerSecond)
+}
+
+class RateStreamContinuousDataReader(
+ startValue: Long,
+ startTimeMs: Long,
+ partitionIndex: Int,
+ increment: Long,
+ rowsPerSecond: Double)
+ extends ContinuousDataReader[Row] {
+ private var nextReadTime: Long = startTimeMs
+ private val readTimeIncrement: Long = (1000 / rowsPerSecond).toLong
+
+ private var currentValue = startValue
+ private var currentRow: Row = null
+
+ override def next(): Boolean = {
+ currentValue += increment
+ nextReadTime += readTimeIncrement
+
+ try {
+ while (System.currentTimeMillis < nextReadTime) {
+ Thread.sleep(nextReadTime - System.currentTimeMillis)
+ }
+ } catch {
+ case _: InterruptedException =>
+ // Someone's trying to end the task; just let them.
+ return false
+ }
+
+ currentRow = Row(
+ DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(nextReadTime)),
+ currentValue)
+
+ true
+ }
+
+ override def get: Row = currentRow
+
+ override def close(): Unit = {}
+
+ override def getOffset(): PartitionOffset =
+ RateStreamPartitionOffset(partitionIndex, currentValue, nextReadTime)
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala
new file mode 100644
index 0000000000000..90e1766c4d9f1
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous
+
+import java.util.concurrent.TimeUnit
+
+import scala.concurrent.duration.Duration
+
+import org.apache.commons.lang3.StringUtils
+
+import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.sql.streaming.{ProcessingTime, Trigger}
+import org.apache.spark.unsafe.types.CalendarInterval
+
+/**
+ * A [[Trigger]] that continuously processes streaming data, asynchronously checkpointing at
+ * the specified interval.
+ */
+@InterfaceStability.Evolving
+case class ContinuousTrigger(intervalMs: Long) extends Trigger {
+ require(intervalMs >= 0, "the interval of trigger should not be negative")
+}
+
+private[sql] object ContinuousTrigger {
+ def apply(interval: String): ContinuousTrigger = {
+ if (StringUtils.isBlank(interval)) {
+ throw new IllegalArgumentException(
+ "interval cannot be null or blank.")
+ }
+ val cal = if (interval.startsWith("interval")) {
+ CalendarInterval.fromString(interval)
+ } else {
+ CalendarInterval.fromString("interval " + interval)
+ }
+ if (cal == null) {
+ throw new IllegalArgumentException(s"Invalid interval: $interval")
+ }
+ if (cal.months > 0) {
+ throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval")
+ }
+ new ContinuousTrigger(cal.microseconds / 1000)
+ }
+
+ def apply(interval: Duration): ContinuousTrigger = {
+ ContinuousTrigger(interval.toMillis)
+ }
+
+ def create(interval: String): ContinuousTrigger = {
+ apply(interval)
+ }
+
+ def create(interval: Long, unit: TimeUnit): ContinuousTrigger = {
+ ContinuousTrigger(unit.toMillis(interval))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
new file mode 100644
index 0000000000000..98017c3ac6a33
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous
+
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.collection.mutable
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper
+import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, PartitionOffset}
+import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter
+import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage
+import org.apache.spark.util.RpcUtils
+
+private[continuous] sealed trait EpochCoordinatorMessage extends Serializable
+
+// Driver epoch trigger message
+/**
+ * Atomically increment the current epoch and get the new value.
+ */
+private[sql] case object IncrementAndGetEpoch extends EpochCoordinatorMessage
+
+// Init messages
+/**
+ * Set the reader and writer partition counts. Tasks may not be started until the coordinator
+ * has acknowledged these messages.
+ */
+private[sql] case class SetReaderPartitions(numPartitions: Int) extends EpochCoordinatorMessage
+case class SetWriterPartitions(numPartitions: Int) extends EpochCoordinatorMessage
+
+// Partition task messages
+/**
+ * Get the current epoch.
+ */
+private[sql] case object GetCurrentEpoch extends EpochCoordinatorMessage
+/**
+ * Commit a partition at the specified epoch with the given message.
+ */
+private[sql] case class CommitPartitionEpoch(
+ partitionId: Int,
+ epoch: Long,
+ message: WriterCommitMessage) extends EpochCoordinatorMessage
+/**
+ * Report that a partition is ending the specified epoch at the specified offset.
+ */
+private[sql] case class ReportPartitionOffset(
+ partitionId: Int,
+ epoch: Long,
+ offset: PartitionOffset) extends EpochCoordinatorMessage
+
+
+/** Helper object used to create reference to [[EpochCoordinator]]. */
+private[sql] object EpochCoordinatorRef extends Logging {
+ private def endpointName(runId: String) = s"EpochCoordinator-$runId"
+
+ /**
+ * Create a reference to a new [[EpochCoordinator]].
+ */
+ def create(
+ writer: ContinuousWriter,
+ reader: ContinuousReader,
+ query: ContinuousExecution,
+ startEpoch: Long,
+ session: SparkSession,
+ env: SparkEnv): RpcEndpointRef = synchronized {
+ val coordinator = new EpochCoordinator(
+ writer, reader, query, startEpoch, session, env.rpcEnv)
+ val ref = env.rpcEnv.setupEndpoint(endpointName(query.runId.toString()), coordinator)
+ logInfo("Registered EpochCoordinator endpoint")
+ ref
+ }
+
+ def get(runId: String, env: SparkEnv): RpcEndpointRef = synchronized {
+ val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName(runId), env.conf, env.rpcEnv)
+ logDebug("Retrieved existing EpochCoordinator endpoint")
+ rpcEndpointRef
+ }
+}
+
+/**
+ * Handles three major epoch coordination tasks for continuous processing:
+ *
+ * * Maintains a local epoch counter (the "driver epoch"), incremented by IncrementAndGetEpoch
+ * and pollable from executors by GetCurrentEpoch. Note that this epoch is *not* immediately
+ * reflected anywhere in ContinuousExecution.
+ * * Collates ReportPartitionOffset messages, and forwards to ContinuousExecution when all
+ * readers have ended a given epoch.
+ * * Collates CommitPartitionEpoch messages, and forwards to ContinuousExecution when all readers
+ * have both committed and reported an end offset for a given epoch.
+ */
+private[continuous] class EpochCoordinator(
+ writer: ContinuousWriter,
+ reader: ContinuousReader,
+ query: ContinuousExecution,
+ startEpoch: Long,
+ session: SparkSession,
+ override val rpcEnv: RpcEnv)
+ extends ThreadSafeRpcEndpoint with Logging {
+
+ private var numReaderPartitions: Int = _
+ private var numWriterPartitions: Int = _
+
+ private var currentDriverEpoch = startEpoch
+
+ // (epoch, partition) -> message
+ private val partitionCommits =
+ mutable.Map[(Long, Int), WriterCommitMessage]()
+ // (epoch, partition) -> offset
+ private val partitionOffsets =
+ mutable.Map[(Long, Int), PartitionOffset]()
+
+ private def resolveCommitsAtEpoch(epoch: Long) = {
+ val thisEpochCommits =
+ partitionCommits.collect { case ((e, _), msg) if e == epoch => msg }
+ val nextEpochOffsets =
+ partitionOffsets.collect { case ((e, _), o) if e == epoch => o }
+
+ if (thisEpochCommits.size == numWriterPartitions &&
+ nextEpochOffsets.size == numReaderPartitions) {
+ logDebug(s"Epoch $epoch has received commits from all partitions. Committing globally.")
+ // Sequencing is important here. We must commit to the writer before recording the commit
+ // in the query, or we will end up dropping the commit if we restart in the middle.
+ writer.commit(epoch, thisEpochCommits.toArray)
+ query.commit(epoch)
+
+ // Cleanup state from before this epoch, now that we know all partitions are forever past it.
+ for (k <- partitionCommits.keys.filter { case (e, _) => e < epoch }) {
+ partitionCommits.remove(k)
+ }
+ for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) {
+ partitionCommits.remove(k)
+ }
+ }
+ }
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case CommitPartitionEpoch(partitionId, epoch, message) =>
+ logDebug(s"Got commit from partition $partitionId at epoch $epoch: $message")
+ if (!partitionCommits.isDefinedAt((epoch, partitionId))) {
+ partitionCommits.put((epoch, partitionId), message)
+ resolveCommitsAtEpoch(epoch)
+ }
+
+ case ReportPartitionOffset(partitionId, epoch, offset) =>
+ partitionOffsets.put((epoch, partitionId), offset)
+ val thisEpochOffsets =
+ partitionOffsets.collect { case ((e, _), o) if e == epoch => o }
+ if (thisEpochOffsets.size == numReaderPartitions) {
+ logDebug(s"Epoch $epoch has offsets reported from all partitions: $thisEpochOffsets")
+ query.addOffset(epoch, reader, thisEpochOffsets.toSeq)
+ resolveCommitsAtEpoch(epoch)
+ }
+ }
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case GetCurrentEpoch =>
+ val result = currentDriverEpoch
+ logDebug(s"Epoch $result")
+ context.reply(result)
+
+ case IncrementAndGetEpoch =>
+ currentDriverEpoch += 1
+ context.reply(currentDriverEpoch)
+
+ case SetReaderPartitions(numPartitions) =>
+ numReaderPartitions = numPartitions
+ context.reply(())
+
+ case SetWriterPartitions(numPartitions) =>
+ numWriterPartitions = numPartitions
+ context.reply(())
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala
new file mode 100644
index 0000000000000..c0ed12cec25ef
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import java.util.Optional
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import org.json4s.DefaultFormats
+import org.json4s.jackson.Serialization
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair}
+import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options}
+import org.apache.spark.sql.sources.v2.reader._
+import org.apache.spark.sql.sources.v2.streaming.MicroBatchReadSupport
+import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset}
+import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType}
+import org.apache.spark.util.{ManualClock, SystemClock}
+
+/**
+ * This is a temporary register as we build out v2 migration. Microbatch read support should
+ * be implemented in the same register as v1.
+ */
+class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with DataSourceRegister {
+ override def createMicroBatchReader(
+ schema: Optional[StructType],
+ checkpointLocation: String,
+ options: DataSourceV2Options): MicroBatchReader = {
+ new RateStreamMicroBatchReader(options)
+ }
+
+ override def shortName(): String = "ratev2"
+}
+
+class RateStreamMicroBatchReader(options: DataSourceV2Options)
+ extends MicroBatchReader {
+ implicit val defaultFormats: DefaultFormats = DefaultFormats
+
+ val clock = {
+ // The option to use a manual clock is provided only for unit testing purposes.
+ if (options.get("useManualClock").orElse("false").toBoolean) new ManualClock
+ else new SystemClock
+ }
+
+ private val numPartitions =
+ options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt
+ private val rowsPerSecond =
+ options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong
+
+ // The interval (in milliseconds) between rows in each partition.
+ // e.g. if there are 4 global rows per second, and 2 partitions, each partition
+ // should output rows every (1000 * 2 / 4) = 500 ms.
+ private val msPerPartitionBetweenRows = (1000 * numPartitions) / rowsPerSecond
+
+ override def readSchema(): StructType = {
+ StructType(
+ StructField("timestamp", TimestampType, false) ::
+ StructField("value", LongType, false) :: Nil)
+ }
+
+ val creationTimeMs = clock.getTimeMillis()
+
+ private var start: RateStreamOffset = _
+ private var end: RateStreamOffset = _
+
+ override def setOffsetRange(
+ start: Optional[Offset],
+ end: Optional[Offset]): Unit = {
+ this.start = start.orElse(
+ RateStreamSourceV2.createInitialOffset(numPartitions, creationTimeMs))
+ .asInstanceOf[RateStreamOffset]
+
+ this.end = end.orElse {
+ val currentTime = clock.getTimeMillis()
+ RateStreamOffset(
+ this.start.partitionToValueAndRunTimeMs.map {
+ case startOffset @ (part, ValueRunTimeMsPair(currentVal, currentReadTime)) =>
+ // Calculate the number of rows we should advance in this partition (based on the
+ // current time), and output a corresponding offset.
+ val readInterval = currentTime - currentReadTime
+ val numNewRows = readInterval / msPerPartitionBetweenRows
+ if (numNewRows <= 0) {
+ startOffset
+ } else {
+ (part, ValueRunTimeMsPair(
+ currentVal + (numNewRows * numPartitions),
+ currentReadTime + (numNewRows * msPerPartitionBetweenRows)))
+ }
+ }
+ )
+ }.asInstanceOf[RateStreamOffset]
+ }
+
+ override def getStartOffset(): Offset = {
+ if (start == null) throw new IllegalStateException("start offset not set")
+ start
+ }
+ override def getEndOffset(): Offset = {
+ if (end == null) throw new IllegalStateException("end offset not set")
+ end
+ }
+
+ override def deserializeOffset(json: String): Offset = {
+ RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json))
+ }
+
+ override def createReadTasks(): java.util.List[ReadTask[Row]] = {
+ val startMap = start.partitionToValueAndRunTimeMs
+ val endMap = end.partitionToValueAndRunTimeMs
+ endMap.keys.toSeq.map { part =>
+ val ValueRunTimeMsPair(endVal, _) = endMap(part)
+ val ValueRunTimeMsPair(startVal, startTimeMs) = startMap(part)
+
+ val packedRows = mutable.ListBuffer[(Long, Long)]()
+ var outVal = startVal + numPartitions
+ var outTimeMs = startTimeMs
+ while (outVal <= endVal) {
+ packedRows.append((outTimeMs, outVal))
+ outVal += numPartitions
+ outTimeMs += msPerPartitionBetweenRows
+ }
+
+ RateStreamBatchTask(packedRows).asInstanceOf[ReadTask[Row]]
+ }.toList.asJava
+ }
+
+ override def commit(end: Offset): Unit = {}
+ override def stop(): Unit = {}
+}
+
+case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends ReadTask[Row] {
+ override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals)
+}
+
+class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] {
+ var currentIndex = -1
+
+ override def next(): Boolean = {
+ // Return true as long as the new index is in the seq.
+ currentIndex += 1
+ currentIndex < vals.size
+ }
+
+ override def get(): Row = {
+ Row(
+ DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(vals(currentIndex)._1)),
+ vals(currentIndex)._2)
+ }
+
+ override def close(): Unit = {}
+}
+
+object RateStreamSourceV2 {
+ val NUM_PARTITIONS = "numPartitions"
+ val ROWS_PER_SECOND = "rowsPerSecond"
+
+ private[sql] def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = {
+ RateStreamOffset(
+ Range(0, numPartitions).map { i =>
+ // Note that the starting offset is exclusive, so we have to decrement the starting value
+ // by the increment that will later be applied. The first row output in each
+ // partition will have a value equal to the partition index.
+ (i,
+ ValueRunTimeMsPair(
+ (i - numPartitions).toLong,
+ creationTimeMs))
+ }.toMap)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
new file mode 100644
index 0000000000000..da7c31cf62428
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import javax.annotation.concurrent.GuardedBy
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.control.NonFatal
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update}
+import org.apache.spark.sql.execution.streaming.Sink
+import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options}
+import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport}
+import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter
+import org.apache.spark.sql.sources.v2.writer._
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.StructType
+
+/**
+ * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
+ * tests and does not provide durability.
+ */
+class MemorySinkV2 extends DataSourceV2
+ with MicroBatchWriteSupport with ContinuousWriteSupport with Logging {
+
+ override def createMicroBatchWriter(
+ queryId: String,
+ batchId: Long,
+ schema: StructType,
+ mode: OutputMode,
+ options: DataSourceV2Options): java.util.Optional[DataSourceV2Writer] = {
+ java.util.Optional.of(new MemoryWriter(this, batchId, mode))
+ }
+
+ override def createContinuousWriter(
+ queryId: String,
+ schema: StructType,
+ mode: OutputMode,
+ options: DataSourceV2Options): java.util.Optional[ContinuousWriter] = {
+ java.util.Optional.of(new ContinuousMemoryWriter(this, mode))
+ }
+
+ private case class AddedData(batchId: Long, data: Array[Row])
+
+ /** An order list of batches that have been written to this [[Sink]]. */
+ @GuardedBy("this")
+ private val batches = new ArrayBuffer[AddedData]()
+
+ /** Returns all rows that are stored in this [[Sink]]. */
+ def allData: Seq[Row] = synchronized {
+ batches.flatMap(_.data)
+ }
+
+ def latestBatchId: Option[Long] = synchronized {
+ batches.lastOption.map(_.batchId)
+ }
+
+ def latestBatchData: Seq[Row] = synchronized {
+ batches.lastOption.toSeq.flatten(_.data)
+ }
+
+ def toDebugString: String = synchronized {
+ batches.map { case AddedData(batchId, data) =>
+ val dataStr = try data.mkString(" ") catch {
+ case NonFatal(e) => "[Error converting to string]"
+ }
+ s"$batchId: $dataStr"
+ }.mkString("\n")
+ }
+
+ def write(batchId: Long, outputMode: OutputMode, newRows: Array[Row]): Unit = {
+ val notCommitted = synchronized {
+ latestBatchId.isEmpty || batchId > latestBatchId.get
+ }
+ if (notCommitted) {
+ logDebug(s"Committing batch $batchId to $this")
+ outputMode match {
+ case Append | Update =>
+ val rows = AddedData(batchId, newRows)
+ synchronized { batches += rows }
+
+ case Complete =>
+ val rows = AddedData(batchId, newRows)
+ synchronized {
+ batches.clear()
+ batches += rows
+ }
+
+ case _ =>
+ throw new IllegalArgumentException(
+ s"Output mode $outputMode is not supported by MemorySink")
+ }
+ } else {
+ logDebug(s"Skipping already committed batch: $batchId")
+ }
+ }
+
+ def clear(): Unit = synchronized {
+ batches.clear()
+ }
+
+ override def toString(): String = "MemorySink"
+}
+
+case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {}
+
+class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode)
+ extends DataSourceV2Writer with Logging {
+
+ override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode)
+
+ def commit(messages: Array[WriterCommitMessage]): Unit = {
+ val newRows = messages.flatMap {
+ case message: MemoryWriterCommitMessage => message.data
+ }
+ sink.write(batchId, outputMode, newRows)
+ }
+
+ override def abort(messages: Array[WriterCommitMessage]): Unit = {
+ // Don't accept any of the new input.
+ }
+}
+
+class ContinuousMemoryWriter(val sink: MemorySinkV2, outputMode: OutputMode)
+ extends ContinuousWriter {
+
+ override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode)
+
+ override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
+ val newRows = messages.flatMap {
+ case message: MemoryWriterCommitMessage => message.data
+ }
+ sink.write(epochId, outputMode, newRows)
+ }
+
+ override def abort(messages: Array[WriterCommitMessage]): Unit = {
+ // Don't accept any of the new input.
+ }
+}
+
+case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] {
+ def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = {
+ new MemoryDataWriter(partitionId, outputMode)
+ }
+}
+
+class MemoryDataWriter(partition: Int, outputMode: OutputMode)
+ extends DataWriter[Row] with Logging {
+
+ private val data = mutable.Buffer[Row]()
+
+ override def write(row: Row): Unit = {
+ data.append(row)
+ }
+
+ override def commit(): MemoryWriterCommitMessage = {
+ val msg = MemoryWriterCommitMessage(partition, data.clone())
+ data.clear()
+ msg
+ }
+
+ override def abort(): Unit = {}
+}
+
+
+/**
+ * Used to query the data that has been written into a [[MemorySink]].
+ */
+case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode {
+ private val sizePerRow = output.map(_.dataType.defaultSize).sum
+
+ override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size)
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala
deleted file mode 100644
index e49546830286b..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala
+++ /dev/null
@@ -1,153 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming.state
-
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, CaseWhen, CreateNamedStruct, GetStructField, IsNull, Literal, UnsafeRow}
-import org.apache.spark.sql.execution.ObjectOperator
-import org.apache.spark.sql.execution.streaming.GroupStateImpl
-import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
-import org.apache.spark.sql.types.{IntegerType, LongType, StructType}
-
-
-/**
- * Class to serialize/write/read/deserialize state for
- * [[org.apache.spark.sql.execution.streaming.FlatMapGroupsWithStateExec]].
- */
-class FlatMapGroupsWithState_StateManager(
- stateEncoder: ExpressionEncoder[Any],
- shouldStoreTimestamp: Boolean) extends Serializable {
-
- /** Schema of the state rows saved in the state store */
- val stateSchema = {
- val schema = new StructType().add("groupState", stateEncoder.schema, nullable = true)
- if (shouldStoreTimestamp) schema.add("timeoutTimestamp", LongType) else schema
- }
-
- /** Get deserialized state and corresponding timeout timestamp for a key */
- def getState(store: StateStore, keyRow: UnsafeRow): FlatMapGroupsWithState_StateData = {
- val stateRow = store.get(keyRow)
- stateDataForGets.withNew(
- keyRow, stateRow, getStateObj(stateRow), getTimestamp(stateRow))
- }
-
- /** Put state and timeout timestamp for a key */
- def putState(store: StateStore, keyRow: UnsafeRow, state: Any, timestamp: Long): Unit = {
- val stateRow = getStateRow(state)
- setTimestamp(stateRow, timestamp)
- store.put(keyRow, stateRow)
- }
-
- /** Removed all information related to a key */
- def removeState(store: StateStore, keyRow: UnsafeRow): Unit = {
- store.remove(keyRow)
- }
-
- /** Get all the keys and corresponding state rows in the state store */
- def getAllState(store: StateStore): Iterator[FlatMapGroupsWithState_StateData] = {
- val stateDataForGetAllState = FlatMapGroupsWithState_StateData()
- store.getRange(None, None).map { pair =>
- stateDataForGetAllState.withNew(
- pair.key, pair.value, getStateObjFromRow(pair.value), getTimestamp(pair.value))
- }
- }
-
- // Ordinals of the information stored in the state row
- private lazy val nestedStateOrdinal = 0
- private lazy val timeoutTimestampOrdinal = 1
-
- // Get the serializer for the state, taking into account whether we need to save timestamps
- private val stateSerializer = {
- val nestedStateExpr = CreateNamedStruct(
- stateEncoder.namedExpressions.flatMap(e => Seq(Literal(e.name), e)))
- if (shouldStoreTimestamp) {
- Seq(nestedStateExpr, Literal(GroupStateImpl.NO_TIMESTAMP))
- } else {
- Seq(nestedStateExpr)
- }
- }
-
- // Get the deserializer for the state. Note that this must be done in the driver, as
- // resolving and binding of deserializer expressions to the encoded type can be safely done
- // only in the driver.
- private val stateDeserializer = {
- val boundRefToNestedState = BoundReference(nestedStateOrdinal, stateEncoder.schema, true)
- val deser = stateEncoder.resolveAndBind().deserializer.transformUp {
- case BoundReference(ordinal, _, _) => GetStructField(boundRefToNestedState, ordinal)
- }
- CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deser)
- }
-
- // Converters for translating state between rows and Java objects
- private lazy val getStateObjFromRow = ObjectOperator.deserializeRowToObject(
- stateDeserializer, stateSchema.toAttributes)
- private lazy val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer)
-
- // Reusable instance for returning state information
- private lazy val stateDataForGets = FlatMapGroupsWithState_StateData()
-
- /** Returns the state as Java object if defined */
- private def getStateObj(stateRow: UnsafeRow): Any = {
- if (stateRow == null) null
- else getStateObjFromRow(stateRow)
- }
-
- /** Returns the row for an updated state */
- private def getStateRow(obj: Any): UnsafeRow = {
- val row = getStateRowFromObj(obj)
- if (obj == null) {
- row.setNullAt(nestedStateOrdinal)
- }
- row
- }
-
- /** Returns the timeout timestamp of a state row is set */
- private def getTimestamp(stateRow: UnsafeRow): Long = {
- if (shouldStoreTimestamp && stateRow != null) {
- stateRow.getLong(timeoutTimestampOrdinal)
- } else NO_TIMESTAMP
- }
-
- /** Set the timestamp in a state row */
- private def setTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = {
- if (shouldStoreTimestamp) stateRow.setLong(timeoutTimestampOrdinal, timeoutTimestamps)
- }
-}
-
-/**
- * Class to capture deserialized state and timestamp return by the state manager.
- * This is intended for reuse.
- */
-case class FlatMapGroupsWithState_StateData(
- var keyRow: UnsafeRow = null,
- var stateRow: UnsafeRow = null,
- var stateObj: Any = null,
- var timeoutTimestamp: Long = -1) {
- def withNew(
- newKeyRow: UnsafeRow,
- newStateRow: UnsafeRow,
- newStateObj: Any,
- newTimeout: Long): this.type = {
- keyRow = newKeyRow
- stateRow = newStateRow
- stateObj = newStateObj
- timeoutTimestamp = newTimeout
- this
- }
-}
-
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
index 43cec4807ae4d..73a105266e1c1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
@@ -27,17 +27,14 @@ import org.apache.spark.internal.Logging
import org.apache.spark.scheduler._
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.metric._
-import org.apache.spark.status.LiveEntity
+import org.apache.spark.sql.internal.StaticSQLConf._
+import org.apache.spark.status.{ElementTrackingStore, KVUtils, LiveEntity}
import org.apache.spark.status.config._
-import org.apache.spark.ui.SparkUI
-import org.apache.spark.util.kvstore.KVStore
-private[sql] class SQLAppStatusListener(
+class SQLAppStatusListener(
conf: SparkConf,
- kvstore: KVStore,
- live: Boolean,
- ui: Option[SparkUI] = None)
- extends SparkListener with Logging {
+ kvstore: ElementTrackingStore,
+ live: Boolean) extends SparkListener with Logging {
// How often to flush intermediate state of a live execution to the store. When replaying logs,
// never flush (only do the very last write).
@@ -49,7 +46,27 @@ private[sql] class SQLAppStatusListener(
private val liveExecutions = new ConcurrentHashMap[Long, LiveExecutionData]()
private val stageMetrics = new ConcurrentHashMap[Int, LiveStageMetrics]()
- private var uiInitialized = false
+ // Returns true if this listener has no live data. Exposed for tests only.
+ private[sql] def noLiveData(): Boolean = {
+ liveExecutions.isEmpty && stageMetrics.isEmpty
+ }
+
+ kvstore.addTrigger(classOf[SQLExecutionUIData], conf.get(UI_RETAINED_EXECUTIONS)) { count =>
+ cleanupExecutions(count)
+ }
+
+ kvstore.onFlush {
+ if (!live) {
+ val now = System.nanoTime()
+ liveExecutions.values.asScala.foreach { exec =>
+ // This saves the partial aggregated metrics to the store; this works currently because
+ // when the SHS sees an updated event log, all old data for the application is thrown
+ // away.
+ exec.metricsValues = aggregateMetrics(exec)
+ exec.write(kvstore, now)
+ }
+ }
+ }
override def onJobStart(event: SparkListenerJobStart): Unit = {
val executionIdString = event.properties.getProperty(SQLExecution.EXECUTION_ID_KEY)
@@ -70,7 +87,7 @@ private[sql] class SQLAppStatusListener(
}
exec.jobs = exec.jobs + (jobId -> JobExecutionStatus.RUNNING)
- exec.stages = event.stageIds.toSet
+ exec.stages ++= event.stageIds.toSet
update(exec)
}
@@ -82,7 +99,7 @@ private[sql] class SQLAppStatusListener(
// Reset the metrics tracking object for the new attempt.
Option(stageMetrics.get(event.stageInfo.stageId)).foreach { metrics =>
metrics.taskMetrics.clear()
- metrics.attemptId = event.stageInfo.attemptId
+ metrics.attemptId = event.stageInfo.attemptNumber
}
}
@@ -158,7 +175,7 @@ private[sql] class SQLAppStatusListener(
// Check the execution again for whether the aggregated metrics data has been calculated.
// This can happen if the UI is requesting this data, and the onExecutionEnd handler is
- // running at the same time. The metrics calculcated for the UI can be innacurate in that
+ // running at the same time. The metrics calculated for the UI can be innacurate in that
// case, since the onExecutionEnd handler will clean up tracked stage metrics.
if (exec.metricsValues != null) {
exec.metricsValues
@@ -212,14 +229,6 @@ private[sql] class SQLAppStatusListener(
}
private def onExecutionStart(event: SparkListenerSQLExecutionStart): Unit = {
- // Install the SQL tab in a live app if it hasn't been initialized yet.
- if (!uiInitialized) {
- ui.foreach { _ui =>
- new SQLTab(new SQLAppStatusStore(kvstore, Some(this)), _ui)
- }
- uiInitialized = true
- }
-
val SparkListenerSQLExecutionStart(executionId, description, details,
physicalPlanDescription, sparkPlanInfo, time) = event
@@ -317,6 +326,17 @@ private[sql] class SQLAppStatusListener(
}
}
+ private def cleanupExecutions(count: Long): Unit = {
+ val countToDelete = count - conf.get(UI_RETAINED_EXECUTIONS)
+ if (countToDelete <= 0) {
+ return
+ }
+
+ val toDelete = KVUtils.viewToSeq(kvstore.view(classOf[SQLExecutionUIData]),
+ countToDelete.toInt) { e => e.completionTime.isDefined }
+ toDelete.foreach { e => kvstore.delete(e.getClass(), e.executionId) }
+ }
+
}
private class LiveExecutionData(val executionId: Long) extends LiveEntity {
@@ -360,7 +380,7 @@ private class LiveStageMetrics(
val accumulatorIds: Array[Long],
val taskMetrics: ConcurrentHashMap[Long, LiveTaskMetrics])
-private[sql] class LiveTaskMetrics(
+private class LiveTaskMetrics(
val ids: Array[Long],
val values: Array[Long],
val succeeded: Boolean)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala
index 586d3ae411c74..910f2e52fdbb3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala
@@ -25,21 +25,17 @@ import scala.collection.mutable.ArrayBuffer
import com.fasterxml.jackson.databind.annotation.JsonDeserialize
-import org.apache.spark.{JobExecutionStatus, SparkConf}
-import org.apache.spark.scheduler.SparkListener
-import org.apache.spark.status.AppStatusPlugin
+import org.apache.spark.JobExecutionStatus
import org.apache.spark.status.KVUtils.KVIndexParam
-import org.apache.spark.ui.SparkUI
-import org.apache.spark.util.Utils
import org.apache.spark.util.kvstore.KVStore
/**
* Provides a view of a KVStore with methods that make it easy to query SQL-specific state. There's
* no state kept in this class, so it's ok to have multiple instances of it in an application.
*/
-private[sql] class SQLAppStatusStore(
+class SQLAppStatusStore(
store: KVStore,
- listener: Option[SQLAppStatusListener] = None) {
+ val listener: Option[SQLAppStatusListener] = None) {
def executionsList(): Seq[SQLExecutionUIData] = {
store.view(classOf[SQLExecutionUIData]).asScala.toSeq
@@ -74,47 +70,9 @@ private[sql] class SQLAppStatusStore(
def planGraph(executionId: Long): SparkPlanGraph = {
store.read(classOf[SparkPlanGraphWrapper], executionId).toSparkPlanGraph()
}
-
-}
-
-/**
- * An AppStatusPlugin for handling the SQL UI and listeners.
- */
-private[sql] class SQLAppStatusPlugin extends AppStatusPlugin {
-
- override def setupListeners(
- conf: SparkConf,
- store: KVStore,
- addListenerFn: SparkListener => Unit,
- live: Boolean): Unit = {
- // For live applications, the listener is installed in [[setupUI]]. This also avoids adding
- // the listener when the UI is disabled. Force installation during testing, though.
- if (!live || Utils.isTesting) {
- val listener = new SQLAppStatusListener(conf, store, live, None)
- addListenerFn(listener)
- }
- }
-
- override def setupUI(ui: SparkUI): Unit = {
- ui.sc match {
- case Some(sc) =>
- // If this is a live application, then install a listener that will enable the SQL
- // tab as soon as there's a SQL event posted to the bus.
- val listener = new SQLAppStatusListener(sc.conf, ui.store.store, true, Some(ui))
- sc.listenerBus.addToStatusQueue(listener)
-
- case _ =>
- // For a replayed application, only add the tab if the store already contains SQL data.
- val sqlStore = new SQLAppStatusStore(ui.store.store)
- if (sqlStore.executionsCount() > 0) {
- new SQLTab(sqlStore, ui)
- }
- }
- }
-
}
-private[sql] class SQLExecutionUIData(
+class SQLExecutionUIData(
@KVIndexParam val executionId: Long,
val description: String,
val details: String,
@@ -132,10 +90,9 @@ private[sql] class SQLExecutionUIData(
* from the SQL listener instance.
*/
@JsonDeserialize(keyAs = classOf[JLong])
- val metricValues: Map[Long, String]
- )
+ val metricValues: Map[Long, String])
-private[sql] class SparkPlanGraphWrapper(
+class SparkPlanGraphWrapper(
@KVIndexParam val executionId: Long,
val nodes: Seq[SparkPlanGraphNodeWrapper],
val edges: Seq[SparkPlanGraphEdge]) {
@@ -146,7 +103,7 @@ private[sql] class SparkPlanGraphWrapper(
}
-private[sql] class SparkPlanGraphClusterWrapper(
+class SparkPlanGraphClusterWrapper(
val id: Long,
val name: String,
val desc: String,
@@ -162,7 +119,7 @@ private[sql] class SparkPlanGraphClusterWrapper(
}
/** Only one of the values should be set. */
-private[sql] class SparkPlanGraphNodeWrapper(
+class SparkPlanGraphNodeWrapper(
val node: SparkPlanGraphNode,
val cluster: SparkPlanGraphClusterWrapper) {
@@ -173,7 +130,7 @@ private[sql] class SparkPlanGraphNodeWrapper(
}
-private[sql] case class SQLPlanMetric(
+case class SQLPlanMetric(
name: String,
accumulatorId: Long,
metricType: String)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLHistoryServerPlugin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLHistoryServerPlugin.scala
new file mode 100644
index 0000000000000..522d0cf79bffa
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLHistoryServerPlugin.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.ui
+
+import org.apache.spark.SparkConf
+import org.apache.spark.scheduler.SparkListener
+import org.apache.spark.status.{AppHistoryServerPlugin, ElementTrackingStore}
+import org.apache.spark.ui.SparkUI
+
+class SQLHistoryServerPlugin extends AppHistoryServerPlugin {
+ override def createListeners(conf: SparkConf, store: ElementTrackingStore): Seq[SparkListener] = {
+ Seq(new SQLAppStatusListener(conf, store, live = false))
+ }
+
+ override def setupUI(ui: SparkUI): Unit = {
+ val sqlStatusStore = new SQLAppStatusStore(ui.store.store)
+ if (sqlStatusStore.executionsCount() > 0) {
+ new SQLTab(sqlStatusStore, ui)
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 058c38c8cb8f4..1e076207bc607 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -86,7 +86,7 @@ abstract class Aggregator[-IN, BUF, OUT] extends Serializable {
def bufferEncoder: Encoder[BUF]
/**
- * Specifies the `Encoder` for the final ouput value type.
+ * Specifies the `Encoder` for the final output value type.
* @since 2.0.0
*/
def outputEncoder: Encoder[OUT]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index 03b654f830520..40a058d2cadd2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -66,6 +66,7 @@ case class UserDefinedFunction protected[sql] (
*
* @since 1.3.0
*/
+ @scala.annotation.varargs
def apply(exprs: Column*): Column = {
Column(ScalaUDF(
f,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 3e4659b9eae60..0d11682d80a3c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -24,6 +24,7 @@ import scala.util.Try
import scala.util.control.NonFatal
import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.sql.api.java._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
@@ -32,7 +33,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint}
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.expressions.UserDefinedFunction
-import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -2171,7 +2171,8 @@ object functions {
def base64(e: Column): Column = withExpr { Base64(e.expr) }
/**
- * Concatenates multiple input string columns together into a single string column.
+ * Concatenates multiple input columns together into a single column.
+ * If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.
*
* @group string_funcs
* @since 1.5.0
@@ -2797,6 +2798,21 @@ object functions {
TruncDate(date.expr, Literal(format))
}
+ /**
+ * Returns timestamp truncated to the unit specified by the format.
+ *
+ * @param format: 'year', 'yyyy', 'yy' for truncate by year,
+ * 'month', 'mon', 'mm' for truncate by month,
+ * 'day', 'dd' for truncate by day,
+ * Other options are: 'second', 'minute', 'hour', 'week', 'month', 'quarter'
+ *
+ * @group datetime_funcs
+ * @since 2.3.0
+ */
+ def date_trunc(format: String, timestamp: Column): Column = withExpr {
+ TruncTimestamp(Literal(format), timestamp.expr)
+ }
+
/**
* Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders
* that time as a timestamp in the given time zone. For example, 'GMT+1' would yield
@@ -3238,42 +3254,66 @@ object functions {
*/
def map_values(e: Column): Column = withExpr { MapValues(e.expr) }
- //////////////////////////////////////////////////////////////////////////////////////////////
- //////////////////////////////////////////////////////////////////////////////////////////////
-
// scalastyle:off line.size.limit
// scalastyle:off parameter.number
/* Use the following code to generate:
- (0 to 10).map { x =>
+
+ (0 to 10).foreach { x =>
val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"})
val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _)
val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]).dataType :: $s"})
println(s"""
- /**
- * Defines a deterministic user-defined function of ${x} arguments as user-defined
- * function (UDF). The data types are automatically inferred based on the function's
- * signature. To change a UDF to nondeterministic, call the API
- * `UserDefinedFunction.asNondeterministic()`.
- *
- * @group udf_funcs
- * @since 1.3.0
- */
- def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = {
- val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
- val inputTypes = Try($inputTypes).toOption
- val udf = UserDefinedFunction(f, dataType, inputTypes)
- if (nullable) udf else udf.asNonNullable()
- }""")
+ |/**
+ | * Defines a Scala closure of $x arguments as user-defined function (UDF).
+ | * The data types are automatically inferred based on the Scala closure's
+ | * signature. By default the returned UDF is deterministic. To change it to
+ | * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ | *
+ | * @group udf_funcs
+ | * @since 1.3.0
+ | */
+ |def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = {
+ | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
+ | val inputTypes = Try($inputTypes).toOption
+ | val udf = UserDefinedFunction(f, dataType, inputTypes)
+ | if (nullable) udf else udf.asNonNullable()
+ |}""".stripMargin)
+ }
+
+ (0 to 10).foreach { i =>
+ val extTypeArgs = (0 to i).map(_ => "_").mkString(", ")
+ val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ")
+ val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]"
+ val anyParams = (1 to i).map(_ => "_: Any").mkString(", ")
+ val funcCall = if (i == 0) "() => func" else "func"
+ println(s"""
+ |/**
+ | * Defines a Java UDF$i instance as user-defined function (UDF).
+ | * The caller must specify the output data type, and there is no automatic input type coercion.
+ | * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ | * API `UserDefinedFunction.asNondeterministic()`.
+ | *
+ | * @group udf_funcs
+ | * @since 2.3.0
+ | */
+ |def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = {
+ | val func = f$anyCast.call($anyParams)
+ | UserDefinedFunction($funcCall, returnType, inputTypes = None)
+ |}""".stripMargin)
}
*/
+ //////////////////////////////////////////////////////////////////////////////////////////////
+ // Scala UDF functions
+ //////////////////////////////////////////////////////////////////////////////////////////////
+
/**
- * Defines a deterministic user-defined function of 0 arguments as user-defined
- * function (UDF). The data types are automatically inferred based on the function's
- * signature. To change a UDF to nondeterministic, call the API
- * `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Scala closure of 0 arguments as user-defined function (UDF).
+ * The data types are automatically inferred based on the Scala closure's
+ * signature. By default the returned UDF is deterministic. To change it to
+ * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 1.3.0
@@ -3286,10 +3326,10 @@ object functions {
}
/**
- * Defines a deterministic user-defined function of 1 arguments as user-defined
- * function (UDF). The data types are automatically inferred based on the function's
- * signature. To change a UDF to nondeterministic, call the API
- * `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Scala closure of 1 arguments as user-defined function (UDF).
+ * The data types are automatically inferred based on the Scala closure's
+ * signature. By default the returned UDF is deterministic. To change it to
+ * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 1.3.0
@@ -3302,10 +3342,10 @@ object functions {
}
/**
- * Defines a deterministic user-defined function of 2 arguments as user-defined
- * function (UDF). The data types are automatically inferred based on the function's
- * signature. To change a UDF to nondeterministic, call the API
- * `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Scala closure of 2 arguments as user-defined function (UDF).
+ * The data types are automatically inferred based on the Scala closure's
+ * signature. By default the returned UDF is deterministic. To change it to
+ * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 1.3.0
@@ -3318,10 +3358,10 @@ object functions {
}
/**
- * Defines a deterministic user-defined function of 3 arguments as user-defined
- * function (UDF). The data types are automatically inferred based on the function's
- * signature. To change a UDF to nondeterministic, call the API
- * `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Scala closure of 3 arguments as user-defined function (UDF).
+ * The data types are automatically inferred based on the Scala closure's
+ * signature. By default the returned UDF is deterministic. To change it to
+ * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 1.3.0
@@ -3334,10 +3374,10 @@ object functions {
}
/**
- * Defines a deterministic user-defined function of 4 arguments as user-defined
- * function (UDF). The data types are automatically inferred based on the function's
- * signature. To change a UDF to nondeterministic, call the API
- * `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Scala closure of 4 arguments as user-defined function (UDF).
+ * The data types are automatically inferred based on the Scala closure's
+ * signature. By default the returned UDF is deterministic. To change it to
+ * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 1.3.0
@@ -3350,10 +3390,10 @@ object functions {
}
/**
- * Defines a deterministic user-defined function of 5 arguments as user-defined
- * function (UDF). The data types are automatically inferred based on the function's
- * signature. To change a UDF to nondeterministic, call the API
- * `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Scala closure of 5 arguments as user-defined function (UDF).
+ * The data types are automatically inferred based on the Scala closure's
+ * signature. By default the returned UDF is deterministic. To change it to
+ * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 1.3.0
@@ -3366,10 +3406,10 @@ object functions {
}
/**
- * Defines a deterministic user-defined function of 6 arguments as user-defined
- * function (UDF). The data types are automatically inferred based on the function's
- * signature. To change a UDF to nondeterministic, call the API
- * `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Scala closure of 6 arguments as user-defined function (UDF).
+ * The data types are automatically inferred based on the Scala closure's
+ * signature. By default the returned UDF is deterministic. To change it to
+ * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 1.3.0
@@ -3382,10 +3422,10 @@ object functions {
}
/**
- * Defines a deterministic user-defined function of 7 arguments as user-defined
- * function (UDF). The data types are automatically inferred based on the function's
- * signature. To change a UDF to nondeterministic, call the API
- * `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Scala closure of 7 arguments as user-defined function (UDF).
+ * The data types are automatically inferred based on the Scala closure's
+ * signature. By default the returned UDF is deterministic. To change it to
+ * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 1.3.0
@@ -3398,10 +3438,10 @@ object functions {
}
/**
- * Defines a deterministic user-defined function of 8 arguments as user-defined
- * function (UDF). The data types are automatically inferred based on the function's
- * signature. To change a UDF to nondeterministic, call the API
- * `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Scala closure of 8 arguments as user-defined function (UDF).
+ * The data types are automatically inferred based on the Scala closure's
+ * signature. By default the returned UDF is deterministic. To change it to
+ * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 1.3.0
@@ -3414,10 +3454,10 @@ object functions {
}
/**
- * Defines a deterministic user-defined function of 9 arguments as user-defined
- * function (UDF). The data types are automatically inferred based on the function's
- * signature. To change a UDF to nondeterministic, call the API
- * `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Scala closure of 9 arguments as user-defined function (UDF).
+ * The data types are automatically inferred based on the Scala closure's
+ * signature. By default the returned UDF is deterministic. To change it to
+ * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 1.3.0
@@ -3430,10 +3470,10 @@ object functions {
}
/**
- * Defines a deterministic user-defined function of 10 arguments as user-defined
- * function (UDF). The data types are automatically inferred based on the function's
- * signature. To change a UDF to nondeterministic, call the API
- * `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Scala closure of 10 arguments as user-defined function (UDF).
+ * The data types are automatically inferred based on the Scala closure's
+ * signature. By default the returned UDF is deterministic. To change it to
+ * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 1.3.0
@@ -3445,13 +3485,172 @@ object functions {
if (nullable) udf else udf.asNonNullable()
}
+ //////////////////////////////////////////////////////////////////////////////////////////////
+ // Java UDF functions
+ //////////////////////////////////////////////////////////////////////////////////////////////
+
+ /**
+ * Defines a Java UDF0 instance as user-defined function (UDF).
+ * The caller must specify the output data type, and there is no automatic input type coercion.
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ * API `UserDefinedFunction.asNondeterministic()`.
+ *
+ * @group udf_funcs
+ * @since 2.3.0
+ */
+ def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = {
+ val func = f.asInstanceOf[UDF0[Any]].call()
+ UserDefinedFunction(() => func, returnType, inputTypes = None)
+ }
+
+ /**
+ * Defines a Java UDF1 instance as user-defined function (UDF).
+ * The caller must specify the output data type, and there is no automatic input type coercion.
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ * API `UserDefinedFunction.asNondeterministic()`.
+ *
+ * @group udf_funcs
+ * @since 2.3.0
+ */
+ def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = {
+ val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any)
+ UserDefinedFunction(func, returnType, inputTypes = None)
+ }
+
+ /**
+ * Defines a Java UDF2 instance as user-defined function (UDF).
+ * The caller must specify the output data type, and there is no automatic input type coercion.
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ * API `UserDefinedFunction.asNondeterministic()`.
+ *
+ * @group udf_funcs
+ * @since 2.3.0
+ */
+ def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = {
+ val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any)
+ UserDefinedFunction(func, returnType, inputTypes = None)
+ }
+
+ /**
+ * Defines a Java UDF3 instance as user-defined function (UDF).
+ * The caller must specify the output data type, and there is no automatic input type coercion.
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ * API `UserDefinedFunction.asNondeterministic()`.
+ *
+ * @group udf_funcs
+ * @since 2.3.0
+ */
+ def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = {
+ val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any)
+ UserDefinedFunction(func, returnType, inputTypes = None)
+ }
+
+ /**
+ * Defines a Java UDF4 instance as user-defined function (UDF).
+ * The caller must specify the output data type, and there is no automatic input type coercion.
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ * API `UserDefinedFunction.asNondeterministic()`.
+ *
+ * @group udf_funcs
+ * @since 2.3.0
+ */
+ def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction = {
+ val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any)
+ UserDefinedFunction(func, returnType, inputTypes = None)
+ }
+
+ /**
+ * Defines a Java UDF5 instance as user-defined function (UDF).
+ * The caller must specify the output data type, and there is no automatic input type coercion.
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ * API `UserDefinedFunction.asNondeterministic()`.
+ *
+ * @group udf_funcs
+ * @since 2.3.0
+ */
+ def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
+ val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any)
+ UserDefinedFunction(func, returnType, inputTypes = None)
+ }
+
+ /**
+ * Defines a Java UDF6 instance as user-defined function (UDF).
+ * The caller must specify the output data type, and there is no automatic input type coercion.
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ * API `UserDefinedFunction.asNondeterministic()`.
+ *
+ * @group udf_funcs
+ * @since 2.3.0
+ */
+ def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
+ val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
+ UserDefinedFunction(func, returnType, inputTypes = None)
+ }
+
+ /**
+ * Defines a Java UDF7 instance as user-defined function (UDF).
+ * The caller must specify the output data type, and there is no automatic input type coercion.
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ * API `UserDefinedFunction.asNondeterministic()`.
+ *
+ * @group udf_funcs
+ * @since 2.3.0
+ */
+ def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
+ val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
+ UserDefinedFunction(func, returnType, inputTypes = None)
+ }
+
+ /**
+ * Defines a Java UDF8 instance as user-defined function (UDF).
+ * The caller must specify the output data type, and there is no automatic input type coercion.
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ * API `UserDefinedFunction.asNondeterministic()`.
+ *
+ * @group udf_funcs
+ * @since 2.3.0
+ */
+ def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
+ val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
+ UserDefinedFunction(func, returnType, inputTypes = None)
+ }
+
+ /**
+ * Defines a Java UDF9 instance as user-defined function (UDF).
+ * The caller must specify the output data type, and there is no automatic input type coercion.
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ * API `UserDefinedFunction.asNondeterministic()`.
+ *
+ * @group udf_funcs
+ * @since 2.3.0
+ */
+ def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
+ val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
+ UserDefinedFunction(func, returnType, inputTypes = None)
+ }
+
+ /**
+ * Defines a Java UDF10 instance as user-defined function (UDF).
+ * The caller must specify the output data type, and there is no automatic input type coercion.
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ * API `UserDefinedFunction.asNondeterministic()`.
+ *
+ * @group udf_funcs
+ * @since 2.3.0
+ */
+ def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
+ val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
+ UserDefinedFunction(func, returnType, inputTypes = None)
+ }
+
// scalastyle:on parameter.number
// scalastyle:on line.size.limit
/**
* Defines a deterministic user-defined function (UDF) using a Scala closure. For this variant,
* the caller must specify the output data type, and there is no automatic input type coercion.
- * To change a UDF to nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ * API `UserDefinedFunction.asNondeterministic()`.
*
* @param f A closure in Scala
* @param dataType The output data type of the UDF
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
index fdd25330c5e67..6ae307bce10c8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
@@ -480,7 +480,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
if (tableMetadata.tableType == CatalogTableType.VIEW) {
// Temp or persistent views: refresh (or invalidate) any metadata/data cached
// in the plan recursively.
- table.queryExecution.analyzed.foreach(_.refresh())
+ table.queryExecution.analyzed.refresh()
} else {
// Non-temp tables: refresh the metadata cache.
sessionCatalog.refreshTable(tableIdent)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala
index b9515ec7bca2a..dac463641cfab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala
@@ -73,6 +73,7 @@ object HiveSerDe {
val key = source.toLowerCase(Locale.ROOT) match {
case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet"
case s if s.startsWith("org.apache.spark.sql.orc") => "orc"
+ case s if s.startsWith("org.apache.spark.sql.hive.orc") => "orc"
case s if s.equals("orcfile") => "orc"
case s if s.equals("parquetfile") => "parquet"
case s if s.equals("avrofile") => "avro"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
index 81359cedebea9..0ff39004fa008 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
@@ -28,11 +28,12 @@ import org.apache.hadoop.fs.FsUrlStreamHandlerFactory
import org.apache.spark.{SparkConf, SparkContext, SparkException}
import org.apache.spark.internal.Logging
-import org.apache.spark.scheduler.LiveListenerBus
-import org.apache.spark.sql.{SparkSession, SQLContext}
+import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.execution.CacheManager
+import org.apache.spark.sql.execution.ui.{SQLAppStatusListener, SQLAppStatusStore, SQLTab}
import org.apache.spark.sql.internal.StaticSQLConf._
+import org.apache.spark.status.ElementTrackingStore
import org.apache.spark.util.{MutableURLClassLoader, Utils}
@@ -82,6 +83,19 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging {
*/
val cacheManager: CacheManager = new CacheManager
+ /**
+ * A status store to query SQL status/metrics of this Spark application, based on SQL-specific
+ * [[org.apache.spark.scheduler.SparkListenerEvent]]s.
+ */
+ val statusStore: SQLAppStatusStore = {
+ val kvStore = sparkContext.statusStore.store.asInstanceOf[ElementTrackingStore]
+ val listener = new SQLAppStatusListener(sparkContext.conf, kvStore, live = true)
+ sparkContext.listenerBus.addToStatusQueue(listener)
+ val statusStore = new SQLAppStatusStore(kvStore, Some(listener))
+ sparkContext.ui.foreach(new SQLTab(statusStore, _))
+ statusStore
+ }
+
/**
* A catalog that interacts with external systems.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala
index f3bfea5f6bfc8..8b92c8b4f56b5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.types.{DataType, MetadataBuilder}
/**
* AggregatedDialect can unify multiple dialects into one virtual Dialect.
* Dialects are tried in order, and the first dialect that does not return a
- * neutral element will will.
+ * neutral element will win.
*
* @param dialects List of dialects.
*/
@@ -63,4 +63,8 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect
case _ => Some(false)
}
}
+
+ override def getTruncateQuery(table: String): String = {
+ dialects.head.getTruncateQuery(table)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index 7c38ed68c0413..83d87a11810c1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -116,6 +116,18 @@ abstract class JdbcDialect extends Serializable {
s"SELECT * FROM $table WHERE 1=0"
}
+ /**
+ * The SQL query that should be used to truncate a table. Dialects can override this method to
+ * return a query that is suitable for a particular database. For PostgreSQL, for instance,
+ * a different query is used to prevent "TRUNCATE" affecting other tables.
+ * @param table The name of the table.
+ * @return The SQL query to use for truncating a table
+ */
+ @Since("2.3.0")
+ def getTruncateQuery(table: String): String = {
+ s"TRUNCATE TABLE $table"
+ }
+
/**
* Override connection specific properties to run before a select is made. This is in place to
* allow dialects that need special treatment to optimize behavior.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
index e3f106c41c7ff..6ef77f24460be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
@@ -18,7 +18,10 @@
package org.apache.spark.sql.jdbc
import java.sql.{Date, Timestamp, Types}
+import java.util.TimeZone
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -29,6 +32,13 @@ private case object OracleDialect extends JdbcDialect {
override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle")
+ private def supportTimeZoneTypes: Boolean = {
+ val timeZone = DateTimeUtils.getTimeZone(SQLConf.get.sessionLocalTimeZone)
+ // TODO: support timezone types when users are not using the JVM timezone, which
+ // is the default value of SESSION_LOCAL_TIMEZONE
+ timeZone == TimeZone.getDefault
+ }
+
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
sqlType match {
@@ -49,7 +59,8 @@ private case object OracleDialect extends JdbcDialect {
case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10))
case _ => None
}
- case TIMESTAMPTZ => Some(TimestampType) // Value for Timestamp with Time Zone in Oracle
+ case TIMESTAMPTZ if supportTimeZoneTypes
+ => Some(TimestampType) // Value for Timestamp with Time Zone in Oracle
case BINARY_FLOAT => Some(FloatType) // Value for OracleTypes.BINARY_FLOAT
case BINARY_DOUBLE => Some(DoubleType) // Value for OracleTypes.BINARY_DOUBLE
case _ => None
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
index 4f61a328f47ca..13a2035f4d0c4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
@@ -85,6 +85,17 @@ private object PostgresDialect extends JdbcDialect {
s"SELECT 1 FROM $table LIMIT 1"
}
+ /**
+ * The SQL query used to truncate a table. For Postgres, the default behaviour is to
+ * also truncate any descendant tables. As this is a (possibly unwanted) side-effect,
+ * the Postgres dialect adds 'ONLY' to truncate only the table in question
+ * @param table The name of the table.
+ * @return The SQL query to use for truncating a table
+ */
+ override def getTruncateQuery(table: String): String = {
+ s"TRUNCATE TABLE ONLY $table"
+ }
+
override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = {
super.beforeFetch(connection, properties)
@@ -97,8 +108,7 @@ private object PostgresDialect extends JdbcDialect {
if (properties.getOrElse(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) {
connection.setAutoCommit(false)
}
-
}
- override def isCascadingTruncateTable(): Option[Boolean] = Some(true)
+ override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index a42e28053a96a..52f2e2639cd86 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.streaming
-import java.util.Locale
+import java.util.{Locale, Optional}
import scala.collection.JavaConverters._
@@ -26,8 +26,12 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
-import org.apache.spark.sql.execution.streaming.StreamingRelation
+import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2}
+import org.apache.spark.sql.sources.StreamSourceProvider
+import org.apache.spark.sql.sources.v2.DataSourceV2Options
+import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport}
import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.Utils
/**
* Interface used to load a streaming `Dataset` from external storage systems (e.g. file systems,
@@ -153,13 +157,45 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
"read files of Hive data source directly.")
}
- val dataSource =
- DataSource(
- sparkSession,
- userSpecifiedSchema = userSpecifiedSchema,
- className = source,
- options = extraOptions.toMap)
- Dataset.ofRows(sparkSession, StreamingRelation(dataSource))
+ val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf).newInstance()
+ val options = new DataSourceV2Options(extraOptions.asJava)
+ // We need to generate the V1 data source so we can pass it to the V2 relation as a shim.
+ // We can't be sure at this point whether we'll actually want to use V2, since we don't know the
+ // writer or whether the query is continuous.
+ val v1DataSource = DataSource(
+ sparkSession,
+ userSpecifiedSchema = userSpecifiedSchema,
+ className = source,
+ options = extraOptions.toMap)
+ val v1Relation = ds match {
+ case _: StreamSourceProvider => Some(StreamingRelation(v1DataSource))
+ case _ => None
+ }
+ ds match {
+ case s: MicroBatchReadSupport =>
+ val tempReader = s.createMicroBatchReader(
+ Optional.ofNullable(userSpecifiedSchema.orNull),
+ Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath,
+ options)
+ Dataset.ofRows(
+ sparkSession,
+ StreamingRelationV2(
+ s, source, extraOptions.toMap,
+ tempReader.readSchema().toAttributes, v1Relation)(sparkSession))
+ case s: ContinuousReadSupport =>
+ val tempReader = s.createContinuousReader(
+ Optional.ofNullable(userSpecifiedSchema.orNull),
+ Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath,
+ options)
+ Dataset.ofRows(
+ sparkSession,
+ StreamingRelationV2(
+ s, source, extraOptions.toMap,
+ tempReader.readSchema().toAttributes, v1Relation)(sparkSession))
+ case _ =>
+ // Code path for data source v1.
+ Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource))
+ }
}
/**
@@ -239,17 +275,20 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
*
* `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be
* considered in every trigger.
- * `sep` (default `,`): sets the single character as a separator for each
+ * `sep` (default `,`): sets a single character as a separator for each
* field and value.
* `encoding` (default `UTF-8`): decodes the CSV files by the given encoding
* type.
- * `quote` (default `"`): sets the single character used for escaping quoted values where
+ * `quote` (default `"`): sets a single character used for escaping quoted values where
* the separator can be part of the value. If you would like to turn off quotations, you need to
* set not `null` but an empty string. This behaviour is different form
* `com.databricks.spark.csv`.
- * `escape` (default `\`): sets the single character used for escaping quotes inside
+ * `escape` (default `\`): sets a single character used for escaping quotes inside
* an already quoted value.
- * `comment` (default empty string): sets the single character used for skipping lines
+ * `charToEscapeQuoteEscaping` (default `escape` or `\0`): sets a single character used for
+ * escaping the escape for the quote character. The default value is escape character when escape
+ * and quote characters are different, `\0` otherwise.
+ * `comment` (default empty string): sets a single character used for skipping lines
* beginning with this character. By default, it is disabled.
* `header` (default `false`): uses the first line as names of columns.
* `inferSchema` (default `false`): infers the input schema automatically from data. It
@@ -298,6 +337,21 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
*/
def csv(path: String): DataFrame = format("csv").load(path)
+ /**
+ * Loads a ORC file stream, returning the result as a `DataFrame`.
+ *
+ * You can set the following ORC-specific option(s) for reading ORC files:
+ *
+ * `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be
+ * considered in every trigger.
+ *
+ *
+ * @since 2.3.0
+ */
+ def orc(path: String): DataFrame = {
+ format("orc").load(path)
+ }
+
/**
* Loads a Parquet file stream, returning the result as a `DataFrame`.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 0be69b98abc8a..db588ae282f38 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -26,7 +26,9 @@ import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
-import org.apache.spark.sql.execution.streaming.{ForeachSink, MemoryPlan, MemorySink}
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
+import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2}
/**
* Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems,
@@ -240,14 +242,23 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
if (extraOptions.get("queryName").isEmpty) {
throw new AnalysisException("queryName must be specified for memory sink")
}
- val sink = new MemorySink(df.schema, outputMode)
- val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink))
+ val (sink, resultDf) = trigger match {
+ case _: ContinuousTrigger =>
+ val s = new MemorySinkV2()
+ val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes))
+ (s, r)
+ case _ =>
+ val s = new MemorySink(df.schema, outputMode)
+ val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s))
+ (s, r)
+ }
val chkpointLoc = extraOptions.get("checkpointLocation")
val recoverFromChkpoint = outputMode == OutputMode.Complete()
val query = df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
chkpointLoc,
df,
+ extraOptions.toMap,
sink,
outputMode,
useTempCheckpointLocation = true,
@@ -262,6 +273,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
extraOptions.get("queryName"),
extraOptions.get("checkpointLocation"),
df,
+ extraOptions.toMap,
sink,
outputMode,
useTempCheckpointLocation = true,
@@ -277,6 +289,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
extraOptions.get("queryName"),
extraOptions.get("checkpointLocation"),
df,
+ extraOptions.toMap,
dataSource.createSink(outputMode),
outputMode,
useTempCheckpointLocation = source == "console",
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
index 48b0ea20e5da1..4b27e0d4ef47b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
@@ -29,8 +29,10 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker
import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger}
import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport}
import org.apache.spark.util.{Clock, SystemClock, Utils}
/**
@@ -188,7 +190,8 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
userSpecifiedName: Option[String],
userSpecifiedCheckpointLocation: Option[String],
df: DataFrame,
- sink: Sink,
+ extraOptions: Map[String, String],
+ sink: BaseStreamingSink,
outputMode: OutputMode,
useTempCheckpointLocation: Boolean,
recoverFromCheckpointLocation: Boolean,
@@ -237,16 +240,36 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
"is not supported in streaming DataFrames/Datasets and will be disabled.")
}
- new StreamingQueryWrapper(new StreamExecution(
- sparkSession,
- userSpecifiedName.orNull,
- checkpointLocation,
- analyzedPlan,
- sink,
- trigger,
- triggerClock,
- outputMode,
- deleteCheckpointOnStop))
+ (sink, trigger) match {
+ case (v2Sink: ContinuousWriteSupport, trigger: ContinuousTrigger) =>
+ UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode)
+ new StreamingQueryWrapper(new ContinuousExecution(
+ sparkSession,
+ userSpecifiedName.orNull,
+ checkpointLocation,
+ analyzedPlan,
+ v2Sink,
+ trigger,
+ triggerClock,
+ outputMode,
+ extraOptions,
+ deleteCheckpointOnStop))
+ case (_: MicroBatchWriteSupport, _) | (_: Sink, _) =>
+ new StreamingQueryWrapper(new MicroBatchExecution(
+ sparkSession,
+ userSpecifiedName.orNull,
+ checkpointLocation,
+ analyzedPlan,
+ sink,
+ trigger,
+ triggerClock,
+ outputMode,
+ extraOptions,
+ deleteCheckpointOnStop))
+ case (_: ContinuousWriteSupport, t) if !t.isInstanceOf[ContinuousTrigger] =>
+ throw new AnalysisException(
+ "Sink only supports continuous writes, but a continuous trigger was not specified.")
+ }
}
/**
@@ -269,7 +292,8 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
userSpecifiedName: Option[String],
userSpecifiedCheckpointLocation: Option[String],
df: DataFrame,
- sink: Sink,
+ extraOptions: Map[String, String],
+ sink: BaseStreamingSink,
outputMode: OutputMode,
useTempCheckpointLocation: Boolean = false,
recoverFromCheckpointLocation: Boolean = true,
@@ -279,6 +303,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
userSpecifiedName,
userSpecifiedCheckpointLocation,
df,
+ extraOptions,
sink,
outputMode,
useTempCheckpointLocation,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
index cedc1dce4a703..0dcb666e2c3e4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
@@ -152,7 +152,7 @@ class StreamingQueryProgress private[sql](
* @param endOffset The ending offset for data being read.
* @param numInputRows The number of records read from this source.
* @param inputRowsPerSecond The rate at which data is arriving from this source.
- * @param processedRowsPerSecond The rate at which data from this source is being procressed by
+ * @param processedRowsPerSecond The rate at which data from this source is being processed by
* Spark.
* @since 2.1.0
*/
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index b007093dad84b..69a2904f5f3fe 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -36,6 +36,7 @@
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.test.TestSparkSession;
import org.apache.spark.sql.types.*;
import org.apache.spark.util.sketch.BloomFilter;
@@ -455,4 +456,15 @@ public void testCircularReferenceBean() {
CircularReference1Bean bean = new CircularReference1Bean();
spark.createDataFrame(Arrays.asList(bean), CircularReference1Bean.class);
}
+
+ @Test
+ public void testUDF() {
+ UserDefinedFunction foo = udf((Integer i, String s) -> i.toString() + s, DataTypes.StringType);
+ Dataset df = spark.table("testData").select(foo.apply(col("key"), col("value")));
+ String[] result = df.collectAsList().stream().map(row -> row.getString(0))
+ .toArray(String[]::new);
+ String[] expected = spark.table("testData").collectAsList().stream()
+ .map(row -> row.get(0).toString() + row.getString(1)).toArray(String[]::new);
+ Assert.assertArrayEquals(expected, result);
+ }
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java
index 447a71d284fbb..288f5e7426c05 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java
@@ -47,7 +47,7 @@ public MyDoubleAvg() {
_inputDataType = DataTypes.createStructType(inputFields);
// The buffer has two values, bufferSum for storing the current sum and
- // bufferCount for storing the number of non-null input values that have been contribuetd
+ // bufferCount for storing the number of non-null input values that have been contributed
// to the current sum.
List bufferFields = new ArrayList<>();
bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true));
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
index 1e1384549a410..c5070b734d521 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
@@ -60,3 +60,12 @@ SELECT a, COUNT(1) FROM testData WHERE false GROUP BY a;
-- Aggregate with empty input and empty GroupBy expressions.
SELECT COUNT(1) FROM testData WHERE false;
SELECT 1 FROM (SELECT COUNT(1) FROM testData WHERE false) t;
+
+-- Aggregate with empty GroupBy expressions and filter on top
+SELECT 1 from (
+ SELECT 1 AS z,
+ MIN(a.x)
+ FROM (select 1 as x) a
+ WHERE false
+) b
+where b.z != b.z
diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql b/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql
new file mode 100644
index 0000000000000..8afa3270f4de4
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql
@@ -0,0 +1,28 @@
+CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1) AS GROUPING(a);
+CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1) AS GROUPING(a);
+
+CREATE TEMPORARY VIEW empty_table as SELECT a FROM t2 WHERE false;
+
+SELECT * FROM t1 INNER JOIN empty_table;
+SELECT * FROM t1 CROSS JOIN empty_table;
+SELECT * FROM t1 LEFT OUTER JOIN empty_table;
+SELECT * FROM t1 RIGHT OUTER JOIN empty_table;
+SELECT * FROM t1 FULL OUTER JOIN empty_table;
+SELECT * FROM t1 LEFT SEMI JOIN empty_table;
+SELECT * FROM t1 LEFT ANTI JOIN empty_table;
+
+SELECT * FROM empty_table INNER JOIN t1;
+SELECT * FROM empty_table CROSS JOIN t1;
+SELECT * FROM empty_table LEFT OUTER JOIN t1;
+SELECT * FROM empty_table RIGHT OUTER JOIN t1;
+SELECT * FROM empty_table FULL OUTER JOIN t1;
+SELECT * FROM empty_table LEFT SEMI JOIN t1;
+SELECT * FROM empty_table LEFT ANTI JOIN t1;
+
+SELECT * FROM empty_table INNER JOIN empty_table;
+SELECT * FROM empty_table CROSS JOIN empty_table;
+SELECT * FROM empty_table LEFT OUTER JOIN empty_table;
+SELECT * FROM empty_table RIGHT OUTER JOIN empty_table;
+SELECT * FROM empty_table FULL OUTER JOIN empty_table;
+SELECT * FROM empty_table LEFT SEMI JOIN empty_table;
+SELECT * FROM empty_table LEFT ANTI JOIN empty_table;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
index 40d0c064f5c44..4113734e1707e 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
@@ -24,3 +24,26 @@ select left("abcd", 2), left("abcd", 5), left("abcd", '2'), left("abcd", null);
select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a');
select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null);
select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a');
+
+-- turn off concatBinaryAsString
+set spark.sql.function.concatBinaryAsString=false;
+
+-- Check if catalyst combine nested `Concat`s if concatBinaryAsString=false
+EXPLAIN SELECT ((col1 || col2) || (col3 || col4)) col
+FROM (
+ SELECT
+ string(id) col1,
+ string(id + 1) col2,
+ encode(string(id + 2), 'utf-8') col3,
+ encode(string(id + 3), 'utf-8') col4
+ FROM range(10)
+);
+
+EXPLAIN SELECT (col1 || (col3 || col4)) col
+FROM (
+ SELECT
+ string(id) col1,
+ encode(string(id + 2), 'utf-8') col3,
+ encode(string(id + 3), 'utf-8') col4
+ FROM range(10)
+)
diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/binaryComparison.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/binaryComparison.sql
new file mode 100644
index 0000000000000..522322ac480be
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/binaryComparison.sql
@@ -0,0 +1,287 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+-- Binary Comparison
+
+CREATE TEMPORARY VIEW t AS SELECT 1;
+
+SELECT cast(1 as binary) = '1' FROM t;
+SELECT cast(1 as binary) > '2' FROM t;
+SELECT cast(1 as binary) >= '2' FROM t;
+SELECT cast(1 as binary) < '2' FROM t;
+SELECT cast(1 as binary) <= '2' FROM t;
+SELECT cast(1 as binary) <> '2' FROM t;
+SELECT cast(1 as binary) = cast(null as string) FROM t;
+SELECT cast(1 as binary) > cast(null as string) FROM t;
+SELECT cast(1 as binary) >= cast(null as string) FROM t;
+SELECT cast(1 as binary) < cast(null as string) FROM t;
+SELECT cast(1 as binary) <= cast(null as string) FROM t;
+SELECT cast(1 as binary) <> cast(null as string) FROM t;
+SELECT '1' = cast(1 as binary) FROM t;
+SELECT '2' > cast(1 as binary) FROM t;
+SELECT '2' >= cast(1 as binary) FROM t;
+SELECT '2' < cast(1 as binary) FROM t;
+SELECT '2' <= cast(1 as binary) FROM t;
+SELECT '2' <> cast(1 as binary) FROM t;
+SELECT cast(null as string) = cast(1 as binary) FROM t;
+SELECT cast(null as string) > cast(1 as binary) FROM t;
+SELECT cast(null as string) >= cast(1 as binary) FROM t;
+SELECT cast(null as string) < cast(1 as binary) FROM t;
+SELECT cast(null as string) <= cast(1 as binary) FROM t;
+SELECT cast(null as string) <> cast(1 as binary) FROM t;
+SELECT cast(1 as tinyint) = '1' FROM t;
+SELECT cast(1 as tinyint) > '2' FROM t;
+SELECT cast(1 as tinyint) >= '2' FROM t;
+SELECT cast(1 as tinyint) < '2' FROM t;
+SELECT cast(1 as tinyint) <= '2' FROM t;
+SELECT cast(1 as tinyint) <> '2' FROM t;
+SELECT cast(1 as tinyint) = cast(null as string) FROM t;
+SELECT cast(1 as tinyint) > cast(null as string) FROM t;
+SELECT cast(1 as tinyint) >= cast(null as string) FROM t;
+SELECT cast(1 as tinyint) < cast(null as string) FROM t;
+SELECT cast(1 as tinyint) <= cast(null as string) FROM t;
+SELECT cast(1 as tinyint) <> cast(null as string) FROM t;
+SELECT '1' = cast(1 as tinyint) FROM t;
+SELECT '2' > cast(1 as tinyint) FROM t;
+SELECT '2' >= cast(1 as tinyint) FROM t;
+SELECT '2' < cast(1 as tinyint) FROM t;
+SELECT '2' <= cast(1 as tinyint) FROM t;
+SELECT '2' <> cast(1 as tinyint) FROM t;
+SELECT cast(null as string) = cast(1 as tinyint) FROM t;
+SELECT cast(null as string) > cast(1 as tinyint) FROM t;
+SELECT cast(null as string) >= cast(1 as tinyint) FROM t;
+SELECT cast(null as string) < cast(1 as tinyint) FROM t;
+SELECT cast(null as string) <= cast(1 as tinyint) FROM t;
+SELECT cast(null as string) <> cast(1 as tinyint) FROM t;
+SELECT cast(1 as smallint) = '1' FROM t;
+SELECT cast(1 as smallint) > '2' FROM t;
+SELECT cast(1 as smallint) >= '2' FROM t;
+SELECT cast(1 as smallint) < '2' FROM t;
+SELECT cast(1 as smallint) <= '2' FROM t;
+SELECT cast(1 as smallint) <> '2' FROM t;
+SELECT cast(1 as smallint) = cast(null as string) FROM t;
+SELECT cast(1 as smallint) > cast(null as string) FROM t;
+SELECT cast(1 as smallint) >= cast(null as string) FROM t;
+SELECT cast(1 as smallint) < cast(null as string) FROM t;
+SELECT cast(1 as smallint) <= cast(null as string) FROM t;
+SELECT cast(1 as smallint) <> cast(null as string) FROM t;
+SELECT '1' = cast(1 as smallint) FROM t;
+SELECT '2' > cast(1 as smallint) FROM t;
+SELECT '2' >= cast(1 as smallint) FROM t;
+SELECT '2' < cast(1 as smallint) FROM t;
+SELECT '2' <= cast(1 as smallint) FROM t;
+SELECT '2' <> cast(1 as smallint) FROM t;
+SELECT cast(null as string) = cast(1 as smallint) FROM t;
+SELECT cast(null as string) > cast(1 as smallint) FROM t;
+SELECT cast(null as string) >= cast(1 as smallint) FROM t;
+SELECT cast(null as string) < cast(1 as smallint) FROM t;
+SELECT cast(null as string) <= cast(1 as smallint) FROM t;
+SELECT cast(null as string) <> cast(1 as smallint) FROM t;
+SELECT cast(1 as int) = '1' FROM t;
+SELECT cast(1 as int) > '2' FROM t;
+SELECT cast(1 as int) >= '2' FROM t;
+SELECT cast(1 as int) < '2' FROM t;
+SELECT cast(1 as int) <= '2' FROM t;
+SELECT cast(1 as int) <> '2' FROM t;
+SELECT cast(1 as int) = cast(null as string) FROM t;
+SELECT cast(1 as int) > cast(null as string) FROM t;
+SELECT cast(1 as int) >= cast(null as string) FROM t;
+SELECT cast(1 as int) < cast(null as string) FROM t;
+SELECT cast(1 as int) <= cast(null as string) FROM t;
+SELECT cast(1 as int) <> cast(null as string) FROM t;
+SELECT '1' = cast(1 as int) FROM t;
+SELECT '2' > cast(1 as int) FROM t;
+SELECT '2' >= cast(1 as int) FROM t;
+SELECT '2' < cast(1 as int) FROM t;
+SELECT '2' <> cast(1 as int) FROM t;
+SELECT '2' <= cast(1 as int) FROM t;
+SELECT cast(null as string) = cast(1 as int) FROM t;
+SELECT cast(null as string) > cast(1 as int) FROM t;
+SELECT cast(null as string) >= cast(1 as int) FROM t;
+SELECT cast(null as string) < cast(1 as int) FROM t;
+SELECT cast(null as string) <> cast(1 as int) FROM t;
+SELECT cast(null as string) <= cast(1 as int) FROM t;
+SELECT cast(1 as bigint) = '1' FROM t;
+SELECT cast(1 as bigint) > '2' FROM t;
+SELECT cast(1 as bigint) >= '2' FROM t;
+SELECT cast(1 as bigint) < '2' FROM t;
+SELECT cast(1 as bigint) <= '2' FROM t;
+SELECT cast(1 as bigint) <> '2' FROM t;
+SELECT cast(1 as bigint) = cast(null as string) FROM t;
+SELECT cast(1 as bigint) > cast(null as string) FROM t;
+SELECT cast(1 as bigint) >= cast(null as string) FROM t;
+SELECT cast(1 as bigint) < cast(null as string) FROM t;
+SELECT cast(1 as bigint) <= cast(null as string) FROM t;
+SELECT cast(1 as bigint) <> cast(null as string) FROM t;
+SELECT '1' = cast(1 as bigint) FROM t;
+SELECT '2' > cast(1 as bigint) FROM t;
+SELECT '2' >= cast(1 as bigint) FROM t;
+SELECT '2' < cast(1 as bigint) FROM t;
+SELECT '2' <= cast(1 as bigint) FROM t;
+SELECT '2' <> cast(1 as bigint) FROM t;
+SELECT cast(null as string) = cast(1 as bigint) FROM t;
+SELECT cast(null as string) > cast(1 as bigint) FROM t;
+SELECT cast(null as string) >= cast(1 as bigint) FROM t;
+SELECT cast(null as string) < cast(1 as bigint) FROM t;
+SELECT cast(null as string) <= cast(1 as bigint) FROM t;
+SELECT cast(null as string) <> cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(10, 0)) = '1' FROM t;
+SELECT cast(1 as decimal(10, 0)) > '2' FROM t;
+SELECT cast(1 as decimal(10, 0)) >= '2' FROM t;
+SELECT cast(1 as decimal(10, 0)) < '2' FROM t;
+SELECT cast(1 as decimal(10, 0)) <> '2' FROM t;
+SELECT cast(1 as decimal(10, 0)) <= '2' FROM t;
+SELECT cast(1 as decimal(10, 0)) = cast(null as string) FROM t;
+SELECT cast(1 as decimal(10, 0)) > cast(null as string) FROM t;
+SELECT cast(1 as decimal(10, 0)) >= cast(null as string) FROM t;
+SELECT cast(1 as decimal(10, 0)) < cast(null as string) FROM t;
+SELECT cast(1 as decimal(10, 0)) <> cast(null as string) FROM t;
+SELECT cast(1 as decimal(10, 0)) <= cast(null as string) FROM t;
+SELECT '1' = cast(1 as decimal(10, 0)) FROM t;
+SELECT '2' > cast(1 as decimal(10, 0)) FROM t;
+SELECT '2' >= cast(1 as decimal(10, 0)) FROM t;
+SELECT '2' < cast(1 as decimal(10, 0)) FROM t;
+SELECT '2' <= cast(1 as decimal(10, 0)) FROM t;
+SELECT '2' <> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(null as string) = cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(null as string) > cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(null as string) >= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(null as string) < cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(null as string) <= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(null as string) <> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as double) = '1' FROM t;
+SELECT cast(1 as double) > '2' FROM t;
+SELECT cast(1 as double) >= '2' FROM t;
+SELECT cast(1 as double) < '2' FROM t;
+SELECT cast(1 as double) <= '2' FROM t;
+SELECT cast(1 as double) <> '2' FROM t;
+SELECT cast(1 as double) = cast(null as string) FROM t;
+SELECT cast(1 as double) > cast(null as string) FROM t;
+SELECT cast(1 as double) >= cast(null as string) FROM t;
+SELECT cast(1 as double) < cast(null as string) FROM t;
+SELECT cast(1 as double) <= cast(null as string) FROM t;
+SELECT cast(1 as double) <> cast(null as string) FROM t;
+SELECT '1' = cast(1 as double) FROM t;
+SELECT '2' > cast(1 as double) FROM t;
+SELECT '2' >= cast(1 as double) FROM t;
+SELECT '2' < cast(1 as double) FROM t;
+SELECT '2' <= cast(1 as double) FROM t;
+SELECT '2' <> cast(1 as double) FROM t;
+SELECT cast(null as string) = cast(1 as double) FROM t;
+SELECT cast(null as string) > cast(1 as double) FROM t;
+SELECT cast(null as string) >= cast(1 as double) FROM t;
+SELECT cast(null as string) < cast(1 as double) FROM t;
+SELECT cast(null as string) <= cast(1 as double) FROM t;
+SELECT cast(null as string) <> cast(1 as double) FROM t;
+SELECT cast(1 as float) = '1' FROM t;
+SELECT cast(1 as float) > '2' FROM t;
+SELECT cast(1 as float) >= '2' FROM t;
+SELECT cast(1 as float) < '2' FROM t;
+SELECT cast(1 as float) <= '2' FROM t;
+SELECT cast(1 as float) <> '2' FROM t;
+SELECT cast(1 as float) = cast(null as string) FROM t;
+SELECT cast(1 as float) > cast(null as string) FROM t;
+SELECT cast(1 as float) >= cast(null as string) FROM t;
+SELECT cast(1 as float) < cast(null as string) FROM t;
+SELECT cast(1 as float) <= cast(null as string) FROM t;
+SELECT cast(1 as float) <> cast(null as string) FROM t;
+SELECT '1' = cast(1 as float) FROM t;
+SELECT '2' > cast(1 as float) FROM t;
+SELECT '2' >= cast(1 as float) FROM t;
+SELECT '2' < cast(1 as float) FROM t;
+SELECT '2' <= cast(1 as float) FROM t;
+SELECT '2' <> cast(1 as float) FROM t;
+SELECT cast(null as string) = cast(1 as float) FROM t;
+SELECT cast(null as string) > cast(1 as float) FROM t;
+SELECT cast(null as string) >= cast(1 as float) FROM t;
+SELECT cast(null as string) < cast(1 as float) FROM t;
+SELECT cast(null as string) <= cast(1 as float) FROM t;
+SELECT cast(null as string) <> cast(1 as float) FROM t;
+-- the following queries return 1 if the search condition is satisfied
+-- and returns nothing if the search condition is not satisfied
+SELECT '1996-09-09' = date('1996-09-09') FROM t;
+SELECT '1996-9-10' > date('1996-09-09') FROM t;
+SELECT '1996-9-10' >= date('1996-09-09') FROM t;
+SELECT '1996-9-10' < date('1996-09-09') FROM t;
+SELECT '1996-9-10' <= date('1996-09-09') FROM t;
+SELECT '1996-9-10' <> date('1996-09-09') FROM t;
+SELECT cast(null as string) = date('1996-09-09') FROM t;
+SELECT cast(null as string)> date('1996-09-09') FROM t;
+SELECT cast(null as string)>= date('1996-09-09') FROM t;
+SELECT cast(null as string)< date('1996-09-09') FROM t;
+SELECT cast(null as string)<= date('1996-09-09') FROM t;
+SELECT cast(null as string)<> date('1996-09-09') FROM t;
+SELECT date('1996-09-09') = '1996-09-09' FROM t;
+SELECT date('1996-9-10') > '1996-09-09' FROM t;
+SELECT date('1996-9-10') >= '1996-09-09' FROM t;
+SELECT date('1996-9-10') < '1996-09-09' FROM t;
+SELECT date('1996-9-10') <= '1996-09-09' FROM t;
+SELECT date('1996-9-10') <> '1996-09-09' FROM t;
+SELECT date('1996-09-09') = cast(null as string) FROM t;
+SELECT date('1996-9-10') > cast(null as string) FROM t;
+SELECT date('1996-9-10') >= cast(null as string) FROM t;
+SELECT date('1996-9-10') < cast(null as string) FROM t;
+SELECT date('1996-9-10') <= cast(null as string) FROM t;
+SELECT date('1996-9-10') <> cast(null as string) FROM t;
+SELECT '1996-09-09 12:12:12.4' = timestamp('1996-09-09 12:12:12.4') FROM t;
+SELECT '1996-09-09 12:12:12.5' > timestamp('1996-09-09 12:12:12.4') FROM t;
+SELECT '1996-09-09 12:12:12.5' >= timestamp('1996-09-09 12:12:12.4') FROM t;
+SELECT '1996-09-09 12:12:12.5' < timestamp('1996-09-09 12:12:12.4') FROM t;
+SELECT '1996-09-09 12:12:12.5' <= timestamp('1996-09-09 12:12:12.4') FROM t;
+SELECT '1996-09-09 12:12:12.5' <> timestamp('1996-09-09 12:12:12.4') FROM t;
+SELECT cast(null as string) = timestamp('1996-09-09 12:12:12.4') FROM t;
+SELECT cast(null as string) > timestamp('1996-09-09 12:12:12.4') FROM t;
+SELECT cast(null as string) >= timestamp('1996-09-09 12:12:12.4') FROM t;
+SELECT cast(null as string) < timestamp('1996-09-09 12:12:12.4') FROM t;
+SELECT cast(null as string) <= timestamp('1996-09-09 12:12:12.4') FROM t;
+SELECT cast(null as string) <> timestamp('1996-09-09 12:12:12.4') FROM t;
+SELECT timestamp('1996-09-09 12:12:12.4' )= '1996-09-09 12:12:12.4' FROM t;
+SELECT timestamp('1996-09-09 12:12:12.5' )> '1996-09-09 12:12:12.4' FROM t;
+SELECT timestamp('1996-09-09 12:12:12.5' )>= '1996-09-09 12:12:12.4' FROM t;
+SELECT timestamp('1996-09-09 12:12:12.5' )< '1996-09-09 12:12:12.4' FROM t;
+SELECT timestamp('1996-09-09 12:12:12.5' )<= '1996-09-09 12:12:12.4' FROM t;
+SELECT timestamp('1996-09-09 12:12:12.5' )<> '1996-09-09 12:12:12.4' FROM t;
+SELECT timestamp('1996-09-09 12:12:12.4' )= cast(null as string) FROM t;
+SELECT timestamp('1996-09-09 12:12:12.5' )> cast(null as string) FROM t;
+SELECT timestamp('1996-09-09 12:12:12.5' )>= cast(null as string) FROM t;
+SELECT timestamp('1996-09-09 12:12:12.5' )< cast(null as string) FROM t;
+SELECT timestamp('1996-09-09 12:12:12.5' )<= cast(null as string) FROM t;
+SELECT timestamp('1996-09-09 12:12:12.5' )<> cast(null as string) FROM t;
+SELECT ' ' = X'0020' FROM t;
+SELECT ' ' > X'001F' FROM t;
+SELECT ' ' >= X'001F' FROM t;
+SELECT ' ' < X'001F' FROM t;
+SELECT ' ' <= X'001F' FROM t;
+SELECT ' ' <> X'001F' FROM t;
+SELECT cast(null as string) = X'0020' FROM t;
+SELECT cast(null as string) > X'001F' FROM t;
+SELECT cast(null as string) >= X'001F' FROM t;
+SELECT cast(null as string) < X'001F' FROM t;
+SELECT cast(null as string) <= X'001F' FROM t;
+SELECT cast(null as string) <> X'001F' FROM t;
+SELECT X'0020' = ' ' FROM t;
+SELECT X'001F' > ' ' FROM t;
+SELECT X'001F' >= ' ' FROM t;
+SELECT X'001F' < ' ' FROM t;
+SELECT X'001F' <= ' ' FROM t;
+SELECT X'001F' <> ' ' FROM t;
+SELECT X'0020' = cast(null as string) FROM t;
+SELECT X'001F' > cast(null as string) FROM t;
+SELECT X'001F' >= cast(null as string) FROM t;
+SELECT X'001F' < cast(null as string) FROM t;
+SELECT X'001F' <= cast(null as string) FROM t;
+SELECT X'001F' <> cast(null as string) FROM t;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/booleanEquality.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/booleanEquality.sql
new file mode 100644
index 0000000000000..442f2355f8e3a
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/booleanEquality.sql
@@ -0,0 +1,122 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+CREATE TEMPORARY VIEW t AS SELECT 1;
+
+SELECT true = cast(1 as tinyint) FROM t;
+SELECT true = cast(1 as smallint) FROM t;
+SELECT true = cast(1 as int) FROM t;
+SELECT true = cast(1 as bigint) FROM t;
+SELECT true = cast(1 as float) FROM t;
+SELECT true = cast(1 as double) FROM t;
+SELECT true = cast(1 as decimal(10, 0)) FROM t;
+SELECT true = cast(1 as string) FROM t;
+SELECT true = cast('1' as binary) FROM t;
+SELECT true = cast(1 as boolean) FROM t;
+SELECT true = cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT true = cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT true <=> cast(1 as tinyint) FROM t;
+SELECT true <=> cast(1 as smallint) FROM t;
+SELECT true <=> cast(1 as int) FROM t;
+SELECT true <=> cast(1 as bigint) FROM t;
+SELECT true <=> cast(1 as float) FROM t;
+SELECT true <=> cast(1 as double) FROM t;
+SELECT true <=> cast(1 as decimal(10, 0)) FROM t;
+SELECT true <=> cast(1 as string) FROM t;
+SELECT true <=> cast('1' as binary) FROM t;
+SELECT true <=> cast(1 as boolean) FROM t;
+SELECT true <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT true <=> cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as tinyint) = true FROM t;
+SELECT cast(1 as smallint) = true FROM t;
+SELECT cast(1 as int) = true FROM t;
+SELECT cast(1 as bigint) = true FROM t;
+SELECT cast(1 as float) = true FROM t;
+SELECT cast(1 as double) = true FROM t;
+SELECT cast(1 as decimal(10, 0)) = true FROM t;
+SELECT cast(1 as string) = true FROM t;
+SELECT cast('1' as binary) = true FROM t;
+SELECT cast(1 as boolean) = true FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) = true FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) = true FROM t;
+
+SELECT cast(1 as tinyint) <=> true FROM t;
+SELECT cast(1 as smallint) <=> true FROM t;
+SELECT cast(1 as int) <=> true FROM t;
+SELECT cast(1 as bigint) <=> true FROM t;
+SELECT cast(1 as float) <=> true FROM t;
+SELECT cast(1 as double) <=> true FROM t;
+SELECT cast(1 as decimal(10, 0)) <=> true FROM t;
+SELECT cast(1 as string) <=> true FROM t;
+SELECT cast('1' as binary) <=> true FROM t;
+SELECT cast(1 as boolean) <=> true FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> true FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) <=> true FROM t;
+
+SELECT false = cast(0 as tinyint) FROM t;
+SELECT false = cast(0 as smallint) FROM t;
+SELECT false = cast(0 as int) FROM t;
+SELECT false = cast(0 as bigint) FROM t;
+SELECT false = cast(0 as float) FROM t;
+SELECT false = cast(0 as double) FROM t;
+SELECT false = cast(0 as decimal(10, 0)) FROM t;
+SELECT false = cast(0 as string) FROM t;
+SELECT false = cast('0' as binary) FROM t;
+SELECT false = cast(0 as boolean) FROM t;
+SELECT false = cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT false = cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT false <=> cast(0 as tinyint) FROM t;
+SELECT false <=> cast(0 as smallint) FROM t;
+SELECT false <=> cast(0 as int) FROM t;
+SELECT false <=> cast(0 as bigint) FROM t;
+SELECT false <=> cast(0 as float) FROM t;
+SELECT false <=> cast(0 as double) FROM t;
+SELECT false <=> cast(0 as decimal(10, 0)) FROM t;
+SELECT false <=> cast(0 as string) FROM t;
+SELECT false <=> cast('0' as binary) FROM t;
+SELECT false <=> cast(0 as boolean) FROM t;
+SELECT false <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT false <=> cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(0 as tinyint) = false FROM t;
+SELECT cast(0 as smallint) = false FROM t;
+SELECT cast(0 as int) = false FROM t;
+SELECT cast(0 as bigint) = false FROM t;
+SELECT cast(0 as float) = false FROM t;
+SELECT cast(0 as double) = false FROM t;
+SELECT cast(0 as decimal(10, 0)) = false FROM t;
+SELECT cast(0 as string) = false FROM t;
+SELECT cast('0' as binary) = false FROM t;
+SELECT cast(0 as boolean) = false FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) = false FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) = false FROM t;
+
+SELECT cast(0 as tinyint) <=> false FROM t;
+SELECT cast(0 as smallint) <=> false FROM t;
+SELECT cast(0 as int) <=> false FROM t;
+SELECT cast(0 as bigint) <=> false FROM t;
+SELECT cast(0 as float) <=> false FROM t;
+SELECT cast(0 as double) <=> false FROM t;
+SELECT cast(0 as decimal(10, 0)) <=> false FROM t;
+SELECT cast(0 as string) <=> false FROM t;
+SELECT cast('0' as binary) <=> false FROM t;
+SELECT cast(0 as boolean) <=> false FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> false FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) <=> false FROM t;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/caseWhenCoercion.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/caseWhenCoercion.sql
new file mode 100644
index 0000000000000..a780529fdca8c
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/caseWhenCoercion.sql
@@ -0,0 +1,174 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+CREATE TEMPORARY VIEW t AS SELECT 1;
+
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as tinyint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as smallint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as int) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as bigint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as float) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as double) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as decimal(10, 0)) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as string) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast('2' as binary) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as boolean) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast('2017-12-11 09:30:00' as date) END FROM t;
+
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as tinyint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as smallint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as int) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as bigint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as float) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as double) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as decimal(10, 0)) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as string) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast('2' as binary) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as boolean) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast('2017-12-11 09:30:00' as date) END FROM t;
+
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as tinyint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as smallint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as int) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as bigint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as float) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as double) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as decimal(10, 0)) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as string) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast('2' as binary) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as boolean) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast('2017-12-11 09:30:00' as date) END FROM t;
+
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as tinyint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as smallint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as int) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as bigint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as float) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as double) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as decimal(10, 0)) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as string) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast('2' as binary) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as boolean) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast('2017-12-11 09:30:00' as date) END FROM t;
+
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as tinyint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as smallint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as int) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as bigint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as float) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as double) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as decimal(10, 0)) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as string) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast('2' as binary) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as boolean) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast('2017-12-11 09:30:00' as date) END FROM t;
+
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as tinyint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as smallint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as int) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as bigint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as float) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as double) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as decimal(10, 0)) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as string) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast('2' as binary) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as boolean) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast('2017-12-11 09:30:00' as date) END FROM t;
+
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as tinyint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as smallint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as int) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as bigint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as float) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as double) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as decimal(10, 0)) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as string) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast('2' as binary) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as boolean) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast('2017-12-11 09:30:00' as date) END FROM t;
+
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as tinyint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as smallint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as int) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as bigint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as float) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as double) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as decimal(10, 0)) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as string) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast('2' as binary) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as boolean) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast('2017-12-11 09:30:00' as date) END FROM t;
+
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as tinyint) END FROM t;
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as smallint) END FROM t;
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as int) END FROM t;
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as bigint) END FROM t;
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as float) END FROM t;
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as double) END FROM t;
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as decimal(10, 0)) END FROM t;
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as string) END FROM t;
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast('2' as binary) END FROM t;
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as boolean) END FROM t;
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t;
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast('2017-12-11 09:30:00' as date) END FROM t;
+
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as tinyint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as smallint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as int) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as bigint) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as float) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as double) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as decimal(10, 0)) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as string) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast('2' as binary) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as boolean) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t;
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast('2017-12-11 09:30:00' as date) END FROM t;
+
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as tinyint) END FROM t;
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as smallint) END FROM t;
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as int) END FROM t;
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as bigint) END FROM t;
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as float) END FROM t;
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as double) END FROM t;
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as decimal(10, 0)) END FROM t;
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as string) END FROM t;
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast('2' as binary) END FROM t;
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as boolean) END FROM t;
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t;
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast('2017-12-11 09:30:00' as date) END FROM t;
+
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as tinyint) END FROM t;
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as smallint) END FROM t;
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as int) END FROM t;
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as bigint) END FROM t;
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as float) END FROM t;
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as double) END FROM t;
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as decimal(10, 0)) END FROM t;
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as string) END FROM t;
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast('2' as binary) END FROM t;
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as boolean) END FROM t;
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t;
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast('2017-12-11 09:30:00' as date) END FROM t;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql
new file mode 100644
index 0000000000000..0beebec5702fd
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql
@@ -0,0 +1,93 @@
+-- Concatenate mixed inputs (output type is string)
+SELECT (col1 || col2 || col3) col
+FROM (
+ SELECT
+ id col1,
+ string(id + 1) col2,
+ encode(string(id + 2), 'utf-8') col3
+ FROM range(10)
+);
+
+SELECT ((col1 || col2) || (col3 || col4) || col5) col
+FROM (
+ SELECT
+ 'prefix_' col1,
+ id col2,
+ string(id + 1) col3,
+ encode(string(id + 2), 'utf-8') col4,
+ CAST(id AS DOUBLE) col5
+ FROM range(10)
+);
+
+SELECT ((col1 || col2) || (col3 || col4)) col
+FROM (
+ SELECT
+ string(id) col1,
+ string(id + 1) col2,
+ encode(string(id + 2), 'utf-8') col3,
+ encode(string(id + 3), 'utf-8') col4
+ FROM range(10)
+);
+
+-- turn on concatBinaryAsString
+set spark.sql.function.concatBinaryAsString=true;
+
+SELECT (col1 || col2) col
+FROM (
+ SELECT
+ encode(string(id), 'utf-8') col1,
+ encode(string(id + 1), 'utf-8') col2
+ FROM range(10)
+);
+
+SELECT (col1 || col2 || col3 || col4) col
+FROM (
+ SELECT
+ encode(string(id), 'utf-8') col1,
+ encode(string(id + 1), 'utf-8') col2,
+ encode(string(id + 2), 'utf-8') col3,
+ encode(string(id + 3), 'utf-8') col4
+ FROM range(10)
+);
+
+SELECT ((col1 || col2) || (col3 || col4)) col
+FROM (
+ SELECT
+ encode(string(id), 'utf-8') col1,
+ encode(string(id + 1), 'utf-8') col2,
+ encode(string(id + 2), 'utf-8') col3,
+ encode(string(id + 3), 'utf-8') col4
+ FROM range(10)
+);
+
+-- turn off concatBinaryAsString
+set spark.sql.function.concatBinaryAsString=false;
+
+-- Concatenate binary inputs (output type is binary)
+SELECT (col1 || col2) col
+FROM (
+ SELECT
+ encode(string(id), 'utf-8') col1,
+ encode(string(id + 1), 'utf-8') col2
+ FROM range(10)
+);
+
+SELECT (col1 || col2 || col3 || col4) col
+FROM (
+ SELECT
+ encode(string(id), 'utf-8') col1,
+ encode(string(id + 1), 'utf-8') col2,
+ encode(string(id + 2), 'utf-8') col3,
+ encode(string(id + 3), 'utf-8') col4
+ FROM range(10)
+);
+
+SELECT ((col1 || col2) || (col3 || col4)) col
+FROM (
+ SELECT
+ encode(string(id), 'utf-8') col1,
+ encode(string(id + 1), 'utf-8') col2,
+ encode(string(id + 2), 'utf-8') col3,
+ encode(string(id + 3), 'utf-8') col4
+ FROM range(10)
+);
diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/dateTimeOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/dateTimeOperations.sql
new file mode 100644
index 0000000000000..1e98221867965
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/dateTimeOperations.sql
@@ -0,0 +1,60 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+CREATE TEMPORARY VIEW t AS SELECT 1;
+
+select cast(1 as tinyint) + interval 2 day;
+select cast(1 as smallint) + interval 2 day;
+select cast(1 as int) + interval 2 day;
+select cast(1 as bigint) + interval 2 day;
+select cast(1 as float) + interval 2 day;
+select cast(1 as double) + interval 2 day;
+select cast(1 as decimal(10, 0)) + interval 2 day;
+select cast('2017-12-11' as string) + interval 2 day;
+select cast('2017-12-11 09:30:00' as string) + interval 2 day;
+select cast('1' as binary) + interval 2 day;
+select cast(1 as boolean) + interval 2 day;
+select cast('2017-12-11 09:30:00.0' as timestamp) + interval 2 day;
+select cast('2017-12-11 09:30:00' as date) + interval 2 day;
+
+select interval 2 day + cast(1 as tinyint);
+select interval 2 day + cast(1 as smallint);
+select interval 2 day + cast(1 as int);
+select interval 2 day + cast(1 as bigint);
+select interval 2 day + cast(1 as float);
+select interval 2 day + cast(1 as double);
+select interval 2 day + cast(1 as decimal(10, 0));
+select interval 2 day + cast('2017-12-11' as string);
+select interval 2 day + cast('2017-12-11 09:30:00' as string);
+select interval 2 day + cast('1' as binary);
+select interval 2 day + cast(1 as boolean);
+select interval 2 day + cast('2017-12-11 09:30:00.0' as timestamp);
+select interval 2 day + cast('2017-12-11 09:30:00' as date);
+
+select cast(1 as tinyint) - interval 2 day;
+select cast(1 as smallint) - interval 2 day;
+select cast(1 as int) - interval 2 day;
+select cast(1 as bigint) - interval 2 day;
+select cast(1 as float) - interval 2 day;
+select cast(1 as double) - interval 2 day;
+select cast(1 as decimal(10, 0)) - interval 2 day;
+select cast('2017-12-11' as string) - interval 2 day;
+select cast('2017-12-11 09:30:00' as string) - interval 2 day;
+select cast('1' as binary) - interval 2 day;
+select cast(1 as boolean) - interval 2 day;
+select cast('2017-12-11 09:30:00.0' as timestamp) - interval 2 day;
+select cast('2017-12-11 09:30:00' as date) - interval 2 day;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql
new file mode 100644
index 0000000000000..c8e108ac2c45e
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql
@@ -0,0 +1,33 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+CREATE TEMPORARY VIEW t AS SELECT 1.0 as a, 0.0 as b;
+
+-- division, remainder and pmod by 0 return NULL
+select a / b from t;
+select a % b from t;
+select pmod(a, b) from t;
+
+-- arithmetic operations causing an overflow return NULL
+select (5e36 + 0.1) + 5e36;
+select (-4e36 - 0.1) - 7e36;
+select 12345678901234567890.0 * 12345678901234567890.0;
+select 1e35 / 0.1;
+
+-- arithmetic operations causing a precision loss return NULL
+select 123456789123456789.1234567890 * 1.123456789123456789;
+select 0.001 / 9876543210987654321098765432109876543.2
diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalPrecision.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalPrecision.sql
new file mode 100644
index 0000000000000..8b04864b18ce3
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalPrecision.sql
@@ -0,0 +1,1448 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+CREATE TEMPORARY VIEW t AS SELECT 1;
+
+SELECT cast(1 as tinyint) + cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as tinyint) + cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as tinyint) + cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as tinyint) + cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as smallint) + cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as smallint) + cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as smallint) + cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as smallint) + cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as int) + cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as int) + cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as int) + cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as int) + cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as bigint) + cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as bigint) + cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as bigint) + cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as bigint) + cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as float) + cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as float) + cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as float) + cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as float) + cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as double) + cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as double) + cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as double) + cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as double) + cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(10, 0)) + cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) + cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) + cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) + cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('1' as binary) + cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('1' as binary) + cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('1' as binary) + cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('1' as binary) + cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) + cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) + cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) + cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) + cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017-12-11 09:30:00' as date) + cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) + cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) + cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) + cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) + cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(5, 0)) + cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(10, 0)) + cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(20, 0)) + cast(1 as tinyint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) + cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(5, 0)) + cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(10, 0)) + cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(20, 0)) + cast(1 as smallint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) + cast(1 as int) FROM t;
+SELECT cast(1 as decimal(5, 0)) + cast(1 as int) FROM t;
+SELECT cast(1 as decimal(10, 0)) + cast(1 as int) FROM t;
+SELECT cast(1 as decimal(20, 0)) + cast(1 as int) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) + cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(5, 0)) + cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(10, 0)) + cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(20, 0)) + cast(1 as bigint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) + cast(1 as float) FROM t;
+SELECT cast(1 as decimal(5, 0)) + cast(1 as float) FROM t;
+SELECT cast(1 as decimal(10, 0)) + cast(1 as float) FROM t;
+SELECT cast(1 as decimal(20, 0)) + cast(1 as float) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) + cast(1 as double) FROM t;
+SELECT cast(1 as decimal(5, 0)) + cast(1 as double) FROM t;
+SELECT cast(1 as decimal(10, 0)) + cast(1 as double) FROM t;
+SELECT cast(1 as decimal(20, 0)) + cast(1 as double) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) + cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(5, 0)) + cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) + cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(20, 0)) + cast(1 as decimal(10, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) + cast(1 as string) FROM t;
+SELECT cast(1 as decimal(5, 0)) + cast(1 as string) FROM t;
+SELECT cast(1 as decimal(10, 0)) + cast(1 as string) FROM t;
+SELECT cast(1 as decimal(20, 0)) + cast(1 as string) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) + cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(5, 0)) + cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(10, 0)) + cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(20, 0)) + cast('1' as binary) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) + cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(5, 0)) + cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(10, 0)) + cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(20, 0)) + cast(1 as boolean) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) + cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(5, 0)) + cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(10, 0)) + cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(20, 0)) + cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) + cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(5, 0)) + cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(10, 0)) + cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(20, 0)) + cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as tinyint) - cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as tinyint) - cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as tinyint) - cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as tinyint) - cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as smallint) - cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as smallint) - cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as smallint) - cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as smallint) - cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as int) - cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as int) - cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as int) - cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as int) - cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as bigint) - cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as bigint) - cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as bigint) - cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as bigint) - cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as float) - cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as float) - cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as float) - cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as float) - cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as double) - cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as double) - cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as double) - cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as double) - cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(10, 0)) - cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) - cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) - cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) - cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('1' as binary) - cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('1' as binary) - cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('1' as binary) - cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('1' as binary) - cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) - cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) - cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) - cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) - cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017-12-11 09:30:00' as date) - cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) - cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) - cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) - cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) - cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(5, 0)) - cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(10, 0)) - cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(20, 0)) - cast(1 as tinyint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) - cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(5, 0)) - cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(10, 0)) - cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(20, 0)) - cast(1 as smallint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) - cast(1 as int) FROM t;
+SELECT cast(1 as decimal(5, 0)) - cast(1 as int) FROM t;
+SELECT cast(1 as decimal(10, 0)) - cast(1 as int) FROM t;
+SELECT cast(1 as decimal(20, 0)) - cast(1 as int) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) - cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(5, 0)) - cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(10, 0)) - cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(20, 0)) - cast(1 as bigint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) - cast(1 as float) FROM t;
+SELECT cast(1 as decimal(5, 0)) - cast(1 as float) FROM t;
+SELECT cast(1 as decimal(10, 0)) - cast(1 as float) FROM t;
+SELECT cast(1 as decimal(20, 0)) - cast(1 as float) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) - cast(1 as double) FROM t;
+SELECT cast(1 as decimal(5, 0)) - cast(1 as double) FROM t;
+SELECT cast(1 as decimal(10, 0)) - cast(1 as double) FROM t;
+SELECT cast(1 as decimal(20, 0)) - cast(1 as double) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) - cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(5, 0)) - cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) - cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(20, 0)) - cast(1 as decimal(10, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) - cast(1 as string) FROM t;
+SELECT cast(1 as decimal(5, 0)) - cast(1 as string) FROM t;
+SELECT cast(1 as decimal(10, 0)) - cast(1 as string) FROM t;
+SELECT cast(1 as decimal(20, 0)) - cast(1 as string) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) - cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(5, 0)) - cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(10, 0)) - cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(20, 0)) - cast('1' as binary) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) - cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(5, 0)) - cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(10, 0)) - cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(20, 0)) - cast(1 as boolean) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) - cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(5, 0)) - cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(10, 0)) - cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(20, 0)) - cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) - cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(5, 0)) - cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(10, 0)) - cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(20, 0)) - cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as tinyint) * cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as tinyint) * cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as tinyint) * cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as tinyint) * cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as smallint) * cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as smallint) * cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as smallint) * cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as smallint) * cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as int) * cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as int) * cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as int) * cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as int) * cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as bigint) * cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as bigint) * cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as bigint) * cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as bigint) * cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as float) * cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as float) * cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as float) * cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as float) * cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as double) * cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as double) * cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as double) * cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as double) * cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(10, 0)) * cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) * cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) * cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) * cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('1' as binary) * cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('1' as binary) * cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('1' as binary) * cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('1' as binary) * cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017*12*11 09:30:00.0' as timestamp) * cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017*12*11 09:30:00.0' as timestamp) * cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017*12*11 09:30:00.0' as timestamp) * cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017*12*11 09:30:00.0' as timestamp) * cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017*12*11 09:30:00' as date) * cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017*12*11 09:30:00' as date) * cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017*12*11 09:30:00' as date) * cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017*12*11 09:30:00' as date) * cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) * cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(5, 0)) * cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(10, 0)) * cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(20, 0)) * cast(1 as tinyint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) * cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(5, 0)) * cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(10, 0)) * cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(20, 0)) * cast(1 as smallint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) * cast(1 as int) FROM t;
+SELECT cast(1 as decimal(5, 0)) * cast(1 as int) FROM t;
+SELECT cast(1 as decimal(10, 0)) * cast(1 as int) FROM t;
+SELECT cast(1 as decimal(20, 0)) * cast(1 as int) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) * cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(5, 0)) * cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(10, 0)) * cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(20, 0)) * cast(1 as bigint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) * cast(1 as float) FROM t;
+SELECT cast(1 as decimal(5, 0)) * cast(1 as float) FROM t;
+SELECT cast(1 as decimal(10, 0)) * cast(1 as float) FROM t;
+SELECT cast(1 as decimal(20, 0)) * cast(1 as float) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) * cast(1 as double) FROM t;
+SELECT cast(1 as decimal(5, 0)) * cast(1 as double) FROM t;
+SELECT cast(1 as decimal(10, 0)) * cast(1 as double) FROM t;
+SELECT cast(1 as decimal(20, 0)) * cast(1 as double) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) * cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(5, 0)) * cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) * cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(20, 0)) * cast(1 as decimal(10, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) * cast(1 as string) FROM t;
+SELECT cast(1 as decimal(5, 0)) * cast(1 as string) FROM t;
+SELECT cast(1 as decimal(10, 0)) * cast(1 as string) FROM t;
+SELECT cast(1 as decimal(20, 0)) * cast(1 as string) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) * cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(5, 0)) * cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(10, 0)) * cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(20, 0)) * cast('1' as binary) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) * cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(5, 0)) * cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(10, 0)) * cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(20, 0)) * cast(1 as boolean) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) * cast('2017*12*11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(5, 0)) * cast('2017*12*11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(10, 0)) * cast('2017*12*11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(20, 0)) * cast('2017*12*11 09:30:00.0' as timestamp) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) * cast('2017*12*11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(5, 0)) * cast('2017*12*11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(10, 0)) * cast('2017*12*11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(20, 0)) * cast('2017*12*11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as tinyint) / cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as tinyint) / cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as tinyint) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as tinyint) / cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as smallint) / cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as smallint) / cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as smallint) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as smallint) / cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as int) / cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as int) / cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as int) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as int) / cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as bigint) / cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as bigint) / cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as bigint) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as bigint) / cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as float) / cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as float) / cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as float) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as float) / cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as double) / cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as double) / cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as double) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as double) / cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('1' as binary) / cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('1' as binary) / cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('1' as binary) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('1' as binary) / cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017/12/11 09:30:00.0' as timestamp) / cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017/12/11 09:30:00.0' as timestamp) / cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017/12/11 09:30:00.0' as timestamp) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017/12/11 09:30:00.0' as timestamp) / cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017/12/11 09:30:00' as date) / cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017/12/11 09:30:00' as date) / cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017/12/11 09:30:00' as date) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017/12/11 09:30:00' as date) / cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) / cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(5, 0)) / cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(20, 0)) / cast(1 as tinyint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) / cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(5, 0)) / cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(20, 0)) / cast(1 as smallint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) / cast(1 as int) FROM t;
+SELECT cast(1 as decimal(5, 0)) / cast(1 as int) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast(1 as int) FROM t;
+SELECT cast(1 as decimal(20, 0)) / cast(1 as int) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) / cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(5, 0)) / cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(20, 0)) / cast(1 as bigint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) / cast(1 as float) FROM t;
+SELECT cast(1 as decimal(5, 0)) / cast(1 as float) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast(1 as float) FROM t;
+SELECT cast(1 as decimal(20, 0)) / cast(1 as float) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) / cast(1 as double) FROM t;
+SELECT cast(1 as decimal(5, 0)) / cast(1 as double) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast(1 as double) FROM t;
+SELECT cast(1 as decimal(20, 0)) / cast(1 as double) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(5, 0)) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(20, 0)) / cast(1 as decimal(10, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) / cast(1 as string) FROM t;
+SELECT cast(1 as decimal(5, 0)) / cast(1 as string) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast(1 as string) FROM t;
+SELECT cast(1 as decimal(20, 0)) / cast(1 as string) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) / cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(5, 0)) / cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(20, 0)) / cast('1' as binary) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) / cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(5, 0)) / cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(20, 0)) / cast(1 as boolean) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) / cast('2017/12/11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(5, 0)) / cast('2017/12/11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast('2017/12/11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(20, 0)) / cast('2017/12/11 09:30:00.0' as timestamp) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) / cast('2017/12/11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(5, 0)) / cast('2017/12/11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast('2017/12/11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(20, 0)) / cast('2017/12/11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as tinyint) % cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as tinyint) % cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as tinyint) % cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as tinyint) % cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as smallint) % cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as smallint) % cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as smallint) % cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as smallint) % cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as int) % cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as int) % cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as int) % cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as int) % cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as bigint) % cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as bigint) % cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as bigint) % cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as bigint) % cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as float) % cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as float) % cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as float) % cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as float) % cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as double) % cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as double) % cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as double) % cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as double) % cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(10, 0)) % cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) % cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) % cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) % cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('1' as binary) % cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('1' as binary) % cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('1' as binary) % cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('1' as binary) % cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) % cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) % cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) % cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) % cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017-12-11 09:30:00' as date) % cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) % cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) % cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) % cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) % cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(5, 0)) % cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(10, 0)) % cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(20, 0)) % cast(1 as tinyint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) % cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(5, 0)) % cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(10, 0)) % cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(20, 0)) % cast(1 as smallint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) % cast(1 as int) FROM t;
+SELECT cast(1 as decimal(5, 0)) % cast(1 as int) FROM t;
+SELECT cast(1 as decimal(10, 0)) % cast(1 as int) FROM t;
+SELECT cast(1 as decimal(20, 0)) % cast(1 as int) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) % cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(5, 0)) % cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(10, 0)) % cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(20, 0)) % cast(1 as bigint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) % cast(1 as float) FROM t;
+SELECT cast(1 as decimal(5, 0)) % cast(1 as float) FROM t;
+SELECT cast(1 as decimal(10, 0)) % cast(1 as float) FROM t;
+SELECT cast(1 as decimal(20, 0)) % cast(1 as float) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) % cast(1 as double) FROM t;
+SELECT cast(1 as decimal(5, 0)) % cast(1 as double) FROM t;
+SELECT cast(1 as decimal(10, 0)) % cast(1 as double) FROM t;
+SELECT cast(1 as decimal(20, 0)) % cast(1 as double) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) % cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(5, 0)) % cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) % cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(20, 0)) % cast(1 as decimal(10, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) % cast(1 as string) FROM t;
+SELECT cast(1 as decimal(5, 0)) % cast(1 as string) FROM t;
+SELECT cast(1 as decimal(10, 0)) % cast(1 as string) FROM t;
+SELECT cast(1 as decimal(20, 0)) % cast(1 as string) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) % cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(5, 0)) % cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(10, 0)) % cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(20, 0)) % cast('1' as binary) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) % cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(5, 0)) % cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(10, 0)) % cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(20, 0)) % cast(1 as boolean) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) % cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(5, 0)) % cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(10, 0)) % cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(20, 0)) % cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) % cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(5, 0)) % cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(10, 0)) % cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(20, 0)) % cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT pmod(cast(1 as tinyint), cast(1 as decimal(3, 0))) FROM t;
+SELECT pmod(cast(1 as tinyint), cast(1 as decimal(5, 0))) FROM t;
+SELECT pmod(cast(1 as tinyint), cast(1 as decimal(10, 0))) FROM t;
+SELECT pmod(cast(1 as tinyint), cast(1 as decimal(20, 0))) FROM t;
+
+SELECT pmod(cast(1 as smallint), cast(1 as decimal(3, 0))) FROM t;
+SELECT pmod(cast(1 as smallint), cast(1 as decimal(5, 0))) FROM t;
+SELECT pmod(cast(1 as smallint), cast(1 as decimal(10, 0))) FROM t;
+SELECT pmod(cast(1 as smallint), cast(1 as decimal(20, 0))) FROM t;
+
+SELECT pmod(cast(1 as int), cast(1 as decimal(3, 0))) FROM t;
+SELECT pmod(cast(1 as int), cast(1 as decimal(5, 0))) FROM t;
+SELECT pmod(cast(1 as int), cast(1 as decimal(10, 0))) FROM t;
+SELECT pmod(cast(1 as int), cast(1 as decimal(20, 0))) FROM t;
+
+SELECT pmod(cast(1 as bigint), cast(1 as decimal(3, 0))) FROM t;
+SELECT pmod(cast(1 as bigint), cast(1 as decimal(5, 0))) FROM t;
+SELECT pmod(cast(1 as bigint), cast(1 as decimal(10, 0))) FROM t;
+SELECT pmod(cast(1 as bigint), cast(1 as decimal(20, 0))) FROM t;
+
+SELECT pmod(cast(1 as float), cast(1 as decimal(3, 0))) FROM t;
+SELECT pmod(cast(1 as float), cast(1 as decimal(5, 0))) FROM t;
+SELECT pmod(cast(1 as float), cast(1 as decimal(10, 0))) FROM t;
+SELECT pmod(cast(1 as float), cast(1 as decimal(20, 0))) FROM t;
+
+SELECT pmod(cast(1 as double), cast(1 as decimal(3, 0))) FROM t;
+SELECT pmod(cast(1 as double), cast(1 as decimal(5, 0))) FROM t;
+SELECT pmod(cast(1 as double), cast(1 as decimal(10, 0))) FROM t;
+SELECT pmod(cast(1 as double), cast(1 as decimal(20, 0))) FROM t;
+
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as decimal(3, 0))) FROM t;
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as decimal(5, 0))) FROM t;
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as decimal(10, 0))) FROM t;
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as decimal(20, 0))) FROM t;
+
+SELECT pmod(cast('1' as binary), cast(1 as decimal(3, 0))) FROM t;
+SELECT pmod(cast('1' as binary), cast(1 as decimal(5, 0))) FROM t;
+SELECT pmod(cast('1' as binary), cast(1 as decimal(10, 0))) FROM t;
+SELECT pmod(cast('1' as binary), cast(1 as decimal(20, 0))) FROM t;
+
+SELECT pmod(cast('2017-12-11 09:30:00.0' as timestamp), cast(1 as decimal(3, 0))) FROM t;
+SELECT pmod(cast('2017-12-11 09:30:00.0' as timestamp), cast(1 as decimal(5, 0))) FROM t;
+SELECT pmod(cast('2017-12-11 09:30:00.0' as timestamp), cast(1 as decimal(10, 0))) FROM t;
+SELECT pmod(cast('2017-12-11 09:30:00.0' as timestamp), cast(1 as decimal(20, 0))) FROM t;
+
+SELECT pmod(cast('2017-12-11 09:30:00' as date), cast(1 as decimal(3, 0))) FROM t;
+SELECT pmod(cast('2017-12-11 09:30:00' as date), cast(1 as decimal(5, 0))) FROM t;
+SELECT pmod(cast('2017-12-11 09:30:00' as date), cast(1 as decimal(10, 0))) FROM t;
+SELECT pmod(cast('2017-12-11 09:30:00' as date), cast(1 as decimal(20, 0))) FROM t;
+
+SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as tinyint)) FROM t;
+SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as tinyint)) FROM t;
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as tinyint)) FROM t;
+SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as tinyint)) FROM t;
+
+SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as smallint)) FROM t;
+SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as smallint)) FROM t;
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as smallint)) FROM t;
+SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as smallint)) FROM t;
+
+SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as int)) FROM t;
+SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as int)) FROM t;
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as int)) FROM t;
+SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as int)) FROM t;
+
+SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as bigint)) FROM t;
+SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as bigint)) FROM t;
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as bigint)) FROM t;
+SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as bigint)) FROM t;
+
+SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as float)) FROM t;
+SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as float)) FROM t;
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as float)) FROM t;
+SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as float)) FROM t;
+
+SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as double)) FROM t;
+SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as double)) FROM t;
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as double)) FROM t;
+SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as double)) FROM t;
+
+SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as decimal(10, 0))) FROM t;
+SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as decimal(10, 0))) FROM t;
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as decimal(10, 0))) FROM t;
+SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as decimal(10, 0))) FROM t;
+
+SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as string)) FROM t;
+SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as string)) FROM t;
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as string)) FROM t;
+SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as string)) FROM t;
+
+SELECT pmod(cast(1 as decimal(3, 0)) , cast('1' as binary)) FROM t;
+SELECT pmod(cast(1 as decimal(5, 0)) , cast('1' as binary)) FROM t;
+SELECT pmod(cast(1 as decimal(10, 0)), cast('1' as binary)) FROM t;
+SELECT pmod(cast(1 as decimal(20, 0)), cast('1' as binary)) FROM t;
+
+SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as boolean)) FROM t;
+SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as boolean)) FROM t;
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as boolean)) FROM t;
+SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as boolean)) FROM t;
+
+SELECT pmod(cast(1 as decimal(3, 0)) , cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT pmod(cast(1 as decimal(5, 0)) , cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT pmod(cast(1 as decimal(10, 0)), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT pmod(cast(1 as decimal(20, 0)), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+
+SELECT pmod(cast(1 as decimal(3, 0)) , cast('2017-12-11 09:30:00' as date)) FROM t;
+SELECT pmod(cast(1 as decimal(5, 0)) , cast('2017-12-11 09:30:00' as date)) FROM t;
+SELECT pmod(cast(1 as decimal(10, 0)), cast('2017-12-11 09:30:00' as date)) FROM t;
+SELECT pmod(cast(1 as decimal(20, 0)), cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast(1 as tinyint) = cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as tinyint) = cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as tinyint) = cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as tinyint) = cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as smallint) = cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as smallint) = cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as smallint) = cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as smallint) = cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as int) = cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as int) = cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as int) = cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as int) = cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as bigint) = cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as bigint) = cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as bigint) = cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as bigint) = cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as float) = cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as float) = cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as float) = cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as float) = cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as double) = cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as double) = cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as double) = cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as double) = cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(10, 0)) = cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) = cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) = cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) = cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('1' as binary) = cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('1' as binary) = cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('1' as binary) = cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('1' as binary) = cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) = cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) = cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) = cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) = cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017-12-11 09:30:00' as date) = cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) = cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) = cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) = cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) = cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(5, 0)) = cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(10, 0)) = cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(20, 0)) = cast(1 as tinyint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) = cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(5, 0)) = cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(10, 0)) = cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(20, 0)) = cast(1 as smallint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) = cast(1 as int) FROM t;
+SELECT cast(1 as decimal(5, 0)) = cast(1 as int) FROM t;
+SELECT cast(1 as decimal(10, 0)) = cast(1 as int) FROM t;
+SELECT cast(1 as decimal(20, 0)) = cast(1 as int) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) = cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(5, 0)) = cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(10, 0)) = cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(20, 0)) = cast(1 as bigint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) = cast(1 as float) FROM t;
+SELECT cast(1 as decimal(5, 0)) = cast(1 as float) FROM t;
+SELECT cast(1 as decimal(10, 0)) = cast(1 as float) FROM t;
+SELECT cast(1 as decimal(20, 0)) = cast(1 as float) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) = cast(1 as double) FROM t;
+SELECT cast(1 as decimal(5, 0)) = cast(1 as double) FROM t;
+SELECT cast(1 as decimal(10, 0)) = cast(1 as double) FROM t;
+SELECT cast(1 as decimal(20, 0)) = cast(1 as double) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) = cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(5, 0)) = cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) = cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(20, 0)) = cast(1 as decimal(10, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) = cast(1 as string) FROM t;
+SELECT cast(1 as decimal(5, 0)) = cast(1 as string) FROM t;
+SELECT cast(1 as decimal(10, 0)) = cast(1 as string) FROM t;
+SELECT cast(1 as decimal(20, 0)) = cast(1 as string) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) = cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(5, 0)) = cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(10, 0)) = cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(20, 0)) = cast('1' as binary) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) = cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(5, 0)) = cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(10, 0)) = cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(20, 0)) = cast(1 as boolean) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) = cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(5, 0)) = cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(10, 0)) = cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(20, 0)) = cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) = cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(5, 0)) = cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(10, 0)) = cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(20, 0)) = cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as tinyint) <=> cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as tinyint) <=> cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as tinyint) <=> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as tinyint) <=> cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as smallint) <=> cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as smallint) <=> cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as smallint) <=> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as smallint) <=> cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as int) <=> cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as int) <=> cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as int) <=> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as int) <=> cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as bigint) <=> cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as bigint) <=> cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as bigint) <=> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as bigint) <=> cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as float) <=> cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as float) <=> cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as float) <=> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as float) <=> cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as double) <=> cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as double) <=> cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as double) <=> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as double) <=> cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('1' as binary) <=> cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('1' as binary) <=> cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('1' as binary) <=> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('1' as binary) <=> cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017-12-11 09:30:00' as date) <=> cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) <=> cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) <=> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) <=> cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <=> cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(5, 0)) <=> cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(20, 0)) <=> cast(1 as tinyint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <=> cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(5, 0)) <=> cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(20, 0)) <=> cast(1 as smallint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <=> cast(1 as int) FROM t;
+SELECT cast(1 as decimal(5, 0)) <=> cast(1 as int) FROM t;
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as int) FROM t;
+SELECT cast(1 as decimal(20, 0)) <=> cast(1 as int) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <=> cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(5, 0)) <=> cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(20, 0)) <=> cast(1 as bigint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <=> cast(1 as float) FROM t;
+SELECT cast(1 as decimal(5, 0)) <=> cast(1 as float) FROM t;
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as float) FROM t;
+SELECT cast(1 as decimal(20, 0)) <=> cast(1 as float) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <=> cast(1 as double) FROM t;
+SELECT cast(1 as decimal(5, 0)) <=> cast(1 as double) FROM t;
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as double) FROM t;
+SELECT cast(1 as decimal(20, 0)) <=> cast(1 as double) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <=> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(5, 0)) <=> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(20, 0)) <=> cast(1 as decimal(10, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <=> cast(1 as string) FROM t;
+SELECT cast(1 as decimal(5, 0)) <=> cast(1 as string) FROM t;
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as string) FROM t;
+SELECT cast(1 as decimal(20, 0)) <=> cast(1 as string) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <=> cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(5, 0)) <=> cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(10, 0)) <=> cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(20, 0)) <=> cast('1' as binary) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <=> cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(5, 0)) <=> cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(20, 0)) <=> cast(1 as boolean) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(5, 0)) <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(10, 0)) <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(20, 0)) <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <=> cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(5, 0)) <=> cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(10, 0)) <=> cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(20, 0)) <=> cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as tinyint) < cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as tinyint) < cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as tinyint) < cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as tinyint) < cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as smallint) < cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as smallint) < cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as smallint) < cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as smallint) < cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as int) < cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as int) < cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as int) < cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as int) < cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as bigint) < cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as bigint) < cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as bigint) < cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as bigint) < cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as float) < cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as float) < cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as float) < cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as float) < cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as double) < cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as double) < cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as double) < cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as double) < cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(10, 0)) < cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) < cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) < cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) < cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('1' as binary) < cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('1' as binary) < cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('1' as binary) < cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('1' as binary) < cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) < cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) < cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) < cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) < cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017-12-11 09:30:00' as date) < cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) < cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) < cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) < cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) < cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(5, 0)) < cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(10, 0)) < cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(20, 0)) < cast(1 as tinyint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) < cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(5, 0)) < cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(10, 0)) < cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(20, 0)) < cast(1 as smallint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) < cast(1 as int) FROM t;
+SELECT cast(1 as decimal(5, 0)) < cast(1 as int) FROM t;
+SELECT cast(1 as decimal(10, 0)) < cast(1 as int) FROM t;
+SELECT cast(1 as decimal(20, 0)) < cast(1 as int) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) < cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(5, 0)) < cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(10, 0)) < cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(20, 0)) < cast(1 as bigint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) < cast(1 as float) FROM t;
+SELECT cast(1 as decimal(5, 0)) < cast(1 as float) FROM t;
+SELECT cast(1 as decimal(10, 0)) < cast(1 as float) FROM t;
+SELECT cast(1 as decimal(20, 0)) < cast(1 as float) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) < cast(1 as double) FROM t;
+SELECT cast(1 as decimal(5, 0)) < cast(1 as double) FROM t;
+SELECT cast(1 as decimal(10, 0)) < cast(1 as double) FROM t;
+SELECT cast(1 as decimal(20, 0)) < cast(1 as double) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) < cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(5, 0)) < cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) < cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(20, 0)) < cast(1 as decimal(10, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) < cast(1 as string) FROM t;
+SELECT cast(1 as decimal(5, 0)) < cast(1 as string) FROM t;
+SELECT cast(1 as decimal(10, 0)) < cast(1 as string) FROM t;
+SELECT cast(1 as decimal(20, 0)) < cast(1 as string) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) < cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(5, 0)) < cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(10, 0)) < cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(20, 0)) < cast('1' as binary) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) < cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(5, 0)) < cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(10, 0)) < cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(20, 0)) < cast(1 as boolean) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) < cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(5, 0)) < cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(10, 0)) < cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(20, 0)) < cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) < cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(5, 0)) < cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(10, 0)) < cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(20, 0)) < cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as tinyint) <= cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as tinyint) <= cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as tinyint) <= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as tinyint) <= cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as smallint) <= cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as smallint) <= cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as smallint) <= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as smallint) <= cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as int) <= cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as int) <= cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as int) <= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as int) <= cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as bigint) <= cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as bigint) <= cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as bigint) <= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as bigint) <= cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as float) <= cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as float) <= cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as float) <= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as float) <= cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as double) <= cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as double) <= cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as double) <= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as double) <= cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('1' as binary) <= cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('1' as binary) <= cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('1' as binary) <= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('1' as binary) <= cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <= cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <= cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <= cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017-12-11 09:30:00' as date) <= cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) <= cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) <= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) <= cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <= cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(5, 0)) <= cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(20, 0)) <= cast(1 as tinyint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <= cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(5, 0)) <= cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(20, 0)) <= cast(1 as smallint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <= cast(1 as int) FROM t;
+SELECT cast(1 as decimal(5, 0)) <= cast(1 as int) FROM t;
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as int) FROM t;
+SELECT cast(1 as decimal(20, 0)) <= cast(1 as int) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <= cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(5, 0)) <= cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(20, 0)) <= cast(1 as bigint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <= cast(1 as float) FROM t;
+SELECT cast(1 as decimal(5, 0)) <= cast(1 as float) FROM t;
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as float) FROM t;
+SELECT cast(1 as decimal(20, 0)) <= cast(1 as float) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <= cast(1 as double) FROM t;
+SELECT cast(1 as decimal(5, 0)) <= cast(1 as double) FROM t;
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as double) FROM t;
+SELECT cast(1 as decimal(20, 0)) <= cast(1 as double) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(5, 0)) <= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(20, 0)) <= cast(1 as decimal(10, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <= cast(1 as string) FROM t;
+SELECT cast(1 as decimal(5, 0)) <= cast(1 as string) FROM t;
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as string) FROM t;
+SELECT cast(1 as decimal(20, 0)) <= cast(1 as string) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <= cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(5, 0)) <= cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(10, 0)) <= cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(20, 0)) <= cast('1' as binary) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <= cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(5, 0)) <= cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(20, 0)) <= cast(1 as boolean) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <= cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(5, 0)) <= cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(10, 0)) <= cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(20, 0)) <= cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <= cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(5, 0)) <= cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(10, 0)) <= cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(20, 0)) <= cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as tinyint) > cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as tinyint) > cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as tinyint) > cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as tinyint) > cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as smallint) > cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as smallint) > cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as smallint) > cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as smallint) > cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as int) > cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as int) > cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as int) > cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as int) > cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as bigint) > cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as bigint) > cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as bigint) > cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as bigint) > cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as float) > cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as float) > cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as float) > cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as float) > cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as double) > cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as double) > cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as double) > cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as double) > cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(10, 0)) > cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) > cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) > cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) > cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('1' as binary) > cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('1' as binary) > cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('1' as binary) > cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('1' as binary) > cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) > cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) > cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) > cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) > cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017-12-11 09:30:00' as date) > cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) > cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) > cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) > cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) > cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(5, 0)) > cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(10, 0)) > cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(20, 0)) > cast(1 as tinyint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) > cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(5, 0)) > cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(10, 0)) > cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(20, 0)) > cast(1 as smallint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) > cast(1 as int) FROM t;
+SELECT cast(1 as decimal(5, 0)) > cast(1 as int) FROM t;
+SELECT cast(1 as decimal(10, 0)) > cast(1 as int) FROM t;
+SELECT cast(1 as decimal(20, 0)) > cast(1 as int) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) > cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(5, 0)) > cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(10, 0)) > cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(20, 0)) > cast(1 as bigint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) > cast(1 as float) FROM t;
+SELECT cast(1 as decimal(5, 0)) > cast(1 as float) FROM t;
+SELECT cast(1 as decimal(10, 0)) > cast(1 as float) FROM t;
+SELECT cast(1 as decimal(20, 0)) > cast(1 as float) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) > cast(1 as double) FROM t;
+SELECT cast(1 as decimal(5, 0)) > cast(1 as double) FROM t;
+SELECT cast(1 as decimal(10, 0)) > cast(1 as double) FROM t;
+SELECT cast(1 as decimal(20, 0)) > cast(1 as double) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) > cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(5, 0)) > cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) > cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(20, 0)) > cast(1 as decimal(10, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) > cast(1 as string) FROM t;
+SELECT cast(1 as decimal(5, 0)) > cast(1 as string) FROM t;
+SELECT cast(1 as decimal(10, 0)) > cast(1 as string) FROM t;
+SELECT cast(1 as decimal(20, 0)) > cast(1 as string) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) > cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(5, 0)) > cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(10, 0)) > cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(20, 0)) > cast('1' as binary) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) > cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(5, 0)) > cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(10, 0)) > cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(20, 0)) > cast(1 as boolean) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) > cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(5, 0)) > cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(10, 0)) > cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(20, 0)) > cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) > cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(5, 0)) > cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(10, 0)) > cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(20, 0)) > cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as tinyint) >= cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as tinyint) >= cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as tinyint) >= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as tinyint) >= cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as smallint) >= cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as smallint) >= cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as smallint) >= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as smallint) >= cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as int) >= cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as int) >= cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as int) >= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as int) >= cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as bigint) >= cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as bigint) >= cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as bigint) >= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as bigint) >= cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as float) >= cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as float) >= cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as float) >= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as float) >= cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as double) >= cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as double) >= cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as double) >= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as double) >= cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('1' as binary) >= cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('1' as binary) >= cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('1' as binary) >= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('1' as binary) >= cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) >= cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) >= cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) >= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) >= cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017-12-11 09:30:00' as date) >= cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) >= cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) >= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) >= cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) >= cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(5, 0)) >= cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(20, 0)) >= cast(1 as tinyint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) >= cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(5, 0)) >= cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(20, 0)) >= cast(1 as smallint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) >= cast(1 as int) FROM t;
+SELECT cast(1 as decimal(5, 0)) >= cast(1 as int) FROM t;
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as int) FROM t;
+SELECT cast(1 as decimal(20, 0)) >= cast(1 as int) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) >= cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(5, 0)) >= cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(20, 0)) >= cast(1 as bigint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) >= cast(1 as float) FROM t;
+SELECT cast(1 as decimal(5, 0)) >= cast(1 as float) FROM t;
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as float) FROM t;
+SELECT cast(1 as decimal(20, 0)) >= cast(1 as float) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) >= cast(1 as double) FROM t;
+SELECT cast(1 as decimal(5, 0)) >= cast(1 as double) FROM t;
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as double) FROM t;
+SELECT cast(1 as decimal(20, 0)) >= cast(1 as double) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) >= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(5, 0)) >= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(20, 0)) >= cast(1 as decimal(10, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) >= cast(1 as string) FROM t;
+SELECT cast(1 as decimal(5, 0)) >= cast(1 as string) FROM t;
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as string) FROM t;
+SELECT cast(1 as decimal(20, 0)) >= cast(1 as string) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) >= cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(5, 0)) >= cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(10, 0)) >= cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(20, 0)) >= cast('1' as binary) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) >= cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(5, 0)) >= cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(20, 0)) >= cast(1 as boolean) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) >= cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(5, 0)) >= cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(10, 0)) >= cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(20, 0)) >= cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) >= cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(5, 0)) >= cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(10, 0)) >= cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(20, 0)) >= cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as tinyint) <> cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as tinyint) <> cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as tinyint) <> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as tinyint) <> cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as smallint) <> cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as smallint) <> cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as smallint) <> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as smallint) <> cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as int) <> cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as int) <> cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as int) <> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as int) <> cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as bigint) <> cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as bigint) <> cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as bigint) <> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as bigint) <> cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as float) <> cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as float) <> cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as float) <> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as float) <> cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as double) <> cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as double) <> cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as double) <> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as double) <> cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as decimal(3, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as decimal(5, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('1' as binary) <> cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('1' as binary) <> cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('1' as binary) <> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('1' as binary) <> cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <> cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <> cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <> cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast('2017-12-11 09:30:00' as date) <> cast(1 as decimal(3, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) <> cast(1 as decimal(5, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) <> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) <> cast(1 as decimal(20, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <> cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(5, 0)) <> cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(20, 0)) <> cast(1 as tinyint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <> cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(5, 0)) <> cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(20, 0)) <> cast(1 as smallint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <> cast(1 as int) FROM t;
+SELECT cast(1 as decimal(5, 0)) <> cast(1 as int) FROM t;
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as int) FROM t;
+SELECT cast(1 as decimal(20, 0)) <> cast(1 as int) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <> cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(5, 0)) <> cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(20, 0)) <> cast(1 as bigint) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <> cast(1 as float) FROM t;
+SELECT cast(1 as decimal(5, 0)) <> cast(1 as float) FROM t;
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as float) FROM t;
+SELECT cast(1 as decimal(20, 0)) <> cast(1 as float) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <> cast(1 as double) FROM t;
+SELECT cast(1 as decimal(5, 0)) <> cast(1 as double) FROM t;
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as double) FROM t;
+SELECT cast(1 as decimal(20, 0)) <> cast(1 as double) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(5, 0)) <> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(20, 0)) <> cast(1 as decimal(10, 0)) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <> cast(1 as string) FROM t;
+SELECT cast(1 as decimal(5, 0)) <> cast(1 as string) FROM t;
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as string) FROM t;
+SELECT cast(1 as decimal(20, 0)) <> cast(1 as string) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <> cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(5, 0)) <> cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(10, 0)) <> cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(20, 0)) <> cast('1' as binary) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <> cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(5, 0)) <> cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(20, 0)) <> cast(1 as boolean) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <> cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(5, 0)) <> cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(10, 0)) <> cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(20, 0)) <> cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+
+SELECT cast(1 as decimal(3, 0)) <> cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(5, 0)) <> cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(10, 0)) <> cast('2017-12-11 09:30:00' as date) FROM t;
+SELECT cast(1 as decimal(20, 0)) <> cast('2017-12-11 09:30:00' as date) FROM t;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/division.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/division.sql
new file mode 100644
index 0000000000000..d669740ddd9ca
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/division.sql
@@ -0,0 +1,174 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+CREATE TEMPORARY VIEW t AS SELECT 1;
+
+SELECT cast(1 as tinyint) / cast(1 as tinyint) FROM t;
+SELECT cast(1 as tinyint) / cast(1 as smallint) FROM t;
+SELECT cast(1 as tinyint) / cast(1 as int) FROM t;
+SELECT cast(1 as tinyint) / cast(1 as bigint) FROM t;
+SELECT cast(1 as tinyint) / cast(1 as float) FROM t;
+SELECT cast(1 as tinyint) / cast(1 as double) FROM t;
+SELECT cast(1 as tinyint) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as tinyint) / cast(1 as string) FROM t;
+SELECT cast(1 as tinyint) / cast('1' as binary) FROM t;
+SELECT cast(1 as tinyint) / cast(1 as boolean) FROM t;
+SELECT cast(1 as tinyint) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as tinyint) / cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as smallint) / cast(1 as tinyint) FROM t;
+SELECT cast(1 as smallint) / cast(1 as smallint) FROM t;
+SELECT cast(1 as smallint) / cast(1 as int) FROM t;
+SELECT cast(1 as smallint) / cast(1 as bigint) FROM t;
+SELECT cast(1 as smallint) / cast(1 as float) FROM t;
+SELECT cast(1 as smallint) / cast(1 as double) FROM t;
+SELECT cast(1 as smallint) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as smallint) / cast(1 as string) FROM t;
+SELECT cast(1 as smallint) / cast('1' as binary) FROM t;
+SELECT cast(1 as smallint) / cast(1 as boolean) FROM t;
+SELECT cast(1 as smallint) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as smallint) / cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as int) / cast(1 as tinyint) FROM t;
+SELECT cast(1 as int) / cast(1 as smallint) FROM t;
+SELECT cast(1 as int) / cast(1 as int) FROM t;
+SELECT cast(1 as int) / cast(1 as bigint) FROM t;
+SELECT cast(1 as int) / cast(1 as float) FROM t;
+SELECT cast(1 as int) / cast(1 as double) FROM t;
+SELECT cast(1 as int) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as int) / cast(1 as string) FROM t;
+SELECT cast(1 as int) / cast('1' as binary) FROM t;
+SELECT cast(1 as int) / cast(1 as boolean) FROM t;
+SELECT cast(1 as int) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as int) / cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as bigint) / cast(1 as tinyint) FROM t;
+SELECT cast(1 as bigint) / cast(1 as smallint) FROM t;
+SELECT cast(1 as bigint) / cast(1 as int) FROM t;
+SELECT cast(1 as bigint) / cast(1 as bigint) FROM t;
+SELECT cast(1 as bigint) / cast(1 as float) FROM t;
+SELECT cast(1 as bigint) / cast(1 as double) FROM t;
+SELECT cast(1 as bigint) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as bigint) / cast(1 as string) FROM t;
+SELECT cast(1 as bigint) / cast('1' as binary) FROM t;
+SELECT cast(1 as bigint) / cast(1 as boolean) FROM t;
+SELECT cast(1 as bigint) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as bigint) / cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as float) / cast(1 as tinyint) FROM t;
+SELECT cast(1 as float) / cast(1 as smallint) FROM t;
+SELECT cast(1 as float) / cast(1 as int) FROM t;
+SELECT cast(1 as float) / cast(1 as bigint) FROM t;
+SELECT cast(1 as float) / cast(1 as float) FROM t;
+SELECT cast(1 as float) / cast(1 as double) FROM t;
+SELECT cast(1 as float) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as float) / cast(1 as string) FROM t;
+SELECT cast(1 as float) / cast('1' as binary) FROM t;
+SELECT cast(1 as float) / cast(1 as boolean) FROM t;
+SELECT cast(1 as float) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as float) / cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as double) / cast(1 as tinyint) FROM t;
+SELECT cast(1 as double) / cast(1 as smallint) FROM t;
+SELECT cast(1 as double) / cast(1 as int) FROM t;
+SELECT cast(1 as double) / cast(1 as bigint) FROM t;
+SELECT cast(1 as double) / cast(1 as float) FROM t;
+SELECT cast(1 as double) / cast(1 as double) FROM t;
+SELECT cast(1 as double) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as double) / cast(1 as string) FROM t;
+SELECT cast(1 as double) / cast('1' as binary) FROM t;
+SELECT cast(1 as double) / cast(1 as boolean) FROM t;
+SELECT cast(1 as double) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as double) / cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as decimal(10, 0)) / cast(1 as tinyint) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast(1 as smallint) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast(1 as int) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast(1 as bigint) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast(1 as float) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast(1 as double) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast(1 as string) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast('1' as binary) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast(1 as boolean) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(10, 0)) / cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as string) / cast(1 as tinyint) FROM t;
+SELECT cast(1 as string) / cast(1 as smallint) FROM t;
+SELECT cast(1 as string) / cast(1 as int) FROM t;
+SELECT cast(1 as string) / cast(1 as bigint) FROM t;
+SELECT cast(1 as string) / cast(1 as float) FROM t;
+SELECT cast(1 as string) / cast(1 as double) FROM t;
+SELECT cast(1 as string) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as string) / cast(1 as string) FROM t;
+SELECT cast(1 as string) / cast('1' as binary) FROM t;
+SELECT cast(1 as string) / cast(1 as boolean) FROM t;
+SELECT cast(1 as string) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as string) / cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast('1' as binary) / cast(1 as tinyint) FROM t;
+SELECT cast('1' as binary) / cast(1 as smallint) FROM t;
+SELECT cast('1' as binary) / cast(1 as int) FROM t;
+SELECT cast('1' as binary) / cast(1 as bigint) FROM t;
+SELECT cast('1' as binary) / cast(1 as float) FROM t;
+SELECT cast('1' as binary) / cast(1 as double) FROM t;
+SELECT cast('1' as binary) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('1' as binary) / cast(1 as string) FROM t;
+SELECT cast('1' as binary) / cast('1' as binary) FROM t;
+SELECT cast('1' as binary) / cast(1 as boolean) FROM t;
+SELECT cast('1' as binary) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast('1' as binary) / cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as boolean) / cast(1 as tinyint) FROM t;
+SELECT cast(1 as boolean) / cast(1 as smallint) FROM t;
+SELECT cast(1 as boolean) / cast(1 as int) FROM t;
+SELECT cast(1 as boolean) / cast(1 as bigint) FROM t;
+SELECT cast(1 as boolean) / cast(1 as float) FROM t;
+SELECT cast(1 as boolean) / cast(1 as double) FROM t;
+SELECT cast(1 as boolean) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast(1 as boolean) / cast(1 as string) FROM t;
+SELECT cast(1 as boolean) / cast('1' as binary) FROM t;
+SELECT cast(1 as boolean) / cast(1 as boolean) FROM t;
+SELECT cast(1 as boolean) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as boolean) / cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as tinyint) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as smallint) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as int) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as bigint) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as float) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as double) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as string) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast('1' as binary) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as boolean) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as tinyint) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as smallint) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as int) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as bigint) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as float) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as double) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as string) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) / cast('1' as binary) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as boolean) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) / cast('2017-12-11 09:30:00' as date) FROM t;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql
new file mode 100644
index 0000000000000..717616f91db05
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql
@@ -0,0 +1,44 @@
+-- Mixed inputs (output type is string)
+SELECT elt(2, col1, col2, col3, col4, col5) col
+FROM (
+ SELECT
+ 'prefix_' col1,
+ id col2,
+ string(id + 1) col3,
+ encode(string(id + 2), 'utf-8') col4,
+ CAST(id AS DOUBLE) col5
+ FROM range(10)
+);
+
+SELECT elt(3, col1, col2, col3, col4) col
+FROM (
+ SELECT
+ string(id) col1,
+ string(id + 1) col2,
+ encode(string(id + 2), 'utf-8') col3,
+ encode(string(id + 3), 'utf-8') col4
+ FROM range(10)
+);
+
+-- turn on eltOutputAsString
+set spark.sql.function.eltOutputAsString=true;
+
+SELECT elt(1, col1, col2) col
+FROM (
+ SELECT
+ encode(string(id), 'utf-8') col1,
+ encode(string(id + 1), 'utf-8') col2
+ FROM range(10)
+);
+
+-- turn off eltOutputAsString
+set spark.sql.function.eltOutputAsString=false;
+
+-- Elt binary inputs (output type is binary)
+SELECT elt(2, col1, col2) col
+FROM (
+ SELECT
+ encode(string(id), 'utf-8') col1,
+ encode(string(id + 1), 'utf-8') col2
+ FROM range(10)
+);
diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/ifCoercion.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/ifCoercion.sql
new file mode 100644
index 0000000000000..42597f169daec
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/ifCoercion.sql
@@ -0,0 +1,174 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+CREATE TEMPORARY VIEW t AS SELECT 1;
+
+SELECT IF(true, cast(1 as tinyint), cast(2 as tinyint)) FROM t;
+SELECT IF(true, cast(1 as tinyint), cast(2 as smallint)) FROM t;
+SELECT IF(true, cast(1 as tinyint), cast(2 as int)) FROM t;
+SELECT IF(true, cast(1 as tinyint), cast(2 as bigint)) FROM t;
+SELECT IF(true, cast(1 as tinyint), cast(2 as float)) FROM t;
+SELECT IF(true, cast(1 as tinyint), cast(2 as double)) FROM t;
+SELECT IF(true, cast(1 as tinyint), cast(2 as decimal(10, 0))) FROM t;
+SELECT IF(true, cast(1 as tinyint), cast(2 as string)) FROM t;
+SELECT IF(true, cast(1 as tinyint), cast('2' as binary)) FROM t;
+SELECT IF(true, cast(1 as tinyint), cast(2 as boolean)) FROM t;
+SELECT IF(true, cast(1 as tinyint), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT IF(true, cast(1 as tinyint), cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT IF(true, cast(1 as smallint), cast(2 as tinyint)) FROM t;
+SELECT IF(true, cast(1 as smallint), cast(2 as smallint)) FROM t;
+SELECT IF(true, cast(1 as smallint), cast(2 as int)) FROM t;
+SELECT IF(true, cast(1 as smallint), cast(2 as bigint)) FROM t;
+SELECT IF(true, cast(1 as smallint), cast(2 as float)) FROM t;
+SELECT IF(true, cast(1 as smallint), cast(2 as double)) FROM t;
+SELECT IF(true, cast(1 as smallint), cast(2 as decimal(10, 0))) FROM t;
+SELECT IF(true, cast(1 as smallint), cast(2 as string)) FROM t;
+SELECT IF(true, cast(1 as smallint), cast('2' as binary)) FROM t;
+SELECT IF(true, cast(1 as smallint), cast(2 as boolean)) FROM t;
+SELECT IF(true, cast(1 as smallint), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT IF(true, cast(1 as smallint), cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT IF(true, cast(1 as int), cast(2 as tinyint)) FROM t;
+SELECT IF(true, cast(1 as int), cast(2 as smallint)) FROM t;
+SELECT IF(true, cast(1 as int), cast(2 as int)) FROM t;
+SELECT IF(true, cast(1 as int), cast(2 as bigint)) FROM t;
+SELECT IF(true, cast(1 as int), cast(2 as float)) FROM t;
+SELECT IF(true, cast(1 as int), cast(2 as double)) FROM t;
+SELECT IF(true, cast(1 as int), cast(2 as decimal(10, 0))) FROM t;
+SELECT IF(true, cast(1 as int), cast(2 as string)) FROM t;
+SELECT IF(true, cast(1 as int), cast('2' as binary)) FROM t;
+SELECT IF(true, cast(1 as int), cast(2 as boolean)) FROM t;
+SELECT IF(true, cast(1 as int), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT IF(true, cast(1 as int), cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT IF(true, cast(1 as bigint), cast(2 as tinyint)) FROM t;
+SELECT IF(true, cast(1 as bigint), cast(2 as smallint)) FROM t;
+SELECT IF(true, cast(1 as bigint), cast(2 as int)) FROM t;
+SELECT IF(true, cast(1 as bigint), cast(2 as bigint)) FROM t;
+SELECT IF(true, cast(1 as bigint), cast(2 as float)) FROM t;
+SELECT IF(true, cast(1 as bigint), cast(2 as double)) FROM t;
+SELECT IF(true, cast(1 as bigint), cast(2 as decimal(10, 0))) FROM t;
+SELECT IF(true, cast(1 as bigint), cast(2 as string)) FROM t;
+SELECT IF(true, cast(1 as bigint), cast('2' as binary)) FROM t;
+SELECT IF(true, cast(1 as bigint), cast(2 as boolean)) FROM t;
+SELECT IF(true, cast(1 as bigint), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT IF(true, cast(1 as bigint), cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT IF(true, cast(1 as float), cast(2 as tinyint)) FROM t;
+SELECT IF(true, cast(1 as float), cast(2 as smallint)) FROM t;
+SELECT IF(true, cast(1 as float), cast(2 as int)) FROM t;
+SELECT IF(true, cast(1 as float), cast(2 as bigint)) FROM t;
+SELECT IF(true, cast(1 as float), cast(2 as float)) FROM t;
+SELECT IF(true, cast(1 as float), cast(2 as double)) FROM t;
+SELECT IF(true, cast(1 as float), cast(2 as decimal(10, 0))) FROM t;
+SELECT IF(true, cast(1 as float), cast(2 as string)) FROM t;
+SELECT IF(true, cast(1 as float), cast('2' as binary)) FROM t;
+SELECT IF(true, cast(1 as float), cast(2 as boolean)) FROM t;
+SELECT IF(true, cast(1 as float), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT IF(true, cast(1 as float), cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT IF(true, cast(1 as double), cast(2 as tinyint)) FROM t;
+SELECT IF(true, cast(1 as double), cast(2 as smallint)) FROM t;
+SELECT IF(true, cast(1 as double), cast(2 as int)) FROM t;
+SELECT IF(true, cast(1 as double), cast(2 as bigint)) FROM t;
+SELECT IF(true, cast(1 as double), cast(2 as float)) FROM t;
+SELECT IF(true, cast(1 as double), cast(2 as double)) FROM t;
+SELECT IF(true, cast(1 as double), cast(2 as decimal(10, 0))) FROM t;
+SELECT IF(true, cast(1 as double), cast(2 as string)) FROM t;
+SELECT IF(true, cast(1 as double), cast('2' as binary)) FROM t;
+SELECT IF(true, cast(1 as double), cast(2 as boolean)) FROM t;
+SELECT IF(true, cast(1 as double), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT IF(true, cast(1 as double), cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as tinyint)) FROM t;
+SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as smallint)) FROM t;
+SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as int)) FROM t;
+SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as bigint)) FROM t;
+SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as float)) FROM t;
+SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as double)) FROM t;
+SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as decimal(10, 0))) FROM t;
+SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as string)) FROM t;
+SELECT IF(true, cast(1 as decimal(10, 0)), cast('2' as binary)) FROM t;
+SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as boolean)) FROM t;
+SELECT IF(true, cast(1 as decimal(10, 0)), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT IF(true, cast(1 as decimal(10, 0)), cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT IF(true, cast(1 as string), cast(2 as tinyint)) FROM t;
+SELECT IF(true, cast(1 as string), cast(2 as smallint)) FROM t;
+SELECT IF(true, cast(1 as string), cast(2 as int)) FROM t;
+SELECT IF(true, cast(1 as string), cast(2 as bigint)) FROM t;
+SELECT IF(true, cast(1 as string), cast(2 as float)) FROM t;
+SELECT IF(true, cast(1 as string), cast(2 as double)) FROM t;
+SELECT IF(true, cast(1 as string), cast(2 as decimal(10, 0))) FROM t;
+SELECT IF(true, cast(1 as string), cast(2 as string)) FROM t;
+SELECT IF(true, cast(1 as string), cast('2' as binary)) FROM t;
+SELECT IF(true, cast(1 as string), cast(2 as boolean)) FROM t;
+SELECT IF(true, cast(1 as string), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT IF(true, cast(1 as string), cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT IF(true, cast('1' as binary), cast(2 as tinyint)) FROM t;
+SELECT IF(true, cast('1' as binary), cast(2 as smallint)) FROM t;
+SELECT IF(true, cast('1' as binary), cast(2 as int)) FROM t;
+SELECT IF(true, cast('1' as binary), cast(2 as bigint)) FROM t;
+SELECT IF(true, cast('1' as binary), cast(2 as float)) FROM t;
+SELECT IF(true, cast('1' as binary), cast(2 as double)) FROM t;
+SELECT IF(true, cast('1' as binary), cast(2 as decimal(10, 0))) FROM t;
+SELECT IF(true, cast('1' as binary), cast(2 as string)) FROM t;
+SELECT IF(true, cast('1' as binary), cast('2' as binary)) FROM t;
+SELECT IF(true, cast('1' as binary), cast(2 as boolean)) FROM t;
+SELECT IF(true, cast('1' as binary), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT IF(true, cast('1' as binary), cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT IF(true, cast(1 as boolean), cast(2 as tinyint)) FROM t;
+SELECT IF(true, cast(1 as boolean), cast(2 as smallint)) FROM t;
+SELECT IF(true, cast(1 as boolean), cast(2 as int)) FROM t;
+SELECT IF(true, cast(1 as boolean), cast(2 as bigint)) FROM t;
+SELECT IF(true, cast(1 as boolean), cast(2 as float)) FROM t;
+SELECT IF(true, cast(1 as boolean), cast(2 as double)) FROM t;
+SELECT IF(true, cast(1 as boolean), cast(2 as decimal(10, 0))) FROM t;
+SELECT IF(true, cast(1 as boolean), cast(2 as string)) FROM t;
+SELECT IF(true, cast(1 as boolean), cast('2' as binary)) FROM t;
+SELECT IF(true, cast(1 as boolean), cast(2 as boolean)) FROM t;
+SELECT IF(true, cast(1 as boolean), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT IF(true, cast(1 as boolean), cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as tinyint)) FROM t;
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as smallint)) FROM t;
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as int)) FROM t;
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as bigint)) FROM t;
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as float)) FROM t;
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as double)) FROM t;
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as decimal(10, 0))) FROM t;
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as string)) FROM t;
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast('2' as binary)) FROM t;
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as boolean)) FROM t;
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as tinyint)) FROM t;
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as smallint)) FROM t;
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as int)) FROM t;
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as bigint)) FROM t;
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as float)) FROM t;
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as double)) FROM t;
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as decimal(10, 0))) FROM t;
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as string)) FROM t;
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast('2' as binary)) FROM t;
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as boolean)) FROM t;
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast('2017-12-11 09:30:00' as date)) FROM t;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/implicitTypeCasts.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/implicitTypeCasts.sql
new file mode 100644
index 0000000000000..6de22b8b7c3de
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/implicitTypeCasts.sql
@@ -0,0 +1,72 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+-- ImplicitTypeCasts
+
+CREATE TEMPORARY VIEW t AS SELECT 1;
+
+SELECT 1 + '2' FROM t;
+SELECT 1 - '2' FROM t;
+SELECT 1 * '2' FROM t;
+SELECT 4 / '2' FROM t;
+SELECT 1.1 + '2' FROM t;
+SELECT 1.1 - '2' FROM t;
+SELECT 1.1 * '2' FROM t;
+SELECT 4.4 / '2' FROM t;
+SELECT 1.1 + '2.2' FROM t;
+SELECT 1.1 - '2.2' FROM t;
+SELECT 1.1 * '2.2' FROM t;
+SELECT 4.4 / '2.2' FROM t;
+
+-- concatenation
+SELECT '$' || cast(1 as smallint) || '$' FROM t;
+SELECT '$' || 1 || '$' FROM t;
+SELECT '$' || cast(1 as bigint) || '$' FROM t;
+SELECT '$' || cast(1.1 as float) || '$' FROM t;
+SELECT '$' || cast(1.1 as double) || '$' FROM t;
+SELECT '$' || 1.1 || '$' FROM t;
+SELECT '$' || cast(1.1 as decimal(8,3)) || '$' FROM t;
+SELECT '$' || 'abcd' || '$' FROM t;
+SELECT '$' || date('1996-09-09') || '$' FROM t;
+SELECT '$' || timestamp('1996-09-09 10:11:12.4' )|| '$' FROM t;
+
+-- length functions
+SELECT length(cast(1 as smallint)) FROM t;
+SELECT length(cast(1 as int)) FROM t;
+SELECT length(cast(1 as bigint)) FROM t;
+SELECT length(cast(1.1 as float)) FROM t;
+SELECT length(cast(1.1 as double)) FROM t;
+SELECT length(1.1) FROM t;
+SELECT length(cast(1.1 as decimal(8,3))) FROM t;
+SELECT length('four') FROM t;
+SELECT length(date('1996-09-10')) FROM t;
+SELECT length(timestamp('1996-09-10 10:11:12.4')) FROM t;
+
+-- extract
+SELECT year( '1996-01-10') FROM t;
+SELECT month( '1996-01-10') FROM t;
+SELECT day( '1996-01-10') FROM t;
+SELECT hour( '10:11:12') FROM t;
+SELECT minute( '10:11:12') FROM t;
+SELECT second( '10:11:12') FROM t;
+
+-- like
+select 1 like '%' FROM t;
+select date('1996-09-10') like '19%' FROM t;
+select '1' like 1 FROM t;
+select '1 ' like 1 FROM t;
+select '1996-09-10' like date('1996-09-10') FROM t;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/inConversion.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/inConversion.sql
new file mode 100644
index 0000000000000..39dbe7268fba0
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/inConversion.sql
@@ -0,0 +1,330 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+CREATE TEMPORARY VIEW t AS SELECT 1;
+
+SELECT cast(1 as tinyint) in (cast(1 as tinyint)) FROM t;
+SELECT cast(1 as tinyint) in (cast(1 as smallint)) FROM t;
+SELECT cast(1 as tinyint) in (cast(1 as int)) FROM t;
+SELECT cast(1 as tinyint) in (cast(1 as bigint)) FROM t;
+SELECT cast(1 as tinyint) in (cast(1 as float)) FROM t;
+SELECT cast(1 as tinyint) in (cast(1 as double)) FROM t;
+SELECT cast(1 as tinyint) in (cast(1 as decimal(10, 0))) FROM t;
+SELECT cast(1 as tinyint) in (cast(1 as string)) FROM t;
+SELECT cast(1 as tinyint) in (cast('1' as binary)) FROM t;
+SELECT cast(1 as tinyint) in (cast(1 as boolean)) FROM t;
+SELECT cast(1 as tinyint) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT cast(1 as tinyint) in (cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast(1 as smallint) in (cast(1 as tinyint)) FROM t;
+SELECT cast(1 as smallint) in (cast(1 as smallint)) FROM t;
+SELECT cast(1 as smallint) in (cast(1 as int)) FROM t;
+SELECT cast(1 as smallint) in (cast(1 as bigint)) FROM t;
+SELECT cast(1 as smallint) in (cast(1 as float)) FROM t;
+SELECT cast(1 as smallint) in (cast(1 as double)) FROM t;
+SELECT cast(1 as smallint) in (cast(1 as decimal(10, 0))) FROM t;
+SELECT cast(1 as smallint) in (cast(1 as string)) FROM t;
+SELECT cast(1 as smallint) in (cast('1' as binary)) FROM t;
+SELECT cast(1 as smallint) in (cast(1 as boolean)) FROM t;
+SELECT cast(1 as smallint) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT cast(1 as smallint) in (cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast(1 as int) in (cast(1 as tinyint)) FROM t;
+SELECT cast(1 as int) in (cast(1 as smallint)) FROM t;
+SELECT cast(1 as int) in (cast(1 as int)) FROM t;
+SELECT cast(1 as int) in (cast(1 as bigint)) FROM t;
+SELECT cast(1 as int) in (cast(1 as float)) FROM t;
+SELECT cast(1 as int) in (cast(1 as double)) FROM t;
+SELECT cast(1 as int) in (cast(1 as decimal(10, 0))) FROM t;
+SELECT cast(1 as int) in (cast(1 as string)) FROM t;
+SELECT cast(1 as int) in (cast('1' as binary)) FROM t;
+SELECT cast(1 as int) in (cast(1 as boolean)) FROM t;
+SELECT cast(1 as int) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT cast(1 as int) in (cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast(1 as bigint) in (cast(1 as tinyint)) FROM t;
+SELECT cast(1 as bigint) in (cast(1 as smallint)) FROM t;
+SELECT cast(1 as bigint) in (cast(1 as int)) FROM t;
+SELECT cast(1 as bigint) in (cast(1 as bigint)) FROM t;
+SELECT cast(1 as bigint) in (cast(1 as float)) FROM t;
+SELECT cast(1 as bigint) in (cast(1 as double)) FROM t;
+SELECT cast(1 as bigint) in (cast(1 as decimal(10, 0))) FROM t;
+SELECT cast(1 as bigint) in (cast(1 as string)) FROM t;
+SELECT cast(1 as bigint) in (cast('1' as binary)) FROM t;
+SELECT cast(1 as bigint) in (cast(1 as boolean)) FROM t;
+SELECT cast(1 as bigint) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT cast(1 as bigint) in (cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast(1 as float) in (cast(1 as tinyint)) FROM t;
+SELECT cast(1 as float) in (cast(1 as smallint)) FROM t;
+SELECT cast(1 as float) in (cast(1 as int)) FROM t;
+SELECT cast(1 as float) in (cast(1 as bigint)) FROM t;
+SELECT cast(1 as float) in (cast(1 as float)) FROM t;
+SELECT cast(1 as float) in (cast(1 as double)) FROM t;
+SELECT cast(1 as float) in (cast(1 as decimal(10, 0))) FROM t;
+SELECT cast(1 as float) in (cast(1 as string)) FROM t;
+SELECT cast(1 as float) in (cast('1' as binary)) FROM t;
+SELECT cast(1 as float) in (cast(1 as boolean)) FROM t;
+SELECT cast(1 as float) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT cast(1 as float) in (cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast(1 as double) in (cast(1 as tinyint)) FROM t;
+SELECT cast(1 as double) in (cast(1 as smallint)) FROM t;
+SELECT cast(1 as double) in (cast(1 as int)) FROM t;
+SELECT cast(1 as double) in (cast(1 as bigint)) FROM t;
+SELECT cast(1 as double) in (cast(1 as float)) FROM t;
+SELECT cast(1 as double) in (cast(1 as double)) FROM t;
+SELECT cast(1 as double) in (cast(1 as decimal(10, 0))) FROM t;
+SELECT cast(1 as double) in (cast(1 as string)) FROM t;
+SELECT cast(1 as double) in (cast('1' as binary)) FROM t;
+SELECT cast(1 as double) in (cast(1 as boolean)) FROM t;
+SELECT cast(1 as double) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT cast(1 as double) in (cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast(1 as decimal(10, 0)) in (cast(1 as tinyint)) FROM t;
+SELECT cast(1 as decimal(10, 0)) in (cast(1 as smallint)) FROM t;
+SELECT cast(1 as decimal(10, 0)) in (cast(1 as int)) FROM t;
+SELECT cast(1 as decimal(10, 0)) in (cast(1 as bigint)) FROM t;
+SELECT cast(1 as decimal(10, 0)) in (cast(1 as float)) FROM t;
+SELECT cast(1 as decimal(10, 0)) in (cast(1 as double)) FROM t;
+SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0))) FROM t;
+SELECT cast(1 as decimal(10, 0)) in (cast(1 as string)) FROM t;
+SELECT cast(1 as decimal(10, 0)) in (cast('1' as binary)) FROM t;
+SELECT cast(1 as decimal(10, 0)) in (cast(1 as boolean)) FROM t;
+SELECT cast(1 as decimal(10, 0)) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT cast(1 as decimal(10, 0)) in (cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast(1 as string) in (cast(1 as tinyint)) FROM t;
+SELECT cast(1 as string) in (cast(1 as smallint)) FROM t;
+SELECT cast(1 as string) in (cast(1 as int)) FROM t;
+SELECT cast(1 as string) in (cast(1 as bigint)) FROM t;
+SELECT cast(1 as string) in (cast(1 as float)) FROM t;
+SELECT cast(1 as string) in (cast(1 as double)) FROM t;
+SELECT cast(1 as string) in (cast(1 as decimal(10, 0))) FROM t;
+SELECT cast(1 as string) in (cast(1 as string)) FROM t;
+SELECT cast(1 as string) in (cast('1' as binary)) FROM t;
+SELECT cast(1 as string) in (cast(1 as boolean)) FROM t;
+SELECT cast(1 as string) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT cast(1 as string) in (cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast('1' as binary) in (cast(1 as tinyint)) FROM t;
+SELECT cast('1' as binary) in (cast(1 as smallint)) FROM t;
+SELECT cast('1' as binary) in (cast(1 as int)) FROM t;
+SELECT cast('1' as binary) in (cast(1 as bigint)) FROM t;
+SELECT cast('1' as binary) in (cast(1 as float)) FROM t;
+SELECT cast('1' as binary) in (cast(1 as double)) FROM t;
+SELECT cast('1' as binary) in (cast(1 as decimal(10, 0))) FROM t;
+SELECT cast('1' as binary) in (cast(1 as string)) FROM t;
+SELECT cast('1' as binary) in (cast('1' as binary)) FROM t;
+SELECT cast('1' as binary) in (cast(1 as boolean)) FROM t;
+SELECT cast('1' as binary) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT cast('1' as binary) in (cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT true in (cast(1 as tinyint)) FROM t;
+SELECT true in (cast(1 as smallint)) FROM t;
+SELECT true in (cast(1 as int)) FROM t;
+SELECT true in (cast(1 as bigint)) FROM t;
+SELECT true in (cast(1 as float)) FROM t;
+SELECT true in (cast(1 as double)) FROM t;
+SELECT true in (cast(1 as decimal(10, 0))) FROM t;
+SELECT true in (cast(1 as string)) FROM t;
+SELECT true in (cast('1' as binary)) FROM t;
+SELECT true in (cast(1 as boolean)) FROM t;
+SELECT true in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT true in (cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as tinyint)) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as smallint)) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as int)) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as bigint)) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as float)) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as double)) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as decimal(10, 0))) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as string)) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2' as binary)) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as boolean)) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as tinyint)) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as smallint)) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as int)) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as bigint)) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as float)) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as double)) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as decimal(10, 0))) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as string)) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) in (cast('2' as binary)) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as boolean)) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as tinyint)) FROM t;
+SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as smallint)) FROM t;
+SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as int)) FROM t;
+SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as bigint)) FROM t;
+SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as float)) FROM t;
+SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as double)) FROM t;
+SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as decimal(10, 0))) FROM t;
+SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as string)) FROM t;
+SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast('1' as binary)) FROM t;
+SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as boolean)) FROM t;
+SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as tinyint)) FROM t;
+SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as smallint)) FROM t;
+SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as int)) FROM t;
+SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as bigint)) FROM t;
+SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as float)) FROM t;
+SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as double)) FROM t;
+SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as decimal(10, 0))) FROM t;
+SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as string)) FROM t;
+SELECT cast(1 as smallint) in (cast(1 as smallint), cast('1' as binary)) FROM t;
+SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as boolean)) FROM t;
+SELECT cast(1 as smallint) in (cast(1 as smallint), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT cast(1 as smallint) in (cast(1 as smallint), cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast(1 as int) in (cast(1 as int), cast(1 as tinyint)) FROM t;
+SELECT cast(1 as int) in (cast(1 as int), cast(1 as smallint)) FROM t;
+SELECT cast(1 as int) in (cast(1 as int), cast(1 as int)) FROM t;
+SELECT cast(1 as int) in (cast(1 as int), cast(1 as bigint)) FROM t;
+SELECT cast(1 as int) in (cast(1 as int), cast(1 as float)) FROM t;
+SELECT cast(1 as int) in (cast(1 as int), cast(1 as double)) FROM t;
+SELECT cast(1 as int) in (cast(1 as int), cast(1 as decimal(10, 0))) FROM t;
+SELECT cast(1 as int) in (cast(1 as int), cast(1 as string)) FROM t;
+SELECT cast(1 as int) in (cast(1 as int), cast('1' as binary)) FROM t;
+SELECT cast(1 as int) in (cast(1 as int), cast(1 as boolean)) FROM t;
+SELECT cast(1 as int) in (cast(1 as int), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT cast(1 as int) in (cast(1 as int), cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as tinyint)) FROM t;
+SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as smallint)) FROM t;
+SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as int)) FROM t;
+SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as bigint)) FROM t;
+SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as float)) FROM t;
+SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as double)) FROM t;
+SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as decimal(10, 0))) FROM t;
+SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as string)) FROM t;
+SELECT cast(1 as bigint) in (cast(1 as bigint), cast('1' as binary)) FROM t;
+SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as boolean)) FROM t;
+SELECT cast(1 as bigint) in (cast(1 as bigint), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT cast(1 as bigint) in (cast(1 as bigint), cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast(1 as float) in (cast(1 as float), cast(1 as tinyint)) FROM t;
+SELECT cast(1 as float) in (cast(1 as float), cast(1 as smallint)) FROM t;
+SELECT cast(1 as float) in (cast(1 as float), cast(1 as int)) FROM t;
+SELECT cast(1 as float) in (cast(1 as float), cast(1 as bigint)) FROM t;
+SELECT cast(1 as float) in (cast(1 as float), cast(1 as float)) FROM t;
+SELECT cast(1 as float) in (cast(1 as float), cast(1 as double)) FROM t;
+SELECT cast(1 as float) in (cast(1 as float), cast(1 as decimal(10, 0))) FROM t;
+SELECT cast(1 as float) in (cast(1 as float), cast(1 as string)) FROM t;
+SELECT cast(1 as float) in (cast(1 as float), cast('1' as binary)) FROM t;
+SELECT cast(1 as float) in (cast(1 as float), cast(1 as boolean)) FROM t;
+SELECT cast(1 as float) in (cast(1 as float), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT cast(1 as float) in (cast(1 as float), cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast(1 as double) in (cast(1 as double), cast(1 as tinyint)) FROM t;
+SELECT cast(1 as double) in (cast(1 as double), cast(1 as smallint)) FROM t;
+SELECT cast(1 as double) in (cast(1 as double), cast(1 as int)) FROM t;
+SELECT cast(1 as double) in (cast(1 as double), cast(1 as bigint)) FROM t;
+SELECT cast(1 as double) in (cast(1 as double), cast(1 as float)) FROM t;
+SELECT cast(1 as double) in (cast(1 as double), cast(1 as double)) FROM t;
+SELECT cast(1 as double) in (cast(1 as double), cast(1 as decimal(10, 0))) FROM t;
+SELECT cast(1 as double) in (cast(1 as double), cast(1 as string)) FROM t;
+SELECT cast(1 as double) in (cast(1 as double), cast('1' as binary)) FROM t;
+SELECT cast(1 as double) in (cast(1 as double), cast(1 as boolean)) FROM t;
+SELECT cast(1 as double) in (cast(1 as double), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT cast(1 as double) in (cast(1 as double), cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as tinyint)) FROM t;
+SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as smallint)) FROM t;
+SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as int)) FROM t;
+SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as bigint)) FROM t;
+SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as float)) FROM t;
+SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as double)) FROM t;
+SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as decimal(10, 0))) FROM t;
+SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as string)) FROM t;
+SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast('1' as binary)) FROM t;
+SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as boolean)) FROM t;
+SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast(1 as string) in (cast(1 as string), cast(1 as tinyint)) FROM t;
+SELECT cast(1 as string) in (cast(1 as string), cast(1 as smallint)) FROM t;
+SELECT cast(1 as string) in (cast(1 as string), cast(1 as int)) FROM t;
+SELECT cast(1 as string) in (cast(1 as string), cast(1 as bigint)) FROM t;
+SELECT cast(1 as string) in (cast(1 as string), cast(1 as float)) FROM t;
+SELECT cast(1 as string) in (cast(1 as string), cast(1 as double)) FROM t;
+SELECT cast(1 as string) in (cast(1 as string), cast(1 as decimal(10, 0))) FROM t;
+SELECT cast(1 as string) in (cast(1 as string), cast(1 as string)) FROM t;
+SELECT cast(1 as string) in (cast(1 as string), cast('1' as binary)) FROM t;
+SELECT cast(1 as string) in (cast(1 as string), cast(1 as boolean)) FROM t;
+SELECT cast(1 as string) in (cast(1 as string), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT cast(1 as string) in (cast(1 as string), cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as tinyint)) FROM t;
+SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as smallint)) FROM t;
+SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as int)) FROM t;
+SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as bigint)) FROM t;
+SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as float)) FROM t;
+SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as double)) FROM t;
+SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as decimal(10, 0))) FROM t;
+SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as string)) FROM t;
+SELECT cast('1' as binary) in (cast('1' as binary), cast('1' as binary)) FROM t;
+SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as boolean)) FROM t;
+SELECT cast('1' as binary) in (cast('1' as binary), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT cast('1' as binary) in (cast('1' as binary), cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as tinyint)) FROM t;
+SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as smallint)) FROM t;
+SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as int)) FROM t;
+SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as bigint)) FROM t;
+SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as float)) FROM t;
+SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as double)) FROM t;
+SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as decimal(10, 0))) FROM t;
+SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as string)) FROM t;
+SELECT cast('1' as boolean) in (cast('1' as boolean), cast('1' as binary)) FROM t;
+SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as boolean)) FROM t;
+SELECT cast('1' as boolean) in (cast('1' as boolean), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT cast('1' as boolean) in (cast('1' as boolean), cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as tinyint)) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as smallint)) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as int)) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as bigint)) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as float)) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as double)) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as decimal(10, 0))) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as string)) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast('1' as binary)) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as boolean)) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as tinyint)) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as smallint)) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as int)) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as bigint)) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as float)) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as double)) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as decimal(10, 0))) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as string)) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast('1' as binary)) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as boolean)) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast('2017-12-11 09:30:00' as date)) FROM t;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/promoteStrings.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/promoteStrings.sql
new file mode 100644
index 0000000000000..a5603a184578d
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/promoteStrings.sql
@@ -0,0 +1,364 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+CREATE TEMPORARY VIEW t AS SELECT 1;
+
+-- Binary arithmetic
+SELECT '1' + cast(1 as tinyint) FROM t;
+SELECT '1' + cast(1 as smallint) FROM t;
+SELECT '1' + cast(1 as int) FROM t;
+SELECT '1' + cast(1 as bigint) FROM t;
+SELECT '1' + cast(1 as float) FROM t;
+SELECT '1' + cast(1 as double) FROM t;
+SELECT '1' + cast(1 as decimal(10, 0)) FROM t;
+SELECT '1' + '1' FROM t;
+SELECT '1' + cast('1' as binary) FROM t;
+SELECT '1' + cast(1 as boolean) FROM t;
+SELECT '1' + cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT '1' + cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT '1' - cast(1 as tinyint) FROM t;
+SELECT '1' - cast(1 as smallint) FROM t;
+SELECT '1' - cast(1 as int) FROM t;
+SELECT '1' - cast(1 as bigint) FROM t;
+SELECT '1' - cast(1 as float) FROM t;
+SELECT '1' - cast(1 as double) FROM t;
+SELECT '1' - cast(1 as decimal(10, 0)) FROM t;
+SELECT '1' - '1' FROM t;
+SELECT '1' - cast('1' as binary) FROM t;
+SELECT '1' - cast(1 as boolean) FROM t;
+SELECT '1' - cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT '1' - cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT '1' * cast(1 as tinyint) FROM t;
+SELECT '1' * cast(1 as smallint) FROM t;
+SELECT '1' * cast(1 as int) FROM t;
+SELECT '1' * cast(1 as bigint) FROM t;
+SELECT '1' * cast(1 as float) FROM t;
+SELECT '1' * cast(1 as double) FROM t;
+SELECT '1' * cast(1 as decimal(10, 0)) FROM t;
+SELECT '1' * '1' FROM t;
+SELECT '1' * cast('1' as binary) FROM t;
+SELECT '1' * cast(1 as boolean) FROM t;
+SELECT '1' * cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT '1' * cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT '1' / cast(1 as tinyint) FROM t;
+SELECT '1' / cast(1 as smallint) FROM t;
+SELECT '1' / cast(1 as int) FROM t;
+SELECT '1' / cast(1 as bigint) FROM t;
+SELECT '1' / cast(1 as float) FROM t;
+SELECT '1' / cast(1 as double) FROM t;
+SELECT '1' / cast(1 as decimal(10, 0)) FROM t;
+SELECT '1' / '1' FROM t;
+SELECT '1' / cast('1' as binary) FROM t;
+SELECT '1' / cast(1 as boolean) FROM t;
+SELECT '1' / cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT '1' / cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT '1' % cast(1 as tinyint) FROM t;
+SELECT '1' % cast(1 as smallint) FROM t;
+SELECT '1' % cast(1 as int) FROM t;
+SELECT '1' % cast(1 as bigint) FROM t;
+SELECT '1' % cast(1 as float) FROM t;
+SELECT '1' % cast(1 as double) FROM t;
+SELECT '1' % cast(1 as decimal(10, 0)) FROM t;
+SELECT '1' % '1' FROM t;
+SELECT '1' % cast('1' as binary) FROM t;
+SELECT '1' % cast(1 as boolean) FROM t;
+SELECT '1' % cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT '1' % cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT pmod('1', cast(1 as tinyint)) FROM t;
+SELECT pmod('1', cast(1 as smallint)) FROM t;
+SELECT pmod('1', cast(1 as int)) FROM t;
+SELECT pmod('1', cast(1 as bigint)) FROM t;
+SELECT pmod('1', cast(1 as float)) FROM t;
+SELECT pmod('1', cast(1 as double)) FROM t;
+SELECT pmod('1', cast(1 as decimal(10, 0))) FROM t;
+SELECT pmod('1', '1') FROM t;
+SELECT pmod('1', cast('1' as binary)) FROM t;
+SELECT pmod('1', cast(1 as boolean)) FROM t;
+SELECT pmod('1', cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT pmod('1', cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT cast(1 as tinyint) + '1' FROM t;
+SELECT cast(1 as smallint) + '1' FROM t;
+SELECT cast(1 as int) + '1' FROM t;
+SELECT cast(1 as bigint) + '1' FROM t;
+SELECT cast(1 as float) + '1' FROM t;
+SELECT cast(1 as double) + '1' FROM t;
+SELECT cast(1 as decimal(10, 0)) + '1' FROM t;
+SELECT cast('1' as binary) + '1' FROM t;
+SELECT cast(1 as boolean) + '1' FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) + '1' FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) + '1' FROM t;
+
+SELECT cast(1 as tinyint) - '1' FROM t;
+SELECT cast(1 as smallint) - '1' FROM t;
+SELECT cast(1 as int) - '1' FROM t;
+SELECT cast(1 as bigint) - '1' FROM t;
+SELECT cast(1 as float) - '1' FROM t;
+SELECT cast(1 as double) - '1' FROM t;
+SELECT cast(1 as decimal(10, 0)) - '1' FROM t;
+SELECT cast('1' as binary) - '1' FROM t;
+SELECT cast(1 as boolean) - '1' FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) - '1' FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) - '1' FROM t;
+
+SELECT cast(1 as tinyint) * '1' FROM t;
+SELECT cast(1 as smallint) * '1' FROM t;
+SELECT cast(1 as int) * '1' FROM t;
+SELECT cast(1 as bigint) * '1' FROM t;
+SELECT cast(1 as float) * '1' FROM t;
+SELECT cast(1 as double) * '1' FROM t;
+SELECT cast(1 as decimal(10, 0)) * '1' FROM t;
+SELECT cast('1' as binary) * '1' FROM t;
+SELECT cast(1 as boolean) * '1' FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) * '1' FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) * '1' FROM t;
+
+SELECT cast(1 as tinyint) / '1' FROM t;
+SELECT cast(1 as smallint) / '1' FROM t;
+SELECT cast(1 as int) / '1' FROM t;
+SELECT cast(1 as bigint) / '1' FROM t;
+SELECT cast(1 as float) / '1' FROM t;
+SELECT cast(1 as double) / '1' FROM t;
+SELECT cast(1 as decimal(10, 0)) / '1' FROM t;
+SELECT cast('1' as binary) / '1' FROM t;
+SELECT cast(1 as boolean) / '1' FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / '1' FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) / '1' FROM t;
+
+SELECT cast(1 as tinyint) % '1' FROM t;
+SELECT cast(1 as smallint) % '1' FROM t;
+SELECT cast(1 as int) % '1' FROM t;
+SELECT cast(1 as bigint) % '1' FROM t;
+SELECT cast(1 as float) % '1' FROM t;
+SELECT cast(1 as double) % '1' FROM t;
+SELECT cast(1 as decimal(10, 0)) % '1' FROM t;
+SELECT cast('1' as binary) % '1' FROM t;
+SELECT cast(1 as boolean) % '1' FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) % '1' FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) % '1' FROM t;
+
+SELECT pmod(cast(1 as tinyint), '1') FROM t;
+SELECT pmod(cast(1 as smallint), '1') FROM t;
+SELECT pmod(cast(1 as int), '1') FROM t;
+SELECT pmod(cast(1 as bigint), '1') FROM t;
+SELECT pmod(cast(1 as float), '1') FROM t;
+SELECT pmod(cast(1 as double), '1') FROM t;
+SELECT pmod(cast(1 as decimal(10, 0)), '1') FROM t;
+SELECT pmod(cast('1' as binary), '1') FROM t;
+SELECT pmod(cast(1 as boolean), '1') FROM t;
+SELECT pmod(cast('2017-12-11 09:30:00.0' as timestamp), '1') FROM t;
+SELECT pmod(cast('2017-12-11 09:30:00' as date), '1') FROM t;
+
+-- Equality
+SELECT '1' = cast(1 as tinyint) FROM t;
+SELECT '1' = cast(1 as smallint) FROM t;
+SELECT '1' = cast(1 as int) FROM t;
+SELECT '1' = cast(1 as bigint) FROM t;
+SELECT '1' = cast(1 as float) FROM t;
+SELECT '1' = cast(1 as double) FROM t;
+SELECT '1' = cast(1 as decimal(10, 0)) FROM t;
+SELECT '1' = '1' FROM t;
+SELECT '1' = cast('1' as binary) FROM t;
+SELECT '1' = cast(1 as boolean) FROM t;
+SELECT '1' = cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT '1' = cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as tinyint) = '1' FROM t;
+SELECT cast(1 as smallint) = '1' FROM t;
+SELECT cast(1 as int) = '1' FROM t;
+SELECT cast(1 as bigint) = '1' FROM t;
+SELECT cast(1 as float) = '1' FROM t;
+SELECT cast(1 as double) = '1' FROM t;
+SELECT cast(1 as decimal(10, 0)) = '1' FROM t;
+SELECT cast('1' as binary) = '1' FROM t;
+SELECT cast(1 as boolean) = '1' FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) = '1' FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) = '1' FROM t;
+
+SELECT '1' <=> cast(1 as tinyint) FROM t;
+SELECT '1' <=> cast(1 as smallint) FROM t;
+SELECT '1' <=> cast(1 as int) FROM t;
+SELECT '1' <=> cast(1 as bigint) FROM t;
+SELECT '1' <=> cast(1 as float) FROM t;
+SELECT '1' <=> cast(1 as double) FROM t;
+SELECT '1' <=> cast(1 as decimal(10, 0)) FROM t;
+SELECT '1' <=> '1' FROM t;
+SELECT '1' <=> cast('1' as binary) FROM t;
+SELECT '1' <=> cast(1 as boolean) FROM t;
+SELECT '1' <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT '1' <=> cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as tinyint) <=> '1' FROM t;
+SELECT cast(1 as smallint) <=> '1' FROM t;
+SELECT cast(1 as int) <=> '1' FROM t;
+SELECT cast(1 as bigint) <=> '1' FROM t;
+SELECT cast(1 as float) <=> '1' FROM t;
+SELECT cast(1 as double) <=> '1' FROM t;
+SELECT cast(1 as decimal(10, 0)) <=> '1' FROM t;
+SELECT cast('1' as binary) <=> '1' FROM t;
+SELECT cast(1 as boolean) <=> '1' FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> '1' FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) <=> '1' FROM t;
+
+-- Binary comparison
+SELECT '1' < cast(1 as tinyint) FROM t;
+SELECT '1' < cast(1 as smallint) FROM t;
+SELECT '1' < cast(1 as int) FROM t;
+SELECT '1' < cast(1 as bigint) FROM t;
+SELECT '1' < cast(1 as float) FROM t;
+SELECT '1' < cast(1 as double) FROM t;
+SELECT '1' < cast(1 as decimal(10, 0)) FROM t;
+SELECT '1' < '1' FROM t;
+SELECT '1' < cast('1' as binary) FROM t;
+SELECT '1' < cast(1 as boolean) FROM t;
+SELECT '1' < cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT '1' < cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT '1' <= cast(1 as tinyint) FROM t;
+SELECT '1' <= cast(1 as smallint) FROM t;
+SELECT '1' <= cast(1 as int) FROM t;
+SELECT '1' <= cast(1 as bigint) FROM t;
+SELECT '1' <= cast(1 as float) FROM t;
+SELECT '1' <= cast(1 as double) FROM t;
+SELECT '1' <= cast(1 as decimal(10, 0)) FROM t;
+SELECT '1' <= '1' FROM t;
+SELECT '1' <= cast('1' as binary) FROM t;
+SELECT '1' <= cast(1 as boolean) FROM t;
+SELECT '1' <= cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT '1' <= cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT '1' > cast(1 as tinyint) FROM t;
+SELECT '1' > cast(1 as smallint) FROM t;
+SELECT '1' > cast(1 as int) FROM t;
+SELECT '1' > cast(1 as bigint) FROM t;
+SELECT '1' > cast(1 as float) FROM t;
+SELECT '1' > cast(1 as double) FROM t;
+SELECT '1' > cast(1 as decimal(10, 0)) FROM t;
+SELECT '1' > '1' FROM t;
+SELECT '1' > cast('1' as binary) FROM t;
+SELECT '1' > cast(1 as boolean) FROM t;
+SELECT '1' > cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT '1' > cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT '1' >= cast(1 as tinyint) FROM t;
+SELECT '1' >= cast(1 as smallint) FROM t;
+SELECT '1' >= cast(1 as int) FROM t;
+SELECT '1' >= cast(1 as bigint) FROM t;
+SELECT '1' >= cast(1 as float) FROM t;
+SELECT '1' >= cast(1 as double) FROM t;
+SELECT '1' >= cast(1 as decimal(10, 0)) FROM t;
+SELECT '1' >= '1' FROM t;
+SELECT '1' >= cast('1' as binary) FROM t;
+SELECT '1' >= cast(1 as boolean) FROM t;
+SELECT '1' >= cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT '1' >= cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT '1' <> cast(1 as tinyint) FROM t;
+SELECT '1' <> cast(1 as smallint) FROM t;
+SELECT '1' <> cast(1 as int) FROM t;
+SELECT '1' <> cast(1 as bigint) FROM t;
+SELECT '1' <> cast(1 as float) FROM t;
+SELECT '1' <> cast(1 as double) FROM t;
+SELECT '1' <> cast(1 as decimal(10, 0)) FROM t;
+SELECT '1' <> '1' FROM t;
+SELECT '1' <> cast('1' as binary) FROM t;
+SELECT '1' <> cast(1 as boolean) FROM t;
+SELECT '1' <> cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT '1' <> cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as tinyint) < '1' FROM t;
+SELECT cast(1 as smallint) < '1' FROM t;
+SELECT cast(1 as int) < '1' FROM t;
+SELECT cast(1 as bigint) < '1' FROM t;
+SELECT cast(1 as float) < '1' FROM t;
+SELECT cast(1 as double) < '1' FROM t;
+SELECT cast(1 as decimal(10, 0)) < '1' FROM t;
+SELECT '1' < '1' FROM t;
+SELECT cast('1' as binary) < '1' FROM t;
+SELECT cast(1 as boolean) < '1' FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) < '1' FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) < '1' FROM t;
+
+SELECT cast(1 as tinyint) <= '1' FROM t;
+SELECT cast(1 as smallint) <= '1' FROM t;
+SELECT cast(1 as int) <= '1' FROM t;
+SELECT cast(1 as bigint) <= '1' FROM t;
+SELECT cast(1 as float) <= '1' FROM t;
+SELECT cast(1 as double) <= '1' FROM t;
+SELECT cast(1 as decimal(10, 0)) <= '1' FROM t;
+SELECT '1' <= '1' FROM t;
+SELECT cast('1' as binary) <= '1' FROM t;
+SELECT cast(1 as boolean) <= '1' FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <= '1' FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) <= '1' FROM t;
+
+SELECT cast(1 as tinyint) > '1' FROM t;
+SELECT cast(1 as smallint) > '1' FROM t;
+SELECT cast(1 as int) > '1' FROM t;
+SELECT cast(1 as bigint) > '1' FROM t;
+SELECT cast(1 as float) > '1' FROM t;
+SELECT cast(1 as double) > '1' FROM t;
+SELECT cast(1 as decimal(10, 0)) > '1' FROM t;
+SELECT '1' > '1' FROM t;
+SELECT cast('1' as binary) > '1' FROM t;
+SELECT cast(1 as boolean) > '1' FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) > '1' FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) > '1' FROM t;
+
+SELECT cast(1 as tinyint) >= '1' FROM t;
+SELECT cast(1 as smallint) >= '1' FROM t;
+SELECT cast(1 as int) >= '1' FROM t;
+SELECT cast(1 as bigint) >= '1' FROM t;
+SELECT cast(1 as float) >= '1' FROM t;
+SELECT cast(1 as double) >= '1' FROM t;
+SELECT cast(1 as decimal(10, 0)) >= '1' FROM t;
+SELECT '1' >= '1' FROM t;
+SELECT cast('1' as binary) >= '1' FROM t;
+SELECT cast(1 as boolean) >= '1' FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) >= '1' FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) >= '1' FROM t;
+
+SELECT cast(1 as tinyint) <> '1' FROM t;
+SELECT cast(1 as smallint) <> '1' FROM t;
+SELECT cast(1 as int) <> '1' FROM t;
+SELECT cast(1 as bigint) <> '1' FROM t;
+SELECT cast(1 as float) <> '1' FROM t;
+SELECT cast(1 as double) <> '1' FROM t;
+SELECT cast(1 as decimal(10, 0)) <> '1' FROM t;
+SELECT '1' <> '1' FROM t;
+SELECT cast('1' as binary) <> '1' FROM t;
+SELECT cast(1 as boolean) <> '1' FROM t;
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <> '1' FROM t;
+SELECT cast('2017-12-11 09:30:00' as date) <> '1' FROM t;
+
+-- Functions
+SELECT abs('1') FROM t;
+SELECT sum('1') FROM t;
+SELECT avg('1') FROM t;
+SELECT stddev_pop('1') FROM t;
+SELECT stddev_samp('1') FROM t;
+SELECT - '1' FROM t;
+SELECT + '1' FROM t;
+SELECT var_pop('1') FROM t;
+SELECT var_samp('1') FROM t;
+SELECT skewness('1') FROM t;
+SELECT kurtosis('1') FROM t;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/stringCastAndExpressions.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/stringCastAndExpressions.sql
new file mode 100644
index 0000000000000..f17adb56dee91
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/stringCastAndExpressions.sql
@@ -0,0 +1,57 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+CREATE TEMPORARY VIEW t AS SELECT 'aa' as a;
+
+-- casting to data types which are unable to represent the string input returns NULL
+select cast(a as byte) from t;
+select cast(a as short) from t;
+select cast(a as int) from t;
+select cast(a as long) from t;
+select cast(a as float) from t;
+select cast(a as double) from t;
+select cast(a as decimal) from t;
+select cast(a as boolean) from t;
+select cast(a as timestamp) from t;
+select cast(a as date) from t;
+-- casting to binary works correctly
+select cast(a as binary) from t;
+-- casting to array, struct or map throws exception
+select cast(a as array) from t;
+select cast(a as struct) from t;
+select cast(a as map) from t;
+
+-- all timestamp/date expressions return NULL if bad input strings are provided
+select to_timestamp(a) from t;
+select to_timestamp('2018-01-01', a) from t;
+select to_unix_timestamp(a) from t;
+select to_unix_timestamp('2018-01-01', a) from t;
+select unix_timestamp(a) from t;
+select unix_timestamp('2018-01-01', a) from t;
+select from_unixtime(a) from t;
+select from_unixtime('2018-01-01', a) from t;
+select next_day(a, 'MO') from t;
+select next_day('2018-01-01', a) from t;
+select trunc(a, 'MM') from t;
+select trunc('2018-01-01', a) from t;
+
+-- some functions return NULL if bad input is provided
+select unhex('-123');
+select sha2(a, a) from t;
+select get_json_object(a, a) from t;
+select json_tuple(a, a) from t;
+select from_json(a, 'a INT') from t;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/widenSetOperationTypes.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/widenSetOperationTypes.sql
new file mode 100644
index 0000000000000..66e9689850d93
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/widenSetOperationTypes.sql
@@ -0,0 +1,175 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+CREATE TEMPORARY VIEW t AS SELECT 1;
+
+-- UNION
+SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as tinyint) FROM t;
+SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as smallint) FROM t;
+SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as int) FROM t;
+SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as bigint) FROM t;
+SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as float) FROM t;
+SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as double) FROM t;
+SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t;
+SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as string) FROM t;
+SELECT cast(1 as tinyint) FROM t UNION SELECT cast('2' as binary) FROM t;
+SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as boolean) FROM t;
+SELECT cast(1 as tinyint) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as tinyint) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as tinyint) FROM t;
+SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as smallint) FROM t;
+SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as int) FROM t;
+SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as bigint) FROM t;
+SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as float) FROM t;
+SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as double) FROM t;
+SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t;
+SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as string) FROM t;
+SELECT cast(1 as smallint) FROM t UNION SELECT cast('2' as binary) FROM t;
+SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as boolean) FROM t;
+SELECT cast(1 as smallint) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as smallint) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as int) FROM t UNION SELECT cast(2 as tinyint) FROM t;
+SELECT cast(1 as int) FROM t UNION SELECT cast(2 as smallint) FROM t;
+SELECT cast(1 as int) FROM t UNION SELECT cast(2 as int) FROM t;
+SELECT cast(1 as int) FROM t UNION SELECT cast(2 as bigint) FROM t;
+SELECT cast(1 as int) FROM t UNION SELECT cast(2 as float) FROM t;
+SELECT cast(1 as int) FROM t UNION SELECT cast(2 as double) FROM t;
+SELECT cast(1 as int) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t;
+SELECT cast(1 as int) FROM t UNION SELECT cast(2 as string) FROM t;
+SELECT cast(1 as int) FROM t UNION SELECT cast('2' as binary) FROM t;
+SELECT cast(1 as int) FROM t UNION SELECT cast(2 as boolean) FROM t;
+SELECT cast(1 as int) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as int) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as tinyint) FROM t;
+SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as smallint) FROM t;
+SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as int) FROM t;
+SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as bigint) FROM t;
+SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as float) FROM t;
+SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as double) FROM t;
+SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t;
+SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as string) FROM t;
+SELECT cast(1 as bigint) FROM t UNION SELECT cast('2' as binary) FROM t;
+SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as boolean) FROM t;
+SELECT cast(1 as bigint) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as bigint) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as float) FROM t UNION SELECT cast(2 as tinyint) FROM t;
+SELECT cast(1 as float) FROM t UNION SELECT cast(2 as smallint) FROM t;
+SELECT cast(1 as float) FROM t UNION SELECT cast(2 as int) FROM t;
+SELECT cast(1 as float) FROM t UNION SELECT cast(2 as bigint) FROM t;
+SELECT cast(1 as float) FROM t UNION SELECT cast(2 as float) FROM t;
+SELECT cast(1 as float) FROM t UNION SELECT cast(2 as double) FROM t;
+SELECT cast(1 as float) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t;
+SELECT cast(1 as float) FROM t UNION SELECT cast(2 as string) FROM t;
+SELECT cast(1 as float) FROM t UNION SELECT cast('2' as binary) FROM t;
+SELECT cast(1 as float) FROM t UNION SELECT cast(2 as boolean) FROM t;
+SELECT cast(1 as float) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as float) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as double) FROM t UNION SELECT cast(2 as tinyint) FROM t;
+SELECT cast(1 as double) FROM t UNION SELECT cast(2 as smallint) FROM t;
+SELECT cast(1 as double) FROM t UNION SELECT cast(2 as int) FROM t;
+SELECT cast(1 as double) FROM t UNION SELECT cast(2 as bigint) FROM t;
+SELECT cast(1 as double) FROM t UNION SELECT cast(2 as float) FROM t;
+SELECT cast(1 as double) FROM t UNION SELECT cast(2 as double) FROM t;
+SELECT cast(1 as double) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t;
+SELECT cast(1 as double) FROM t UNION SELECT cast(2 as string) FROM t;
+SELECT cast(1 as double) FROM t UNION SELECT cast('2' as binary) FROM t;
+SELECT cast(1 as double) FROM t UNION SELECT cast(2 as boolean) FROM t;
+SELECT cast(1 as double) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as double) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as tinyint) FROM t;
+SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as smallint) FROM t;
+SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as int) FROM t;
+SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as bigint) FROM t;
+SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as float) FROM t;
+SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as double) FROM t;
+SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t;
+SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as string) FROM t;
+SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast('2' as binary) FROM t;
+SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as boolean) FROM t;
+SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as string) FROM t UNION SELECT cast(2 as tinyint) FROM t;
+SELECT cast(1 as string) FROM t UNION SELECT cast(2 as smallint) FROM t;
+SELECT cast(1 as string) FROM t UNION SELECT cast(2 as int) FROM t;
+SELECT cast(1 as string) FROM t UNION SELECT cast(2 as bigint) FROM t;
+SELECT cast(1 as string) FROM t UNION SELECT cast(2 as float) FROM t;
+SELECT cast(1 as string) FROM t UNION SELECT cast(2 as double) FROM t;
+SELECT cast(1 as string) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t;
+SELECT cast(1 as string) FROM t UNION SELECT cast(2 as string) FROM t;
+SELECT cast(1 as string) FROM t UNION SELECT cast('2' as binary) FROM t;
+SELECT cast(1 as string) FROM t UNION SELECT cast(2 as boolean) FROM t;
+SELECT cast(1 as string) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as string) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as tinyint) FROM t;
+SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as smallint) FROM t;
+SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as int) FROM t;
+SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as bigint) FROM t;
+SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as float) FROM t;
+SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as double) FROM t;
+SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t;
+SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as string) FROM t;
+SELECT cast('1' as binary) FROM t UNION SELECT cast('2' as binary) FROM t;
+SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as boolean) FROM t;
+SELECT cast('1' as binary) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast('1' as binary) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as tinyint) FROM t;
+SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as smallint) FROM t;
+SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as int) FROM t;
+SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as bigint) FROM t;
+SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as float) FROM t;
+SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as double) FROM t;
+SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t;
+SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as string) FROM t;
+SELECT cast(1 as boolean) FROM t UNION SELECT cast('2' as binary) FROM t;
+SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as boolean) FROM t;
+SELECT cast(1 as boolean) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast(1 as boolean) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as tinyint) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as smallint) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as int) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as bigint) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as float) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as double) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as string) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast('2' as binary) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as boolean) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t;
+
+SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as tinyint) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as smallint) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as int) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as bigint) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as float) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as double) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as string) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast('2' as binary) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as boolean) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t;
+SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/windowFrameCoercion.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/windowFrameCoercion.sql
new file mode 100644
index 0000000000000..5cd3538757499
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/windowFrameCoercion.sql
@@ -0,0 +1,44 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+CREATE TEMPORARY VIEW t AS SELECT 1;
+
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as tinyint)) FROM t;
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as smallint)) FROM t;
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as int)) FROM t;
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as bigint)) FROM t;
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as float)) FROM t;
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as double)) FROM t;
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as decimal(10, 0))) FROM t;
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as string)) FROM t;
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('1' as binary)) FROM t;
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as boolean)) FROM t;
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('2017-12-11 09:30:00.0' as timestamp)) FROM t;
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('2017-12-11 09:30:00' as date)) FROM t;
+
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as tinyint) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t;
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as smallint) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t;
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as int) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t;
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as bigint) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t;
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as float) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t;
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as double) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t;
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as decimal(10, 0)) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t;
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as string) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t;
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('1' as binary) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t;
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as boolean) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t;
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('2017-12-11 09:30:00.0' as timestamp) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t;
+SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('2017-12-11 09:30:00' as date) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t;
diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out
index 7b2f46f6c2a66..bbb6851e69c7e 100644
--- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out
@@ -44,6 +44,7 @@ struct<>
-- !query 4 output
+
-- !query 5
select current_date, current_timestamp from ttf1
-- !query 5 schema
@@ -63,6 +64,7 @@ struct<>
-- !query 6 output
+
-- !query 7
select current_date = current_date(), current_timestamp = current_timestamp(), a, b from ttf2
-- !query 7 schema
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
index 986bb01c13fe4..c1abc6dff754b 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 25
+-- Number of queries: 26
-- !query 0
@@ -227,3 +227,17 @@ SELECT 1 FROM (SELECT COUNT(1) FROM testData WHERE false) t
struct<1:int>
-- !query 24 output
1
+
+
+-- !query 25
+SELECT 1 from (
+ SELECT 1 AS z,
+ MIN(a.x)
+ FROM (select 1 as x) a
+ WHERE false
+) b
+where b.z != b.z
+-- !query 25 schema
+struct<1:int>
+-- !query 25 output
+
diff --git a/sql/core/src/test/resources/sql-tests/results/join-empty-relation.sql.out b/sql/core/src/test/resources/sql-tests/results/join-empty-relation.sql.out
new file mode 100644
index 0000000000000..857073a827f24
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/join-empty-relation.sql.out
@@ -0,0 +1,194 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 24
+
+
+-- !query 0
+CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1) AS GROUPING(a)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1) AS GROUPING(a)
+-- !query 1 schema
+struct<>
+-- !query 1 output
+
+
+
+-- !query 2
+CREATE TEMPORARY VIEW empty_table as SELECT a FROM t2 WHERE false
+-- !query 2 schema
+struct<>
+-- !query 2 output
+
+
+
+-- !query 3
+SELECT * FROM t1 INNER JOIN empty_table
+-- !query 3 schema
+struct
+-- !query 3 output
+
+
+
+-- !query 4
+SELECT * FROM t1 CROSS JOIN empty_table
+-- !query 4 schema
+struct
+-- !query 4 output
+
+
+
+-- !query 5
+SELECT * FROM t1 LEFT OUTER JOIN empty_table
+-- !query 5 schema
+struct
+-- !query 5 output
+1 NULL
+
+
+-- !query 6
+SELECT * FROM t1 RIGHT OUTER JOIN empty_table
+-- !query 6 schema
+struct
+-- !query 6 output
+
+
+
+-- !query 7
+SELECT * FROM t1 FULL OUTER JOIN empty_table
+-- !query 7 schema
+struct
+-- !query 7 output
+1 NULL
+
+
+-- !query 8
+SELECT * FROM t1 LEFT SEMI JOIN empty_table
+-- !query 8 schema
+struct
+-- !query 8 output
+
+
+
+-- !query 9
+SELECT * FROM t1 LEFT ANTI JOIN empty_table
+-- !query 9 schema
+struct
+-- !query 9 output
+1
+
+
+-- !query 10
+SELECT * FROM empty_table INNER JOIN t1
+-- !query 10 schema
+struct
+-- !query 10 output
+
+
+
+-- !query 11
+SELECT * FROM empty_table CROSS JOIN t1
+-- !query 11 schema
+struct
+-- !query 11 output
+
+
+
+-- !query 12
+SELECT * FROM empty_table LEFT OUTER JOIN t1
+-- !query 12 schema
+struct
+-- !query 12 output
+
+
+
+-- !query 13
+SELECT * FROM empty_table RIGHT OUTER JOIN t1
+-- !query 13 schema
+struct
+-- !query 13 output
+NULL 1
+
+
+-- !query 14
+SELECT * FROM empty_table FULL OUTER JOIN t1
+-- !query 14 schema
+struct
+-- !query 14 output
+NULL 1
+
+
+-- !query 15
+SELECT * FROM empty_table LEFT SEMI JOIN t1
+-- !query 15 schema
+struct
+-- !query 15 output
+
+
+
+-- !query 16
+SELECT * FROM empty_table LEFT ANTI JOIN t1
+-- !query 16 schema
+struct
+-- !query 16 output
+
+
+
+-- !query 17
+SELECT * FROM empty_table INNER JOIN empty_table
+-- !query 17 schema
+struct
+-- !query 17 output
+
+
+
+-- !query 18
+SELECT * FROM empty_table CROSS JOIN empty_table
+-- !query 18 schema
+struct
+-- !query 18 output
+
+
+
+-- !query 19
+SELECT * FROM empty_table LEFT OUTER JOIN empty_table
+-- !query 19 schema
+struct
+-- !query 19 output
+
+
+
+-- !query 20
+SELECT * FROM empty_table RIGHT OUTER JOIN empty_table
+-- !query 20 schema
+struct
+-- !query 20 output
+
+
+
+-- !query 21
+SELECT * FROM empty_table FULL OUTER JOIN empty_table
+-- !query 21 schema
+struct
+-- !query 21 output
+
+
+
+-- !query 22
+SELECT * FROM empty_table LEFT SEMI JOIN empty_table
+-- !query 22 schema
+struct
+-- !query 22 output
+
+
+
+-- !query 23
+SELECT * FROM empty_table LEFT ANTI JOIN empty_table
+-- !query 23 schema
+struct
+-- !query 23 output
+
diff --git a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out
index 8cd0d51da64f5..d51f6d37e4b41 100644
--- a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 31
+-- Number of queries: 32
-- !query 0
@@ -34,225 +34,225 @@ struct<(CAST(1.5 AS DOUBLE) = CAST(1.51 AS DOUBLE)):boolean>
false
--- !query 3
-select 1 > '1'
--- !query 3 schema
-struct<(1 > CAST(1 AS INT)):boolean>
--- !query 3 output
-false
-
-
-- !query 4
-select 2 > '1.0'
+select 1 > '1'
-- !query 4 schema
-struct<(2 > CAST(1.0 AS INT)):boolean>
+struct<(1 > CAST(1 AS INT)):boolean>
-- !query 4 output
-true
+false
-- !query 5
-select 2 > '2.0'
+select 2 > '1.0'
-- !query 5 schema
-struct<(2 > CAST(2.0 AS INT)):boolean>
+struct<(2 > CAST(1.0 AS INT)):boolean>
-- !query 5 output
-false
+true
-- !query 6
-select 2 > '2.2'
+select 2 > '2.0'
-- !query 6 schema
-struct<(2 > CAST(2.2 AS INT)):boolean>
+struct<(2 > CAST(2.0 AS INT)):boolean>
-- !query 6 output
false
-- !query 7
-select '1.5' > 0.5
+select 2 > '2.2'
-- !query 7 schema
-struct<(CAST(1.5 AS DOUBLE) > CAST(0.5 AS DOUBLE)):boolean>
+struct<(2 > CAST(2.2 AS INT)):boolean>
-- !query 7 output
-true
+false
-- !query 8
-select to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52')
+select '1.5' > 0.5
-- !query 8 schema
-struct<(to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52')):boolean>
+struct<(CAST(1.5 AS DOUBLE) > CAST(0.5 AS DOUBLE)):boolean>
-- !query 8 output
-false
+true
-- !query 9
-select to_date('2009-07-30 04:17:52') > '2009-07-30 04:17:52'
+select to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52')
-- !query 9 schema
-struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) > 2009-07-30 04:17:52):boolean>
+struct<(to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52')):boolean>
-- !query 9 output
false
-- !query 10
-select 1 >= '1'
+select to_date('2009-07-30 04:17:52') > '2009-07-30 04:17:52'
-- !query 10 schema
-struct<(1 >= CAST(1 AS INT)):boolean>
+struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) > 2009-07-30 04:17:52):boolean>
-- !query 10 output
-true
+false
-- !query 11
-select 2 >= '1.0'
+select 1 >= '1'
-- !query 11 schema
-struct<(2 >= CAST(1.0 AS INT)):boolean>
+struct<(1 >= CAST(1 AS INT)):boolean>
-- !query 11 output
true
-- !query 12
-select 2 >= '2.0'
+select 2 >= '1.0'
-- !query 12 schema
-struct<(2 >= CAST(2.0 AS INT)):boolean>
+struct<(2 >= CAST(1.0 AS INT)):boolean>
-- !query 12 output
true
-- !query 13
-select 2.0 >= '2.2'
+select 2 >= '2.0'
-- !query 13 schema
-struct<(CAST(2.0 AS DOUBLE) >= CAST(2.2 AS DOUBLE)):boolean>
+struct<(2 >= CAST(2.0 AS INT)):boolean>
-- !query 13 output
-false
+true
-- !query 14
-select '1.5' >= 0.5
+select 2.0 >= '2.2'
-- !query 14 schema
-struct<(CAST(1.5 AS DOUBLE) >= CAST(0.5 AS DOUBLE)):boolean>
+struct<(CAST(2.0 AS DOUBLE) >= CAST(2.2 AS DOUBLE)):boolean>
-- !query 14 output
-true
+false
-- !query 15
-select to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52')
+select '1.5' >= 0.5
-- !query 15 schema
-struct<(to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52')):boolean>
+struct<(CAST(1.5 AS DOUBLE) >= CAST(0.5 AS DOUBLE)):boolean>
-- !query 15 output
true
-- !query 16
-select to_date('2009-07-30 04:17:52') >= '2009-07-30 04:17:52'
+select to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52')
-- !query 16 schema
-struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) >= 2009-07-30 04:17:52):boolean>
+struct<(to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52')):boolean>
-- !query 16 output
-false
+true
-- !query 17
-select 1 < '1'
+select to_date('2009-07-30 04:17:52') >= '2009-07-30 04:17:52'
-- !query 17 schema
-struct<(1 < CAST(1 AS INT)):boolean>
+struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) >= 2009-07-30 04:17:52):boolean>
-- !query 17 output
false
-- !query 18
-select 2 < '1.0'
+select 1 < '1'
-- !query 18 schema
-struct<(2 < CAST(1.0 AS INT)):boolean>
+struct<(1 < CAST(1 AS INT)):boolean>
-- !query 18 output
false
-- !query 19
-select 2 < '2.0'
+select 2 < '1.0'
-- !query 19 schema
-struct<(2 < CAST(2.0 AS INT)):boolean>
+struct<(2 < CAST(1.0 AS INT)):boolean>
-- !query 19 output
false
-- !query 20
-select 2.0 < '2.2'
+select 2 < '2.0'
-- !query 20 schema
-struct<(CAST(2.0 AS DOUBLE) < CAST(2.2 AS DOUBLE)):boolean>
+struct<(2 < CAST(2.0 AS INT)):boolean>
-- !query 20 output
-true
+false
-- !query 21
-select 0.5 < '1.5'
+select 2.0 < '2.2'
-- !query 21 schema
-struct<(CAST(0.5 AS DOUBLE) < CAST(1.5 AS DOUBLE)):boolean>
+struct<(CAST(2.0 AS DOUBLE) < CAST(2.2 AS DOUBLE)):boolean>
-- !query 21 output
true
-- !query 22
-select to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52')
+select 0.5 < '1.5'
-- !query 22 schema
-struct<(to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52')):boolean>
+struct<(CAST(0.5 AS DOUBLE) < CAST(1.5 AS DOUBLE)):boolean>
-- !query 22 output
-false
+true
-- !query 23
-select to_date('2009-07-30 04:17:52') < '2009-07-30 04:17:52'
+select to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52')
-- !query 23 schema
-struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) < 2009-07-30 04:17:52):boolean>
+struct<(to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52')):boolean>
-- !query 23 output
-true
+false
-- !query 24
-select 1 <= '1'
+select to_date('2009-07-30 04:17:52') < '2009-07-30 04:17:52'
-- !query 24 schema
-struct<(1 <= CAST(1 AS INT)):boolean>
+struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) < 2009-07-30 04:17:52):boolean>
-- !query 24 output
true
-- !query 25
-select 2 <= '1.0'
+select 1 <= '1'
-- !query 25 schema
-struct<(2 <= CAST(1.0 AS INT)):boolean>
+struct<(1 <= CAST(1 AS INT)):boolean>
-- !query 25 output
-false
+true
-- !query 26
-select 2 <= '2.0'
+select 2 <= '1.0'
-- !query 26 schema
-struct<(2 <= CAST(2.0 AS INT)):boolean>
+struct<(2 <= CAST(1.0 AS INT)):boolean>
-- !query 26 output
-true
+false
-- !query 27
-select 2.0 <= '2.2'
+select 2 <= '2.0'
-- !query 27 schema
-struct<(CAST(2.0 AS DOUBLE) <= CAST(2.2 AS DOUBLE)):boolean>
+struct<(2 <= CAST(2.0 AS INT)):boolean>
-- !query 27 output
true
-- !query 28
-select 0.5 <= '1.5'
+select 2.0 <= '2.2'
-- !query 28 schema
-struct<(CAST(0.5 AS DOUBLE) <= CAST(1.5 AS DOUBLE)):boolean>
+struct<(CAST(2.0 AS DOUBLE) <= CAST(2.2 AS DOUBLE)):boolean>
-- !query 28 output
true
-- !query 29
-select to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52')
+select 0.5 <= '1.5'
-- !query 29 schema
-struct<(to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52')):boolean>
+struct<(CAST(0.5 AS DOUBLE) <= CAST(1.5 AS DOUBLE)):boolean>
-- !query 29 output
true
-- !query 30
-select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52'
+select to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52')
-- !query 30 schema
-struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) <= 2009-07-30 04:17:52):boolean>
+struct<(to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52')):boolean>
-- !query 30 output
true
+
+
+-- !query 31
+select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52'
+-- !query 31 schema
+struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) <= 2009-07-30 04:17:52):boolean>
+-- !query 31 output
+true
diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
index 2d9b3d7d2ca33..d5f8705a35ed6 100644
--- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 12
+-- Number of queries: 15
-- !query 0
@@ -118,3 +118,46 @@ select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a')
struct
-- !query 11 output
NULL NULL
+
+
+-- !query 12
+set spark.sql.function.concatBinaryAsString=false
+-- !query 12 schema
+struct
+-- !query 12 output
+spark.sql.function.concatBinaryAsString false
+
+
+-- !query 13
+EXPLAIN SELECT ((col1 || col2) || (col3 || col4)) col
+FROM (
+ SELECT
+ string(id) col1,
+ string(id + 1) col2,
+ encode(string(id + 2), 'utf-8') col3,
+ encode(string(id + 3), 'utf-8') col4
+ FROM range(10)
+)
+-- !query 13 schema
+struct
+-- !query 13 output
+== Physical Plan ==
+*Project [concat(cast(id#xL as string), cast((id#xL + 1) as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x]
++- *Range (0, 10, step=1, splits=2)
+
+
+-- !query 14
+EXPLAIN SELECT (col1 || (col3 || col4)) col
+FROM (
+ SELECT
+ string(id) col1,
+ encode(string(id + 2), 'utf-8') col3,
+ encode(string(id + 3), 'utf-8') col4
+ FROM range(10)
+)
+-- !query 14 schema
+struct
+-- !query 14 output
+== Physical Plan ==
+*Project [concat(cast(id#xL as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x]
++- *Range (0, 10, step=1, splits=2)
diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/binaryComparison.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/binaryComparison.sql.out
new file mode 100644
index 0000000000000..2914d6015ea88
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/binaryComparison.sql.out
@@ -0,0 +1,2146 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 265
+
+
+-- !query 0
+CREATE TEMPORARY VIEW t AS SELECT 1
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+SELECT cast(1 as binary) = '1' FROM t
+-- !query 1 schema
+struct<>
+-- !query 1 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7
+
+
+-- !query 2
+SELECT cast(1 as binary) > '2' FROM t
+-- !query 2 schema
+struct<>
+-- !query 2 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7
+
+
+-- !query 3
+SELECT cast(1 as binary) >= '2' FROM t
+-- !query 3 schema
+struct<>
+-- !query 3 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7
+
+
+-- !query 4
+SELECT cast(1 as binary) < '2' FROM t
+-- !query 4 schema
+struct<>
+-- !query 4 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7
+
+
+-- !query 5
+SELECT cast(1 as binary) <= '2' FROM t
+-- !query 5 schema
+struct<>
+-- !query 5 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7
+
+
+-- !query 6
+SELECT cast(1 as binary) <> '2' FROM t
+-- !query 6 schema
+struct<>
+-- !query 6 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7
+
+
+-- !query 7
+SELECT cast(1 as binary) = cast(null as string) FROM t
+-- !query 7 schema
+struct<>
+-- !query 7 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7
+
+
+-- !query 8
+SELECT cast(1 as binary) > cast(null as string) FROM t
+-- !query 8 schema
+struct<>
+-- !query 8 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7
+
+
+-- !query 9
+SELECT cast(1 as binary) >= cast(null as string) FROM t
+-- !query 9 schema
+struct<>
+-- !query 9 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7
+
+
+-- !query 10
+SELECT cast(1 as binary) < cast(null as string) FROM t
+-- !query 10 schema
+struct<>
+-- !query 10 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7
+
+
+-- !query 11
+SELECT cast(1 as binary) <= cast(null as string) FROM t
+-- !query 11 schema
+struct<>
+-- !query 11 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7
+
+
+-- !query 12
+SELECT cast(1 as binary) <> cast(null as string) FROM t
+-- !query 12 schema
+struct<>
+-- !query 12 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7
+
+
+-- !query 13
+SELECT '1' = cast(1 as binary) FROM t
+-- !query 13 schema
+struct<>
+-- !query 13 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 13
+
+
+-- !query 14
+SELECT '2' > cast(1 as binary) FROM t
+-- !query 14 schema
+struct<>
+-- !query 14 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 13
+
+
+-- !query 15
+SELECT '2' >= cast(1 as binary) FROM t
+-- !query 15 schema
+struct<>
+-- !query 15 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 14
+
+
+-- !query 16
+SELECT '2' < cast(1 as binary) FROM t
+-- !query 16 schema
+struct<>
+-- !query 16 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 13
+
+
+-- !query 17
+SELECT '2' <= cast(1 as binary) FROM t
+-- !query 17 schema
+struct<>
+-- !query 17 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 14
+
+
+-- !query 18
+SELECT '2' <> cast(1 as binary) FROM t
+-- !query 18 schema
+struct<>
+-- !query 18 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 14
+
+
+-- !query 19
+SELECT cast(null as string) = cast(1 as binary) FROM t
+-- !query 19 schema
+struct<>
+-- !query 19 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 30
+
+
+-- !query 20
+SELECT cast(null as string) > cast(1 as binary) FROM t
+-- !query 20 schema
+struct<>
+-- !query 20 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 30
+
+
+-- !query 21
+SELECT cast(null as string) >= cast(1 as binary) FROM t
+-- !query 21 schema
+struct<>
+-- !query 21 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 31
+
+
+-- !query 22
+SELECT cast(null as string) < cast(1 as binary) FROM t
+-- !query 22 schema
+struct<>
+-- !query 22 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 30
+
+
+-- !query 23
+SELECT cast(null as string) <= cast(1 as binary) FROM t
+-- !query 23 schema
+struct<>
+-- !query 23 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 31
+
+
+-- !query 24
+SELECT cast(null as string) <> cast(1 as binary) FROM t
+-- !query 24 schema
+struct<>
+-- !query 24 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 31
+
+
+-- !query 25
+SELECT cast(1 as tinyint) = '1' FROM t
+-- !query 25 schema
+struct<(CAST(1 AS TINYINT) = CAST(1 AS TINYINT)):boolean>
+-- !query 25 output
+true
+
+
+-- !query 26
+SELECT cast(1 as tinyint) > '2' FROM t
+-- !query 26 schema
+struct<(CAST(1 AS TINYINT) > CAST(2 AS TINYINT)):boolean>
+-- !query 26 output
+false
+
+
+-- !query 27
+SELECT cast(1 as tinyint) >= '2' FROM t
+-- !query 27 schema
+struct<(CAST(1 AS TINYINT) >= CAST(2 AS TINYINT)):boolean>
+-- !query 27 output
+false
+
+
+-- !query 28
+SELECT cast(1 as tinyint) < '2' FROM t
+-- !query 28 schema
+struct<(CAST(1 AS TINYINT) < CAST(2 AS TINYINT)):boolean>
+-- !query 28 output
+true
+
+
+-- !query 29
+SELECT cast(1 as tinyint) <= '2' FROM t
+-- !query 29 schema
+struct<(CAST(1 AS TINYINT) <= CAST(2 AS TINYINT)):boolean>
+-- !query 29 output
+true
+
+
+-- !query 30
+SELECT cast(1 as tinyint) <> '2' FROM t
+-- !query 30 schema
+struct<(NOT (CAST(1 AS TINYINT) = CAST(2 AS TINYINT))):boolean>
+-- !query 30 output
+true
+
+
+-- !query 31
+SELECT cast(1 as tinyint) = cast(null as string) FROM t
+-- !query 31 schema
+struct<(CAST(1 AS TINYINT) = CAST(CAST(NULL AS STRING) AS TINYINT)):boolean>
+-- !query 31 output
+NULL
+
+
+-- !query 32
+SELECT cast(1 as tinyint) > cast(null as string) FROM t
+-- !query 32 schema
+struct<(CAST(1 AS TINYINT) > CAST(CAST(NULL AS STRING) AS TINYINT)):boolean>
+-- !query 32 output
+NULL
+
+
+-- !query 33
+SELECT cast(1 as tinyint) >= cast(null as string) FROM t
+-- !query 33 schema
+struct<(CAST(1 AS TINYINT) >= CAST(CAST(NULL AS STRING) AS TINYINT)):boolean>
+-- !query 33 output
+NULL
+
+
+-- !query 34
+SELECT cast(1 as tinyint) < cast(null as string) FROM t
+-- !query 34 schema
+struct<(CAST(1 AS TINYINT) < CAST(CAST(NULL AS STRING) AS TINYINT)):boolean>
+-- !query 34 output
+NULL
+
+
+-- !query 35
+SELECT cast(1 as tinyint) <= cast(null as string) FROM t
+-- !query 35 schema
+struct<(CAST(1 AS TINYINT) <= CAST(CAST(NULL AS STRING) AS TINYINT)):boolean>
+-- !query 35 output
+NULL
+
+
+-- !query 36
+SELECT cast(1 as tinyint) <> cast(null as string) FROM t
+-- !query 36 schema
+struct<(NOT (CAST(1 AS TINYINT) = CAST(CAST(NULL AS STRING) AS TINYINT))):boolean>
+-- !query 36 output
+NULL
+
+
+-- !query 37
+SELECT '1' = cast(1 as tinyint) FROM t
+-- !query 37 schema
+struct<(CAST(1 AS TINYINT) = CAST(1 AS TINYINT)):boolean>
+-- !query 37 output
+true
+
+
+-- !query 38
+SELECT '2' > cast(1 as tinyint) FROM t
+-- !query 38 schema
+struct<(CAST(2 AS TINYINT) > CAST(1 AS TINYINT)):boolean>
+-- !query 38 output
+true
+
+
+-- !query 39
+SELECT '2' >= cast(1 as tinyint) FROM t
+-- !query 39 schema
+struct<(CAST(2 AS TINYINT) >= CAST(1 AS TINYINT)):boolean>
+-- !query 39 output
+true
+
+
+-- !query 40
+SELECT '2' < cast(1 as tinyint) FROM t
+-- !query 40 schema
+struct<(CAST(2 AS TINYINT) < CAST(1 AS TINYINT)):boolean>
+-- !query 40 output
+false
+
+
+-- !query 41
+SELECT '2' <= cast(1 as tinyint) FROM t
+-- !query 41 schema
+struct<(CAST(2 AS TINYINT) <= CAST(1 AS TINYINT)):boolean>
+-- !query 41 output
+false
+
+
+-- !query 42
+SELECT '2' <> cast(1 as tinyint) FROM t
+-- !query 42 schema
+struct<(NOT (CAST(2 AS TINYINT) = CAST(1 AS TINYINT))):boolean>
+-- !query 42 output
+true
+
+
+-- !query 43
+SELECT cast(null as string) = cast(1 as tinyint) FROM t
+-- !query 43 schema
+struct<(CAST(CAST(NULL AS STRING) AS TINYINT) = CAST(1 AS TINYINT)):boolean>
+-- !query 43 output
+NULL
+
+
+-- !query 44
+SELECT cast(null as string) > cast(1 as tinyint) FROM t
+-- !query 44 schema
+struct<(CAST(CAST(NULL AS STRING) AS TINYINT) > CAST(1 AS TINYINT)):boolean>
+-- !query 44 output
+NULL
+
+
+-- !query 45
+SELECT cast(null as string) >= cast(1 as tinyint) FROM t
+-- !query 45 schema
+struct<(CAST(CAST(NULL AS STRING) AS TINYINT) >= CAST(1 AS TINYINT)):boolean>
+-- !query 45 output
+NULL
+
+
+-- !query 46
+SELECT cast(null as string) < cast(1 as tinyint) FROM t
+-- !query 46 schema
+struct<(CAST(CAST(NULL AS STRING) AS TINYINT) < CAST(1 AS TINYINT)):boolean>
+-- !query 46 output
+NULL
+
+
+-- !query 47
+SELECT cast(null as string) <= cast(1 as tinyint) FROM t
+-- !query 47 schema
+struct<(CAST(CAST(NULL AS STRING) AS TINYINT) <= CAST(1 AS TINYINT)):boolean>
+-- !query 47 output
+NULL
+
+
+-- !query 48
+SELECT cast(null as string) <> cast(1 as tinyint) FROM t
+-- !query 48 schema
+struct<(NOT (CAST(CAST(NULL AS STRING) AS TINYINT) = CAST(1 AS TINYINT))):boolean>
+-- !query 48 output
+NULL
+
+
+-- !query 49
+SELECT cast(1 as smallint) = '1' FROM t
+-- !query 49 schema
+struct<(CAST(1 AS SMALLINT) = CAST(1 AS SMALLINT)):boolean>
+-- !query 49 output
+true
+
+
+-- !query 50
+SELECT cast(1 as smallint) > '2' FROM t
+-- !query 50 schema
+struct<(CAST(1 AS SMALLINT) > CAST(2 AS SMALLINT)):boolean>
+-- !query 50 output
+false
+
+
+-- !query 51
+SELECT cast(1 as smallint) >= '2' FROM t
+-- !query 51 schema
+struct<(CAST(1 AS SMALLINT) >= CAST(2 AS SMALLINT)):boolean>
+-- !query 51 output
+false
+
+
+-- !query 52
+SELECT cast(1 as smallint) < '2' FROM t
+-- !query 52 schema
+struct<(CAST(1 AS SMALLINT) < CAST(2 AS SMALLINT)):boolean>
+-- !query 52 output
+true
+
+
+-- !query 53
+SELECT cast(1 as smallint) <= '2' FROM t
+-- !query 53 schema
+struct<(CAST(1 AS SMALLINT) <= CAST(2 AS SMALLINT)):boolean>
+-- !query 53 output
+true
+
+
+-- !query 54
+SELECT cast(1 as smallint) <> '2' FROM t
+-- !query 54 schema
+struct<(NOT (CAST(1 AS SMALLINT) = CAST(2 AS SMALLINT))):boolean>
+-- !query 54 output
+true
+
+
+-- !query 55
+SELECT cast(1 as smallint) = cast(null as string) FROM t
+-- !query 55 schema
+struct<(CAST(1 AS SMALLINT) = CAST(CAST(NULL AS STRING) AS SMALLINT)):boolean>
+-- !query 55 output
+NULL
+
+
+-- !query 56
+SELECT cast(1 as smallint) > cast(null as string) FROM t
+-- !query 56 schema
+struct<(CAST(1 AS SMALLINT) > CAST(CAST(NULL AS STRING) AS SMALLINT)):boolean>
+-- !query 56 output
+NULL
+
+
+-- !query 57
+SELECT cast(1 as smallint) >= cast(null as string) FROM t
+-- !query 57 schema
+struct<(CAST(1 AS SMALLINT) >= CAST(CAST(NULL AS STRING) AS SMALLINT)):boolean>
+-- !query 57 output
+NULL
+
+
+-- !query 58
+SELECT cast(1 as smallint) < cast(null as string) FROM t
+-- !query 58 schema
+struct<(CAST(1 AS SMALLINT) < CAST(CAST(NULL AS STRING) AS SMALLINT)):boolean>
+-- !query 58 output
+NULL
+
+
+-- !query 59
+SELECT cast(1 as smallint) <= cast(null as string) FROM t
+-- !query 59 schema
+struct<(CAST(1 AS SMALLINT) <= CAST(CAST(NULL AS STRING) AS SMALLINT)):boolean>
+-- !query 59 output
+NULL
+
+
+-- !query 60
+SELECT cast(1 as smallint) <> cast(null as string) FROM t
+-- !query 60 schema
+struct<(NOT (CAST(1 AS SMALLINT) = CAST(CAST(NULL AS STRING) AS SMALLINT))):boolean>
+-- !query 60 output
+NULL
+
+
+-- !query 61
+SELECT '1' = cast(1 as smallint) FROM t
+-- !query 61 schema
+struct<(CAST(1 AS SMALLINT) = CAST(1 AS SMALLINT)):boolean>
+-- !query 61 output
+true
+
+
+-- !query 62
+SELECT '2' > cast(1 as smallint) FROM t
+-- !query 62 schema
+struct<(CAST(2 AS SMALLINT) > CAST(1 AS SMALLINT)):boolean>
+-- !query 62 output
+true
+
+
+-- !query 63
+SELECT '2' >= cast(1 as smallint) FROM t
+-- !query 63 schema
+struct<(CAST(2 AS SMALLINT) >= CAST(1 AS SMALLINT)):boolean>
+-- !query 63 output
+true
+
+
+-- !query 64
+SELECT '2' < cast(1 as smallint) FROM t
+-- !query 64 schema
+struct<(CAST(2 AS SMALLINT) < CAST(1 AS SMALLINT)):boolean>
+-- !query 64 output
+false
+
+
+-- !query 65
+SELECT '2' <= cast(1 as smallint) FROM t
+-- !query 65 schema
+struct<(CAST(2 AS SMALLINT) <= CAST(1 AS SMALLINT)):boolean>
+-- !query 65 output
+false
+
+
+-- !query 66
+SELECT '2' <> cast(1 as smallint) FROM t
+-- !query 66 schema
+struct<(NOT (CAST(2 AS SMALLINT) = CAST(1 AS SMALLINT))):boolean>
+-- !query 66 output
+true
+
+
+-- !query 67
+SELECT cast(null as string) = cast(1 as smallint) FROM t
+-- !query 67 schema
+struct<(CAST(CAST(NULL AS STRING) AS SMALLINT) = CAST(1 AS SMALLINT)):boolean>
+-- !query 67 output
+NULL
+
+
+-- !query 68
+SELECT cast(null as string) > cast(1 as smallint) FROM t
+-- !query 68 schema
+struct<(CAST(CAST(NULL AS STRING) AS SMALLINT) > CAST(1 AS SMALLINT)):boolean>
+-- !query 68 output
+NULL
+
+
+-- !query 69
+SELECT cast(null as string) >= cast(1 as smallint) FROM t
+-- !query 69 schema
+struct<(CAST(CAST(NULL AS STRING) AS SMALLINT) >= CAST(1 AS SMALLINT)):boolean>
+-- !query 69 output
+NULL
+
+
+-- !query 70
+SELECT cast(null as string) < cast(1 as smallint) FROM t
+-- !query 70 schema
+struct<(CAST(CAST(NULL AS STRING) AS SMALLINT) < CAST(1 AS SMALLINT)):boolean>
+-- !query 70 output
+NULL
+
+
+-- !query 71
+SELECT cast(null as string) <= cast(1 as smallint) FROM t
+-- !query 71 schema
+struct<(CAST(CAST(NULL AS STRING) AS SMALLINT) <= CAST(1 AS SMALLINT)):boolean>
+-- !query 71 output
+NULL
+
+
+-- !query 72
+SELECT cast(null as string) <> cast(1 as smallint) FROM t
+-- !query 72 schema
+struct<(NOT (CAST(CAST(NULL AS STRING) AS SMALLINT) = CAST(1 AS SMALLINT))):boolean>
+-- !query 72 output
+NULL
+
+
+-- !query 73
+SELECT cast(1 as int) = '1' FROM t
+-- !query 73 schema
+struct<(CAST(1 AS INT) = CAST(1 AS INT)):boolean>
+-- !query 73 output
+true
+
+
+-- !query 74
+SELECT cast(1 as int) > '2' FROM t
+-- !query 74 schema
+struct<(CAST(1 AS INT) > CAST(2 AS INT)):boolean>
+-- !query 74 output
+false
+
+
+-- !query 75
+SELECT cast(1 as int) >= '2' FROM t
+-- !query 75 schema
+struct<(CAST(1 AS INT) >= CAST(2 AS INT)):boolean>
+-- !query 75 output
+false
+
+
+-- !query 76
+SELECT cast(1 as int) < '2' FROM t
+-- !query 76 schema
+struct<(CAST(1 AS INT) < CAST(2 AS INT)):boolean>
+-- !query 76 output
+true
+
+
+-- !query 77
+SELECT cast(1 as int) <= '2' FROM t
+-- !query 77 schema
+struct<(CAST(1 AS INT) <= CAST(2 AS INT)):boolean>
+-- !query 77 output
+true
+
+
+-- !query 78
+SELECT cast(1 as int) <> '2' FROM t
+-- !query 78 schema
+struct<(NOT (CAST(1 AS INT) = CAST(2 AS INT))):boolean>
+-- !query 78 output
+true
+
+
+-- !query 79
+SELECT cast(1 as int) = cast(null as string) FROM t
+-- !query 79 schema
+struct<(CAST(1 AS INT) = CAST(CAST(NULL AS STRING) AS INT)):boolean>
+-- !query 79 output
+NULL
+
+
+-- !query 80
+SELECT cast(1 as int) > cast(null as string) FROM t
+-- !query 80 schema
+struct<(CAST(1 AS INT) > CAST(CAST(NULL AS STRING) AS INT)):boolean>
+-- !query 80 output
+NULL
+
+
+-- !query 81
+SELECT cast(1 as int) >= cast(null as string) FROM t
+-- !query 81 schema
+struct<(CAST(1 AS INT) >= CAST(CAST(NULL AS STRING) AS INT)):boolean>
+-- !query 81 output
+NULL
+
+
+-- !query 82
+SELECT cast(1 as int) < cast(null as string) FROM t
+-- !query 82 schema
+struct<(CAST(1 AS INT) < CAST(CAST(NULL AS STRING) AS INT)):boolean>
+-- !query 82 output
+NULL
+
+
+-- !query 83
+SELECT cast(1 as int) <= cast(null as string) FROM t
+-- !query 83 schema
+struct<(CAST(1 AS INT) <= CAST(CAST(NULL AS STRING) AS INT)):boolean>
+-- !query 83 output
+NULL
+
+
+-- !query 84
+SELECT cast(1 as int) <> cast(null as string) FROM t
+-- !query 84 schema
+struct<(NOT (CAST(1 AS INT) = CAST(CAST(NULL AS STRING) AS INT))):boolean>
+-- !query 84 output
+NULL
+
+
+-- !query 85
+SELECT '1' = cast(1 as int) FROM t
+-- !query 85 schema
+struct<(CAST(1 AS INT) = CAST(1 AS INT)):boolean>
+-- !query 85 output
+true
+
+
+-- !query 86
+SELECT '2' > cast(1 as int) FROM t
+-- !query 86 schema
+struct<(CAST(2 AS INT) > CAST(1 AS INT)):boolean>
+-- !query 86 output
+true
+
+
+-- !query 87
+SELECT '2' >= cast(1 as int) FROM t
+-- !query 87 schema
+struct<(CAST(2 AS INT) >= CAST(1 AS INT)):boolean>
+-- !query 87 output
+true
+
+
+-- !query 88
+SELECT '2' < cast(1 as int) FROM t
+-- !query 88 schema
+struct<(CAST(2 AS INT) < CAST(1 AS INT)):boolean>
+-- !query 88 output
+false
+
+
+-- !query 89
+SELECT '2' <> cast(1 as int) FROM t
+-- !query 89 schema
+struct<(NOT (CAST(2 AS INT) = CAST(1 AS INT))):boolean>
+-- !query 89 output
+true
+
+
+-- !query 90
+SELECT '2' <= cast(1 as int) FROM t
+-- !query 90 schema
+struct<(CAST(2 AS INT) <= CAST(1 AS INT)):boolean>
+-- !query 90 output
+false
+
+
+-- !query 91
+SELECT cast(null as string) = cast(1 as int) FROM t
+-- !query 91 schema
+struct<(CAST(CAST(NULL AS STRING) AS INT) = CAST(1 AS INT)):boolean>
+-- !query 91 output
+NULL
+
+
+-- !query 92
+SELECT cast(null as string) > cast(1 as int) FROM t
+-- !query 92 schema
+struct<(CAST(CAST(NULL AS STRING) AS INT) > CAST(1 AS INT)):boolean>
+-- !query 92 output
+NULL
+
+
+-- !query 93
+SELECT cast(null as string) >= cast(1 as int) FROM t
+-- !query 93 schema
+struct<(CAST(CAST(NULL AS STRING) AS INT) >= CAST(1 AS INT)):boolean>
+-- !query 93 output
+NULL
+
+
+-- !query 94
+SELECT cast(null as string) < cast(1 as int) FROM t
+-- !query 94 schema
+struct<(CAST(CAST(NULL AS STRING) AS INT) < CAST(1 AS INT)):boolean>
+-- !query 94 output
+NULL
+
+
+-- !query 95
+SELECT cast(null as string) <> cast(1 as int) FROM t
+-- !query 95 schema
+struct<(NOT (CAST(CAST(NULL AS STRING) AS INT) = CAST(1 AS INT))):boolean>
+-- !query 95 output
+NULL
+
+
+-- !query 96
+SELECT cast(null as string) <= cast(1 as int) FROM t
+-- !query 96 schema
+struct<(CAST(CAST(NULL AS STRING) AS INT) <= CAST(1 AS INT)):boolean>
+-- !query 96 output
+NULL
+
+
+-- !query 97
+SELECT cast(1 as bigint) = '1' FROM t
+-- !query 97 schema
+struct<(CAST(1 AS BIGINT) = CAST(1 AS BIGINT)):boolean>
+-- !query 97 output
+true
+
+
+-- !query 98
+SELECT cast(1 as bigint) > '2' FROM t
+-- !query 98 schema
+struct<(CAST(1 AS BIGINT) > CAST(2 AS BIGINT)):boolean>
+-- !query 98 output
+false
+
+
+-- !query 99
+SELECT cast(1 as bigint) >= '2' FROM t
+-- !query 99 schema
+struct<(CAST(1 AS BIGINT) >= CAST(2 AS BIGINT)):boolean>
+-- !query 99 output
+false
+
+
+-- !query 100
+SELECT cast(1 as bigint) < '2' FROM t
+-- !query 100 schema
+struct<(CAST(1 AS BIGINT) < CAST(2 AS BIGINT)):boolean>
+-- !query 100 output
+true
+
+
+-- !query 101
+SELECT cast(1 as bigint) <= '2' FROM t
+-- !query 101 schema
+struct<(CAST(1 AS BIGINT) <= CAST(2 AS BIGINT)):boolean>
+-- !query 101 output
+true
+
+
+-- !query 102
+SELECT cast(1 as bigint) <> '2' FROM t
+-- !query 102 schema
+struct<(NOT (CAST(1 AS BIGINT) = CAST(2 AS BIGINT))):boolean>
+-- !query 102 output
+true
+
+
+-- !query 103
+SELECT cast(1 as bigint) = cast(null as string) FROM t
+-- !query 103 schema
+struct<(CAST(1 AS BIGINT) = CAST(CAST(NULL AS STRING) AS BIGINT)):boolean>
+-- !query 103 output
+NULL
+
+
+-- !query 104
+SELECT cast(1 as bigint) > cast(null as string) FROM t
+-- !query 104 schema
+struct<(CAST(1 AS BIGINT) > CAST(CAST(NULL AS STRING) AS BIGINT)):boolean>
+-- !query 104 output
+NULL
+
+
+-- !query 105
+SELECT cast(1 as bigint) >= cast(null as string) FROM t
+-- !query 105 schema
+struct<(CAST(1 AS BIGINT) >= CAST(CAST(NULL AS STRING) AS BIGINT)):boolean>
+-- !query 105 output
+NULL
+
+
+-- !query 106
+SELECT cast(1 as bigint) < cast(null as string) FROM t
+-- !query 106 schema
+struct<(CAST(1 AS BIGINT) < CAST(CAST(NULL AS STRING) AS BIGINT)):boolean>
+-- !query 106 output
+NULL
+
+
+-- !query 107
+SELECT cast(1 as bigint) <= cast(null as string) FROM t
+-- !query 107 schema
+struct<(CAST(1 AS BIGINT) <= CAST(CAST(NULL AS STRING) AS BIGINT)):boolean>
+-- !query 107 output
+NULL
+
+
+-- !query 108
+SELECT cast(1 as bigint) <> cast(null as string) FROM t
+-- !query 108 schema
+struct<(NOT (CAST(1 AS BIGINT) = CAST(CAST(NULL AS STRING) AS BIGINT))):boolean>
+-- !query 108 output
+NULL
+
+
+-- !query 109
+SELECT '1' = cast(1 as bigint) FROM t
+-- !query 109 schema
+struct<(CAST(1 AS BIGINT) = CAST(1 AS BIGINT)):boolean>
+-- !query 109 output
+true
+
+
+-- !query 110
+SELECT '2' > cast(1 as bigint) FROM t
+-- !query 110 schema
+struct<(CAST(2 AS BIGINT) > CAST(1 AS BIGINT)):boolean>
+-- !query 110 output
+true
+
+
+-- !query 111
+SELECT '2' >= cast(1 as bigint) FROM t
+-- !query 111 schema
+struct<(CAST(2 AS BIGINT) >= CAST(1 AS BIGINT)):boolean>
+-- !query 111 output
+true
+
+
+-- !query 112
+SELECT '2' < cast(1 as bigint) FROM t
+-- !query 112 schema
+struct<(CAST(2 AS BIGINT) < CAST(1 AS BIGINT)):boolean>
+-- !query 112 output
+false
+
+
+-- !query 113
+SELECT '2' <= cast(1 as bigint) FROM t
+-- !query 113 schema
+struct<(CAST(2 AS BIGINT) <= CAST(1 AS BIGINT)):boolean>
+-- !query 113 output
+false
+
+
+-- !query 114
+SELECT '2' <> cast(1 as bigint) FROM t
+-- !query 114 schema
+struct<(NOT (CAST(2 AS BIGINT) = CAST(1 AS BIGINT))):boolean>
+-- !query 114 output
+true
+
+
+-- !query 115
+SELECT cast(null as string) = cast(1 as bigint) FROM t
+-- !query 115 schema
+struct<(CAST(CAST(NULL AS STRING) AS BIGINT) = CAST(1 AS BIGINT)):boolean>
+-- !query 115 output
+NULL
+
+
+-- !query 116
+SELECT cast(null as string) > cast(1 as bigint) FROM t
+-- !query 116 schema
+struct<(CAST(CAST(NULL AS STRING) AS BIGINT) > CAST(1 AS BIGINT)):boolean>
+-- !query 116 output
+NULL
+
+
+-- !query 117
+SELECT cast(null as string) >= cast(1 as bigint) FROM t
+-- !query 117 schema
+struct<(CAST(CAST(NULL AS STRING) AS BIGINT) >= CAST(1 AS BIGINT)):boolean>
+-- !query 117 output
+NULL
+
+
+-- !query 118
+SELECT cast(null as string) < cast(1 as bigint) FROM t
+-- !query 118 schema
+struct<(CAST(CAST(NULL AS STRING) AS BIGINT) < CAST(1 AS BIGINT)):boolean>
+-- !query 118 output
+NULL
+
+
+-- !query 119
+SELECT cast(null as string) <= cast(1 as bigint) FROM t
+-- !query 119 schema
+struct<(CAST(CAST(NULL AS STRING) AS BIGINT) <= CAST(1 AS BIGINT)):boolean>
+-- !query 119 output
+NULL
+
+
+-- !query 120
+SELECT cast(null as string) <> cast(1 as bigint) FROM t
+-- !query 120 schema
+struct<(NOT (CAST(CAST(NULL AS STRING) AS BIGINT) = CAST(1 AS BIGINT))):boolean>
+-- !query 120 output
+NULL
+
+
+-- !query 121
+SELECT cast(1 as decimal(10, 0)) = '1' FROM t
+-- !query 121 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) = CAST(1 AS DOUBLE)):boolean>
+-- !query 121 output
+true
+
+
+-- !query 122
+SELECT cast(1 as decimal(10, 0)) > '2' FROM t
+-- !query 122 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) > CAST(2 AS DOUBLE)):boolean>
+-- !query 122 output
+false
+
+
+-- !query 123
+SELECT cast(1 as decimal(10, 0)) >= '2' FROM t
+-- !query 123 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) >= CAST(2 AS DOUBLE)):boolean>
+-- !query 123 output
+false
+
+
+-- !query 124
+SELECT cast(1 as decimal(10, 0)) < '2' FROM t
+-- !query 124 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) < CAST(2 AS DOUBLE)):boolean>
+-- !query 124 output
+true
+
+
+-- !query 125
+SELECT cast(1 as decimal(10, 0)) <> '2' FROM t
+-- !query 125 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) = CAST(2 AS DOUBLE))):boolean>
+-- !query 125 output
+true
+
+
+-- !query 126
+SELECT cast(1 as decimal(10, 0)) <= '2' FROM t
+-- !query 126 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) <= CAST(2 AS DOUBLE)):boolean>
+-- !query 126 output
+true
+
+
+-- !query 127
+SELECT cast(1 as decimal(10, 0)) = cast(null as string) FROM t
+-- !query 127 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) = CAST(CAST(NULL AS STRING) AS DOUBLE)):boolean>
+-- !query 127 output
+NULL
+
+
+-- !query 128
+SELECT cast(1 as decimal(10, 0)) > cast(null as string) FROM t
+-- !query 128 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) > CAST(CAST(NULL AS STRING) AS DOUBLE)):boolean>
+-- !query 128 output
+NULL
+
+
+-- !query 129
+SELECT cast(1 as decimal(10, 0)) >= cast(null as string) FROM t
+-- !query 129 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) >= CAST(CAST(NULL AS STRING) AS DOUBLE)):boolean>
+-- !query 129 output
+NULL
+
+
+-- !query 130
+SELECT cast(1 as decimal(10, 0)) < cast(null as string) FROM t
+-- !query 130 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) < CAST(CAST(NULL AS STRING) AS DOUBLE)):boolean>
+-- !query 130 output
+NULL
+
+
+-- !query 131
+SELECT cast(1 as decimal(10, 0)) <> cast(null as string) FROM t
+-- !query 131 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) = CAST(CAST(NULL AS STRING) AS DOUBLE))):boolean>
+-- !query 131 output
+NULL
+
+
+-- !query 132
+SELECT cast(1 as decimal(10, 0)) <= cast(null as string) FROM t
+-- !query 132 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) <= CAST(CAST(NULL AS STRING) AS DOUBLE)):boolean>
+-- !query 132 output
+NULL
+
+
+-- !query 133
+SELECT '1' = cast(1 as decimal(10, 0)) FROM t
+-- !query 133 schema
+struct<(CAST(1 AS DOUBLE) = CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean>
+-- !query 133 output
+true
+
+
+-- !query 134
+SELECT '2' > cast(1 as decimal(10, 0)) FROM t
+-- !query 134 schema
+struct<(CAST(2 AS DOUBLE) > CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean>
+-- !query 134 output
+true
+
+
+-- !query 135
+SELECT '2' >= cast(1 as decimal(10, 0)) FROM t
+-- !query 135 schema
+struct<(CAST(2 AS DOUBLE) >= CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean>
+-- !query 135 output
+true
+
+
+-- !query 136
+SELECT '2' < cast(1 as decimal(10, 0)) FROM t
+-- !query 136 schema
+struct<(CAST(2 AS DOUBLE) < CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean>
+-- !query 136 output
+false
+
+
+-- !query 137
+SELECT '2' <= cast(1 as decimal(10, 0)) FROM t
+-- !query 137 schema
+struct<(CAST(2 AS DOUBLE) <= CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean>
+-- !query 137 output
+false
+
+
+-- !query 138
+SELECT '2' <> cast(1 as decimal(10, 0)) FROM t
+-- !query 138 schema
+struct<(NOT (CAST(2 AS DOUBLE) = CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE))):boolean>
+-- !query 138 output
+true
+
+
+-- !query 139
+SELECT cast(null as string) = cast(1 as decimal(10, 0)) FROM t
+-- !query 139 schema
+struct<(CAST(CAST(NULL AS STRING) AS DOUBLE) = CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean>
+-- !query 139 output
+NULL
+
+
+-- !query 140
+SELECT cast(null as string) > cast(1 as decimal(10, 0)) FROM t
+-- !query 140 schema
+struct<(CAST(CAST(NULL AS STRING) AS DOUBLE) > CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean>
+-- !query 140 output
+NULL
+
+
+-- !query 141
+SELECT cast(null as string) >= cast(1 as decimal(10, 0)) FROM t
+-- !query 141 schema
+struct<(CAST(CAST(NULL AS STRING) AS DOUBLE) >= CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean>
+-- !query 141 output
+NULL
+
+
+-- !query 142
+SELECT cast(null as string) < cast(1 as decimal(10, 0)) FROM t
+-- !query 142 schema
+struct<(CAST(CAST(NULL AS STRING) AS DOUBLE) < CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean>
+-- !query 142 output
+NULL
+
+
+-- !query 143
+SELECT cast(null as string) <= cast(1 as decimal(10, 0)) FROM t
+-- !query 143 schema
+struct<(CAST(CAST(NULL AS STRING) AS DOUBLE) <= CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean>
+-- !query 143 output
+NULL
+
+
+-- !query 144
+SELECT cast(null as string) <> cast(1 as decimal(10, 0)) FROM t
+-- !query 144 schema
+struct<(NOT (CAST(CAST(NULL AS STRING) AS DOUBLE) = CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE))):boolean>
+-- !query 144 output
+NULL
+
+
+-- !query 145
+SELECT cast(1 as double) = '1' FROM t
+-- !query 145 schema
+struct<(CAST(1 AS DOUBLE) = CAST(1 AS DOUBLE)):boolean>
+-- !query 145 output
+true
+
+
+-- !query 146
+SELECT cast(1 as double) > '2' FROM t
+-- !query 146 schema
+struct<(CAST(1 AS DOUBLE) > CAST(2 AS DOUBLE)):boolean>
+-- !query 146 output
+false
+
+
+-- !query 147
+SELECT cast(1 as double) >= '2' FROM t
+-- !query 147 schema
+struct<(CAST(1 AS DOUBLE) >= CAST(2 AS DOUBLE)):boolean>
+-- !query 147 output
+false
+
+
+-- !query 148
+SELECT cast(1 as double) < '2' FROM t
+-- !query 148 schema
+struct<(CAST(1 AS DOUBLE) < CAST(2 AS DOUBLE)):boolean>
+-- !query 148 output
+true
+
+
+-- !query 149
+SELECT cast(1 as double) <= '2' FROM t
+-- !query 149 schema
+struct<(CAST(1 AS DOUBLE) <= CAST(2 AS DOUBLE)):boolean>
+-- !query 149 output
+true
+
+
+-- !query 150
+SELECT cast(1 as double) <> '2' FROM t
+-- !query 150 schema
+struct<(NOT (CAST(1 AS DOUBLE) = CAST(2 AS DOUBLE))):boolean>
+-- !query 150 output
+true
+
+
+-- !query 151
+SELECT cast(1 as double) = cast(null as string) FROM t
+-- !query 151 schema
+struct<(CAST(1 AS DOUBLE) = CAST(CAST(NULL AS STRING) AS DOUBLE)):boolean>
+-- !query 151 output
+NULL
+
+
+-- !query 152
+SELECT cast(1 as double) > cast(null as string) FROM t
+-- !query 152 schema
+struct<(CAST(1 AS DOUBLE) > CAST(CAST(NULL AS STRING) AS DOUBLE)):boolean>
+-- !query 152 output
+NULL
+
+
+-- !query 153
+SELECT cast(1 as double) >= cast(null as string) FROM t
+-- !query 153 schema
+struct<(CAST(1 AS DOUBLE) >= CAST(CAST(NULL AS STRING) AS DOUBLE)):boolean>
+-- !query 153 output
+NULL
+
+
+-- !query 154
+SELECT cast(1 as double) < cast(null as string) FROM t
+-- !query 154 schema
+struct<(CAST(1 AS DOUBLE) < CAST(CAST(NULL AS STRING) AS DOUBLE)):boolean>
+-- !query 154 output
+NULL
+
+
+-- !query 155
+SELECT cast(1 as double) <= cast(null as string) FROM t
+-- !query 155 schema
+struct<(CAST(1 AS DOUBLE) <= CAST(CAST(NULL AS STRING) AS DOUBLE)):boolean>
+-- !query 155 output
+NULL
+
+
+-- !query 156
+SELECT cast(1 as double) <> cast(null as string) FROM t
+-- !query 156 schema
+struct<(NOT (CAST(1 AS DOUBLE) = CAST(CAST(NULL AS STRING) AS DOUBLE))):boolean>
+-- !query 156 output
+NULL
+
+
+-- !query 157
+SELECT '1' = cast(1 as double) FROM t
+-- !query 157 schema
+struct<(CAST(1 AS DOUBLE) = CAST(1 AS DOUBLE)):boolean>
+-- !query 157 output
+true
+
+
+-- !query 158
+SELECT '2' > cast(1 as double) FROM t
+-- !query 158 schema
+struct<(CAST(2 AS DOUBLE) > CAST(1 AS DOUBLE)):boolean>
+-- !query 158 output
+true
+
+
+-- !query 159
+SELECT '2' >= cast(1 as double) FROM t
+-- !query 159 schema
+struct<(CAST(2 AS DOUBLE) >= CAST(1 AS DOUBLE)):boolean>
+-- !query 159 output
+true
+
+
+-- !query 160
+SELECT '2' < cast(1 as double) FROM t
+-- !query 160 schema
+struct<(CAST(2 AS DOUBLE) < CAST(1 AS DOUBLE)):boolean>
+-- !query 160 output
+false
+
+
+-- !query 161
+SELECT '2' <= cast(1 as double) FROM t
+-- !query 161 schema
+struct<(CAST(2 AS DOUBLE) <= CAST(1 AS DOUBLE)):boolean>
+-- !query 161 output
+false
+
+
+-- !query 162
+SELECT '2' <> cast(1 as double) FROM t
+-- !query 162 schema
+struct<(NOT (CAST(2 AS DOUBLE) = CAST(1 AS DOUBLE))):boolean>
+-- !query 162 output
+true
+
+
+-- !query 163
+SELECT cast(null as string) = cast(1 as double) FROM t
+-- !query 163 schema
+struct<(CAST(CAST(NULL AS STRING) AS DOUBLE) = CAST(1 AS DOUBLE)):boolean>
+-- !query 163 output
+NULL
+
+
+-- !query 164
+SELECT cast(null as string) > cast(1 as double) FROM t
+-- !query 164 schema
+struct<(CAST(CAST(NULL AS STRING) AS DOUBLE) > CAST(1 AS DOUBLE)):boolean>
+-- !query 164 output
+NULL
+
+
+-- !query 165
+SELECT cast(null as string) >= cast(1 as double) FROM t
+-- !query 165 schema
+struct<(CAST(CAST(NULL AS STRING) AS DOUBLE) >= CAST(1 AS DOUBLE)):boolean>
+-- !query 165 output
+NULL
+
+
+-- !query 166
+SELECT cast(null as string) < cast(1 as double) FROM t
+-- !query 166 schema
+struct<(CAST(CAST(NULL AS STRING) AS DOUBLE) < CAST(1 AS DOUBLE)):boolean>
+-- !query 166 output
+NULL
+
+
+-- !query 167
+SELECT cast(null as string) <= cast(1 as double) FROM t
+-- !query 167 schema
+struct<(CAST(CAST(NULL AS STRING) AS DOUBLE) <= CAST(1 AS DOUBLE)):boolean>
+-- !query 167 output
+NULL
+
+
+-- !query 168
+SELECT cast(null as string) <> cast(1 as double) FROM t
+-- !query 168 schema
+struct<(NOT (CAST(CAST(NULL AS STRING) AS DOUBLE) = CAST(1 AS DOUBLE))):boolean>
+-- !query 168 output
+NULL
+
+
+-- !query 169
+SELECT cast(1 as float) = '1' FROM t
+-- !query 169 schema
+struct<(CAST(1 AS FLOAT) = CAST(1 AS FLOAT)):boolean>
+-- !query 169 output
+true
+
+
+-- !query 170
+SELECT cast(1 as float) > '2' FROM t
+-- !query 170 schema
+struct<(CAST(1 AS FLOAT) > CAST(2 AS FLOAT)):boolean>
+-- !query 170 output
+false
+
+
+-- !query 171
+SELECT cast(1 as float) >= '2' FROM t
+-- !query 171 schema
+struct<(CAST(1 AS FLOAT) >= CAST(2 AS FLOAT)):boolean>
+-- !query 171 output
+false
+
+
+-- !query 172
+SELECT cast(1 as float) < '2' FROM t
+-- !query 172 schema
+struct<(CAST(1 AS FLOAT) < CAST(2 AS FLOAT)):boolean>
+-- !query 172 output
+true
+
+
+-- !query 173
+SELECT cast(1 as float) <= '2' FROM t
+-- !query 173 schema
+struct<(CAST(1 AS FLOAT) <= CAST(2 AS FLOAT)):boolean>
+-- !query 173 output
+true
+
+
+-- !query 174
+SELECT cast(1 as float) <> '2' FROM t
+-- !query 174 schema
+struct<(NOT (CAST(1 AS FLOAT) = CAST(2 AS FLOAT))):boolean>
+-- !query 174 output
+true
+
+
+-- !query 175
+SELECT cast(1 as float) = cast(null as string) FROM t
+-- !query 175 schema
+struct<(CAST(1 AS FLOAT) = CAST(CAST(NULL AS STRING) AS FLOAT)):boolean>
+-- !query 175 output
+NULL
+
+
+-- !query 176
+SELECT cast(1 as float) > cast(null as string) FROM t
+-- !query 176 schema
+struct<(CAST(1 AS FLOAT) > CAST(CAST(NULL AS STRING) AS FLOAT)):boolean>
+-- !query 176 output
+NULL
+
+
+-- !query 177
+SELECT cast(1 as float) >= cast(null as string) FROM t
+-- !query 177 schema
+struct<(CAST(1 AS FLOAT) >= CAST(CAST(NULL AS STRING) AS FLOAT)):boolean>
+-- !query 177 output
+NULL
+
+
+-- !query 178
+SELECT cast(1 as float) < cast(null as string) FROM t
+-- !query 178 schema
+struct<(CAST(1 AS FLOAT) < CAST(CAST(NULL AS STRING) AS FLOAT)):boolean>
+-- !query 178 output
+NULL
+
+
+-- !query 179
+SELECT cast(1 as float) <= cast(null as string) FROM t
+-- !query 179 schema
+struct<(CAST(1 AS FLOAT) <= CAST(CAST(NULL AS STRING) AS FLOAT)):boolean>
+-- !query 179 output
+NULL
+
+
+-- !query 180
+SELECT cast(1 as float) <> cast(null as string) FROM t
+-- !query 180 schema
+struct<(NOT (CAST(1 AS FLOAT) = CAST(CAST(NULL AS STRING) AS FLOAT))):boolean>
+-- !query 180 output
+NULL
+
+
+-- !query 181
+SELECT '1' = cast(1 as float) FROM t
+-- !query 181 schema
+struct<(CAST(1 AS FLOAT) = CAST(1 AS FLOAT)):boolean>
+-- !query 181 output
+true
+
+
+-- !query 182
+SELECT '2' > cast(1 as float) FROM t
+-- !query 182 schema
+struct<(CAST(2 AS FLOAT) > CAST(1 AS FLOAT)):boolean>
+-- !query 182 output
+true
+
+
+-- !query 183
+SELECT '2' >= cast(1 as float) FROM t
+-- !query 183 schema
+struct<(CAST(2 AS FLOAT) >= CAST(1 AS FLOAT)):boolean>
+-- !query 183 output
+true
+
+
+-- !query 184
+SELECT '2' < cast(1 as float) FROM t
+-- !query 184 schema
+struct<(CAST(2 AS FLOAT) < CAST(1 AS FLOAT)):boolean>
+-- !query 184 output
+false
+
+
+-- !query 185
+SELECT '2' <= cast(1 as float) FROM t
+-- !query 185 schema
+struct<(CAST(2 AS FLOAT) <= CAST(1 AS FLOAT)):boolean>
+-- !query 185 output
+false
+
+
+-- !query 186
+SELECT '2' <> cast(1 as float) FROM t
+-- !query 186 schema
+struct<(NOT (CAST(2 AS FLOAT) = CAST(1 AS FLOAT))):boolean>
+-- !query 186 output
+true
+
+
+-- !query 187
+SELECT cast(null as string) = cast(1 as float) FROM t
+-- !query 187 schema
+struct<(CAST(CAST(NULL AS STRING) AS FLOAT) = CAST(1 AS FLOAT)):boolean>
+-- !query 187 output
+NULL
+
+
+-- !query 188
+SELECT cast(null as string) > cast(1 as float) FROM t
+-- !query 188 schema
+struct<(CAST(CAST(NULL AS STRING) AS FLOAT) > CAST(1 AS FLOAT)):boolean>
+-- !query 188 output
+NULL
+
+
+-- !query 189
+SELECT cast(null as string) >= cast(1 as float) FROM t
+-- !query 189 schema
+struct<(CAST(CAST(NULL AS STRING) AS FLOAT) >= CAST(1 AS FLOAT)):boolean>
+-- !query 189 output
+NULL
+
+
+-- !query 190
+SELECT cast(null as string) < cast(1 as float) FROM t
+-- !query 190 schema
+struct<(CAST(CAST(NULL AS STRING) AS FLOAT) < CAST(1 AS FLOAT)):boolean>
+-- !query 190 output
+NULL
+
+
+-- !query 191
+SELECT cast(null as string) <= cast(1 as float) FROM t
+-- !query 191 schema
+struct<(CAST(CAST(NULL AS STRING) AS FLOAT) <= CAST(1 AS FLOAT)):boolean>
+-- !query 191 output
+NULL
+
+
+-- !query 192
+SELECT cast(null as string) <> cast(1 as float) FROM t
+-- !query 192 schema
+struct<(NOT (CAST(CAST(NULL AS STRING) AS FLOAT) = CAST(1 AS FLOAT))):boolean>
+-- !query 192 output
+NULL
+
+
+-- !query 193
+SELECT '1996-09-09' = date('1996-09-09') FROM t
+-- !query 193 schema
+struct<(1996-09-09 = CAST(CAST(1996-09-09 AS DATE) AS STRING)):boolean>
+-- !query 193 output
+true
+
+
+-- !query 194
+SELECT '1996-9-10' > date('1996-09-09') FROM t
+-- !query 194 schema
+struct<(1996-9-10 > CAST(CAST(1996-09-09 AS DATE) AS STRING)):boolean>
+-- !query 194 output
+true
+
+
+-- !query 195
+SELECT '1996-9-10' >= date('1996-09-09') FROM t
+-- !query 195 schema
+struct<(1996-9-10 >= CAST(CAST(1996-09-09 AS DATE) AS STRING)):boolean>
+-- !query 195 output
+true
+
+
+-- !query 196
+SELECT '1996-9-10' < date('1996-09-09') FROM t
+-- !query 196 schema
+struct<(1996-9-10 < CAST(CAST(1996-09-09 AS DATE) AS STRING)):boolean>
+-- !query 196 output
+false
+
+
+-- !query 197
+SELECT '1996-9-10' <= date('1996-09-09') FROM t
+-- !query 197 schema
+struct<(1996-9-10 <= CAST(CAST(1996-09-09 AS DATE) AS STRING)):boolean>
+-- !query 197 output
+false
+
+
+-- !query 198
+SELECT '1996-9-10' <> date('1996-09-09') FROM t
+-- !query 198 schema
+struct<(NOT (1996-9-10 = CAST(CAST(1996-09-09 AS DATE) AS STRING))):boolean>
+-- !query 198 output
+true
+
+
+-- !query 199
+SELECT cast(null as string) = date('1996-09-09') FROM t
+-- !query 199 schema
+struct<(CAST(NULL AS STRING) = CAST(CAST(1996-09-09 AS DATE) AS STRING)):boolean>
+-- !query 199 output
+NULL
+
+
+-- !query 200
+SELECT cast(null as string)> date('1996-09-09') FROM t
+-- !query 200 schema
+struct<(CAST(NULL AS STRING) > CAST(CAST(1996-09-09 AS DATE) AS STRING)):boolean>
+-- !query 200 output
+NULL
+
+
+-- !query 201
+SELECT cast(null as string)>= date('1996-09-09') FROM t
+-- !query 201 schema
+struct<(CAST(NULL AS STRING) >= CAST(CAST(1996-09-09 AS DATE) AS STRING)):boolean>
+-- !query 201 output
+NULL
+
+
+-- !query 202
+SELECT cast(null as string)< date('1996-09-09') FROM t
+-- !query 202 schema
+struct<(CAST(NULL AS STRING) < CAST(CAST(1996-09-09 AS DATE) AS STRING)):boolean>
+-- !query 202 output
+NULL
+
+
+-- !query 203
+SELECT cast(null as string)<= date('1996-09-09') FROM t
+-- !query 203 schema
+struct<(CAST(NULL AS STRING) <= CAST(CAST(1996-09-09 AS DATE) AS STRING)):boolean>
+-- !query 203 output
+NULL
+
+
+-- !query 204
+SELECT cast(null as string)<> date('1996-09-09') FROM t
+-- !query 204 schema
+struct<(NOT (CAST(NULL AS STRING) = CAST(CAST(1996-09-09 AS DATE) AS STRING))):boolean>
+-- !query 204 output
+NULL
+
+
+-- !query 205
+SELECT date('1996-09-09') = '1996-09-09' FROM t
+-- !query 205 schema
+struct<(CAST(CAST(1996-09-09 AS DATE) AS STRING) = 1996-09-09):boolean>
+-- !query 205 output
+true
+
+
+-- !query 206
+SELECT date('1996-9-10') > '1996-09-09' FROM t
+-- !query 206 schema
+struct<(CAST(CAST(1996-9-10 AS DATE) AS STRING) > 1996-09-09):boolean>
+-- !query 206 output
+true
+
+
+-- !query 207
+SELECT date('1996-9-10') >= '1996-09-09' FROM t
+-- !query 207 schema
+struct<(CAST(CAST(1996-9-10 AS DATE) AS STRING) >= 1996-09-09):boolean>
+-- !query 207 output
+true
+
+
+-- !query 208
+SELECT date('1996-9-10') < '1996-09-09' FROM t
+-- !query 208 schema
+struct<(CAST(CAST(1996-9-10 AS DATE) AS STRING) < 1996-09-09):boolean>
+-- !query 208 output
+false
+
+
+-- !query 209
+SELECT date('1996-9-10') <= '1996-09-09' FROM t
+-- !query 209 schema
+struct<(CAST(CAST(1996-9-10 AS DATE) AS STRING) <= 1996-09-09):boolean>
+-- !query 209 output
+false
+
+
+-- !query 210
+SELECT date('1996-9-10') <> '1996-09-09' FROM t
+-- !query 210 schema
+struct<(NOT (CAST(CAST(1996-9-10 AS DATE) AS STRING) = 1996-09-09)):boolean>
+-- !query 210 output
+true
+
+
+-- !query 211
+SELECT date('1996-09-09') = cast(null as string) FROM t
+-- !query 211 schema
+struct<(CAST(CAST(1996-09-09 AS DATE) AS STRING) = CAST(NULL AS STRING)):boolean>
+-- !query 211 output
+NULL
+
+
+-- !query 212
+SELECT date('1996-9-10') > cast(null as string) FROM t
+-- !query 212 schema
+struct<(CAST(CAST(1996-9-10 AS DATE) AS STRING) > CAST(NULL AS STRING)):boolean>
+-- !query 212 output
+NULL
+
+
+-- !query 213
+SELECT date('1996-9-10') >= cast(null as string) FROM t
+-- !query 213 schema
+struct<(CAST(CAST(1996-9-10 AS DATE) AS STRING) >= CAST(NULL AS STRING)):boolean>
+-- !query 213 output
+NULL
+
+
+-- !query 214
+SELECT date('1996-9-10') < cast(null as string) FROM t
+-- !query 214 schema
+struct<(CAST(CAST(1996-9-10 AS DATE) AS STRING) < CAST(NULL AS STRING)):boolean>
+-- !query 214 output
+NULL
+
+
+-- !query 215
+SELECT date('1996-9-10') <= cast(null as string) FROM t
+-- !query 215 schema
+struct<(CAST(CAST(1996-9-10 AS DATE) AS STRING) <= CAST(NULL AS STRING)):boolean>
+-- !query 215 output
+NULL
+
+
+-- !query 216
+SELECT date('1996-9-10') <> cast(null as string) FROM t
+-- !query 216 schema
+struct<(NOT (CAST(CAST(1996-9-10 AS DATE) AS STRING) = CAST(NULL AS STRING))):boolean>
+-- !query 216 output
+NULL
+
+
+-- !query 217
+SELECT '1996-09-09 12:12:12.4' = timestamp('1996-09-09 12:12:12.4') FROM t
+-- !query 217 schema
+struct<(CAST(1996-09-09 12:12:12.4 AS TIMESTAMP) = CAST(1996-09-09 12:12:12.4 AS TIMESTAMP)):boolean>
+-- !query 217 output
+true
+
+
+-- !query 218
+SELECT '1996-09-09 12:12:12.5' > timestamp('1996-09-09 12:12:12.4') FROM t
+-- !query 218 schema
+struct<(1996-09-09 12:12:12.5 > CAST(CAST(1996-09-09 12:12:12.4 AS TIMESTAMP) AS STRING)):boolean>
+-- !query 218 output
+true
+
+
+-- !query 219
+SELECT '1996-09-09 12:12:12.5' >= timestamp('1996-09-09 12:12:12.4') FROM t
+-- !query 219 schema
+struct<(1996-09-09 12:12:12.5 >= CAST(CAST(1996-09-09 12:12:12.4 AS TIMESTAMP) AS STRING)):boolean>
+-- !query 219 output
+true
+
+
+-- !query 220
+SELECT '1996-09-09 12:12:12.5' < timestamp('1996-09-09 12:12:12.4') FROM t
+-- !query 220 schema
+struct<(1996-09-09 12:12:12.5 < CAST(CAST(1996-09-09 12:12:12.4 AS TIMESTAMP) AS STRING)):boolean>
+-- !query 220 output
+false
+
+
+-- !query 221
+SELECT '1996-09-09 12:12:12.5' <= timestamp('1996-09-09 12:12:12.4') FROM t
+-- !query 221 schema
+struct<(1996-09-09 12:12:12.5 <= CAST(CAST(1996-09-09 12:12:12.4 AS TIMESTAMP) AS STRING)):boolean>
+-- !query 221 output
+false
+
+
+-- !query 222
+SELECT '1996-09-09 12:12:12.5' <> timestamp('1996-09-09 12:12:12.4') FROM t
+-- !query 222 schema
+struct<(NOT (CAST(1996-09-09 12:12:12.5 AS TIMESTAMP) = CAST(1996-09-09 12:12:12.4 AS TIMESTAMP))):boolean>
+-- !query 222 output
+true
+
+
+-- !query 223
+SELECT cast(null as string) = timestamp('1996-09-09 12:12:12.4') FROM t
+-- !query 223 schema
+struct<(CAST(CAST(NULL AS STRING) AS TIMESTAMP) = CAST(1996-09-09 12:12:12.4 AS TIMESTAMP)):boolean>
+-- !query 223 output
+NULL
+
+
+-- !query 224
+SELECT cast(null as string) > timestamp('1996-09-09 12:12:12.4') FROM t
+-- !query 224 schema
+struct<(CAST(NULL AS STRING) > CAST(CAST(1996-09-09 12:12:12.4 AS TIMESTAMP) AS STRING)):boolean>
+-- !query 224 output
+NULL
+
+
+-- !query 225
+SELECT cast(null as string) >= timestamp('1996-09-09 12:12:12.4') FROM t
+-- !query 225 schema
+struct<(CAST(NULL AS STRING) >= CAST(CAST(1996-09-09 12:12:12.4 AS TIMESTAMP) AS STRING)):boolean>
+-- !query 225 output
+NULL
+
+
+-- !query 226
+SELECT cast(null as string) < timestamp('1996-09-09 12:12:12.4') FROM t
+-- !query 226 schema
+struct<(CAST(NULL AS STRING) < CAST(CAST(1996-09-09 12:12:12.4 AS TIMESTAMP) AS STRING)):boolean>
+-- !query 226 output
+NULL
+
+
+-- !query 227
+SELECT cast(null as string) <= timestamp('1996-09-09 12:12:12.4') FROM t
+-- !query 227 schema
+struct<(CAST(NULL AS STRING) <= CAST(CAST(1996-09-09 12:12:12.4 AS TIMESTAMP) AS STRING)):boolean>
+-- !query 227 output
+NULL
+
+
+-- !query 228
+SELECT cast(null as string) <> timestamp('1996-09-09 12:12:12.4') FROM t
+-- !query 228 schema
+struct<(NOT (CAST(CAST(NULL AS STRING) AS TIMESTAMP) = CAST(1996-09-09 12:12:12.4 AS TIMESTAMP))):boolean>
+-- !query 228 output
+NULL
+
+
+-- !query 229
+SELECT timestamp('1996-09-09 12:12:12.4' )= '1996-09-09 12:12:12.4' FROM t
+-- !query 229 schema
+struct<(CAST(1996-09-09 12:12:12.4 AS TIMESTAMP) = CAST(1996-09-09 12:12:12.4 AS TIMESTAMP)):boolean>
+-- !query 229 output
+true
+
+
+-- !query 230
+SELECT timestamp('1996-09-09 12:12:12.5' )> '1996-09-09 12:12:12.4' FROM t
+-- !query 230 schema
+struct<(CAST(CAST(1996-09-09 12:12:12.5 AS TIMESTAMP) AS STRING) > 1996-09-09 12:12:12.4):boolean>
+-- !query 230 output
+true
+
+
+-- !query 231
+SELECT timestamp('1996-09-09 12:12:12.5' )>= '1996-09-09 12:12:12.4' FROM t
+-- !query 231 schema
+struct<(CAST(CAST(1996-09-09 12:12:12.5 AS TIMESTAMP) AS STRING) >= 1996-09-09 12:12:12.4):boolean>
+-- !query 231 output
+true
+
+
+-- !query 232
+SELECT timestamp('1996-09-09 12:12:12.5' )< '1996-09-09 12:12:12.4' FROM t
+-- !query 232 schema
+struct<(CAST(CAST(1996-09-09 12:12:12.5 AS TIMESTAMP) AS STRING) < 1996-09-09 12:12:12.4):boolean>
+-- !query 232 output
+false
+
+
+-- !query 233
+SELECT timestamp('1996-09-09 12:12:12.5' )<= '1996-09-09 12:12:12.4' FROM t
+-- !query 233 schema
+struct<(CAST(CAST(1996-09-09 12:12:12.5 AS TIMESTAMP) AS STRING) <= 1996-09-09 12:12:12.4):boolean>
+-- !query 233 output
+false
+
+
+-- !query 234
+SELECT timestamp('1996-09-09 12:12:12.5' )<> '1996-09-09 12:12:12.4' FROM t
+-- !query 234 schema
+struct<(NOT (CAST(1996-09-09 12:12:12.5 AS TIMESTAMP) = CAST(1996-09-09 12:12:12.4 AS TIMESTAMP))):boolean>
+-- !query 234 output
+true
+
+
+-- !query 235
+SELECT timestamp('1996-09-09 12:12:12.4' )= cast(null as string) FROM t
+-- !query 235 schema
+struct<(CAST(1996-09-09 12:12:12.4 AS TIMESTAMP) = CAST(CAST(NULL AS STRING) AS TIMESTAMP)):boolean>
+-- !query 235 output
+NULL
+
+
+-- !query 236
+SELECT timestamp('1996-09-09 12:12:12.5' )> cast(null as string) FROM t
+-- !query 236 schema
+struct<(CAST(CAST(1996-09-09 12:12:12.5 AS TIMESTAMP) AS STRING) > CAST(NULL AS STRING)):boolean>
+-- !query 236 output
+NULL
+
+
+-- !query 237
+SELECT timestamp('1996-09-09 12:12:12.5' )>= cast(null as string) FROM t
+-- !query 237 schema
+struct<(CAST(CAST(1996-09-09 12:12:12.5 AS TIMESTAMP) AS STRING) >= CAST(NULL AS STRING)):boolean>
+-- !query 237 output
+NULL
+
+
+-- !query 238
+SELECT timestamp('1996-09-09 12:12:12.5' )< cast(null as string) FROM t
+-- !query 238 schema
+struct<(CAST(CAST(1996-09-09 12:12:12.5 AS TIMESTAMP) AS STRING) < CAST(NULL AS STRING)):boolean>
+-- !query 238 output
+NULL
+
+
+-- !query 239
+SELECT timestamp('1996-09-09 12:12:12.5' )<= cast(null as string) FROM t
+-- !query 239 schema
+struct<(CAST(CAST(1996-09-09 12:12:12.5 AS TIMESTAMP) AS STRING) <= CAST(NULL AS STRING)):boolean>
+-- !query 239 output
+NULL
+
+
+-- !query 240
+SELECT timestamp('1996-09-09 12:12:12.5' )<> cast(null as string) FROM t
+-- !query 240 schema
+struct<(NOT (CAST(1996-09-09 12:12:12.5 AS TIMESTAMP) = CAST(CAST(NULL AS STRING) AS TIMESTAMP))):boolean>
+-- !query 240 output
+NULL
+
+
+-- !query 241
+SELECT ' ' = X'0020' FROM t
+-- !query 241 schema
+struct<(CAST( AS BINARY) = X'0020'):boolean>
+-- !query 241 output
+false
+
+
+-- !query 242
+SELECT ' ' > X'001F' FROM t
+-- !query 242 schema
+struct<(CAST( AS BINARY) > X'001F'):boolean>
+-- !query 242 output
+true
+
+
+-- !query 243
+SELECT ' ' >= X'001F' FROM t
+-- !query 243 schema
+struct<(CAST( AS BINARY) >= X'001F'):boolean>
+-- !query 243 output
+true
+
+
+-- !query 244
+SELECT ' ' < X'001F' FROM t
+-- !query 244 schema
+struct<(CAST( AS BINARY) < X'001F'):boolean>
+-- !query 244 output
+false
+
+
+-- !query 245
+SELECT ' ' <= X'001F' FROM t
+-- !query 245 schema
+struct<(CAST( AS BINARY) <= X'001F'):boolean>
+-- !query 245 output
+false
+
+
+-- !query 246
+SELECT ' ' <> X'001F' FROM t
+-- !query 246 schema
+struct<(NOT (CAST( AS BINARY) = X'001F')):boolean>
+-- !query 246 output
+true
+
+
+-- !query 247
+SELECT cast(null as string) = X'0020' FROM t
+-- !query 247 schema
+struct<(CAST(CAST(NULL AS STRING) AS BINARY) = X'0020'):boolean>
+-- !query 247 output
+NULL
+
+
+-- !query 248
+SELECT cast(null as string) > X'001F' FROM t
+-- !query 248 schema
+struct<(CAST(CAST(NULL AS STRING) AS BINARY) > X'001F'):boolean>
+-- !query 248 output
+NULL
+
+
+-- !query 249
+SELECT cast(null as string) >= X'001F' FROM t
+-- !query 249 schema
+struct<(CAST(CAST(NULL AS STRING) AS BINARY) >= X'001F'):boolean>
+-- !query 249 output
+NULL
+
+
+-- !query 250
+SELECT cast(null as string) < X'001F' FROM t
+-- !query 250 schema
+struct<(CAST(CAST(NULL AS STRING) AS BINARY) < X'001F'):boolean>
+-- !query 250 output
+NULL
+
+
+-- !query 251
+SELECT cast(null as string) <= X'001F' FROM t
+-- !query 251 schema
+struct<(CAST(CAST(NULL AS STRING) AS BINARY) <= X'001F'):boolean>
+-- !query 251 output
+NULL
+
+
+-- !query 252
+SELECT cast(null as string) <> X'001F' FROM t
+-- !query 252 schema
+struct<(NOT (CAST(CAST(NULL AS STRING) AS BINARY) = X'001F')):boolean>
+-- !query 252 output
+NULL
+
+
+-- !query 253
+SELECT X'0020' = ' ' FROM t
+-- !query 253 schema
+struct<(X'0020' = CAST( AS BINARY)):boolean>
+-- !query 253 output
+false
+
+
+-- !query 254
+SELECT X'001F' > ' ' FROM t
+-- !query 254 schema
+struct<(X'001F' > CAST( AS BINARY)):boolean>
+-- !query 254 output
+false
+
+
+-- !query 255
+SELECT X'001F' >= ' ' FROM t
+-- !query 255 schema
+struct<(X'001F' >= CAST( AS BINARY)):boolean>
+-- !query 255 output
+false
+
+
+-- !query 256
+SELECT X'001F' < ' ' FROM t
+-- !query 256 schema
+struct<(X'001F' < CAST( AS BINARY)):boolean>
+-- !query 256 output
+true
+
+
+-- !query 257
+SELECT X'001F' <= ' ' FROM t
+-- !query 257 schema
+struct<(X'001F' <= CAST( AS BINARY)):boolean>
+-- !query 257 output
+true
+
+
+-- !query 258
+SELECT X'001F' <> ' ' FROM t
+-- !query 258 schema
+struct<(NOT (X'001F' = CAST( AS BINARY))):boolean>
+-- !query 258 output
+true
+
+
+-- !query 259
+SELECT X'0020' = cast(null as string) FROM t
+-- !query 259 schema
+struct<(X'0020' = CAST(CAST(NULL AS STRING) AS BINARY)):boolean>
+-- !query 259 output
+NULL
+
+
+-- !query 260
+SELECT X'001F' > cast(null as string) FROM t
+-- !query 260 schema
+struct<(X'001F' > CAST(CAST(NULL AS STRING) AS BINARY)):boolean>
+-- !query 260 output
+NULL
+
+
+-- !query 261
+SELECT X'001F' >= cast(null as string) FROM t
+-- !query 261 schema
+struct<(X'001F' >= CAST(CAST(NULL AS STRING) AS BINARY)):boolean>
+-- !query 261 output
+NULL
+
+
+-- !query 262
+SELECT X'001F' < cast(null as string) FROM t
+-- !query 262 schema
+struct<(X'001F' < CAST(CAST(NULL AS STRING) AS BINARY)):boolean>
+-- !query 262 output
+NULL
+
+
+-- !query 263
+SELECT X'001F' <= cast(null as string) FROM t
+-- !query 263 schema
+struct<(X'001F' <= CAST(CAST(NULL AS STRING) AS BINARY)):boolean>
+-- !query 263 output
+NULL
+
+
+-- !query 264
+SELECT X'001F' <> cast(null as string) FROM t
+-- !query 264 schema
+struct<(NOT (X'001F' = CAST(CAST(NULL AS STRING) AS BINARY))):boolean>
+-- !query 264 output
+NULL
diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/booleanEquality.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/booleanEquality.sql.out
new file mode 100644
index 0000000000000..46775d79ff4a2
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/booleanEquality.sql.out
@@ -0,0 +1,802 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 97
+
+
+-- !query 0
+CREATE TEMPORARY VIEW t AS SELECT 1
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+SELECT true = cast(1 as tinyint) FROM t
+-- !query 1 schema
+struct<(CAST(true AS TINYINT) = CAST(1 AS TINYINT)):boolean>
+-- !query 1 output
+true
+
+
+-- !query 2
+SELECT true = cast(1 as smallint) FROM t
+-- !query 2 schema
+struct<(CAST(true AS SMALLINT) = CAST(1 AS SMALLINT)):boolean>
+-- !query 2 output
+true
+
+
+-- !query 3
+SELECT true = cast(1 as int) FROM t
+-- !query 3 schema
+struct<(CAST(true AS INT) = CAST(1 AS INT)):boolean>
+-- !query 3 output
+true
+
+
+-- !query 4
+SELECT true = cast(1 as bigint) FROM t
+-- !query 4 schema
+struct<(CAST(true AS BIGINT) = CAST(1 AS BIGINT)):boolean>
+-- !query 4 output
+true
+
+
+-- !query 5
+SELECT true = cast(1 as float) FROM t
+-- !query 5 schema
+struct<(CAST(true AS FLOAT) = CAST(1 AS FLOAT)):boolean>
+-- !query 5 output
+true
+
+
+-- !query 6
+SELECT true = cast(1 as double) FROM t
+-- !query 6 schema
+struct<(CAST(true AS DOUBLE) = CAST(1 AS DOUBLE)):boolean>
+-- !query 6 output
+true
+
+
+-- !query 7
+SELECT true = cast(1 as decimal(10, 0)) FROM t
+-- !query 7 schema
+struct<(CAST(true AS DECIMAL(10,0)) = CAST(1 AS DECIMAL(10,0))):boolean>
+-- !query 7 output
+true
+
+
+-- !query 8
+SELECT true = cast(1 as string) FROM t
+-- !query 8 schema
+struct<(true = CAST(CAST(1 AS STRING) AS BOOLEAN)):boolean>
+-- !query 8 output
+true
+
+
+-- !query 9
+SELECT true = cast('1' as binary) FROM t
+-- !query 9 schema
+struct<>
+-- !query 9 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(true = CAST('1' AS BINARY))' due to data type mismatch: differing types in '(true = CAST('1' AS BINARY))' (boolean and binary).; line 1 pos 7
+
+
+-- !query 10
+SELECT true = cast(1 as boolean) FROM t
+-- !query 10 schema
+struct<(true = CAST(1 AS BOOLEAN)):boolean>
+-- !query 10 output
+true
+
+
+-- !query 11
+SELECT true = cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 11 schema
+struct<>
+-- !query 11 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(true = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(true = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (boolean and timestamp).; line 1 pos 7
+
+
+-- !query 12
+SELECT true = cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 12 schema
+struct<>
+-- !query 12 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(true = CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(true = CAST('2017-12-11 09:30:00' AS DATE))' (boolean and date).; line 1 pos 7
+
+
+-- !query 13
+SELECT true <=> cast(1 as tinyint) FROM t
+-- !query 13 schema
+struct<(CAST(true AS TINYINT) <=> CAST(1 AS TINYINT)):boolean>
+-- !query 13 output
+true
+
+
+-- !query 14
+SELECT true <=> cast(1 as smallint) FROM t
+-- !query 14 schema
+struct<(CAST(true AS SMALLINT) <=> CAST(1 AS SMALLINT)):boolean>
+-- !query 14 output
+true
+
+
+-- !query 15
+SELECT true <=> cast(1 as int) FROM t
+-- !query 15 schema
+struct<(CAST(true AS INT) <=> CAST(1 AS INT)):boolean>
+-- !query 15 output
+true
+
+
+-- !query 16
+SELECT true <=> cast(1 as bigint) FROM t
+-- !query 16 schema
+struct<(CAST(true AS BIGINT) <=> CAST(1 AS BIGINT)):boolean>
+-- !query 16 output
+true
+
+
+-- !query 17
+SELECT true <=> cast(1 as float) FROM t
+-- !query 17 schema
+struct<(CAST(true AS FLOAT) <=> CAST(1 AS FLOAT)):boolean>
+-- !query 17 output
+true
+
+
+-- !query 18
+SELECT true <=> cast(1 as double) FROM t
+-- !query 18 schema
+struct<(CAST(true AS DOUBLE) <=> CAST(1 AS DOUBLE)):boolean>
+-- !query 18 output
+true
+
+
+-- !query 19
+SELECT true <=> cast(1 as decimal(10, 0)) FROM t
+-- !query 19 schema
+struct<(CAST(true AS DECIMAL(10,0)) <=> CAST(1 AS DECIMAL(10,0))):boolean>
+-- !query 19 output
+true
+
+
+-- !query 20
+SELECT true <=> cast(1 as string) FROM t
+-- !query 20 schema
+struct<(true <=> CAST(CAST(1 AS STRING) AS BOOLEAN)):boolean>
+-- !query 20 output
+true
+
+
+-- !query 21
+SELECT true <=> cast('1' as binary) FROM t
+-- !query 21 schema
+struct<>
+-- !query 21 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(true <=> CAST('1' AS BINARY))' due to data type mismatch: differing types in '(true <=> CAST('1' AS BINARY))' (boolean and binary).; line 1 pos 7
+
+
+-- !query 22
+SELECT true <=> cast(1 as boolean) FROM t
+-- !query 22 schema
+struct<(true <=> CAST(1 AS BOOLEAN)):boolean>
+-- !query 22 output
+true
+
+
+-- !query 23
+SELECT true <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 23 schema
+struct<>
+-- !query 23 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(true <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(true <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (boolean and timestamp).; line 1 pos 7
+
+
+-- !query 24
+SELECT true <=> cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 24 schema
+struct<>
+-- !query 24 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(true <=> CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(true <=> CAST('2017-12-11 09:30:00' AS DATE))' (boolean and date).; line 1 pos 7
+
+
+-- !query 25
+SELECT cast(1 as tinyint) = true FROM t
+-- !query 25 schema
+struct<(CAST(1 AS TINYINT) = CAST(true AS TINYINT)):boolean>
+-- !query 25 output
+true
+
+
+-- !query 26
+SELECT cast(1 as smallint) = true FROM t
+-- !query 26 schema
+struct<(CAST(1 AS SMALLINT) = CAST(true AS SMALLINT)):boolean>
+-- !query 26 output
+true
+
+
+-- !query 27
+SELECT cast(1 as int) = true FROM t
+-- !query 27 schema
+struct<(CAST(1 AS INT) = CAST(true AS INT)):boolean>
+-- !query 27 output
+true
+
+
+-- !query 28
+SELECT cast(1 as bigint) = true FROM t
+-- !query 28 schema
+struct<(CAST(1 AS BIGINT) = CAST(true AS BIGINT)):boolean>
+-- !query 28 output
+true
+
+
+-- !query 29
+SELECT cast(1 as float) = true FROM t
+-- !query 29 schema
+struct<(CAST(1 AS FLOAT) = CAST(true AS FLOAT)):boolean>
+-- !query 29 output
+true
+
+
+-- !query 30
+SELECT cast(1 as double) = true FROM t
+-- !query 30 schema
+struct<(CAST(1 AS DOUBLE) = CAST(true AS DOUBLE)):boolean>
+-- !query 30 output
+true
+
+
+-- !query 31
+SELECT cast(1 as decimal(10, 0)) = true FROM t
+-- !query 31 schema
+struct<(CAST(1 AS DECIMAL(10,0)) = CAST(true AS DECIMAL(10,0))):boolean>
+-- !query 31 output
+true
+
+
+-- !query 32
+SELECT cast(1 as string) = true FROM t
+-- !query 32 schema
+struct<(CAST(CAST(1 AS STRING) AS BOOLEAN) = true):boolean>
+-- !query 32 output
+true
+
+
+-- !query 33
+SELECT cast('1' as binary) = true FROM t
+-- !query 33 schema
+struct<>
+-- !query 33 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) = true)' due to data type mismatch: differing types in '(CAST('1' AS BINARY) = true)' (binary and boolean).; line 1 pos 7
+
+
+-- !query 34
+SELECT cast(1 as boolean) = true FROM t
+-- !query 34 schema
+struct<(CAST(1 AS BOOLEAN) = true):boolean>
+-- !query 34 output
+true
+
+
+-- !query 35
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) = true FROM t
+-- !query 35 schema
+struct<>
+-- !query 35 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = true)' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = true)' (timestamp and boolean).; line 1 pos 7
+
+
+-- !query 36
+SELECT cast('2017-12-11 09:30:00' as date) = true FROM t
+-- !query 36 schema
+struct<>
+-- !query 36 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) = true)' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) = true)' (date and boolean).; line 1 pos 7
+
+
+-- !query 37
+SELECT cast(1 as tinyint) <=> true FROM t
+-- !query 37 schema
+struct<(CAST(1 AS TINYINT) <=> CAST(true AS TINYINT)):boolean>
+-- !query 37 output
+true
+
+
+-- !query 38
+SELECT cast(1 as smallint) <=> true FROM t
+-- !query 38 schema
+struct<(CAST(1 AS SMALLINT) <=> CAST(true AS SMALLINT)):boolean>
+-- !query 38 output
+true
+
+
+-- !query 39
+SELECT cast(1 as int) <=> true FROM t
+-- !query 39 schema
+struct<(CAST(1 AS INT) <=> CAST(true AS INT)):boolean>
+-- !query 39 output
+true
+
+
+-- !query 40
+SELECT cast(1 as bigint) <=> true FROM t
+-- !query 40 schema
+struct<(CAST(1 AS BIGINT) <=> CAST(true AS BIGINT)):boolean>
+-- !query 40 output
+true
+
+
+-- !query 41
+SELECT cast(1 as float) <=> true FROM t
+-- !query 41 schema
+struct<(CAST(1 AS FLOAT) <=> CAST(true AS FLOAT)):boolean>
+-- !query 41 output
+true
+
+
+-- !query 42
+SELECT cast(1 as double) <=> true FROM t
+-- !query 42 schema
+struct<(CAST(1 AS DOUBLE) <=> CAST(true AS DOUBLE)):boolean>
+-- !query 42 output
+true
+
+
+-- !query 43
+SELECT cast(1 as decimal(10, 0)) <=> true FROM t
+-- !query 43 schema
+struct<(CAST(1 AS DECIMAL(10,0)) <=> CAST(true AS DECIMAL(10,0))):boolean>
+-- !query 43 output
+true
+
+
+-- !query 44
+SELECT cast(1 as string) <=> true FROM t
+-- !query 44 schema
+struct<(CAST(CAST(1 AS STRING) AS BOOLEAN) <=> true):boolean>
+-- !query 44 output
+true
+
+
+-- !query 45
+SELECT cast('1' as binary) <=> true FROM t
+-- !query 45 schema
+struct<>
+-- !query 45 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) <=> true)' due to data type mismatch: differing types in '(CAST('1' AS BINARY) <=> true)' (binary and boolean).; line 1 pos 7
+
+
+-- !query 46
+SELECT cast(1 as boolean) <=> true FROM t
+-- !query 46 schema
+struct<(CAST(1 AS BOOLEAN) <=> true):boolean>
+-- !query 46 output
+true
+
+
+-- !query 47
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> true FROM t
+-- !query 47 schema
+struct<>
+-- !query 47 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> true)' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> true)' (timestamp and boolean).; line 1 pos 7
+
+
+-- !query 48
+SELECT cast('2017-12-11 09:30:00' as date) <=> true FROM t
+-- !query 48 schema
+struct<>
+-- !query 48 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) <=> true)' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) <=> true)' (date and boolean).; line 1 pos 7
+
+
+-- !query 49
+SELECT false = cast(0 as tinyint) FROM t
+-- !query 49 schema
+struct<(CAST(false AS TINYINT) = CAST(0 AS TINYINT)):boolean>
+-- !query 49 output
+true
+
+
+-- !query 50
+SELECT false = cast(0 as smallint) FROM t
+-- !query 50 schema
+struct<(CAST(false AS SMALLINT) = CAST(0 AS SMALLINT)):boolean>
+-- !query 50 output
+true
+
+
+-- !query 51
+SELECT false = cast(0 as int) FROM t
+-- !query 51 schema
+struct<(CAST(false AS INT) = CAST(0 AS INT)):boolean>
+-- !query 51 output
+true
+
+
+-- !query 52
+SELECT false = cast(0 as bigint) FROM t
+-- !query 52 schema
+struct<(CAST(false AS BIGINT) = CAST(0 AS BIGINT)):boolean>
+-- !query 52 output
+true
+
+
+-- !query 53
+SELECT false = cast(0 as float) FROM t
+-- !query 53 schema
+struct<(CAST(false AS FLOAT) = CAST(0 AS FLOAT)):boolean>
+-- !query 53 output
+true
+
+
+-- !query 54
+SELECT false = cast(0 as double) FROM t
+-- !query 54 schema
+struct<(CAST(false AS DOUBLE) = CAST(0 AS DOUBLE)):boolean>
+-- !query 54 output
+true
+
+
+-- !query 55
+SELECT false = cast(0 as decimal(10, 0)) FROM t
+-- !query 55 schema
+struct<(CAST(false AS DECIMAL(10,0)) = CAST(0 AS DECIMAL(10,0))):boolean>
+-- !query 55 output
+true
+
+
+-- !query 56
+SELECT false = cast(0 as string) FROM t
+-- !query 56 schema
+struct<(false = CAST(CAST(0 AS STRING) AS BOOLEAN)):boolean>
+-- !query 56 output
+true
+
+
+-- !query 57
+SELECT false = cast('0' as binary) FROM t
+-- !query 57 schema
+struct<>
+-- !query 57 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(false = CAST('0' AS BINARY))' due to data type mismatch: differing types in '(false = CAST('0' AS BINARY))' (boolean and binary).; line 1 pos 7
+
+
+-- !query 58
+SELECT false = cast(0 as boolean) FROM t
+-- !query 58 schema
+struct<(false = CAST(0 AS BOOLEAN)):boolean>
+-- !query 58 output
+true
+
+
+-- !query 59
+SELECT false = cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 59 schema
+struct<>
+-- !query 59 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(false = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(false = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (boolean and timestamp).; line 1 pos 7
+
+
+-- !query 60
+SELECT false = cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 60 schema
+struct<>
+-- !query 60 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(false = CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(false = CAST('2017-12-11 09:30:00' AS DATE))' (boolean and date).; line 1 pos 7
+
+
+-- !query 61
+SELECT false <=> cast(0 as tinyint) FROM t
+-- !query 61 schema
+struct<(CAST(false AS TINYINT) <=> CAST(0 AS TINYINT)):boolean>
+-- !query 61 output
+true
+
+
+-- !query 62
+SELECT false <=> cast(0 as smallint) FROM t
+-- !query 62 schema
+struct<(CAST(false AS SMALLINT) <=> CAST(0 AS SMALLINT)):boolean>
+-- !query 62 output
+true
+
+
+-- !query 63
+SELECT false <=> cast(0 as int) FROM t
+-- !query 63 schema
+struct<(CAST(false AS INT) <=> CAST(0 AS INT)):boolean>
+-- !query 63 output
+true
+
+
+-- !query 64
+SELECT false <=> cast(0 as bigint) FROM t
+-- !query 64 schema
+struct<(CAST(false AS BIGINT) <=> CAST(0 AS BIGINT)):boolean>
+-- !query 64 output
+true
+
+
+-- !query 65
+SELECT false <=> cast(0 as float) FROM t
+-- !query 65 schema
+struct<(CAST(false AS FLOAT) <=> CAST(0 AS FLOAT)):boolean>
+-- !query 65 output
+true
+
+
+-- !query 66
+SELECT false <=> cast(0 as double) FROM t
+-- !query 66 schema
+struct<(CAST(false AS DOUBLE) <=> CAST(0 AS DOUBLE)):boolean>
+-- !query 66 output
+true
+
+
+-- !query 67
+SELECT false <=> cast(0 as decimal(10, 0)) FROM t
+-- !query 67 schema
+struct<(CAST(false AS DECIMAL(10,0)) <=> CAST(0 AS DECIMAL(10,0))):boolean>
+-- !query 67 output
+true
+
+
+-- !query 68
+SELECT false <=> cast(0 as string) FROM t
+-- !query 68 schema
+struct<(false <=> CAST(CAST(0 AS STRING) AS BOOLEAN)):boolean>
+-- !query 68 output
+true
+
+
+-- !query 69
+SELECT false <=> cast('0' as binary) FROM t
+-- !query 69 schema
+struct<>
+-- !query 69 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(false <=> CAST('0' AS BINARY))' due to data type mismatch: differing types in '(false <=> CAST('0' AS BINARY))' (boolean and binary).; line 1 pos 7
+
+
+-- !query 70
+SELECT false <=> cast(0 as boolean) FROM t
+-- !query 70 schema
+struct<(false <=> CAST(0 AS BOOLEAN)):boolean>
+-- !query 70 output
+true
+
+
+-- !query 71
+SELECT false <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 71 schema
+struct<>
+-- !query 71 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(false <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(false <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (boolean and timestamp).; line 1 pos 7
+
+
+-- !query 72
+SELECT false <=> cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 72 schema
+struct<>
+-- !query 72 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(false <=> CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(false <=> CAST('2017-12-11 09:30:00' AS DATE))' (boolean and date).; line 1 pos 7
+
+
+-- !query 73
+SELECT cast(0 as tinyint) = false FROM t
+-- !query 73 schema
+struct<(CAST(0 AS TINYINT) = CAST(false AS TINYINT)):boolean>
+-- !query 73 output
+true
+
+
+-- !query 74
+SELECT cast(0 as smallint) = false FROM t
+-- !query 74 schema
+struct<(CAST(0 AS SMALLINT) = CAST(false AS SMALLINT)):boolean>
+-- !query 74 output
+true
+
+
+-- !query 75
+SELECT cast(0 as int) = false FROM t
+-- !query 75 schema
+struct<(CAST(0 AS INT) = CAST(false AS INT)):boolean>
+-- !query 75 output
+true
+
+
+-- !query 76
+SELECT cast(0 as bigint) = false FROM t
+-- !query 76 schema
+struct<(CAST(0 AS BIGINT) = CAST(false AS BIGINT)):boolean>
+-- !query 76 output
+true
+
+
+-- !query 77
+SELECT cast(0 as float) = false FROM t
+-- !query 77 schema
+struct<(CAST(0 AS FLOAT) = CAST(false AS FLOAT)):boolean>
+-- !query 77 output
+true
+
+
+-- !query 78
+SELECT cast(0 as double) = false FROM t
+-- !query 78 schema
+struct<(CAST(0 AS DOUBLE) = CAST(false AS DOUBLE)):boolean>
+-- !query 78 output
+true
+
+
+-- !query 79
+SELECT cast(0 as decimal(10, 0)) = false FROM t
+-- !query 79 schema
+struct<(CAST(0 AS DECIMAL(10,0)) = CAST(false AS DECIMAL(10,0))):boolean>
+-- !query 79 output
+true
+
+
+-- !query 80
+SELECT cast(0 as string) = false FROM t
+-- !query 80 schema
+struct<(CAST(CAST(0 AS STRING) AS BOOLEAN) = false):boolean>
+-- !query 80 output
+true
+
+
+-- !query 81
+SELECT cast('0' as binary) = false FROM t
+-- !query 81 schema
+struct<>
+-- !query 81 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('0' AS BINARY) = false)' due to data type mismatch: differing types in '(CAST('0' AS BINARY) = false)' (binary and boolean).; line 1 pos 7
+
+
+-- !query 82
+SELECT cast(0 as boolean) = false FROM t
+-- !query 82 schema
+struct<(CAST(0 AS BOOLEAN) = false):boolean>
+-- !query 82 output
+true
+
+
+-- !query 83
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) = false FROM t
+-- !query 83 schema
+struct<>
+-- !query 83 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = false)' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = false)' (timestamp and boolean).; line 1 pos 7
+
+
+-- !query 84
+SELECT cast('2017-12-11 09:30:00' as date) = false FROM t
+-- !query 84 schema
+struct<>
+-- !query 84 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) = false)' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) = false)' (date and boolean).; line 1 pos 7
+
+
+-- !query 85
+SELECT cast(0 as tinyint) <=> false FROM t
+-- !query 85 schema
+struct<(CAST(0 AS TINYINT) <=> CAST(false AS TINYINT)):boolean>
+-- !query 85 output
+true
+
+
+-- !query 86
+SELECT cast(0 as smallint) <=> false FROM t
+-- !query 86 schema
+struct<(CAST(0 AS SMALLINT) <=> CAST(false AS SMALLINT)):boolean>
+-- !query 86 output
+true
+
+
+-- !query 87
+SELECT cast(0 as int) <=> false FROM t
+-- !query 87 schema
+struct<(CAST(0 AS INT) <=> CAST(false AS INT)):boolean>
+-- !query 87 output
+true
+
+
+-- !query 88
+SELECT cast(0 as bigint) <=> false FROM t
+-- !query 88 schema
+struct<(CAST(0 AS BIGINT) <=> CAST(false AS BIGINT)):boolean>
+-- !query 88 output
+true
+
+
+-- !query 89
+SELECT cast(0 as float) <=> false FROM t
+-- !query 89 schema
+struct<(CAST(0 AS FLOAT) <=> CAST(false AS FLOAT)):boolean>
+-- !query 89 output
+true
+
+
+-- !query 90
+SELECT cast(0 as double) <=> false FROM t
+-- !query 90 schema
+struct<(CAST(0 AS DOUBLE) <=> CAST(false AS DOUBLE)):boolean>
+-- !query 90 output
+true
+
+
+-- !query 91
+SELECT cast(0 as decimal(10, 0)) <=> false FROM t
+-- !query 91 schema
+struct<(CAST(0 AS DECIMAL(10,0)) <=> CAST(false AS DECIMAL(10,0))):boolean>
+-- !query 91 output
+true
+
+
+-- !query 92
+SELECT cast(0 as string) <=> false FROM t
+-- !query 92 schema
+struct<(CAST(CAST(0 AS STRING) AS BOOLEAN) <=> false):boolean>
+-- !query 92 output
+true
+
+
+-- !query 93
+SELECT cast('0' as binary) <=> false FROM t
+-- !query 93 schema
+struct<>
+-- !query 93 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('0' AS BINARY) <=> false)' due to data type mismatch: differing types in '(CAST('0' AS BINARY) <=> false)' (binary and boolean).; line 1 pos 7
+
+
+-- !query 94
+SELECT cast(0 as boolean) <=> false FROM t
+-- !query 94 schema
+struct<(CAST(0 AS BOOLEAN) <=> false):boolean>
+-- !query 94 output
+true
+
+
+-- !query 95
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> false FROM t
+-- !query 95 schema
+struct<>
+-- !query 95 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> false)' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> false)' (timestamp and boolean).; line 1 pos 7
+
+
+-- !query 96
+SELECT cast('2017-12-11 09:30:00' as date) <=> false FROM t
+-- !query 96 schema
+struct<>
+-- !query 96 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) <=> false)' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) <=> false)' (date and boolean).; line 1 pos 7
diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/caseWhenCoercion.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/caseWhenCoercion.sql.out
new file mode 100644
index 0000000000000..a739f8d73181c
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/caseWhenCoercion.sql.out
@@ -0,0 +1,1232 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 145
+
+
+-- !query 0
+CREATE TEMPORARY VIEW t AS SELECT 1
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as tinyint) END FROM t
+-- !query 1 schema
+struct
+-- !query 1 output
+1
+
+
+-- !query 2
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as smallint) END FROM t
+-- !query 2 schema
+struct
+-- !query 2 output
+1
+
+
+-- !query 3
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as int) END FROM t
+-- !query 3 schema
+struct
+-- !query 3 output
+1
+
+
+-- !query 4
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as bigint) END FROM t
+-- !query 4 schema
+struct
+-- !query 4 output
+1
+
+
+-- !query 5
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as float) END FROM t
+-- !query 5 schema
+struct
+-- !query 5 output
+1.0
+
+
+-- !query 6
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as double) END FROM t
+-- !query 6 schema
+struct
+-- !query 6 output
+1.0
+
+
+-- !query 7
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as decimal(10, 0)) END FROM t
+-- !query 7 schema
+struct
+-- !query 7 output
+1
+
+
+-- !query 8
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as string) END FROM t
+-- !query 8 schema
+struct
+-- !query 8 output
+1
+
+
+-- !query 9
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast('2' as binary) END FROM t
+-- !query 9 schema
+struct<>
+-- !query 9 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS TINYINT) ELSE CAST('2' AS BINARY) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 10
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as boolean) END FROM t
+-- !query 10 schema
+struct<>
+-- !query 10 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS TINYINT) ELSE CAST(2 AS BOOLEAN) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 11
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t
+-- !query 11 schema
+struct<>
+-- !query 11 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS TINYINT) ELSE CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 12
+SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast('2017-12-11 09:30:00' as date) END FROM t
+-- !query 12 schema
+struct<>
+-- !query 12 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS TINYINT) ELSE CAST('2017-12-11 09:30:00' AS DATE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 13
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as tinyint) END FROM t
+-- !query 13 schema
+struct
+-- !query 13 output
+1
+
+
+-- !query 14
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as smallint) END FROM t
+-- !query 14 schema
+struct
+-- !query 14 output
+1
+
+
+-- !query 15
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as int) END FROM t
+-- !query 15 schema
+struct
+-- !query 15 output
+1
+
+
+-- !query 16
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as bigint) END FROM t
+-- !query 16 schema
+struct
+-- !query 16 output
+1
+
+
+-- !query 17
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as float) END FROM t
+-- !query 17 schema
+struct
+-- !query 17 output
+1.0
+
+
+-- !query 18
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as double) END FROM t
+-- !query 18 schema
+struct
+-- !query 18 output
+1.0
+
+
+-- !query 19
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as decimal(10, 0)) END FROM t
+-- !query 19 schema
+struct
+-- !query 19 output
+1
+
+
+-- !query 20
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as string) END FROM t
+-- !query 20 schema
+struct
+-- !query 20 output
+1
+
+
+-- !query 21
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast('2' as binary) END FROM t
+-- !query 21 schema
+struct<>
+-- !query 21 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS SMALLINT) ELSE CAST('2' AS BINARY) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 22
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as boolean) END FROM t
+-- !query 22 schema
+struct<>
+-- !query 22 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS SMALLINT) ELSE CAST(2 AS BOOLEAN) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 23
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t
+-- !query 23 schema
+struct<>
+-- !query 23 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS SMALLINT) ELSE CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 24
+SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast('2017-12-11 09:30:00' as date) END FROM t
+-- !query 24 schema
+struct<>
+-- !query 24 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS SMALLINT) ELSE CAST('2017-12-11 09:30:00' AS DATE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 25
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as tinyint) END FROM t
+-- !query 25 schema
+struct
+-- !query 25 output
+1
+
+
+-- !query 26
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as smallint) END FROM t
+-- !query 26 schema
+struct
+-- !query 26 output
+1
+
+
+-- !query 27
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as int) END FROM t
+-- !query 27 schema
+struct
+-- !query 27 output
+1
+
+
+-- !query 28
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as bigint) END FROM t
+-- !query 28 schema
+struct
+-- !query 28 output
+1
+
+
+-- !query 29
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as float) END FROM t
+-- !query 29 schema
+struct
+-- !query 29 output
+1.0
+
+
+-- !query 30
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as double) END FROM t
+-- !query 30 schema
+struct
+-- !query 30 output
+1.0
+
+
+-- !query 31
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as decimal(10, 0)) END FROM t
+-- !query 31 schema
+struct
+-- !query 31 output
+1
+
+
+-- !query 32
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as string) END FROM t
+-- !query 32 schema
+struct
+-- !query 32 output
+1
+
+
+-- !query 33
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast('2' as binary) END FROM t
+-- !query 33 schema
+struct<>
+-- !query 33 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS INT) ELSE CAST('2' AS BINARY) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 34
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as boolean) END FROM t
+-- !query 34 schema
+struct<>
+-- !query 34 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS INT) ELSE CAST(2 AS BOOLEAN) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 35
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t
+-- !query 35 schema
+struct<>
+-- !query 35 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS INT) ELSE CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 36
+SELECT CASE WHEN true THEN cast(1 as int) ELSE cast('2017-12-11 09:30:00' as date) END FROM t
+-- !query 36 schema
+struct<>
+-- !query 36 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS INT) ELSE CAST('2017-12-11 09:30:00' AS DATE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 37
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as tinyint) END FROM t
+-- !query 37 schema
+struct
+-- !query 37 output
+1
+
+
+-- !query 38
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as smallint) END FROM t
+-- !query 38 schema
+struct
+-- !query 38 output
+1
+
+
+-- !query 39
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as int) END FROM t
+-- !query 39 schema
+struct
+-- !query 39 output
+1
+
+
+-- !query 40
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as bigint) END FROM t
+-- !query 40 schema
+struct
+-- !query 40 output
+1
+
+
+-- !query 41
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as float) END FROM t
+-- !query 41 schema
+struct
+-- !query 41 output
+1.0
+
+
+-- !query 42
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as double) END FROM t
+-- !query 42 schema
+struct
+-- !query 42 output
+1.0
+
+
+-- !query 43
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as decimal(10, 0)) END FROM t
+-- !query 43 schema
+struct
+-- !query 43 output
+1
+
+
+-- !query 44
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as string) END FROM t
+-- !query 44 schema
+struct
+-- !query 44 output
+1
+
+
+-- !query 45
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast('2' as binary) END FROM t
+-- !query 45 schema
+struct<>
+-- !query 45 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS BIGINT) ELSE CAST('2' AS BINARY) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 46
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as boolean) END FROM t
+-- !query 46 schema
+struct<>
+-- !query 46 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS BIGINT) ELSE CAST(2 AS BOOLEAN) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 47
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t
+-- !query 47 schema
+struct<>
+-- !query 47 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS BIGINT) ELSE CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 48
+SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast('2017-12-11 09:30:00' as date) END FROM t
+-- !query 48 schema
+struct<>
+-- !query 48 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS BIGINT) ELSE CAST('2017-12-11 09:30:00' AS DATE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 49
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as tinyint) END FROM t
+-- !query 49 schema
+struct
+-- !query 49 output
+1.0
+
+
+-- !query 50
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as smallint) END FROM t
+-- !query 50 schema
+struct
+-- !query 50 output
+1.0
+
+
+-- !query 51
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as int) END FROM t
+-- !query 51 schema
+struct
+-- !query 51 output
+1.0
+
+
+-- !query 52
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as bigint) END FROM t
+-- !query 52 schema
+struct
+-- !query 52 output
+1.0
+
+
+-- !query 53
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as float) END FROM t
+-- !query 53 schema
+struct
+-- !query 53 output
+1.0
+
+
+-- !query 54
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as double) END FROM t
+-- !query 54 schema
+struct
+-- !query 54 output
+1.0
+
+
+-- !query 55
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as decimal(10, 0)) END FROM t
+-- !query 55 schema
+struct
+-- !query 55 output
+1.0
+
+
+-- !query 56
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as string) END FROM t
+-- !query 56 schema
+struct
+-- !query 56 output
+1.0
+
+
+-- !query 57
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast('2' as binary) END FROM t
+-- !query 57 schema
+struct<>
+-- !query 57 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS FLOAT) ELSE CAST('2' AS BINARY) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 58
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as boolean) END FROM t
+-- !query 58 schema
+struct<>
+-- !query 58 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS FLOAT) ELSE CAST(2 AS BOOLEAN) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 59
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t
+-- !query 59 schema
+struct<>
+-- !query 59 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS FLOAT) ELSE CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 60
+SELECT CASE WHEN true THEN cast(1 as float) ELSE cast('2017-12-11 09:30:00' as date) END FROM t
+-- !query 60 schema
+struct<>
+-- !query 60 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS FLOAT) ELSE CAST('2017-12-11 09:30:00' AS DATE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 61
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as tinyint) END FROM t
+-- !query 61 schema
+struct
+-- !query 61 output
+1.0
+
+
+-- !query 62
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as smallint) END FROM t
+-- !query 62 schema
+struct
+-- !query 62 output
+1.0
+
+
+-- !query 63
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as int) END FROM t
+-- !query 63 schema
+struct
+-- !query 63 output
+1.0
+
+
+-- !query 64
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as bigint) END FROM t
+-- !query 64 schema
+struct
+-- !query 64 output
+1.0
+
+
+-- !query 65
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as float) END FROM t
+-- !query 65 schema
+struct
+-- !query 65 output
+1.0
+
+
+-- !query 66
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as double) END FROM t
+-- !query 66 schema
+struct
+-- !query 66 output
+1.0
+
+
+-- !query 67
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as decimal(10, 0)) END FROM t
+-- !query 67 schema
+struct
+-- !query 67 output
+1.0
+
+
+-- !query 68
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as string) END FROM t
+-- !query 68 schema
+struct
+-- !query 68 output
+1.0
+
+
+-- !query 69
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast('2' as binary) END FROM t
+-- !query 69 schema
+struct<>
+-- !query 69 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS DOUBLE) ELSE CAST('2' AS BINARY) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 70
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as boolean) END FROM t
+-- !query 70 schema
+struct<>
+-- !query 70 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS DOUBLE) ELSE CAST(2 AS BOOLEAN) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 71
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t
+-- !query 71 schema
+struct<>
+-- !query 71 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS DOUBLE) ELSE CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 72
+SELECT CASE WHEN true THEN cast(1 as double) ELSE cast('2017-12-11 09:30:00' as date) END FROM t
+-- !query 72 schema
+struct<>
+-- !query 72 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS DOUBLE) ELSE CAST('2017-12-11 09:30:00' AS DATE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 73
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as tinyint) END FROM t
+-- !query 73 schema
+struct
+-- !query 73 output
+1
+
+
+-- !query 74
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as smallint) END FROM t
+-- !query 74 schema
+struct
+-- !query 74 output
+1
+
+
+-- !query 75
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as int) END FROM t
+-- !query 75 schema
+struct
+-- !query 75 output
+1
+
+
+-- !query 76
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as bigint) END FROM t
+-- !query 76 schema
+struct
+-- !query 76 output
+1
+
+
+-- !query 77
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as float) END FROM t
+-- !query 77 schema
+struct
+-- !query 77 output
+1.0
+
+
+-- !query 78
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as double) END FROM t
+-- !query 78 schema
+struct
+-- !query 78 output
+1.0
+
+
+-- !query 79
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as decimal(10, 0)) END FROM t
+-- !query 79 schema
+struct
+-- !query 79 output
+1
+
+
+-- !query 80
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as string) END FROM t
+-- !query 80 schema
+struct
+-- !query 80 output
+1
+
+
+-- !query 81
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast('2' as binary) END FROM t
+-- !query 81 schema
+struct<>
+-- !query 81 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS DECIMAL(10,0)) ELSE CAST('2' AS BINARY) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 82
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as boolean) END FROM t
+-- !query 82 schema
+struct<>
+-- !query 82 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS DECIMAL(10,0)) ELSE CAST(2 AS BOOLEAN) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 83
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t
+-- !query 83 schema
+struct<>
+-- !query 83 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS DECIMAL(10,0)) ELSE CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 84
+SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast('2017-12-11 09:30:00' as date) END FROM t
+-- !query 84 schema
+struct<>
+-- !query 84 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS DECIMAL(10,0)) ELSE CAST('2017-12-11 09:30:00' AS DATE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 85
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as tinyint) END FROM t
+-- !query 85 schema
+struct
+-- !query 85 output
+1
+
+
+-- !query 86
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as smallint) END FROM t
+-- !query 86 schema
+struct
+-- !query 86 output
+1
+
+
+-- !query 87
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as int) END FROM t
+-- !query 87 schema
+struct
+-- !query 87 output
+1
+
+
+-- !query 88
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as bigint) END FROM t
+-- !query 88 schema
+struct
+-- !query 88 output
+1
+
+
+-- !query 89
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as float) END FROM t
+-- !query 89 schema
+struct
+-- !query 89 output
+1
+
+
+-- !query 90
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as double) END FROM t
+-- !query 90 schema
+struct
+-- !query 90 output
+1
+
+
+-- !query 91
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as decimal(10, 0)) END FROM t
+-- !query 91 schema
+struct
+-- !query 91 output
+1
+
+
+-- !query 92
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as string) END FROM t
+-- !query 92 schema
+struct
+-- !query 92 output
+1
+
+
+-- !query 93
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast('2' as binary) END FROM t
+-- !query 93 schema
+struct<>
+-- !query 93 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS STRING) ELSE CAST('2' AS BINARY) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 94
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as boolean) END FROM t
+-- !query 94 schema
+struct<>
+-- !query 94 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS STRING) ELSE CAST(2 AS BOOLEAN) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 95
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t
+-- !query 95 schema
+struct
+-- !query 95 output
+1
+
+
+-- !query 96
+SELECT CASE WHEN true THEN cast(1 as string) ELSE cast('2017-12-11 09:30:00' as date) END FROM t
+-- !query 96 schema
+struct
+-- !query 96 output
+1
+
+
+-- !query 97
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as tinyint) END FROM t
+-- !query 97 schema
+struct<>
+-- !query 97 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('1' AS BINARY) ELSE CAST(2 AS TINYINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 98
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as smallint) END FROM t
+-- !query 98 schema
+struct<>
+-- !query 98 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('1' AS BINARY) ELSE CAST(2 AS SMALLINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 99
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as int) END FROM t
+-- !query 99 schema
+struct<>
+-- !query 99 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('1' AS BINARY) ELSE CAST(2 AS INT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 100
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as bigint) END FROM t
+-- !query 100 schema
+struct<>
+-- !query 100 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('1' AS BINARY) ELSE CAST(2 AS BIGINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 101
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as float) END FROM t
+-- !query 101 schema
+struct<>
+-- !query 101 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('1' AS BINARY) ELSE CAST(2 AS FLOAT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 102
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as double) END FROM t
+-- !query 102 schema
+struct<>
+-- !query 102 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('1' AS BINARY) ELSE CAST(2 AS DOUBLE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 103
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as decimal(10, 0)) END FROM t
+-- !query 103 schema
+struct<>
+-- !query 103 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('1' AS BINARY) ELSE CAST(2 AS DECIMAL(10,0)) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 104
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as string) END FROM t
+-- !query 104 schema
+struct<>
+-- !query 104 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('1' AS BINARY) ELSE CAST(2 AS STRING) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 105
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast('2' as binary) END FROM t
+-- !query 105 schema
+struct
+-- !query 105 output
+1
+
+
+-- !query 106
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as boolean) END FROM t
+-- !query 106 schema
+struct<>
+-- !query 106 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('1' AS BINARY) ELSE CAST(2 AS BOOLEAN) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 107
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t
+-- !query 107 schema
+struct<>
+-- !query 107 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('1' AS BINARY) ELSE CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 108
+SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast('2017-12-11 09:30:00' as date) END FROM t
+-- !query 108 schema
+struct<>
+-- !query 108 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('1' AS BINARY) ELSE CAST('2017-12-11 09:30:00' AS DATE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 109
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as tinyint) END FROM t
+-- !query 109 schema
+struct<>
+-- !query 109 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS BOOLEAN) ELSE CAST(2 AS TINYINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 110
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as smallint) END FROM t
+-- !query 110 schema
+struct<>
+-- !query 110 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS BOOLEAN) ELSE CAST(2 AS SMALLINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 111
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as int) END FROM t
+-- !query 111 schema
+struct<>
+-- !query 111 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS BOOLEAN) ELSE CAST(2 AS INT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 112
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as bigint) END FROM t
+-- !query 112 schema
+struct<>
+-- !query 112 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS BOOLEAN) ELSE CAST(2 AS BIGINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 113
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as float) END FROM t
+-- !query 113 schema
+struct<>
+-- !query 113 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS BOOLEAN) ELSE CAST(2 AS FLOAT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 114
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as double) END FROM t
+-- !query 114 schema
+struct<>
+-- !query 114 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS BOOLEAN) ELSE CAST(2 AS DOUBLE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 115
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as decimal(10, 0)) END FROM t
+-- !query 115 schema
+struct<>
+-- !query 115 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS BOOLEAN) ELSE CAST(2 AS DECIMAL(10,0)) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 116
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as string) END FROM t
+-- !query 116 schema
+struct<>
+-- !query 116 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS BOOLEAN) ELSE CAST(2 AS STRING) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 117
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast('2' as binary) END FROM t
+-- !query 117 schema
+struct<>
+-- !query 117 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS BOOLEAN) ELSE CAST('2' AS BINARY) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 118
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as boolean) END FROM t
+-- !query 118 schema
+struct
+-- !query 118 output
+true
+
+
+-- !query 119
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t
+-- !query 119 schema
+struct<>
+-- !query 119 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS BOOLEAN) ELSE CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 120
+SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast('2017-12-11 09:30:00' as date) END FROM t
+-- !query 120 schema
+struct<>
+-- !query 120 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST(1 AS BOOLEAN) ELSE CAST('2017-12-11 09:30:00' AS DATE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 121
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as tinyint) END FROM t
+-- !query 121 schema
+struct<>
+-- !query 121 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) ELSE CAST(2 AS TINYINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 122
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as smallint) END FROM t
+-- !query 122 schema
+struct<>
+-- !query 122 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) ELSE CAST(2 AS SMALLINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 123
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as int) END FROM t
+-- !query 123 schema
+struct<>
+-- !query 123 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) ELSE CAST(2 AS INT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 124
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as bigint) END FROM t
+-- !query 124 schema
+struct<>
+-- !query 124 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) ELSE CAST(2 AS BIGINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 125
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as float) END FROM t
+-- !query 125 schema
+struct<>
+-- !query 125 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) ELSE CAST(2 AS FLOAT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 126
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as double) END FROM t
+-- !query 126 schema
+struct<>
+-- !query 126 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) ELSE CAST(2 AS DOUBLE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 127
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as decimal(10, 0)) END FROM t
+-- !query 127 schema
+struct<>
+-- !query 127 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) ELSE CAST(2 AS DECIMAL(10,0)) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 128
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as string) END FROM t
+-- !query 128 schema
+struct
+-- !query 128 output
+2017-12-12 09:30:00
+
+
+-- !query 129
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast('2' as binary) END FROM t
+-- !query 129 schema
+struct<>
+-- !query 129 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) ELSE CAST('2' AS BINARY) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 130
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as boolean) END FROM t
+-- !query 130 schema
+struct<>
+-- !query 130 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) ELSE CAST(2 AS BOOLEAN) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 131
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t
+-- !query 131 schema
+struct
+-- !query 131 output
+2017-12-12 09:30:00
+
+
+-- !query 132
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast('2017-12-11 09:30:00' as date) END FROM t
+-- !query 132 schema
+struct
+-- !query 132 output
+2017-12-12 09:30:00
+
+
+-- !query 133
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as tinyint) END FROM t
+-- !query 133 schema
+struct<>
+-- !query 133 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00' AS DATE) ELSE CAST(2 AS TINYINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 134
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as smallint) END FROM t
+-- !query 134 schema
+struct<>
+-- !query 134 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00' AS DATE) ELSE CAST(2 AS SMALLINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 135
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as int) END FROM t
+-- !query 135 schema
+struct<>
+-- !query 135 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00' AS DATE) ELSE CAST(2 AS INT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 136
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as bigint) END FROM t
+-- !query 136 schema
+struct<>
+-- !query 136 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00' AS DATE) ELSE CAST(2 AS BIGINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 137
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as float) END FROM t
+-- !query 137 schema
+struct<>
+-- !query 137 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00' AS DATE) ELSE CAST(2 AS FLOAT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 138
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as double) END FROM t
+-- !query 138 schema
+struct<>
+-- !query 138 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00' AS DATE) ELSE CAST(2 AS DOUBLE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 139
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as decimal(10, 0)) END FROM t
+-- !query 139 schema
+struct<>
+-- !query 139 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00' AS DATE) ELSE CAST(2 AS DECIMAL(10,0)) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 140
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as string) END FROM t
+-- !query 140 schema
+struct
+-- !query 140 output
+2017-12-12
+
+
+-- !query 141
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast('2' as binary) END FROM t
+-- !query 141 schema
+struct<>
+-- !query 141 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00' AS DATE) ELSE CAST('2' AS BINARY) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 142
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as boolean) END FROM t
+-- !query 142 schema
+struct<>
+-- !query 142 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00' AS DATE) ELSE CAST(2 AS BOOLEAN) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7
+
+
+-- !query 143
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t
+-- !query 143 schema
+struct
+-- !query 143 output
+2017-12-12 00:00:00
+
+
+-- !query 144
+SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast('2017-12-11 09:30:00' as date) END FROM t
+-- !query 144 schema
+struct
+-- !query 144 output
+2017-12-12
diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out
new file mode 100644
index 0000000000000..09729fdc2ec32
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out
@@ -0,0 +1,239 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 11
+
+
+-- !query 0
+SELECT (col1 || col2 || col3) col
+FROM (
+ SELECT
+ id col1,
+ string(id + 1) col2,
+ encode(string(id + 2), 'utf-8') col3
+ FROM range(10)
+)
+-- !query 0 schema
+struct
+-- !query 0 output
+012
+123
+234
+345
+456
+567
+678
+789
+8910
+91011
+
+
+-- !query 1
+SELECT ((col1 || col2) || (col3 || col4) || col5) col
+FROM (
+ SELECT
+ 'prefix_' col1,
+ id col2,
+ string(id + 1) col3,
+ encode(string(id + 2), 'utf-8') col4,
+ CAST(id AS DOUBLE) col5
+ FROM range(10)
+)
+-- !query 1 schema
+struct
+-- !query 1 output
+prefix_0120.0
+prefix_1231.0
+prefix_2342.0
+prefix_3453.0
+prefix_4564.0
+prefix_5675.0
+prefix_6786.0
+prefix_7897.0
+prefix_89108.0
+prefix_910119.0
+
+
+-- !query 2
+SELECT ((col1 || col2) || (col3 || col4)) col
+FROM (
+ SELECT
+ string(id) col1,
+ string(id + 1) col2,
+ encode(string(id + 2), 'utf-8') col3,
+ encode(string(id + 3), 'utf-8') col4
+ FROM range(10)
+)
+-- !query 2 schema
+struct
+-- !query 2 output
+0123
+1234
+2345
+3456
+4567
+5678
+6789
+78910
+891011
+9101112
+
+
+-- !query 3
+set spark.sql.function.concatBinaryAsString=true
+-- !query 3 schema
+struct
+-- !query 3 output
+spark.sql.function.concatBinaryAsString true
+
+
+-- !query 4
+SELECT (col1 || col2) col
+FROM (
+ SELECT
+ encode(string(id), 'utf-8') col1,
+ encode(string(id + 1), 'utf-8') col2
+ FROM range(10)
+)
+-- !query 4 schema
+struct
+-- !query 4 output
+01
+12
+23
+34
+45
+56
+67
+78
+89
+910
+
+
+-- !query 5
+SELECT (col1 || col2 || col3 || col4) col
+FROM (
+ SELECT
+ encode(string(id), 'utf-8') col1,
+ encode(string(id + 1), 'utf-8') col2,
+ encode(string(id + 2), 'utf-8') col3,
+ encode(string(id + 3), 'utf-8') col4
+ FROM range(10)
+)
+-- !query 5 schema
+struct
+-- !query 5 output
+0123
+1234
+2345
+3456
+4567
+5678
+6789
+78910
+891011
+9101112
+
+
+-- !query 6
+SELECT ((col1 || col2) || (col3 || col4)) col
+FROM (
+ SELECT
+ encode(string(id), 'utf-8') col1,
+ encode(string(id + 1), 'utf-8') col2,
+ encode(string(id + 2), 'utf-8') col3,
+ encode(string(id + 3), 'utf-8') col4
+ FROM range(10)
+)
+-- !query 6 schema
+struct
+-- !query 6 output
+0123
+1234
+2345
+3456
+4567
+5678
+6789
+78910
+891011
+9101112
+
+
+-- !query 7
+set spark.sql.function.concatBinaryAsString=false
+-- !query 7 schema
+struct
+-- !query 7 output
+spark.sql.function.concatBinaryAsString false
+
+
+-- !query 8
+SELECT (col1 || col2) col
+FROM (
+ SELECT
+ encode(string(id), 'utf-8') col1,
+ encode(string(id + 1), 'utf-8') col2
+ FROM range(10)
+)
+-- !query 8 schema
+struct
+-- !query 8 output
+01
+12
+23
+34
+45
+56
+67
+78
+89
+910
+
+
+-- !query 9
+SELECT (col1 || col2 || col3 || col4) col
+FROM (
+ SELECT
+ encode(string(id), 'utf-8') col1,
+ encode(string(id + 1), 'utf-8') col2,
+ encode(string(id + 2), 'utf-8') col3,
+ encode(string(id + 3), 'utf-8') col4
+ FROM range(10)
+)
+-- !query 9 schema
+struct
+-- !query 9 output
+0123
+1234
+2345
+3456
+4567
+5678
+6789
+78910
+891011
+9101112
+
+
+-- !query 10
+SELECT ((col1 || col2) || (col3 || col4)) col
+FROM (
+ SELECT
+ encode(string(id), 'utf-8') col1,
+ encode(string(id + 1), 'utf-8') col2,
+ encode(string(id + 2), 'utf-8') col3,
+ encode(string(id + 3), 'utf-8') col4
+ FROM range(10)
+)
+-- !query 10 schema
+struct
+-- !query 10 output
+0123
+1234
+2345
+3456
+4567
+5678
+6789
+78910
+891011
+9101112
diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/dateTimeOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/dateTimeOperations.sql.out
new file mode 100644
index 0000000000000..12c1d1617679f
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/dateTimeOperations.sql.out
@@ -0,0 +1,349 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 40
+
+
+-- !query 0
+CREATE TEMPORARY VIEW t AS SELECT 1
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+select cast(1 as tinyint) + interval 2 day
+-- !query 1 schema
+struct<>
+-- !query 1 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS TINYINT) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS TINYINT) + interval 2 days)' (tinyint and calendarinterval).; line 1 pos 7
+
+
+-- !query 2
+select cast(1 as smallint) + interval 2 day
+-- !query 2 schema
+struct<>
+-- !query 2 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS SMALLINT) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS SMALLINT) + interval 2 days)' (smallint and calendarinterval).; line 1 pos 7
+
+
+-- !query 3
+select cast(1 as int) + interval 2 day
+-- !query 3 schema
+struct<>
+-- !query 3 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS INT) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS INT) + interval 2 days)' (int and calendarinterval).; line 1 pos 7
+
+
+-- !query 4
+select cast(1 as bigint) + interval 2 day
+-- !query 4 schema
+struct<>
+-- !query 4 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS BIGINT) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS BIGINT) + interval 2 days)' (bigint and calendarinterval).; line 1 pos 7
+
+
+-- !query 5
+select cast(1 as float) + interval 2 day
+-- !query 5 schema
+struct<>
+-- !query 5 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS FLOAT) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS FLOAT) + interval 2 days)' (float and calendarinterval).; line 1 pos 7
+
+
+-- !query 6
+select cast(1 as double) + interval 2 day
+-- !query 6 schema
+struct<>
+-- !query 6 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DOUBLE) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS DOUBLE) + interval 2 days)' (double and calendarinterval).; line 1 pos 7
+
+
+-- !query 7
+select cast(1 as decimal(10, 0)) + interval 2 day
+-- !query 7 schema
+struct<>
+-- !query 7 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) + interval 2 days)' (decimal(10,0) and calendarinterval).; line 1 pos 7
+
+
+-- !query 8
+select cast('2017-12-11' as string) + interval 2 day
+-- !query 8 schema
+struct
+-- !query 8 output
+2017-12-13 00:00:00
+
+
+-- !query 9
+select cast('2017-12-11 09:30:00' as string) + interval 2 day
+-- !query 9 schema
+struct
+-- !query 9 output
+2017-12-13 09:30:00
+
+
+-- !query 10
+select cast('1' as binary) + interval 2 day
+-- !query 10 schema
+struct<>
+-- !query 10 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) + interval 2 days)' due to data type mismatch: differing types in '(CAST('1' AS BINARY) + interval 2 days)' (binary and calendarinterval).; line 1 pos 7
+
+
+-- !query 11
+select cast(1 as boolean) + interval 2 day
+-- !query 11 schema
+struct<>
+-- !query 11 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS BOOLEAN) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) + interval 2 days)' (boolean and calendarinterval).; line 1 pos 7
+
+
+-- !query 12
+select cast('2017-12-11 09:30:00.0' as timestamp) + interval 2 day
+-- !query 12 schema
+struct
+-- !query 12 output
+2017-12-13 09:30:00
+
+
+-- !query 13
+select cast('2017-12-11 09:30:00' as date) + interval 2 day
+-- !query 13 schema
+struct
+-- !query 13 output
+2017-12-13
+
+
+-- !query 14
+select interval 2 day + cast(1 as tinyint)
+-- !query 14 schema
+struct<>
+-- !query 14 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(interval 2 days + CAST(1 AS TINYINT))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS TINYINT))' (calendarinterval and tinyint).; line 1 pos 7
+
+
+-- !query 15
+select interval 2 day + cast(1 as smallint)
+-- !query 15 schema
+struct<>
+-- !query 15 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(interval 2 days + CAST(1 AS SMALLINT))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS SMALLINT))' (calendarinterval and smallint).; line 1 pos 7
+
+
+-- !query 16
+select interval 2 day + cast(1 as int)
+-- !query 16 schema
+struct<>
+-- !query 16 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(interval 2 days + CAST(1 AS INT))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS INT))' (calendarinterval and int).; line 1 pos 7
+
+
+-- !query 17
+select interval 2 day + cast(1 as bigint)
+-- !query 17 schema
+struct<>
+-- !query 17 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(interval 2 days + CAST(1 AS BIGINT))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS BIGINT))' (calendarinterval and bigint).; line 1 pos 7
+
+
+-- !query 18
+select interval 2 day + cast(1 as float)
+-- !query 18 schema
+struct<>
+-- !query 18 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(interval 2 days + CAST(1 AS FLOAT))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS FLOAT))' (calendarinterval and float).; line 1 pos 7
+
+
+-- !query 19
+select interval 2 day + cast(1 as double)
+-- !query 19 schema
+struct<>
+-- !query 19 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(interval 2 days + CAST(1 AS DOUBLE))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS DOUBLE))' (calendarinterval and double).; line 1 pos 7
+
+
+-- !query 20
+select interval 2 day + cast(1 as decimal(10, 0))
+-- !query 20 schema
+struct<>
+-- !query 20 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(interval 2 days + CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS DECIMAL(10,0)))' (calendarinterval and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 21
+select interval 2 day + cast('2017-12-11' as string)
+-- !query 21 schema
+struct
+-- !query 21 output
+2017-12-13 00:00:00
+
+
+-- !query 22
+select interval 2 day + cast('2017-12-11 09:30:00' as string)
+-- !query 22 schema
+struct
+-- !query 22 output
+2017-12-13 09:30:00
+
+
+-- !query 23
+select interval 2 day + cast('1' as binary)
+-- !query 23 schema
+struct<>
+-- !query 23 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(interval 2 days + CAST('1' AS BINARY))' due to data type mismatch: differing types in '(interval 2 days + CAST('1' AS BINARY))' (calendarinterval and binary).; line 1 pos 7
+
+
+-- !query 24
+select interval 2 day + cast(1 as boolean)
+-- !query 24 schema
+struct<>
+-- !query 24 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(interval 2 days + CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS BOOLEAN))' (calendarinterval and boolean).; line 1 pos 7
+
+
+-- !query 25
+select interval 2 day + cast('2017-12-11 09:30:00.0' as timestamp)
+-- !query 25 schema
+struct
+-- !query 25 output
+2017-12-13 09:30:00
+
+
+-- !query 26
+select interval 2 day + cast('2017-12-11 09:30:00' as date)
+-- !query 26 schema
+struct
+-- !query 26 output
+2017-12-13
+
+
+-- !query 27
+select cast(1 as tinyint) - interval 2 day
+-- !query 27 schema
+struct<>
+-- !query 27 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS TINYINT) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS TINYINT) - interval 2 days)' (tinyint and calendarinterval).; line 1 pos 7
+
+
+-- !query 28
+select cast(1 as smallint) - interval 2 day
+-- !query 28 schema
+struct<>
+-- !query 28 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS SMALLINT) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS SMALLINT) - interval 2 days)' (smallint and calendarinterval).; line 1 pos 7
+
+
+-- !query 29
+select cast(1 as int) - interval 2 day
+-- !query 29 schema
+struct<>
+-- !query 29 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS INT) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS INT) - interval 2 days)' (int and calendarinterval).; line 1 pos 7
+
+
+-- !query 30
+select cast(1 as bigint) - interval 2 day
+-- !query 30 schema
+struct<>
+-- !query 30 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS BIGINT) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS BIGINT) - interval 2 days)' (bigint and calendarinterval).; line 1 pos 7
+
+
+-- !query 31
+select cast(1 as float) - interval 2 day
+-- !query 31 schema
+struct<>
+-- !query 31 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS FLOAT) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS FLOAT) - interval 2 days)' (float and calendarinterval).; line 1 pos 7
+
+
+-- !query 32
+select cast(1 as double) - interval 2 day
+-- !query 32 schema
+struct<>
+-- !query 32 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DOUBLE) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS DOUBLE) - interval 2 days)' (double and calendarinterval).; line 1 pos 7
+
+
+-- !query 33
+select cast(1 as decimal(10, 0)) - interval 2 day
+-- !query 33 schema
+struct<>
+-- !query 33 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) - interval 2 days)' (decimal(10,0) and calendarinterval).; line 1 pos 7
+
+
+-- !query 34
+select cast('2017-12-11' as string) - interval 2 day
+-- !query 34 schema
+struct
+-- !query 34 output
+2017-12-09 00:00:00
+
+
+-- !query 35
+select cast('2017-12-11 09:30:00' as string) - interval 2 day
+-- !query 35 schema
+struct
+-- !query 35 output
+2017-12-09 09:30:00
+
+
+-- !query 36
+select cast('1' as binary) - interval 2 day
+-- !query 36 schema
+struct<>
+-- !query 36 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) - interval 2 days)' due to data type mismatch: differing types in '(CAST('1' AS BINARY) - interval 2 days)' (binary and calendarinterval).; line 1 pos 7
+
+
+-- !query 37
+select cast(1 as boolean) - interval 2 day
+-- !query 37 schema
+struct<>
+-- !query 37 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS BOOLEAN) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) - interval 2 days)' (boolean and calendarinterval).; line 1 pos 7
+
+
+-- !query 38
+select cast('2017-12-11 09:30:00.0' as timestamp) - interval 2 day
+-- !query 38 schema
+struct
+-- !query 38 output
+2017-12-09 09:30:00
+
+
+-- !query 39
+select cast('2017-12-11 09:30:00' as date) - interval 2 day
+-- !query 39 schema
+struct
+-- !query 39 output
+2017-12-09
diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out
new file mode 100644
index 0000000000000..ce02f6adc456c
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out
@@ -0,0 +1,82 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 10
+
+
+-- !query 0
+CREATE TEMPORARY VIEW t AS SELECT 1.0 as a, 0.0 as b
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+select a / b from t
+-- !query 1 schema
+struct<(CAST(a AS DECIMAL(2,1)) / CAST(b AS DECIMAL(2,1))):decimal(8,6)>
+-- !query 1 output
+NULL
+
+
+-- !query 2
+select a % b from t
+-- !query 2 schema
+struct<(CAST(a AS DECIMAL(2,1)) % CAST(b AS DECIMAL(2,1))):decimal(1,1)>
+-- !query 2 output
+NULL
+
+
+-- !query 3
+select pmod(a, b) from t
+-- !query 3 schema
+struct
+-- !query 3 output
+NULL
+
+
+-- !query 4
+select (5e36 + 0.1) + 5e36
+-- !query 4 schema
+struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)>
+-- !query 4 output
+NULL
+
+
+-- !query 5
+select (-4e36 - 0.1) - 7e36
+-- !query 5 schema
+struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)>
+-- !query 5 output
+NULL
+
+
+-- !query 6
+select 12345678901234567890.0 * 12345678901234567890.0
+-- !query 6 schema
+struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)>
+-- !query 6 output
+NULL
+
+
+-- !query 7
+select 1e35 / 0.1
+-- !query 7 schema
+struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)>
+-- !query 7 output
+NULL
+
+
+-- !query 8
+select 123456789123456789.1234567890 * 1.123456789123456789
+-- !query 8 schema
+struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)>
+-- !query 8 output
+NULL
+
+
+-- !query 9
+select 0.001 / 9876543210987654321098765432109876543.2
+-- !query 9 schema
+struct<(CAST(0.001 AS DECIMAL(38,3)) / CAST(9876543210987654321098765432109876543.2 AS DECIMAL(38,3))):decimal(38,37)>
+-- !query 9 output
+NULL
diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out
new file mode 100644
index 0000000000000..ebc8201ed5a1d
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out
@@ -0,0 +1,9514 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 1145
+
+
+-- !query 0
+CREATE TEMPORARY VIEW t AS SELECT 1
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+SELECT cast(1 as tinyint) + cast(1 as decimal(3, 0)) FROM t
+-- !query 1 schema
+struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) + CAST(1 AS DECIMAL(3,0))):decimal(4,0)>
+-- !query 1 output
+2
+
+
+-- !query 2
+SELECT cast(1 as tinyint) + cast(1 as decimal(5, 0)) FROM t
+-- !query 2 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(6,0)) + CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(6,0))):decimal(6,0)>
+-- !query 2 output
+2
+
+
+-- !query 3
+SELECT cast(1 as tinyint) + cast(1 as decimal(10, 0)) FROM t
+-- !query 3 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(11,0)) + CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 3 output
+2
+
+
+-- !query 4
+SELECT cast(1 as tinyint) + cast(1 as decimal(20, 0)) FROM t
+-- !query 4 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(21,0)) + CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 4 output
+2
+
+
+-- !query 5
+SELECT cast(1 as smallint) + cast(1 as decimal(3, 0)) FROM t
+-- !query 5 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(6,0)) + CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(6,0))):decimal(6,0)>
+-- !query 5 output
+2
+
+
+-- !query 6
+SELECT cast(1 as smallint) + cast(1 as decimal(5, 0)) FROM t
+-- !query 6 schema
+struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) + CAST(1 AS DECIMAL(5,0))):decimal(6,0)>
+-- !query 6 output
+2
+
+
+-- !query 7
+SELECT cast(1 as smallint) + cast(1 as decimal(10, 0)) FROM t
+-- !query 7 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(11,0)) + CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 7 output
+2
+
+
+-- !query 8
+SELECT cast(1 as smallint) + cast(1 as decimal(20, 0)) FROM t
+-- !query 8 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(21,0)) + CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 8 output
+2
+
+
+-- !query 9
+SELECT cast(1 as int) + cast(1 as decimal(3, 0)) FROM t
+-- !query 9 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(11,0)) + CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 9 output
+2
+
+
+-- !query 10
+SELECT cast(1 as int) + cast(1 as decimal(5, 0)) FROM t
+-- !query 10 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(11,0)) + CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 10 output
+2
+
+
+-- !query 11
+SELECT cast(1 as int) + cast(1 as decimal(10, 0)) FROM t
+-- !query 11 schema
+struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) + CAST(1 AS DECIMAL(10,0))):decimal(11,0)>
+-- !query 11 output
+2
+
+
+-- !query 12
+SELECT cast(1 as int) + cast(1 as decimal(20, 0)) FROM t
+-- !query 12 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(21,0)) + CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 12 output
+2
+
+
+-- !query 13
+SELECT cast(1 as bigint) + cast(1 as decimal(3, 0)) FROM t
+-- !query 13 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0)) + CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 13 output
+2
+
+
+-- !query 14
+SELECT cast(1 as bigint) + cast(1 as decimal(5, 0)) FROM t
+-- !query 14 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0)) + CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 14 output
+2
+
+
+-- !query 15
+SELECT cast(1 as bigint) + cast(1 as decimal(10, 0)) FROM t
+-- !query 15 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0)) + CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 15 output
+2
+
+
+-- !query 16
+SELECT cast(1 as bigint) + cast(1 as decimal(20, 0)) FROM t
+-- !query 16 schema
+struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) + CAST(1 AS DECIMAL(20,0))):decimal(21,0)>
+-- !query 16 output
+2
+
+
+-- !query 17
+SELECT cast(1 as float) + cast(1 as decimal(3, 0)) FROM t
+-- !query 17 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) + CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):double>
+-- !query 17 output
+2.0
+
+
+-- !query 18
+SELECT cast(1 as float) + cast(1 as decimal(5, 0)) FROM t
+-- !query 18 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) + CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):double>
+-- !query 18 output
+2.0
+
+
+-- !query 19
+SELECT cast(1 as float) + cast(1 as decimal(10, 0)) FROM t
+-- !query 19 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) + CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double>
+-- !query 19 output
+2.0
+
+
+-- !query 20
+SELECT cast(1 as float) + cast(1 as decimal(20, 0)) FROM t
+-- !query 20 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) + CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):double>
+-- !query 20 output
+2.0
+
+
+-- !query 21
+SELECT cast(1 as double) + cast(1 as decimal(3, 0)) FROM t
+-- !query 21 schema
+struct<(CAST(1 AS DOUBLE) + CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):double>
+-- !query 21 output
+2.0
+
+
+-- !query 22
+SELECT cast(1 as double) + cast(1 as decimal(5, 0)) FROM t
+-- !query 22 schema
+struct<(CAST(1 AS DOUBLE) + CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):double>
+-- !query 22 output
+2.0
+
+
+-- !query 23
+SELECT cast(1 as double) + cast(1 as decimal(10, 0)) FROM t
+-- !query 23 schema
+struct<(CAST(1 AS DOUBLE) + CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double>
+-- !query 23 output
+2.0
+
+
+-- !query 24
+SELECT cast(1 as double) + cast(1 as decimal(20, 0)) FROM t
+-- !query 24 schema
+struct<(CAST(1 AS DOUBLE) + CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):double>
+-- !query 24 output
+2.0
+
+
+-- !query 25
+SELECT cast(1 as decimal(10, 0)) + cast(1 as decimal(3, 0)) FROM t
+-- !query 25 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0)) + CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 25 output
+2
+
+
+-- !query 26
+SELECT cast(1 as decimal(10, 0)) + cast(1 as decimal(5, 0)) FROM t
+-- !query 26 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0)) + CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 26 output
+2
+
+
+-- !query 27
+SELECT cast(1 as decimal(10, 0)) + cast(1 as decimal(10, 0)) FROM t
+-- !query 27 schema
+struct<(CAST(1 AS DECIMAL(10,0)) + CAST(1 AS DECIMAL(10,0))):decimal(11,0)>
+-- !query 27 output
+2
+
+
+-- !query 28
+SELECT cast(1 as decimal(10, 0)) + cast(1 as decimal(20, 0)) FROM t
+-- !query 28 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(21,0)) + CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 28 output
+2
+
+
+-- !query 29
+SELECT cast('1' as binary) + cast(1 as decimal(3, 0)) FROM t
+-- !query 29 schema
+struct<>
+-- !query 29 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) + CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) + CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 30
+SELECT cast('1' as binary) + cast(1 as decimal(5, 0)) FROM t
+-- !query 30 schema
+struct<>
+-- !query 30 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) + CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) + CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 31
+SELECT cast('1' as binary) + cast(1 as decimal(10, 0)) FROM t
+-- !query 31 schema
+struct<>
+-- !query 31 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) + CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) + CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 32
+SELECT cast('1' as binary) + cast(1 as decimal(20, 0)) FROM t
+-- !query 32 schema
+struct<>
+-- !query 32 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) + CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) + CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 33
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) + cast(1 as decimal(3, 0)) FROM t
+-- !query 33 schema
+struct<>
+-- !query 33 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) + CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) + CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 34
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) + cast(1 as decimal(5, 0)) FROM t
+-- !query 34 schema
+struct<>
+-- !query 34 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) + CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) + CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 35
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) + cast(1 as decimal(10, 0)) FROM t
+-- !query 35 schema
+struct<>
+-- !query 35 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) + CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) + CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 36
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) + cast(1 as decimal(20, 0)) FROM t
+-- !query 36 schema
+struct<>
+-- !query 36 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) + CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) + CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 37
+SELECT cast('2017-12-11 09:30:00' as date) + cast(1 as decimal(3, 0)) FROM t
+-- !query 37 schema
+struct<>
+-- !query 37 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) + CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) + CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 38
+SELECT cast('2017-12-11 09:30:00' as date) + cast(1 as decimal(5, 0)) FROM t
+-- !query 38 schema
+struct<>
+-- !query 38 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) + CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) + CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 39
+SELECT cast('2017-12-11 09:30:00' as date) + cast(1 as decimal(10, 0)) FROM t
+-- !query 39 schema
+struct<>
+-- !query 39 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) + CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) + CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 40
+SELECT cast('2017-12-11 09:30:00' as date) + cast(1 as decimal(20, 0)) FROM t
+-- !query 40 schema
+struct<>
+-- !query 40 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) + CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) + CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 41
+SELECT cast(1 as decimal(3, 0)) + cast(1 as tinyint) FROM t
+-- !query 41 schema
+struct<(CAST(1 AS DECIMAL(3,0)) + CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0))):decimal(4,0)>
+-- !query 41 output
+2
+
+
+-- !query 42
+SELECT cast(1 as decimal(5, 0)) + cast(1 as tinyint) FROM t
+-- !query 42 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(6,0)) + CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(6,0))):decimal(6,0)>
+-- !query 42 output
+2
+
+
+-- !query 43
+SELECT cast(1 as decimal(10, 0)) + cast(1 as tinyint) FROM t
+-- !query 43 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0)) + CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 43 output
+2
+
+
+-- !query 44
+SELECT cast(1 as decimal(20, 0)) + cast(1 as tinyint) FROM t
+-- !query 44 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0)) + CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 44 output
+2
+
+
+-- !query 45
+SELECT cast(1 as decimal(3, 0)) + cast(1 as smallint) FROM t
+-- !query 45 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(6,0)) + CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(6,0))):decimal(6,0)>
+-- !query 45 output
+2
+
+
+-- !query 46
+SELECT cast(1 as decimal(5, 0)) + cast(1 as smallint) FROM t
+-- !query 46 schema
+struct<(CAST(1 AS DECIMAL(5,0)) + CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0))):decimal(6,0)>
+-- !query 46 output
+2
+
+
+-- !query 47
+SELECT cast(1 as decimal(10, 0)) + cast(1 as smallint) FROM t
+-- !query 47 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0)) + CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 47 output
+2
+
+
+-- !query 48
+SELECT cast(1 as decimal(20, 0)) + cast(1 as smallint) FROM t
+-- !query 48 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0)) + CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 48 output
+2
+
+
+-- !query 49
+SELECT cast(1 as decimal(3, 0)) + cast(1 as int) FROM t
+-- !query 49 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(11,0)) + CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 49 output
+2
+
+
+-- !query 50
+SELECT cast(1 as decimal(5, 0)) + cast(1 as int) FROM t
+-- !query 50 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(11,0)) + CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 50 output
+2
+
+
+-- !query 51
+SELECT cast(1 as decimal(10, 0)) + cast(1 as int) FROM t
+-- !query 51 schema
+struct<(CAST(1 AS DECIMAL(10,0)) + CAST(CAST(1 AS INT) AS DECIMAL(10,0))):decimal(11,0)>
+-- !query 51 output
+2
+
+
+-- !query 52
+SELECT cast(1 as decimal(20, 0)) + cast(1 as int) FROM t
+-- !query 52 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0)) + CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 52 output
+2
+
+
+-- !query 53
+SELECT cast(1 as decimal(3, 0)) + cast(1 as bigint) FROM t
+-- !query 53 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(21,0)) + CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 53 output
+2
+
+
+-- !query 54
+SELECT cast(1 as decimal(5, 0)) + cast(1 as bigint) FROM t
+-- !query 54 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(21,0)) + CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 54 output
+2
+
+
+-- !query 55
+SELECT cast(1 as decimal(10, 0)) + cast(1 as bigint) FROM t
+-- !query 55 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(21,0)) + CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 55 output
+2
+
+
+-- !query 56
+SELECT cast(1 as decimal(20, 0)) + cast(1 as bigint) FROM t
+-- !query 56 schema
+struct<(CAST(1 AS DECIMAL(20,0)) + CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):decimal(21,0)>
+-- !query 56 output
+2
+
+
+-- !query 57
+SELECT cast(1 as decimal(3, 0)) + cast(1 as float) FROM t
+-- !query 57 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) + CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 57 output
+2.0
+
+
+-- !query 58
+SELECT cast(1 as decimal(5, 0)) + cast(1 as float) FROM t
+-- !query 58 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) + CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 58 output
+2.0
+
+
+-- !query 59
+SELECT cast(1 as decimal(10, 0)) + cast(1 as float) FROM t
+-- !query 59 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) + CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 59 output
+2.0
+
+
+-- !query 60
+SELECT cast(1 as decimal(20, 0)) + cast(1 as float) FROM t
+-- !query 60 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) + CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 60 output
+2.0
+
+
+-- !query 61
+SELECT cast(1 as decimal(3, 0)) + cast(1 as double) FROM t
+-- !query 61 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) + CAST(1 AS DOUBLE)):double>
+-- !query 61 output
+2.0
+
+
+-- !query 62
+SELECT cast(1 as decimal(5, 0)) + cast(1 as double) FROM t
+-- !query 62 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) + CAST(1 AS DOUBLE)):double>
+-- !query 62 output
+2.0
+
+
+-- !query 63
+SELECT cast(1 as decimal(10, 0)) + cast(1 as double) FROM t
+-- !query 63 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) + CAST(1 AS DOUBLE)):double>
+-- !query 63 output
+2.0
+
+
+-- !query 64
+SELECT cast(1 as decimal(20, 0)) + cast(1 as double) FROM t
+-- !query 64 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) + CAST(1 AS DOUBLE)):double>
+-- !query 64 output
+2.0
+
+
+-- !query 65
+SELECT cast(1 as decimal(3, 0)) + cast(1 as decimal(10, 0)) FROM t
+-- !query 65 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(11,0)) + CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 65 output
+2
+
+
+-- !query 66
+SELECT cast(1 as decimal(5, 0)) + cast(1 as decimal(10, 0)) FROM t
+-- !query 66 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(11,0)) + CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 66 output
+2
+
+
+-- !query 67
+SELECT cast(1 as decimal(10, 0)) + cast(1 as decimal(10, 0)) FROM t
+-- !query 67 schema
+struct<(CAST(1 AS DECIMAL(10,0)) + CAST(1 AS DECIMAL(10,0))):decimal(11,0)>
+-- !query 67 output
+2
+
+
+-- !query 68
+SELECT cast(1 as decimal(20, 0)) + cast(1 as decimal(10, 0)) FROM t
+-- !query 68 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0)) + CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 68 output
+2
+
+
+-- !query 69
+SELECT cast(1 as decimal(3, 0)) + cast(1 as string) FROM t
+-- !query 69 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) + CAST(CAST(1 AS STRING) AS DOUBLE)):double>
+-- !query 69 output
+2.0
+
+
+-- !query 70
+SELECT cast(1 as decimal(5, 0)) + cast(1 as string) FROM t
+-- !query 70 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) + CAST(CAST(1 AS STRING) AS DOUBLE)):double>
+-- !query 70 output
+2.0
+
+
+-- !query 71
+SELECT cast(1 as decimal(10, 0)) + cast(1 as string) FROM t
+-- !query 71 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) + CAST(CAST(1 AS STRING) AS DOUBLE)):double>
+-- !query 71 output
+2.0
+
+
+-- !query 72
+SELECT cast(1 as decimal(20, 0)) + cast(1 as string) FROM t
+-- !query 72 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) + CAST(CAST(1 AS STRING) AS DOUBLE)):double>
+-- !query 72 output
+2.0
+
+
+-- !query 73
+SELECT cast(1 as decimal(3, 0)) + cast('1' as binary) FROM t
+-- !query 73 schema
+struct<>
+-- !query 73 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) + CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) + CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7
+
+
+-- !query 74
+SELECT cast(1 as decimal(5, 0)) + cast('1' as binary) FROM t
+-- !query 74 schema
+struct<>
+-- !query 74 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) + CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) + CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7
+
+
+-- !query 75
+SELECT cast(1 as decimal(10, 0)) + cast('1' as binary) FROM t
+-- !query 75 schema
+struct<>
+-- !query 75 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) + CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) + CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7
+
+
+-- !query 76
+SELECT cast(1 as decimal(20, 0)) + cast('1' as binary) FROM t
+-- !query 76 schema
+struct<>
+-- !query 76 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) + CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) + CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7
+
+
+-- !query 77
+SELECT cast(1 as decimal(3, 0)) + cast(1 as boolean) FROM t
+-- !query 77 schema
+struct<>
+-- !query 77 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) + CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) + CAST(1 AS BOOLEAN))' (decimal(3,0) and boolean).; line 1 pos 7
+
+
+-- !query 78
+SELECT cast(1 as decimal(5, 0)) + cast(1 as boolean) FROM t
+-- !query 78 schema
+struct<>
+-- !query 78 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) + CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) + CAST(1 AS BOOLEAN))' (decimal(5,0) and boolean).; line 1 pos 7
+
+
+-- !query 79
+SELECT cast(1 as decimal(10, 0)) + cast(1 as boolean) FROM t
+-- !query 79 schema
+struct<>
+-- !query 79 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) + CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) + CAST(1 AS BOOLEAN))' (decimal(10,0) and boolean).; line 1 pos 7
+
+
+-- !query 80
+SELECT cast(1 as decimal(20, 0)) + cast(1 as boolean) FROM t
+-- !query 80 schema
+struct<>
+-- !query 80 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) + CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) + CAST(1 AS BOOLEAN))' (decimal(20,0) and boolean).; line 1 pos 7
+
+
+-- !query 81
+SELECT cast(1 as decimal(3, 0)) + cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 81 schema
+struct<>
+-- !query 81 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) + CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) + CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7
+
+
+-- !query 82
+SELECT cast(1 as decimal(5, 0)) + cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 82 schema
+struct<>
+-- !query 82 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) + CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) + CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7
+
+
+-- !query 83
+SELECT cast(1 as decimal(10, 0)) + cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 83 schema
+struct<>
+-- !query 83 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) + CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) + CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7
+
+
+-- !query 84
+SELECT cast(1 as decimal(20, 0)) + cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 84 schema
+struct<>
+-- !query 84 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) + CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) + CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7
+
+
+-- !query 85
+SELECT cast(1 as decimal(3, 0)) + cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 85 schema
+struct<>
+-- !query 85 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) + CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) + CAST('2017-12-11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7
+
+
+-- !query 86
+SELECT cast(1 as decimal(5, 0)) + cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 86 schema
+struct<>
+-- !query 86 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) + CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) + CAST('2017-12-11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7
+
+
+-- !query 87
+SELECT cast(1 as decimal(10, 0)) + cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 87 schema
+struct<>
+-- !query 87 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) + CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) + CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7
+
+
+-- !query 88
+SELECT cast(1 as decimal(20, 0)) + cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 88 schema
+struct<>
+-- !query 88 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) + CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) + CAST('2017-12-11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7
+
+
+-- !query 89
+SELECT cast(1 as tinyint) - cast(1 as decimal(3, 0)) FROM t
+-- !query 89 schema
+struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) - CAST(1 AS DECIMAL(3,0))):decimal(4,0)>
+-- !query 89 output
+0
+
+
+-- !query 90
+SELECT cast(1 as tinyint) - cast(1 as decimal(5, 0)) FROM t
+-- !query 90 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(6,0)) - CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(6,0))):decimal(6,0)>
+-- !query 90 output
+0
+
+
+-- !query 91
+SELECT cast(1 as tinyint) - cast(1 as decimal(10, 0)) FROM t
+-- !query 91 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(11,0)) - CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 91 output
+0
+
+
+-- !query 92
+SELECT cast(1 as tinyint) - cast(1 as decimal(20, 0)) FROM t
+-- !query 92 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(21,0)) - CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 92 output
+0
+
+
+-- !query 93
+SELECT cast(1 as smallint) - cast(1 as decimal(3, 0)) FROM t
+-- !query 93 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(6,0)) - CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(6,0))):decimal(6,0)>
+-- !query 93 output
+0
+
+
+-- !query 94
+SELECT cast(1 as smallint) - cast(1 as decimal(5, 0)) FROM t
+-- !query 94 schema
+struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) - CAST(1 AS DECIMAL(5,0))):decimal(6,0)>
+-- !query 94 output
+0
+
+
+-- !query 95
+SELECT cast(1 as smallint) - cast(1 as decimal(10, 0)) FROM t
+-- !query 95 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(11,0)) - CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 95 output
+0
+
+
+-- !query 96
+SELECT cast(1 as smallint) - cast(1 as decimal(20, 0)) FROM t
+-- !query 96 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(21,0)) - CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 96 output
+0
+
+
+-- !query 97
+SELECT cast(1 as int) - cast(1 as decimal(3, 0)) FROM t
+-- !query 97 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(11,0)) - CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 97 output
+0
+
+
+-- !query 98
+SELECT cast(1 as int) - cast(1 as decimal(5, 0)) FROM t
+-- !query 98 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(11,0)) - CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 98 output
+0
+
+
+-- !query 99
+SELECT cast(1 as int) - cast(1 as decimal(10, 0)) FROM t
+-- !query 99 schema
+struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) - CAST(1 AS DECIMAL(10,0))):decimal(11,0)>
+-- !query 99 output
+0
+
+
+-- !query 100
+SELECT cast(1 as int) - cast(1 as decimal(20, 0)) FROM t
+-- !query 100 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(21,0)) - CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 100 output
+0
+
+
+-- !query 101
+SELECT cast(1 as bigint) - cast(1 as decimal(3, 0)) FROM t
+-- !query 101 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0)) - CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 101 output
+0
+
+
+-- !query 102
+SELECT cast(1 as bigint) - cast(1 as decimal(5, 0)) FROM t
+-- !query 102 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0)) - CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 102 output
+0
+
+
+-- !query 103
+SELECT cast(1 as bigint) - cast(1 as decimal(10, 0)) FROM t
+-- !query 103 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0)) - CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 103 output
+0
+
+
+-- !query 104
+SELECT cast(1 as bigint) - cast(1 as decimal(20, 0)) FROM t
+-- !query 104 schema
+struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) - CAST(1 AS DECIMAL(20,0))):decimal(21,0)>
+-- !query 104 output
+0
+
+
+-- !query 105
+SELECT cast(1 as float) - cast(1 as decimal(3, 0)) FROM t
+-- !query 105 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) - CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):double>
+-- !query 105 output
+0.0
+
+
+-- !query 106
+SELECT cast(1 as float) - cast(1 as decimal(5, 0)) FROM t
+-- !query 106 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) - CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):double>
+-- !query 106 output
+0.0
+
+
+-- !query 107
+SELECT cast(1 as float) - cast(1 as decimal(10, 0)) FROM t
+-- !query 107 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) - CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double>
+-- !query 107 output
+0.0
+
+
+-- !query 108
+SELECT cast(1 as float) - cast(1 as decimal(20, 0)) FROM t
+-- !query 108 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) - CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):double>
+-- !query 108 output
+0.0
+
+
+-- !query 109
+SELECT cast(1 as double) - cast(1 as decimal(3, 0)) FROM t
+-- !query 109 schema
+struct<(CAST(1 AS DOUBLE) - CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):double>
+-- !query 109 output
+0.0
+
+
+-- !query 110
+SELECT cast(1 as double) - cast(1 as decimal(5, 0)) FROM t
+-- !query 110 schema
+struct<(CAST(1 AS DOUBLE) - CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):double>
+-- !query 110 output
+0.0
+
+
+-- !query 111
+SELECT cast(1 as double) - cast(1 as decimal(10, 0)) FROM t
+-- !query 111 schema
+struct<(CAST(1 AS DOUBLE) - CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double>
+-- !query 111 output
+0.0
+
+
+-- !query 112
+SELECT cast(1 as double) - cast(1 as decimal(20, 0)) FROM t
+-- !query 112 schema
+struct<(CAST(1 AS DOUBLE) - CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):double>
+-- !query 112 output
+0.0
+
+
+-- !query 113
+SELECT cast(1 as decimal(10, 0)) - cast(1 as decimal(3, 0)) FROM t
+-- !query 113 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0)) - CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 113 output
+0
+
+
+-- !query 114
+SELECT cast(1 as decimal(10, 0)) - cast(1 as decimal(5, 0)) FROM t
+-- !query 114 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0)) - CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 114 output
+0
+
+
+-- !query 115
+SELECT cast(1 as decimal(10, 0)) - cast(1 as decimal(10, 0)) FROM t
+-- !query 115 schema
+struct<(CAST(1 AS DECIMAL(10,0)) - CAST(1 AS DECIMAL(10,0))):decimal(11,0)>
+-- !query 115 output
+0
+
+
+-- !query 116
+SELECT cast(1 as decimal(10, 0)) - cast(1 as decimal(20, 0)) FROM t
+-- !query 116 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(21,0)) - CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 116 output
+0
+
+
+-- !query 117
+SELECT cast('1' as binary) - cast(1 as decimal(3, 0)) FROM t
+-- !query 117 schema
+struct<>
+-- !query 117 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) - CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) - CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 118
+SELECT cast('1' as binary) - cast(1 as decimal(5, 0)) FROM t
+-- !query 118 schema
+struct<>
+-- !query 118 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) - CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) - CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 119
+SELECT cast('1' as binary) - cast(1 as decimal(10, 0)) FROM t
+-- !query 119 schema
+struct<>
+-- !query 119 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) - CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) - CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 120
+SELECT cast('1' as binary) - cast(1 as decimal(20, 0)) FROM t
+-- !query 120 schema
+struct<>
+-- !query 120 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) - CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) - CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 121
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) - cast(1 as decimal(3, 0)) FROM t
+-- !query 121 schema
+struct<>
+-- !query 121 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) - CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) - CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 122
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) - cast(1 as decimal(5, 0)) FROM t
+-- !query 122 schema
+struct<>
+-- !query 122 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) - CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) - CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 123
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) - cast(1 as decimal(10, 0)) FROM t
+-- !query 123 schema
+struct<>
+-- !query 123 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) - CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) - CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 124
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) - cast(1 as decimal(20, 0)) FROM t
+-- !query 124 schema
+struct<>
+-- !query 124 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) - CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) - CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 125
+SELECT cast('2017-12-11 09:30:00' as date) - cast(1 as decimal(3, 0)) FROM t
+-- !query 125 schema
+struct<>
+-- !query 125 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) - CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) - CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 126
+SELECT cast('2017-12-11 09:30:00' as date) - cast(1 as decimal(5, 0)) FROM t
+-- !query 126 schema
+struct<>
+-- !query 126 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) - CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) - CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 127
+SELECT cast('2017-12-11 09:30:00' as date) - cast(1 as decimal(10, 0)) FROM t
+-- !query 127 schema
+struct<>
+-- !query 127 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) - CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) - CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 128
+SELECT cast('2017-12-11 09:30:00' as date) - cast(1 as decimal(20, 0)) FROM t
+-- !query 128 schema
+struct<>
+-- !query 128 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) - CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) - CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 129
+SELECT cast(1 as decimal(3, 0)) - cast(1 as tinyint) FROM t
+-- !query 129 schema
+struct<(CAST(1 AS DECIMAL(3,0)) - CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0))):decimal(4,0)>
+-- !query 129 output
+0
+
+
+-- !query 130
+SELECT cast(1 as decimal(5, 0)) - cast(1 as tinyint) FROM t
+-- !query 130 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(6,0)) - CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(6,0))):decimal(6,0)>
+-- !query 130 output
+0
+
+
+-- !query 131
+SELECT cast(1 as decimal(10, 0)) - cast(1 as tinyint) FROM t
+-- !query 131 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0)) - CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 131 output
+0
+
+
+-- !query 132
+SELECT cast(1 as decimal(20, 0)) - cast(1 as tinyint) FROM t
+-- !query 132 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0)) - CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 132 output
+0
+
+
+-- !query 133
+SELECT cast(1 as decimal(3, 0)) - cast(1 as smallint) FROM t
+-- !query 133 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(6,0)) - CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(6,0))):decimal(6,0)>
+-- !query 133 output
+0
+
+
+-- !query 134
+SELECT cast(1 as decimal(5, 0)) - cast(1 as smallint) FROM t
+-- !query 134 schema
+struct<(CAST(1 AS DECIMAL(5,0)) - CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0))):decimal(6,0)>
+-- !query 134 output
+0
+
+
+-- !query 135
+SELECT cast(1 as decimal(10, 0)) - cast(1 as smallint) FROM t
+-- !query 135 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0)) - CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 135 output
+0
+
+
+-- !query 136
+SELECT cast(1 as decimal(20, 0)) - cast(1 as smallint) FROM t
+-- !query 136 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0)) - CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 136 output
+0
+
+
+-- !query 137
+SELECT cast(1 as decimal(3, 0)) - cast(1 as int) FROM t
+-- !query 137 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(11,0)) - CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 137 output
+0
+
+
+-- !query 138
+SELECT cast(1 as decimal(5, 0)) - cast(1 as int) FROM t
+-- !query 138 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(11,0)) - CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 138 output
+0
+
+
+-- !query 139
+SELECT cast(1 as decimal(10, 0)) - cast(1 as int) FROM t
+-- !query 139 schema
+struct<(CAST(1 AS DECIMAL(10,0)) - CAST(CAST(1 AS INT) AS DECIMAL(10,0))):decimal(11,0)>
+-- !query 139 output
+0
+
+
+-- !query 140
+SELECT cast(1 as decimal(20, 0)) - cast(1 as int) FROM t
+-- !query 140 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0)) - CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 140 output
+0
+
+
+-- !query 141
+SELECT cast(1 as decimal(3, 0)) - cast(1 as bigint) FROM t
+-- !query 141 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(21,0)) - CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 141 output
+0
+
+
+-- !query 142
+SELECT cast(1 as decimal(5, 0)) - cast(1 as bigint) FROM t
+-- !query 142 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(21,0)) - CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 142 output
+0
+
+
+-- !query 143
+SELECT cast(1 as decimal(10, 0)) - cast(1 as bigint) FROM t
+-- !query 143 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(21,0)) - CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 143 output
+0
+
+
+-- !query 144
+SELECT cast(1 as decimal(20, 0)) - cast(1 as bigint) FROM t
+-- !query 144 schema
+struct<(CAST(1 AS DECIMAL(20,0)) - CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):decimal(21,0)>
+-- !query 144 output
+0
+
+
+-- !query 145
+SELECT cast(1 as decimal(3, 0)) - cast(1 as float) FROM t
+-- !query 145 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) - CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 145 output
+0.0
+
+
+-- !query 146
+SELECT cast(1 as decimal(5, 0)) - cast(1 as float) FROM t
+-- !query 146 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) - CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 146 output
+0.0
+
+
+-- !query 147
+SELECT cast(1 as decimal(10, 0)) - cast(1 as float) FROM t
+-- !query 147 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) - CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 147 output
+0.0
+
+
+-- !query 148
+SELECT cast(1 as decimal(20, 0)) - cast(1 as float) FROM t
+-- !query 148 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) - CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 148 output
+0.0
+
+
+-- !query 149
+SELECT cast(1 as decimal(3, 0)) - cast(1 as double) FROM t
+-- !query 149 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) - CAST(1 AS DOUBLE)):double>
+-- !query 149 output
+0.0
+
+
+-- !query 150
+SELECT cast(1 as decimal(5, 0)) - cast(1 as double) FROM t
+-- !query 150 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) - CAST(1 AS DOUBLE)):double>
+-- !query 150 output
+0.0
+
+
+-- !query 151
+SELECT cast(1 as decimal(10, 0)) - cast(1 as double) FROM t
+-- !query 151 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) - CAST(1 AS DOUBLE)):double>
+-- !query 151 output
+0.0
+
+
+-- !query 152
+SELECT cast(1 as decimal(20, 0)) - cast(1 as double) FROM t
+-- !query 152 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) - CAST(1 AS DOUBLE)):double>
+-- !query 152 output
+0.0
+
+
+-- !query 153
+SELECT cast(1 as decimal(3, 0)) - cast(1 as decimal(10, 0)) FROM t
+-- !query 153 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(11,0)) - CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 153 output
+0
+
+
+-- !query 154
+SELECT cast(1 as decimal(5, 0)) - cast(1 as decimal(10, 0)) FROM t
+-- !query 154 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(11,0)) - CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)>
+-- !query 154 output
+0
+
+
+-- !query 155
+SELECT cast(1 as decimal(10, 0)) - cast(1 as decimal(10, 0)) FROM t
+-- !query 155 schema
+struct<(CAST(1 AS DECIMAL(10,0)) - CAST(1 AS DECIMAL(10,0))):decimal(11,0)>
+-- !query 155 output
+0
+
+
+-- !query 156
+SELECT cast(1 as decimal(20, 0)) - cast(1 as decimal(10, 0)) FROM t
+-- !query 156 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0)) - CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(21,0))):decimal(21,0)>
+-- !query 156 output
+0
+
+
+-- !query 157
+SELECT cast(1 as decimal(3, 0)) - cast(1 as string) FROM t
+-- !query 157 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) - CAST(CAST(1 AS STRING) AS DOUBLE)):double>
+-- !query 157 output
+0.0
+
+
+-- !query 158
+SELECT cast(1 as decimal(5, 0)) - cast(1 as string) FROM t
+-- !query 158 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) - CAST(CAST(1 AS STRING) AS DOUBLE)):double>
+-- !query 158 output
+0.0
+
+
+-- !query 159
+SELECT cast(1 as decimal(10, 0)) - cast(1 as string) FROM t
+-- !query 159 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) - CAST(CAST(1 AS STRING) AS DOUBLE)):double>
+-- !query 159 output
+0.0
+
+
+-- !query 160
+SELECT cast(1 as decimal(20, 0)) - cast(1 as string) FROM t
+-- !query 160 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) - CAST(CAST(1 AS STRING) AS DOUBLE)):double>
+-- !query 160 output
+0.0
+
+
+-- !query 161
+SELECT cast(1 as decimal(3, 0)) - cast('1' as binary) FROM t
+-- !query 161 schema
+struct<>
+-- !query 161 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) - CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) - CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7
+
+
+-- !query 162
+SELECT cast(1 as decimal(5, 0)) - cast('1' as binary) FROM t
+-- !query 162 schema
+struct<>
+-- !query 162 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) - CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) - CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7
+
+
+-- !query 163
+SELECT cast(1 as decimal(10, 0)) - cast('1' as binary) FROM t
+-- !query 163 schema
+struct<>
+-- !query 163 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) - CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) - CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7
+
+
+-- !query 164
+SELECT cast(1 as decimal(20, 0)) - cast('1' as binary) FROM t
+-- !query 164 schema
+struct<>
+-- !query 164 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) - CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) - CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7
+
+
+-- !query 165
+SELECT cast(1 as decimal(3, 0)) - cast(1 as boolean) FROM t
+-- !query 165 schema
+struct<>
+-- !query 165 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) - CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) - CAST(1 AS BOOLEAN))' (decimal(3,0) and boolean).; line 1 pos 7
+
+
+-- !query 166
+SELECT cast(1 as decimal(5, 0)) - cast(1 as boolean) FROM t
+-- !query 166 schema
+struct<>
+-- !query 166 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) - CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) - CAST(1 AS BOOLEAN))' (decimal(5,0) and boolean).; line 1 pos 7
+
+
+-- !query 167
+SELECT cast(1 as decimal(10, 0)) - cast(1 as boolean) FROM t
+-- !query 167 schema
+struct<>
+-- !query 167 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) - CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) - CAST(1 AS BOOLEAN))' (decimal(10,0) and boolean).; line 1 pos 7
+
+
+-- !query 168
+SELECT cast(1 as decimal(20, 0)) - cast(1 as boolean) FROM t
+-- !query 168 schema
+struct<>
+-- !query 168 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) - CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) - CAST(1 AS BOOLEAN))' (decimal(20,0) and boolean).; line 1 pos 7
+
+
+-- !query 169
+SELECT cast(1 as decimal(3, 0)) - cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 169 schema
+struct<>
+-- !query 169 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) - CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) - CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7
+
+
+-- !query 170
+SELECT cast(1 as decimal(5, 0)) - cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 170 schema
+struct<>
+-- !query 170 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) - CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) - CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7
+
+
+-- !query 171
+SELECT cast(1 as decimal(10, 0)) - cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 171 schema
+struct<>
+-- !query 171 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) - CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) - CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7
+
+
+-- !query 172
+SELECT cast(1 as decimal(20, 0)) - cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 172 schema
+struct<>
+-- !query 172 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) - CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) - CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7
+
+
+-- !query 173
+SELECT cast(1 as decimal(3, 0)) - cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 173 schema
+struct<>
+-- !query 173 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) - CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) - CAST('2017-12-11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7
+
+
+-- !query 174
+SELECT cast(1 as decimal(5, 0)) - cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 174 schema
+struct<>
+-- !query 174 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) - CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) - CAST('2017-12-11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7
+
+
+-- !query 175
+SELECT cast(1 as decimal(10, 0)) - cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 175 schema
+struct<>
+-- !query 175 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) - CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) - CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7
+
+
+-- !query 176
+SELECT cast(1 as decimal(20, 0)) - cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 176 schema
+struct<>
+-- !query 176 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) - CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) - CAST('2017-12-11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7
+
+
+-- !query 177
+SELECT cast(1 as tinyint) * cast(1 as decimal(3, 0)) FROM t
+-- !query 177 schema
+struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) * CAST(1 AS DECIMAL(3,0))):decimal(7,0)>
+-- !query 177 output
+1
+
+
+-- !query 178
+SELECT cast(1 as tinyint) * cast(1 as decimal(5, 0)) FROM t
+-- !query 178 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0)) * CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0))):decimal(9,0)>
+-- !query 178 output
+1
+
+
+-- !query 179
+SELECT cast(1 as tinyint) * cast(1 as decimal(10, 0)) FROM t
+-- !query 179 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)) * CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(14,0)>
+-- !query 179 output
+1
+
+
+-- !query 180
+SELECT cast(1 as tinyint) * cast(1 as decimal(20, 0)) FROM t
+-- !query 180 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0)) * CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(24,0)>
+-- !query 180 output
+1
+
+
+-- !query 181
+SELECT cast(1 as smallint) * cast(1 as decimal(3, 0)) FROM t
+-- !query 181 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0)) * CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0))):decimal(9,0)>
+-- !query 181 output
+1
+
+
+-- !query 182
+SELECT cast(1 as smallint) * cast(1 as decimal(5, 0)) FROM t
+-- !query 182 schema
+struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) * CAST(1 AS DECIMAL(5,0))):decimal(11,0)>
+-- !query 182 output
+1
+
+
+-- !query 183
+SELECT cast(1 as smallint) * cast(1 as decimal(10, 0)) FROM t
+-- !query 183 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)) * CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(16,0)>
+-- !query 183 output
+1
+
+
+-- !query 184
+SELECT cast(1 as smallint) * cast(1 as decimal(20, 0)) FROM t
+-- !query 184 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0)) * CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(26,0)>
+-- !query 184 output
+1
+
+
+-- !query 185
+SELECT cast(1 as int) * cast(1 as decimal(3, 0)) FROM t
+-- !query 185 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) * CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):decimal(14,0)>
+-- !query 185 output
+1
+
+
+-- !query 186
+SELECT cast(1 as int) * cast(1 as decimal(5, 0)) FROM t
+-- !query 186 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) * CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):decimal(16,0)>
+-- !query 186 output
+1
+
+
+-- !query 187
+SELECT cast(1 as int) * cast(1 as decimal(10, 0)) FROM t
+-- !query 187 schema
+struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) * CAST(1 AS DECIMAL(10,0))):decimal(21,0)>
+-- !query 187 output
+1
+
+
+-- !query 188
+SELECT cast(1 as int) * cast(1 as decimal(20, 0)) FROM t
+-- !query 188 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0)) * CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(31,0)>
+-- !query 188 output
+1
+
+
+-- !query 189
+SELECT cast(1 as bigint) * cast(1 as decimal(3, 0)) FROM t
+-- !query 189 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) * CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0))):decimal(24,0)>
+-- !query 189 output
+1
+
+
+-- !query 190
+SELECT cast(1 as bigint) * cast(1 as decimal(5, 0)) FROM t
+-- !query 190 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) * CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0))):decimal(26,0)>
+-- !query 190 output
+1
+
+
+-- !query 191
+SELECT cast(1 as bigint) * cast(1 as decimal(10, 0)) FROM t
+-- !query 191 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) * CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):decimal(31,0)>
+-- !query 191 output
+1
+
+
+-- !query 192
+SELECT cast(1 as bigint) * cast(1 as decimal(20, 0)) FROM t
+-- !query 192 schema
+struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) * CAST(1 AS DECIMAL(20,0))):decimal(38,0)>
+-- !query 192 output
+1
+
+
+-- !query 193
+SELECT cast(1 as float) * cast(1 as decimal(3, 0)) FROM t
+-- !query 193 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) * CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):double>
+-- !query 193 output
+1.0
+
+
+-- !query 194
+SELECT cast(1 as float) * cast(1 as decimal(5, 0)) FROM t
+-- !query 194 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) * CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):double>
+-- !query 194 output
+1.0
+
+
+-- !query 195
+SELECT cast(1 as float) * cast(1 as decimal(10, 0)) FROM t
+-- !query 195 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) * CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double>
+-- !query 195 output
+1.0
+
+
+-- !query 196
+SELECT cast(1 as float) * cast(1 as decimal(20, 0)) FROM t
+-- !query 196 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) * CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):double>
+-- !query 196 output
+1.0
+
+
+-- !query 197
+SELECT cast(1 as double) * cast(1 as decimal(3, 0)) FROM t
+-- !query 197 schema
+struct<(CAST(1 AS DOUBLE) * CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):double>
+-- !query 197 output
+1.0
+
+
+-- !query 198
+SELECT cast(1 as double) * cast(1 as decimal(5, 0)) FROM t
+-- !query 198 schema
+struct<(CAST(1 AS DOUBLE) * CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):double>
+-- !query 198 output
+1.0
+
+
+-- !query 199
+SELECT cast(1 as double) * cast(1 as decimal(10, 0)) FROM t
+-- !query 199 schema
+struct<(CAST(1 AS DOUBLE) * CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double>
+-- !query 199 output
+1.0
+
+
+-- !query 200
+SELECT cast(1 as double) * cast(1 as decimal(20, 0)) FROM t
+-- !query 200 schema
+struct<(CAST(1 AS DOUBLE) * CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):double>
+-- !query 200 output
+1.0
+
+
+-- !query 201
+SELECT cast(1 as decimal(10, 0)) * cast(1 as decimal(3, 0)) FROM t
+-- !query 201 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) * CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):decimal(14,0)>
+-- !query 201 output
+1
+
+
+-- !query 202
+SELECT cast(1 as decimal(10, 0)) * cast(1 as decimal(5, 0)) FROM t
+-- !query 202 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) * CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):decimal(16,0)>
+-- !query 202 output
+1
+
+
+-- !query 203
+SELECT cast(1 as decimal(10, 0)) * cast(1 as decimal(10, 0)) FROM t
+-- !query 203 schema
+struct<(CAST(1 AS DECIMAL(10,0)) * CAST(1 AS DECIMAL(10,0))):decimal(21,0)>
+-- !query 203 output
+1
+
+
+-- !query 204
+SELECT cast(1 as decimal(10, 0)) * cast(1 as decimal(20, 0)) FROM t
+-- !query 204 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) * CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(31,0)>
+-- !query 204 output
+1
+
+
+-- !query 205
+SELECT cast('1' as binary) * cast(1 as decimal(3, 0)) FROM t
+-- !query 205 schema
+struct<>
+-- !query 205 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) * CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) * CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 206
+SELECT cast('1' as binary) * cast(1 as decimal(5, 0)) FROM t
+-- !query 206 schema
+struct<>
+-- !query 206 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) * CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) * CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 207
+SELECT cast('1' as binary) * cast(1 as decimal(10, 0)) FROM t
+-- !query 207 schema
+struct<>
+-- !query 207 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) * CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) * CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 208
+SELECT cast('1' as binary) * cast(1 as decimal(20, 0)) FROM t
+-- !query 208 schema
+struct<>
+-- !query 208 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) * CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) * CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 209
+SELECT cast('2017*12*11 09:30:00.0' as timestamp) * cast(1 as decimal(3, 0)) FROM t
+-- !query 209 schema
+struct<>
+-- !query 209 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017*12*11 09:30:00.0' AS TIMESTAMP) * CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017*12*11 09:30:00.0' AS TIMESTAMP) * CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 210
+SELECT cast('2017*12*11 09:30:00.0' as timestamp) * cast(1 as decimal(5, 0)) FROM t
+-- !query 210 schema
+struct<>
+-- !query 210 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017*12*11 09:30:00.0' AS TIMESTAMP) * CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017*12*11 09:30:00.0' AS TIMESTAMP) * CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 211
+SELECT cast('2017*12*11 09:30:00.0' as timestamp) * cast(1 as decimal(10, 0)) FROM t
+-- !query 211 schema
+struct<>
+-- !query 211 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017*12*11 09:30:00.0' AS TIMESTAMP) * CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017*12*11 09:30:00.0' AS TIMESTAMP) * CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 212
+SELECT cast('2017*12*11 09:30:00.0' as timestamp) * cast(1 as decimal(20, 0)) FROM t
+-- !query 212 schema
+struct<>
+-- !query 212 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017*12*11 09:30:00.0' AS TIMESTAMP) * CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017*12*11 09:30:00.0' AS TIMESTAMP) * CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 213
+SELECT cast('2017*12*11 09:30:00' as date) * cast(1 as decimal(3, 0)) FROM t
+-- !query 213 schema
+struct<>
+-- !query 213 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017*12*11 09:30:00' AS DATE) * CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017*12*11 09:30:00' AS DATE) * CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 214
+SELECT cast('2017*12*11 09:30:00' as date) * cast(1 as decimal(5, 0)) FROM t
+-- !query 214 schema
+struct<>
+-- !query 214 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017*12*11 09:30:00' AS DATE) * CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017*12*11 09:30:00' AS DATE) * CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 215
+SELECT cast('2017*12*11 09:30:00' as date) * cast(1 as decimal(10, 0)) FROM t
+-- !query 215 schema
+struct<>
+-- !query 215 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017*12*11 09:30:00' AS DATE) * CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017*12*11 09:30:00' AS DATE) * CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 216
+SELECT cast('2017*12*11 09:30:00' as date) * cast(1 as decimal(20, 0)) FROM t
+-- !query 216 schema
+struct<>
+-- !query 216 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017*12*11 09:30:00' AS DATE) * CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017*12*11 09:30:00' AS DATE) * CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 217
+SELECT cast(1 as decimal(3, 0)) * cast(1 as tinyint) FROM t
+-- !query 217 schema
+struct<(CAST(1 AS DECIMAL(3,0)) * CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0))):decimal(7,0)>
+-- !query 217 output
+1
+
+
+-- !query 218
+SELECT cast(1 as decimal(5, 0)) * cast(1 as tinyint) FROM t
+-- !query 218 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0)) * CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0))):decimal(9,0)>
+-- !query 218 output
+1
+
+
+-- !query 219
+SELECT cast(1 as decimal(10, 0)) * cast(1 as tinyint) FROM t
+-- !query 219 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) * CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0))):decimal(14,0)>
+-- !query 219 output
+1
+
+
+-- !query 220
+SELECT cast(1 as decimal(20, 0)) * cast(1 as tinyint) FROM t
+-- !query 220 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) * CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0))):decimal(24,0)>
+-- !query 220 output
+1
+
+
+-- !query 221
+SELECT cast(1 as decimal(3, 0)) * cast(1 as smallint) FROM t
+-- !query 221 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0)) * CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0))):decimal(9,0)>
+-- !query 221 output
+1
+
+
+-- !query 222
+SELECT cast(1 as decimal(5, 0)) * cast(1 as smallint) FROM t
+-- !query 222 schema
+struct<(CAST(1 AS DECIMAL(5,0)) * CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0))):decimal(11,0)>
+-- !query 222 output
+1
+
+
+-- !query 223
+SELECT cast(1 as decimal(10, 0)) * cast(1 as smallint) FROM t
+-- !query 223 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) * CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0))):decimal(16,0)>
+-- !query 223 output
+1
+
+
+-- !query 224
+SELECT cast(1 as decimal(20, 0)) * cast(1 as smallint) FROM t
+-- !query 224 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) * CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0))):decimal(26,0)>
+-- !query 224 output
+1
+
+
+-- !query 225
+SELECT cast(1 as decimal(3, 0)) * cast(1 as int) FROM t
+-- !query 225 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) * CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(14,0)>
+-- !query 225 output
+1
+
+
+-- !query 226
+SELECT cast(1 as decimal(5, 0)) * cast(1 as int) FROM t
+-- !query 226 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) * CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(16,0)>
+-- !query 226 output
+1
+
+
+-- !query 227
+SELECT cast(1 as decimal(10, 0)) * cast(1 as int) FROM t
+-- !query 227 schema
+struct<(CAST(1 AS DECIMAL(10,0)) * CAST(CAST(1 AS INT) AS DECIMAL(10,0))):decimal(21,0)>
+-- !query 227 output
+1
+
+
+-- !query 228
+SELECT cast(1 as decimal(20, 0)) * cast(1 as int) FROM t
+-- !query 228 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) * CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0))):decimal(31,0)>
+-- !query 228 output
+1
+
+
+-- !query 229
+SELECT cast(1 as decimal(3, 0)) * cast(1 as bigint) FROM t
+-- !query 229 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0)) * CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(24,0)>
+-- !query 229 output
+1
+
+
+-- !query 230
+SELECT cast(1 as decimal(5, 0)) * cast(1 as bigint) FROM t
+-- !query 230 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0)) * CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(26,0)>
+-- !query 230 output
+1
+
+
+-- !query 231
+SELECT cast(1 as decimal(10, 0)) * cast(1 as bigint) FROM t
+-- !query 231 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) * CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(31,0)>
+-- !query 231 output
+1
+
+
+-- !query 232
+SELECT cast(1 as decimal(20, 0)) * cast(1 as bigint) FROM t
+-- !query 232 schema
+struct<(CAST(1 AS DECIMAL(20,0)) * CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):decimal(38,0)>
+-- !query 232 output
+1
+
+
+-- !query 233
+SELECT cast(1 as decimal(3, 0)) * cast(1 as float) FROM t
+-- !query 233 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) * CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 233 output
+1.0
+
+
+-- !query 234
+SELECT cast(1 as decimal(5, 0)) * cast(1 as float) FROM t
+-- !query 234 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) * CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 234 output
+1.0
+
+
+-- !query 235
+SELECT cast(1 as decimal(10, 0)) * cast(1 as float) FROM t
+-- !query 235 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) * CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 235 output
+1.0
+
+
+-- !query 236
+SELECT cast(1 as decimal(20, 0)) * cast(1 as float) FROM t
+-- !query 236 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) * CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 236 output
+1.0
+
+
+-- !query 237
+SELECT cast(1 as decimal(3, 0)) * cast(1 as double) FROM t
+-- !query 237 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) * CAST(1 AS DOUBLE)):double>
+-- !query 237 output
+1.0
+
+
+-- !query 238
+SELECT cast(1 as decimal(5, 0)) * cast(1 as double) FROM t
+-- !query 238 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) * CAST(1 AS DOUBLE)):double>
+-- !query 238 output
+1.0
+
+
+-- !query 239
+SELECT cast(1 as decimal(10, 0)) * cast(1 as double) FROM t
+-- !query 239 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) * CAST(1 AS DOUBLE)):double>
+-- !query 239 output
+1.0
+
+
+-- !query 240
+SELECT cast(1 as decimal(20, 0)) * cast(1 as double) FROM t
+-- !query 240 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) * CAST(1 AS DOUBLE)):double>
+-- !query 240 output
+1.0
+
+
+-- !query 241
+SELECT cast(1 as decimal(3, 0)) * cast(1 as decimal(10, 0)) FROM t
+-- !query 241 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) * CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(14,0)>
+-- !query 241 output
+1
+
+
+-- !query 242
+SELECT cast(1 as decimal(5, 0)) * cast(1 as decimal(10, 0)) FROM t
+-- !query 242 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) * CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(16,0)>
+-- !query 242 output
+1
+
+
+-- !query 243
+SELECT cast(1 as decimal(10, 0)) * cast(1 as decimal(10, 0)) FROM t
+-- !query 243 schema
+struct<(CAST(1 AS DECIMAL(10,0)) * CAST(1 AS DECIMAL(10,0))):decimal(21,0)>
+-- !query 243 output
+1
+
+
+-- !query 244
+SELECT cast(1 as decimal(20, 0)) * cast(1 as decimal(10, 0)) FROM t
+-- !query 244 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) * CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):decimal(31,0)>
+-- !query 244 output
+1
+
+
+-- !query 245
+SELECT cast(1 as decimal(3, 0)) * cast(1 as string) FROM t
+-- !query 245 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) * CAST(CAST(1 AS STRING) AS DOUBLE)):double>
+-- !query 245 output
+1.0
+
+
+-- !query 246
+SELECT cast(1 as decimal(5, 0)) * cast(1 as string) FROM t
+-- !query 246 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) * CAST(CAST(1 AS STRING) AS DOUBLE)):double>
+-- !query 246 output
+1.0
+
+
+-- !query 247
+SELECT cast(1 as decimal(10, 0)) * cast(1 as string) FROM t
+-- !query 247 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) * CAST(CAST(1 AS STRING) AS DOUBLE)):double>
+-- !query 247 output
+1.0
+
+
+-- !query 248
+SELECT cast(1 as decimal(20, 0)) * cast(1 as string) FROM t
+-- !query 248 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) * CAST(CAST(1 AS STRING) AS DOUBLE)):double>
+-- !query 248 output
+1.0
+
+
+-- !query 249
+SELECT cast(1 as decimal(3, 0)) * cast('1' as binary) FROM t
+-- !query 249 schema
+struct<>
+-- !query 249 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) * CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) * CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7
+
+
+-- !query 250
+SELECT cast(1 as decimal(5, 0)) * cast('1' as binary) FROM t
+-- !query 250 schema
+struct<>
+-- !query 250 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) * CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) * CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7
+
+
+-- !query 251
+SELECT cast(1 as decimal(10, 0)) * cast('1' as binary) FROM t
+-- !query 251 schema
+struct<>
+-- !query 251 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) * CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) * CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7
+
+
+-- !query 252
+SELECT cast(1 as decimal(20, 0)) * cast('1' as binary) FROM t
+-- !query 252 schema
+struct<>
+-- !query 252 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) * CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) * CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7
+
+
+-- !query 253
+SELECT cast(1 as decimal(3, 0)) * cast(1 as boolean) FROM t
+-- !query 253 schema
+struct<>
+-- !query 253 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) * CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) * CAST(1 AS BOOLEAN))' (decimal(3,0) and boolean).; line 1 pos 7
+
+
+-- !query 254
+SELECT cast(1 as decimal(5, 0)) * cast(1 as boolean) FROM t
+-- !query 254 schema
+struct<>
+-- !query 254 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) * CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) * CAST(1 AS BOOLEAN))' (decimal(5,0) and boolean).; line 1 pos 7
+
+
+-- !query 255
+SELECT cast(1 as decimal(10, 0)) * cast(1 as boolean) FROM t
+-- !query 255 schema
+struct<>
+-- !query 255 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) * CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) * CAST(1 AS BOOLEAN))' (decimal(10,0) and boolean).; line 1 pos 7
+
+
+-- !query 256
+SELECT cast(1 as decimal(20, 0)) * cast(1 as boolean) FROM t
+-- !query 256 schema
+struct<>
+-- !query 256 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) * CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) * CAST(1 AS BOOLEAN))' (decimal(20,0) and boolean).; line 1 pos 7
+
+
+-- !query 257
+SELECT cast(1 as decimal(3, 0)) * cast('2017*12*11 09:30:00.0' as timestamp) FROM t
+-- !query 257 schema
+struct<>
+-- !query 257 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) * CAST('2017*12*11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) * CAST('2017*12*11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7
+
+
+-- !query 258
+SELECT cast(1 as decimal(5, 0)) * cast('2017*12*11 09:30:00.0' as timestamp) FROM t
+-- !query 258 schema
+struct<>
+-- !query 258 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) * CAST('2017*12*11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) * CAST('2017*12*11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7
+
+
+-- !query 259
+SELECT cast(1 as decimal(10, 0)) * cast('2017*12*11 09:30:00.0' as timestamp) FROM t
+-- !query 259 schema
+struct<>
+-- !query 259 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) * CAST('2017*12*11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) * CAST('2017*12*11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7
+
+
+-- !query 260
+SELECT cast(1 as decimal(20, 0)) * cast('2017*12*11 09:30:00.0' as timestamp) FROM t
+-- !query 260 schema
+struct<>
+-- !query 260 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) * CAST('2017*12*11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) * CAST('2017*12*11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7
+
+
+-- !query 261
+SELECT cast(1 as decimal(3, 0)) * cast('2017*12*11 09:30:00' as date) FROM t
+-- !query 261 schema
+struct<>
+-- !query 261 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) * CAST('2017*12*11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) * CAST('2017*12*11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7
+
+
+-- !query 262
+SELECT cast(1 as decimal(5, 0)) * cast('2017*12*11 09:30:00' as date) FROM t
+-- !query 262 schema
+struct<>
+-- !query 262 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) * CAST('2017*12*11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) * CAST('2017*12*11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7
+
+
+-- !query 263
+SELECT cast(1 as decimal(10, 0)) * cast('2017*12*11 09:30:00' as date) FROM t
+-- !query 263 schema
+struct<>
+-- !query 263 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) * CAST('2017*12*11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) * CAST('2017*12*11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7
+
+
+-- !query 264
+SELECT cast(1 as decimal(20, 0)) * cast('2017*12*11 09:30:00' as date) FROM t
+-- !query 264 schema
+struct<>
+-- !query 264 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) * CAST('2017*12*11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) * CAST('2017*12*11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7
+
+
+-- !query 265
+SELECT cast(1 as tinyint) / cast(1 as decimal(3, 0)) FROM t
+-- !query 265 schema
+struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) / CAST(1 AS DECIMAL(3,0))):decimal(9,6)>
+-- !query 265 output
+1
+
+
+-- !query 266
+SELECT cast(1 as tinyint) / cast(1 as decimal(5, 0)) FROM t
+-- !query 266 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0)) / CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0))):decimal(9,6)>
+-- !query 266 output
+1
+
+
+-- !query 267
+SELECT cast(1 as tinyint) / cast(1 as decimal(10, 0)) FROM t
+-- !query 267 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)) / CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(14,11)>
+-- !query 267 output
+1
+
+
+-- !query 268
+SELECT cast(1 as tinyint) / cast(1 as decimal(20, 0)) FROM t
+-- !query 268 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0)) / CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(24,21)>
+-- !query 268 output
+1
+
+
+-- !query 269
+SELECT cast(1 as smallint) / cast(1 as decimal(3, 0)) FROM t
+-- !query 269 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0)) / CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0))):decimal(11,6)>
+-- !query 269 output
+1
+
+
+-- !query 270
+SELECT cast(1 as smallint) / cast(1 as decimal(5, 0)) FROM t
+-- !query 270 schema
+struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) / CAST(1 AS DECIMAL(5,0))):decimal(11,6)>
+-- !query 270 output
+1
+
+
+-- !query 271
+SELECT cast(1 as smallint) / cast(1 as decimal(10, 0)) FROM t
+-- !query 271 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)) / CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(16,11)>
+-- !query 271 output
+1
+
+
+-- !query 272
+SELECT cast(1 as smallint) / cast(1 as decimal(20, 0)) FROM t
+-- !query 272 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0)) / CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(26,21)>
+-- !query 272 output
+1
+
+
+-- !query 273
+SELECT cast(1 as int) / cast(1 as decimal(3, 0)) FROM t
+-- !query 273 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) / CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):decimal(16,6)>
+-- !query 273 output
+1
+
+
+-- !query 274
+SELECT cast(1 as int) / cast(1 as decimal(5, 0)) FROM t
+-- !query 274 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) / CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):decimal(16,6)>
+-- !query 274 output
+1
+
+
+-- !query 275
+SELECT cast(1 as int) / cast(1 as decimal(10, 0)) FROM t
+-- !query 275 schema
+struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) / CAST(1 AS DECIMAL(10,0))):decimal(21,11)>
+-- !query 275 output
+1
+
+
+-- !query 276
+SELECT cast(1 as int) / cast(1 as decimal(20, 0)) FROM t
+-- !query 276 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0)) / CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(31,21)>
+-- !query 276 output
+1
+
+
+-- !query 277
+SELECT cast(1 as bigint) / cast(1 as decimal(3, 0)) FROM t
+-- !query 277 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) / CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0))):decimal(26,6)>
+-- !query 277 output
+1
+
+
+-- !query 278
+SELECT cast(1 as bigint) / cast(1 as decimal(5, 0)) FROM t
+-- !query 278 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) / CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0))):decimal(26,6)>
+-- !query 278 output
+1
+
+
+-- !query 279
+SELECT cast(1 as bigint) / cast(1 as decimal(10, 0)) FROM t
+-- !query 279 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) / CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):decimal(31,11)>
+-- !query 279 output
+1
+
+
+-- !query 280
+SELECT cast(1 as bigint) / cast(1 as decimal(20, 0)) FROM t
+-- !query 280 schema
+struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) / CAST(1 AS DECIMAL(20,0))):decimal(38,19)>
+-- !query 280 output
+1
+
+
+-- !query 281
+SELECT cast(1 as float) / cast(1 as decimal(3, 0)) FROM t
+-- !query 281 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) AS DOUBLE)):double>
+-- !query 281 output
+1.0
+
+
+-- !query 282
+SELECT cast(1 as float) / cast(1 as decimal(5, 0)) FROM t
+-- !query 282 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) AS DOUBLE)):double>
+-- !query 282 output
+1.0
+
+
+-- !query 283
+SELECT cast(1 as float) / cast(1 as decimal(10, 0)) FROM t
+-- !query 283 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) AS DOUBLE)):double>
+-- !query 283 output
+1.0
+
+
+-- !query 284
+SELECT cast(1 as float) / cast(1 as decimal(20, 0)) FROM t
+-- !query 284 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) AS DOUBLE)):double>
+-- !query 284 output
+1.0
+
+
+-- !query 285
+SELECT cast(1 as double) / cast(1 as decimal(3, 0)) FROM t
+-- !query 285 schema
+struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):double>
+-- !query 285 output
+1.0
+
+
+-- !query 286
+SELECT cast(1 as double) / cast(1 as decimal(5, 0)) FROM t
+-- !query 286 schema
+struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):double>
+-- !query 286 output
+1.0
+
+
+-- !query 287
+SELECT cast(1 as double) / cast(1 as decimal(10, 0)) FROM t
+-- !query 287 schema
+struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double>
+-- !query 287 output
+1.0
+
+
+-- !query 288
+SELECT cast(1 as double) / cast(1 as decimal(20, 0)) FROM t
+-- !query 288 schema
+struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):double>
+-- !query 288 output
+1.0
+
+
+-- !query 289
+SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(3, 0)) FROM t
+-- !query 289 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) / CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):decimal(16,6)>
+-- !query 289 output
+1
+
+
+-- !query 290
+SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(5, 0)) FROM t
+-- !query 290 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) / CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):decimal(16,6)>
+-- !query 290 output
+1
+
+
+-- !query 291
+SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(10, 0)) FROM t
+-- !query 291 schema
+struct<(CAST(1 AS DECIMAL(10,0)) / CAST(1 AS DECIMAL(10,0))):decimal(21,11)>
+-- !query 291 output
+1
+
+
+-- !query 292
+SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(20, 0)) FROM t
+-- !query 292 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) / CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(31,21)>
+-- !query 292 output
+1
+
+
+-- !query 293
+SELECT cast('1' as binary) / cast(1 as decimal(3, 0)) FROM t
+-- !query 293 schema
+struct<>
+-- !query 293 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 294
+SELECT cast('1' as binary) / cast(1 as decimal(5, 0)) FROM t
+-- !query 294 schema
+struct<>
+-- !query 294 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 295
+SELECT cast('1' as binary) / cast(1 as decimal(10, 0)) FROM t
+-- !query 295 schema
+struct<>
+-- !query 295 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 296
+SELECT cast('1' as binary) / cast(1 as decimal(20, 0)) FROM t
+-- !query 296 schema
+struct<>
+-- !query 296 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 297
+SELECT cast('2017/12/11 09:30:00.0' as timestamp) / cast(1 as decimal(3, 0)) FROM t
+-- !query 297 schema
+struct<>
+-- !query 297 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017/12/11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017/12/11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 298
+SELECT cast('2017/12/11 09:30:00.0' as timestamp) / cast(1 as decimal(5, 0)) FROM t
+-- !query 298 schema
+struct<>
+-- !query 298 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017/12/11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017/12/11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 299
+SELECT cast('2017/12/11 09:30:00.0' as timestamp) / cast(1 as decimal(10, 0)) FROM t
+-- !query 299 schema
+struct<>
+-- !query 299 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017/12/11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017/12/11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 300
+SELECT cast('2017/12/11 09:30:00.0' as timestamp) / cast(1 as decimal(20, 0)) FROM t
+-- !query 300 schema
+struct<>
+-- !query 300 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017/12/11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017/12/11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 301
+SELECT cast('2017/12/11 09:30:00' as date) / cast(1 as decimal(3, 0)) FROM t
+-- !query 301 schema
+struct<>
+-- !query 301 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017/12/11 09:30:00' AS DATE) / CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017/12/11 09:30:00' AS DATE) / CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 302
+SELECT cast('2017/12/11 09:30:00' as date) / cast(1 as decimal(5, 0)) FROM t
+-- !query 302 schema
+struct<>
+-- !query 302 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017/12/11 09:30:00' AS DATE) / CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017/12/11 09:30:00' AS DATE) / CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 303
+SELECT cast('2017/12/11 09:30:00' as date) / cast(1 as decimal(10, 0)) FROM t
+-- !query 303 schema
+struct<>
+-- !query 303 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017/12/11 09:30:00' AS DATE) / CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017/12/11 09:30:00' AS DATE) / CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 304
+SELECT cast('2017/12/11 09:30:00' as date) / cast(1 as decimal(20, 0)) FROM t
+-- !query 304 schema
+struct<>
+-- !query 304 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017/12/11 09:30:00' AS DATE) / CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017/12/11 09:30:00' AS DATE) / CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 305
+SELECT cast(1 as decimal(3, 0)) / cast(1 as tinyint) FROM t
+-- !query 305 schema
+struct<(CAST(1 AS DECIMAL(3,0)) / CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0))):decimal(9,6)>
+-- !query 305 output
+1
+
+
+-- !query 306
+SELECT cast(1 as decimal(5, 0)) / cast(1 as tinyint) FROM t
+-- !query 306 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0)) / CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0))):decimal(11,6)>
+-- !query 306 output
+1
+
+
+-- !query 307
+SELECT cast(1 as decimal(10, 0)) / cast(1 as tinyint) FROM t
+-- !query 307 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) / CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0))):decimal(16,6)>
+-- !query 307 output
+1
+
+
+-- !query 308
+SELECT cast(1 as decimal(20, 0)) / cast(1 as tinyint) FROM t
+-- !query 308 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) / CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0))):decimal(26,6)>
+-- !query 308 output
+1
+
+
+-- !query 309
+SELECT cast(1 as decimal(3, 0)) / cast(1 as smallint) FROM t
+-- !query 309 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0)) / CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0))):decimal(9,6)>
+-- !query 309 output
+1
+
+
+-- !query 310
+SELECT cast(1 as decimal(5, 0)) / cast(1 as smallint) FROM t
+-- !query 310 schema
+struct<(CAST(1 AS DECIMAL(5,0)) / CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0))):decimal(11,6)>
+-- !query 310 output
+1
+
+
+-- !query 311
+SELECT cast(1 as decimal(10, 0)) / cast(1 as smallint) FROM t
+-- !query 311 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) / CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0))):decimal(16,6)>
+-- !query 311 output
+1
+
+
+-- !query 312
+SELECT cast(1 as decimal(20, 0)) / cast(1 as smallint) FROM t
+-- !query 312 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) / CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0))):decimal(26,6)>
+-- !query 312 output
+1
+
+
+-- !query 313
+SELECT cast(1 as decimal(3, 0)) / cast(1 as int) FROM t
+-- !query 313 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) / CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(14,11)>
+-- !query 313 output
+1
+
+
+-- !query 314
+SELECT cast(1 as decimal(5, 0)) / cast(1 as int) FROM t
+-- !query 314 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) / CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(16,11)>
+-- !query 314 output
+1
+
+
+-- !query 315
+SELECT cast(1 as decimal(10, 0)) / cast(1 as int) FROM t
+-- !query 315 schema
+struct<(CAST(1 AS DECIMAL(10,0)) / CAST(CAST(1 AS INT) AS DECIMAL(10,0))):decimal(21,11)>
+-- !query 315 output
+1
+
+
+-- !query 316
+SELECT cast(1 as decimal(20, 0)) / cast(1 as int) FROM t
+-- !query 316 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) / CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0))):decimal(31,11)>
+-- !query 316 output
+1
+
+
+-- !query 317
+SELECT cast(1 as decimal(3, 0)) / cast(1 as bigint) FROM t
+-- !query 317 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0)) / CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(24,21)>
+-- !query 317 output
+1
+
+
+-- !query 318
+SELECT cast(1 as decimal(5, 0)) / cast(1 as bigint) FROM t
+-- !query 318 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0)) / CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(26,21)>
+-- !query 318 output
+1
+
+
+-- !query 319
+SELECT cast(1 as decimal(10, 0)) / cast(1 as bigint) FROM t
+-- !query 319 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) / CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(31,21)>
+-- !query 319 output
+1
+
+
+-- !query 320
+SELECT cast(1 as decimal(20, 0)) / cast(1 as bigint) FROM t
+-- !query 320 schema
+struct<(CAST(1 AS DECIMAL(20,0)) / CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):decimal(38,19)>
+-- !query 320 output
+1
+
+
+-- !query 321
+SELECT cast(1 as decimal(3, 0)) / cast(1 as float) FROM t
+-- !query 321 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 321 output
+1.0
+
+
+-- !query 322
+SELECT cast(1 as decimal(5, 0)) / cast(1 as float) FROM t
+-- !query 322 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 322 output
+1.0
+
+
+-- !query 323
+SELECT cast(1 as decimal(10, 0)) / cast(1 as float) FROM t
+-- !query 323 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 323 output
+1.0
+
+
+-- !query 324
+SELECT cast(1 as decimal(20, 0)) / cast(1 as float) FROM t
+-- !query 324 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 324 output
+1.0
+
+
+-- !query 325
+SELECT cast(1 as decimal(3, 0)) / cast(1 as double) FROM t
+-- !query 325 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) / CAST(1 AS DOUBLE)):double>
+-- !query 325 output
+1.0
+
+
+-- !query 326
+SELECT cast(1 as decimal(5, 0)) / cast(1 as double) FROM t
+-- !query 326 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) / CAST(1 AS DOUBLE)):double>
+-- !query 326 output
+1.0
+
+
+-- !query 327
+SELECT cast(1 as decimal(10, 0)) / cast(1 as double) FROM t
+-- !query 327 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) / CAST(1 AS DOUBLE)):double>
+-- !query 327 output
+1.0
+
+
+-- !query 328
+SELECT cast(1 as decimal(20, 0)) / cast(1 as double) FROM t
+-- !query 328 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) / CAST(1 AS DOUBLE)):double>
+-- !query 328 output
+1.0
+
+
+-- !query 329
+SELECT cast(1 as decimal(3, 0)) / cast(1 as decimal(10, 0)) FROM t
+-- !query 329 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) / CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(14,11)>
+-- !query 329 output
+1
+
+
+-- !query 330
+SELECT cast(1 as decimal(5, 0)) / cast(1 as decimal(10, 0)) FROM t
+-- !query 330 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) / CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(16,11)>
+-- !query 330 output
+1
+
+
+-- !query 331
+SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(10, 0)) FROM t
+-- !query 331 schema
+struct<(CAST(1 AS DECIMAL(10,0)) / CAST(1 AS DECIMAL(10,0))):decimal(21,11)>
+-- !query 331 output
+1
+
+
+-- !query 332
+SELECT cast(1 as decimal(20, 0)) / cast(1 as decimal(10, 0)) FROM t
+-- !query 332 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) / CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):decimal(31,11)>
+-- !query 332 output
+1
+
+
+-- !query 333
+SELECT cast(1 as decimal(3, 0)) / cast(1 as string) FROM t
+-- !query 333 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) / CAST(CAST(1 AS STRING) AS DOUBLE)):double>
+-- !query 333 output
+1.0
+
+
+-- !query 334
+SELECT cast(1 as decimal(5, 0)) / cast(1 as string) FROM t
+-- !query 334 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) / CAST(CAST(1 AS STRING) AS DOUBLE)):double>
+-- !query 334 output
+1.0
+
+
+-- !query 335
+SELECT cast(1 as decimal(10, 0)) / cast(1 as string) FROM t
+-- !query 335 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) / CAST(CAST(1 AS STRING) AS DOUBLE)):double>
+-- !query 335 output
+1.0
+
+
+-- !query 336
+SELECT cast(1 as decimal(20, 0)) / cast(1 as string) FROM t
+-- !query 336 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) / CAST(CAST(1 AS STRING) AS DOUBLE)):double>
+-- !query 336 output
+1.0
+
+
+-- !query 337
+SELECT cast(1 as decimal(3, 0)) / cast('1' as binary) FROM t
+-- !query 337 schema
+struct<>
+-- !query 337 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) / CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7
+
+
+-- !query 338
+SELECT cast(1 as decimal(5, 0)) / cast('1' as binary) FROM t
+-- !query 338 schema
+struct<>
+-- !query 338 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) / CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7
+
+
+-- !query 339
+SELECT cast(1 as decimal(10, 0)) / cast('1' as binary) FROM t
+-- !query 339 schema
+struct<>
+-- !query 339 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) / CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7
+
+
+-- !query 340
+SELECT cast(1 as decimal(20, 0)) / cast('1' as binary) FROM t
+-- !query 340 schema
+struct<>
+-- !query 340 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) / CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7
+
+
+-- !query 341
+SELECT cast(1 as decimal(3, 0)) / cast(1 as boolean) FROM t
+-- !query 341 schema
+struct<>
+-- !query 341 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) / CAST(1 AS BOOLEAN))' (decimal(3,0) and boolean).; line 1 pos 7
+
+
+-- !query 342
+SELECT cast(1 as decimal(5, 0)) / cast(1 as boolean) FROM t
+-- !query 342 schema
+struct<>
+-- !query 342 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) / CAST(1 AS BOOLEAN))' (decimal(5,0) and boolean).; line 1 pos 7
+
+
+-- !query 343
+SELECT cast(1 as decimal(10, 0)) / cast(1 as boolean) FROM t
+-- !query 343 schema
+struct<>
+-- !query 343 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) / CAST(1 AS BOOLEAN))' (decimal(10,0) and boolean).; line 1 pos 7
+
+
+-- !query 344
+SELECT cast(1 as decimal(20, 0)) / cast(1 as boolean) FROM t
+-- !query 344 schema
+struct<>
+-- !query 344 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) / CAST(1 AS BOOLEAN))' (decimal(20,0) and boolean).; line 1 pos 7
+
+
+-- !query 345
+SELECT cast(1 as decimal(3, 0)) / cast('2017/12/11 09:30:00.0' as timestamp) FROM t
+-- !query 345 schema
+struct<>
+-- !query 345 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) / CAST('2017/12/11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) / CAST('2017/12/11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7
+
+
+-- !query 346
+SELECT cast(1 as decimal(5, 0)) / cast('2017/12/11 09:30:00.0' as timestamp) FROM t
+-- !query 346 schema
+struct<>
+-- !query 346 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) / CAST('2017/12/11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) / CAST('2017/12/11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7
+
+
+-- !query 347
+SELECT cast(1 as decimal(10, 0)) / cast('2017/12/11 09:30:00.0' as timestamp) FROM t
+-- !query 347 schema
+struct<>
+-- !query 347 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) / CAST('2017/12/11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) / CAST('2017/12/11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7
+
+
+-- !query 348
+SELECT cast(1 as decimal(20, 0)) / cast('2017/12/11 09:30:00.0' as timestamp) FROM t
+-- !query 348 schema
+struct<>
+-- !query 348 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) / CAST('2017/12/11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) / CAST('2017/12/11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7
+
+
+-- !query 349
+SELECT cast(1 as decimal(3, 0)) / cast('2017/12/11 09:30:00' as date) FROM t
+-- !query 349 schema
+struct<>
+-- !query 349 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) / CAST('2017/12/11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) / CAST('2017/12/11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7
+
+
+-- !query 350
+SELECT cast(1 as decimal(5, 0)) / cast('2017/12/11 09:30:00' as date) FROM t
+-- !query 350 schema
+struct<>
+-- !query 350 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) / CAST('2017/12/11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) / CAST('2017/12/11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7
+
+
+-- !query 351
+SELECT cast(1 as decimal(10, 0)) / cast('2017/12/11 09:30:00' as date) FROM t
+-- !query 351 schema
+struct<>
+-- !query 351 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) / CAST('2017/12/11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) / CAST('2017/12/11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7
+
+
+-- !query 352
+SELECT cast(1 as decimal(20, 0)) / cast('2017/12/11 09:30:00' as date) FROM t
+-- !query 352 schema
+struct<>
+-- !query 352 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) / CAST('2017/12/11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) / CAST('2017/12/11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7
+
+
+-- !query 353
+SELECT cast(1 as tinyint) % cast(1 as decimal(3, 0)) FROM t
+-- !query 353 schema
+struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) % CAST(1 AS DECIMAL(3,0))):decimal(3,0)>
+-- !query 353 output
+0
+
+
+-- !query 354
+SELECT cast(1 as tinyint) % cast(1 as decimal(5, 0)) FROM t
+-- !query 354 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0)) % CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0))):decimal(3,0)>
+-- !query 354 output
+0
+
+
+-- !query 355
+SELECT cast(1 as tinyint) % cast(1 as decimal(10, 0)) FROM t
+-- !query 355 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)) % CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(3,0)>
+-- !query 355 output
+0
+
+
+-- !query 356
+SELECT cast(1 as tinyint) % cast(1 as decimal(20, 0)) FROM t
+-- !query 356 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0)) % CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(3,0)>
+-- !query 356 output
+0
+
+
+-- !query 357
+SELECT cast(1 as smallint) % cast(1 as decimal(3, 0)) FROM t
+-- !query 357 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0)) % CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0))):decimal(3,0)>
+-- !query 357 output
+0
+
+
+-- !query 358
+SELECT cast(1 as smallint) % cast(1 as decimal(5, 0)) FROM t
+-- !query 358 schema
+struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) % CAST(1 AS DECIMAL(5,0))):decimal(5,0)>
+-- !query 358 output
+0
+
+
+-- !query 359
+SELECT cast(1 as smallint) % cast(1 as decimal(10, 0)) FROM t
+-- !query 359 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)) % CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(5,0)>
+-- !query 359 output
+0
+
+
+-- !query 360
+SELECT cast(1 as smallint) % cast(1 as decimal(20, 0)) FROM t
+-- !query 360 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0)) % CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(5,0)>
+-- !query 360 output
+0
+
+
+-- !query 361
+SELECT cast(1 as int) % cast(1 as decimal(3, 0)) FROM t
+-- !query 361 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) % CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):decimal(3,0)>
+-- !query 361 output
+0
+
+
+-- !query 362
+SELECT cast(1 as int) % cast(1 as decimal(5, 0)) FROM t
+-- !query 362 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) % CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):decimal(5,0)>
+-- !query 362 output
+0
+
+
+-- !query 363
+SELECT cast(1 as int) % cast(1 as decimal(10, 0)) FROM t
+-- !query 363 schema
+struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) % CAST(1 AS DECIMAL(10,0))):decimal(10,0)>
+-- !query 363 output
+0
+
+
+-- !query 364
+SELECT cast(1 as int) % cast(1 as decimal(20, 0)) FROM t
+-- !query 364 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0)) % CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(10,0)>
+-- !query 364 output
+0
+
+
+-- !query 365
+SELECT cast(1 as bigint) % cast(1 as decimal(3, 0)) FROM t
+-- !query 365 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) % CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0))):decimal(3,0)>
+-- !query 365 output
+0
+
+
+-- !query 366
+SELECT cast(1 as bigint) % cast(1 as decimal(5, 0)) FROM t
+-- !query 366 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) % CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0))):decimal(5,0)>
+-- !query 366 output
+0
+
+
+-- !query 367
+SELECT cast(1 as bigint) % cast(1 as decimal(10, 0)) FROM t
+-- !query 367 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) % CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):decimal(10,0)>
+-- !query 367 output
+0
+
+
+-- !query 368
+SELECT cast(1 as bigint) % cast(1 as decimal(20, 0)) FROM t
+-- !query 368 schema
+struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) % CAST(1 AS DECIMAL(20,0))):decimal(20,0)>
+-- !query 368 output
+0
+
+
+-- !query 369
+SELECT cast(1 as float) % cast(1 as decimal(3, 0)) FROM t
+-- !query 369 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) % CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):double>
+-- !query 369 output
+0.0
+
+
+-- !query 370
+SELECT cast(1 as float) % cast(1 as decimal(5, 0)) FROM t
+-- !query 370 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) % CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):double>
+-- !query 370 output
+0.0
+
+
+-- !query 371
+SELECT cast(1 as float) % cast(1 as decimal(10, 0)) FROM t
+-- !query 371 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) % CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double>
+-- !query 371 output
+0.0
+
+
+-- !query 372
+SELECT cast(1 as float) % cast(1 as decimal(20, 0)) FROM t
+-- !query 372 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) % CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):double>
+-- !query 372 output
+0.0
+
+
+-- !query 373
+SELECT cast(1 as double) % cast(1 as decimal(3, 0)) FROM t
+-- !query 373 schema
+struct<(CAST(1 AS DOUBLE) % CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):double>
+-- !query 373 output
+0.0
+
+
+-- !query 374
+SELECT cast(1 as double) % cast(1 as decimal(5, 0)) FROM t
+-- !query 374 schema
+struct<(CAST(1 AS DOUBLE) % CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):double>
+-- !query 374 output
+0.0
+
+
+-- !query 375
+SELECT cast(1 as double) % cast(1 as decimal(10, 0)) FROM t
+-- !query 375 schema
+struct<(CAST(1 AS DOUBLE) % CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double>
+-- !query 375 output
+0.0
+
+
+-- !query 376
+SELECT cast(1 as double) % cast(1 as decimal(20, 0)) FROM t
+-- !query 376 schema
+struct<(CAST(1 AS DOUBLE) % CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):double>
+-- !query 376 output
+0.0
+
+
+-- !query 377
+SELECT cast(1 as decimal(10, 0)) % cast(1 as decimal(3, 0)) FROM t
+-- !query 377 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) % CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):decimal(3,0)>
+-- !query 377 output
+0
+
+
+-- !query 378
+SELECT cast(1 as decimal(10, 0)) % cast(1 as decimal(5, 0)) FROM t
+-- !query 378 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) % CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):decimal(5,0)>
+-- !query 378 output
+0
+
+
+-- !query 379
+SELECT cast(1 as decimal(10, 0)) % cast(1 as decimal(10, 0)) FROM t
+-- !query 379 schema
+struct<(CAST(1 AS DECIMAL(10,0)) % CAST(1 AS DECIMAL(10,0))):decimal(10,0)>
+-- !query 379 output
+0
+
+
+-- !query 380
+SELECT cast(1 as decimal(10, 0)) % cast(1 as decimal(20, 0)) FROM t
+-- !query 380 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) % CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(10,0)>
+-- !query 380 output
+0
+
+
+-- !query 381
+SELECT cast('1' as binary) % cast(1 as decimal(3, 0)) FROM t
+-- !query 381 schema
+struct<>
+-- !query 381 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) % CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) % CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 382
+SELECT cast('1' as binary) % cast(1 as decimal(5, 0)) FROM t
+-- !query 382 schema
+struct<>
+-- !query 382 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) % CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) % CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 383
+SELECT cast('1' as binary) % cast(1 as decimal(10, 0)) FROM t
+-- !query 383 schema
+struct<>
+-- !query 383 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) % CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) % CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 384
+SELECT cast('1' as binary) % cast(1 as decimal(20, 0)) FROM t
+-- !query 384 schema
+struct<>
+-- !query 384 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) % CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) % CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 385
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) % cast(1 as decimal(3, 0)) FROM t
+-- !query 385 schema
+struct<>
+-- !query 385 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) % CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) % CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 386
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) % cast(1 as decimal(5, 0)) FROM t
+-- !query 386 schema
+struct<>
+-- !query 386 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) % CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) % CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 387
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) % cast(1 as decimal(10, 0)) FROM t
+-- !query 387 schema
+struct<>
+-- !query 387 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) % CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) % CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 388
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) % cast(1 as decimal(20, 0)) FROM t
+-- !query 388 schema
+struct<>
+-- !query 388 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) % CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) % CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 389
+SELECT cast('2017-12-11 09:30:00' as date) % cast(1 as decimal(3, 0)) FROM t
+-- !query 389 schema
+struct<>
+-- !query 389 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) % CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) % CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 390
+SELECT cast('2017-12-11 09:30:00' as date) % cast(1 as decimal(5, 0)) FROM t
+-- !query 390 schema
+struct<>
+-- !query 390 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) % CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) % CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 391
+SELECT cast('2017-12-11 09:30:00' as date) % cast(1 as decimal(10, 0)) FROM t
+-- !query 391 schema
+struct<>
+-- !query 391 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) % CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) % CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 392
+SELECT cast('2017-12-11 09:30:00' as date) % cast(1 as decimal(20, 0)) FROM t
+-- !query 392 schema
+struct<>
+-- !query 392 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) % CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) % CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 393
+SELECT cast(1 as decimal(3, 0)) % cast(1 as tinyint) FROM t
+-- !query 393 schema
+struct<(CAST(1 AS DECIMAL(3,0)) % CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0))):decimal(3,0)>
+-- !query 393 output
+0
+
+
+-- !query 394
+SELECT cast(1 as decimal(5, 0)) % cast(1 as tinyint) FROM t
+-- !query 394 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0)) % CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0))):decimal(3,0)>
+-- !query 394 output
+0
+
+
+-- !query 395
+SELECT cast(1 as decimal(10, 0)) % cast(1 as tinyint) FROM t
+-- !query 395 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) % CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0))):decimal(3,0)>
+-- !query 395 output
+0
+
+
+-- !query 396
+SELECT cast(1 as decimal(20, 0)) % cast(1 as tinyint) FROM t
+-- !query 396 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) % CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0))):decimal(3,0)>
+-- !query 396 output
+0
+
+
+-- !query 397
+SELECT cast(1 as decimal(3, 0)) % cast(1 as smallint) FROM t
+-- !query 397 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0)) % CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0))):decimal(3,0)>
+-- !query 397 output
+0
+
+
+-- !query 398
+SELECT cast(1 as decimal(5, 0)) % cast(1 as smallint) FROM t
+-- !query 398 schema
+struct<(CAST(1 AS DECIMAL(5,0)) % CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0))):decimal(5,0)>
+-- !query 398 output
+0
+
+
+-- !query 399
+SELECT cast(1 as decimal(10, 0)) % cast(1 as smallint) FROM t
+-- !query 399 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) % CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0))):decimal(5,0)>
+-- !query 399 output
+0
+
+
+-- !query 400
+SELECT cast(1 as decimal(20, 0)) % cast(1 as smallint) FROM t
+-- !query 400 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) % CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0))):decimal(5,0)>
+-- !query 400 output
+0
+
+
+-- !query 401
+SELECT cast(1 as decimal(3, 0)) % cast(1 as int) FROM t
+-- !query 401 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) % CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(3,0)>
+-- !query 401 output
+0
+
+
+-- !query 402
+SELECT cast(1 as decimal(5, 0)) % cast(1 as int) FROM t
+-- !query 402 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) % CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(5,0)>
+-- !query 402 output
+0
+
+
+-- !query 403
+SELECT cast(1 as decimal(10, 0)) % cast(1 as int) FROM t
+-- !query 403 schema
+struct<(CAST(1 AS DECIMAL(10,0)) % CAST(CAST(1 AS INT) AS DECIMAL(10,0))):decimal(10,0)>
+-- !query 403 output
+0
+
+
+-- !query 404
+SELECT cast(1 as decimal(20, 0)) % cast(1 as int) FROM t
+-- !query 404 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) % CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0))):decimal(10,0)>
+-- !query 404 output
+0
+
+
+-- !query 405
+SELECT cast(1 as decimal(3, 0)) % cast(1 as bigint) FROM t
+-- !query 405 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0)) % CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(3,0)>
+-- !query 405 output
+0
+
+
+-- !query 406
+SELECT cast(1 as decimal(5, 0)) % cast(1 as bigint) FROM t
+-- !query 406 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0)) % CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(5,0)>
+-- !query 406 output
+0
+
+
+-- !query 407
+SELECT cast(1 as decimal(10, 0)) % cast(1 as bigint) FROM t
+-- !query 407 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) % CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(10,0)>
+-- !query 407 output
+0
+
+
+-- !query 408
+SELECT cast(1 as decimal(20, 0)) % cast(1 as bigint) FROM t
+-- !query 408 schema
+struct<(CAST(1 AS DECIMAL(20,0)) % CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):decimal(20,0)>
+-- !query 408 output
+0
+
+
+-- !query 409
+SELECT cast(1 as decimal(3, 0)) % cast(1 as float) FROM t
+-- !query 409 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) % CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 409 output
+0.0
+
+
+-- !query 410
+SELECT cast(1 as decimal(5, 0)) % cast(1 as float) FROM t
+-- !query 410 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) % CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 410 output
+0.0
+
+
+-- !query 411
+SELECT cast(1 as decimal(10, 0)) % cast(1 as float) FROM t
+-- !query 411 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) % CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 411 output
+0.0
+
+
+-- !query 412
+SELECT cast(1 as decimal(20, 0)) % cast(1 as float) FROM t
+-- !query 412 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) % CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 412 output
+0.0
+
+
+-- !query 413
+SELECT cast(1 as decimal(3, 0)) % cast(1 as double) FROM t
+-- !query 413 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) % CAST(1 AS DOUBLE)):double>
+-- !query 413 output
+0.0
+
+
+-- !query 414
+SELECT cast(1 as decimal(5, 0)) % cast(1 as double) FROM t
+-- !query 414 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) % CAST(1 AS DOUBLE)):double>
+-- !query 414 output
+0.0
+
+
+-- !query 415
+SELECT cast(1 as decimal(10, 0)) % cast(1 as double) FROM t
+-- !query 415 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) % CAST(1 AS DOUBLE)):double>
+-- !query 415 output
+0.0
+
+
+-- !query 416
+SELECT cast(1 as decimal(20, 0)) % cast(1 as double) FROM t
+-- !query 416 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) % CAST(1 AS DOUBLE)):double>
+-- !query 416 output
+0.0
+
+
+-- !query 417
+SELECT cast(1 as decimal(3, 0)) % cast(1 as decimal(10, 0)) FROM t
+-- !query 417 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) % CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(3,0)>
+-- !query 417 output
+0
+
+
+-- !query 418
+SELECT cast(1 as decimal(5, 0)) % cast(1 as decimal(10, 0)) FROM t
+-- !query 418 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) % CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(5,0)>
+-- !query 418 output
+0
+
+
+-- !query 419
+SELECT cast(1 as decimal(10, 0)) % cast(1 as decimal(10, 0)) FROM t
+-- !query 419 schema
+struct<(CAST(1 AS DECIMAL(10,0)) % CAST(1 AS DECIMAL(10,0))):decimal(10,0)>
+-- !query 419 output
+0
+
+
+-- !query 420
+SELECT cast(1 as decimal(20, 0)) % cast(1 as decimal(10, 0)) FROM t
+-- !query 420 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) % CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):decimal(10,0)>
+-- !query 420 output
+0
+
+
+-- !query 421
+SELECT cast(1 as decimal(3, 0)) % cast(1 as string) FROM t
+-- !query 421 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) % CAST(CAST(1 AS STRING) AS DOUBLE)):double>
+-- !query 421 output
+0.0
+
+
+-- !query 422
+SELECT cast(1 as decimal(5, 0)) % cast(1 as string) FROM t
+-- !query 422 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) % CAST(CAST(1 AS STRING) AS DOUBLE)):double>
+-- !query 422 output
+0.0
+
+
+-- !query 423
+SELECT cast(1 as decimal(10, 0)) % cast(1 as string) FROM t
+-- !query 423 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) % CAST(CAST(1 AS STRING) AS DOUBLE)):double>
+-- !query 423 output
+0.0
+
+
+-- !query 424
+SELECT cast(1 as decimal(20, 0)) % cast(1 as string) FROM t
+-- !query 424 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) % CAST(CAST(1 AS STRING) AS DOUBLE)):double>
+-- !query 424 output
+0.0
+
+
+-- !query 425
+SELECT cast(1 as decimal(3, 0)) % cast('1' as binary) FROM t
+-- !query 425 schema
+struct<>
+-- !query 425 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) % CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) % CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7
+
+
+-- !query 426
+SELECT cast(1 as decimal(5, 0)) % cast('1' as binary) FROM t
+-- !query 426 schema
+struct<>
+-- !query 426 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) % CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) % CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7
+
+
+-- !query 427
+SELECT cast(1 as decimal(10, 0)) % cast('1' as binary) FROM t
+-- !query 427 schema
+struct<>
+-- !query 427 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) % CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) % CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7
+
+
+-- !query 428
+SELECT cast(1 as decimal(20, 0)) % cast('1' as binary) FROM t
+-- !query 428 schema
+struct<>
+-- !query 428 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) % CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) % CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7
+
+
+-- !query 429
+SELECT cast(1 as decimal(3, 0)) % cast(1 as boolean) FROM t
+-- !query 429 schema
+struct<>
+-- !query 429 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) % CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) % CAST(1 AS BOOLEAN))' (decimal(3,0) and boolean).; line 1 pos 7
+
+
+-- !query 430
+SELECT cast(1 as decimal(5, 0)) % cast(1 as boolean) FROM t
+-- !query 430 schema
+struct<>
+-- !query 430 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) % CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) % CAST(1 AS BOOLEAN))' (decimal(5,0) and boolean).; line 1 pos 7
+
+
+-- !query 431
+SELECT cast(1 as decimal(10, 0)) % cast(1 as boolean) FROM t
+-- !query 431 schema
+struct<>
+-- !query 431 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) % CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) % CAST(1 AS BOOLEAN))' (decimal(10,0) and boolean).; line 1 pos 7
+
+
+-- !query 432
+SELECT cast(1 as decimal(20, 0)) % cast(1 as boolean) FROM t
+-- !query 432 schema
+struct<>
+-- !query 432 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) % CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) % CAST(1 AS BOOLEAN))' (decimal(20,0) and boolean).; line 1 pos 7
+
+
+-- !query 433
+SELECT cast(1 as decimal(3, 0)) % cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 433 schema
+struct<>
+-- !query 433 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) % CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) % CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7
+
+
+-- !query 434
+SELECT cast(1 as decimal(5, 0)) % cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 434 schema
+struct<>
+-- !query 434 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) % CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) % CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7
+
+
+-- !query 435
+SELECT cast(1 as decimal(10, 0)) % cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 435 schema
+struct<>
+-- !query 435 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) % CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) % CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7
+
+
+-- !query 436
+SELECT cast(1 as decimal(20, 0)) % cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 436 schema
+struct<>
+-- !query 436 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) % CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) % CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7
+
+
+-- !query 437
+SELECT cast(1 as decimal(3, 0)) % cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 437 schema
+struct<>
+-- !query 437 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) % CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) % CAST('2017-12-11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7
+
+
+-- !query 438
+SELECT cast(1 as decimal(5, 0)) % cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 438 schema
+struct<>
+-- !query 438 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) % CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) % CAST('2017-12-11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7
+
+
+-- !query 439
+SELECT cast(1 as decimal(10, 0)) % cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 439 schema
+struct<>
+-- !query 439 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) % CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) % CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7
+
+
+-- !query 440
+SELECT cast(1 as decimal(20, 0)) % cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 440 schema
+struct<>
+-- !query 440 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) % CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) % CAST('2017-12-11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7
+
+
+-- !query 441
+SELECT pmod(cast(1 as tinyint), cast(1 as decimal(3, 0))) FROM t
+-- !query 441 schema
+struct
+-- !query 441 output
+0
+
+
+-- !query 442
+SELECT pmod(cast(1 as tinyint), cast(1 as decimal(5, 0))) FROM t
+-- !query 442 schema
+struct
+-- !query 442 output
+0
+
+
+-- !query 443
+SELECT pmod(cast(1 as tinyint), cast(1 as decimal(10, 0))) FROM t
+-- !query 443 schema
+struct
+-- !query 443 output
+0
+
+
+-- !query 444
+SELECT pmod(cast(1 as tinyint), cast(1 as decimal(20, 0))) FROM t
+-- !query 444 schema
+struct
+-- !query 444 output
+0
+
+
+-- !query 445
+SELECT pmod(cast(1 as smallint), cast(1 as decimal(3, 0))) FROM t
+-- !query 445 schema
+struct
+-- !query 445 output
+0
+
+
+-- !query 446
+SELECT pmod(cast(1 as smallint), cast(1 as decimal(5, 0))) FROM t
+-- !query 446 schema
+struct
+-- !query 446 output
+0
+
+
+-- !query 447
+SELECT pmod(cast(1 as smallint), cast(1 as decimal(10, 0))) FROM t
+-- !query 447 schema
+struct
+-- !query 447 output
+0
+
+
+-- !query 448
+SELECT pmod(cast(1 as smallint), cast(1 as decimal(20, 0))) FROM t
+-- !query 448 schema
+struct
+-- !query 448 output
+0
+
+
+-- !query 449
+SELECT pmod(cast(1 as int), cast(1 as decimal(3, 0))) FROM t
+-- !query 449 schema
+struct
+-- !query 449 output
+0
+
+
+-- !query 450
+SELECT pmod(cast(1 as int), cast(1 as decimal(5, 0))) FROM t
+-- !query 450 schema
+struct
+-- !query 450 output
+0
+
+
+-- !query 451
+SELECT pmod(cast(1 as int), cast(1 as decimal(10, 0))) FROM t
+-- !query 451 schema
+struct
+-- !query 451 output
+0
+
+
+-- !query 452
+SELECT pmod(cast(1 as int), cast(1 as decimal(20, 0))) FROM t
+-- !query 452 schema
+struct
+-- !query 452 output
+0
+
+
+-- !query 453
+SELECT pmod(cast(1 as bigint), cast(1 as decimal(3, 0))) FROM t
+-- !query 453 schema
+struct
+-- !query 453 output
+0
+
+
+-- !query 454
+SELECT pmod(cast(1 as bigint), cast(1 as decimal(5, 0))) FROM t
+-- !query 454 schema
+struct
+-- !query 454 output
+0
+
+
+-- !query 455
+SELECT pmod(cast(1 as bigint), cast(1 as decimal(10, 0))) FROM t
+-- !query 455 schema
+struct
+-- !query 455 output
+0
+
+
+-- !query 456
+SELECT pmod(cast(1 as bigint), cast(1 as decimal(20, 0))) FROM t
+-- !query 456 schema
+struct
+-- !query 456 output
+0
+
+
+-- !query 457
+SELECT pmod(cast(1 as float), cast(1 as decimal(3, 0))) FROM t
+-- !query 457 schema
+struct
+-- !query 457 output
+0.0
+
+
+-- !query 458
+SELECT pmod(cast(1 as float), cast(1 as decimal(5, 0))) FROM t
+-- !query 458 schema
+struct
+-- !query 458 output
+0.0
+
+
+-- !query 459
+SELECT pmod(cast(1 as float), cast(1 as decimal(10, 0))) FROM t
+-- !query 459 schema
+struct
+-- !query 459 output
+0.0
+
+
+-- !query 460
+SELECT pmod(cast(1 as float), cast(1 as decimal(20, 0))) FROM t
+-- !query 460 schema
+struct
+-- !query 460 output
+0.0
+
+
+-- !query 461
+SELECT pmod(cast(1 as double), cast(1 as decimal(3, 0))) FROM t
+-- !query 461 schema
+struct
+-- !query 461 output
+0.0
+
+
+-- !query 462
+SELECT pmod(cast(1 as double), cast(1 as decimal(5, 0))) FROM t
+-- !query 462 schema
+struct
+-- !query 462 output
+0.0
+
+
+-- !query 463
+SELECT pmod(cast(1 as double), cast(1 as decimal(10, 0))) FROM t
+-- !query 463 schema
+struct
+-- !query 463 output
+0.0
+
+
+-- !query 464
+SELECT pmod(cast(1 as double), cast(1 as decimal(20, 0))) FROM t
+-- !query 464 schema
+struct
+-- !query 464 output
+0.0
+
+
+-- !query 465
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as decimal(3, 0))) FROM t
+-- !query 465 schema
+struct
+-- !query 465 output
+0
+
+
+-- !query 466
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as decimal(5, 0))) FROM t
+-- !query 466 schema
+struct
+-- !query 466 output
+0
+
+
+-- !query 467
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as decimal(10, 0))) FROM t
+-- !query 467 schema
+struct
+-- !query 467 output
+0
+
+
+-- !query 468
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as decimal(20, 0))) FROM t
+-- !query 468 schema
+struct
+-- !query 468 output
+0
+
+
+-- !query 469
+SELECT pmod(cast('1' as binary), cast(1 as decimal(3, 0))) FROM t
+-- !query 469 schema
+struct<>
+-- !query 469 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST('1' AS BINARY), CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in 'pmod(CAST('1' AS BINARY), CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 470
+SELECT pmod(cast('1' as binary), cast(1 as decimal(5, 0))) FROM t
+-- !query 470 schema
+struct<>
+-- !query 470 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST('1' AS BINARY), CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in 'pmod(CAST('1' AS BINARY), CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 471
+SELECT pmod(cast('1' as binary), cast(1 as decimal(10, 0))) FROM t
+-- !query 471 schema
+struct<>
+-- !query 471 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST('1' AS BINARY), CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in 'pmod(CAST('1' AS BINARY), CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 472
+SELECT pmod(cast('1' as binary), cast(1 as decimal(20, 0))) FROM t
+-- !query 472 schema
+struct<>
+-- !query 472 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST('1' AS BINARY), CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in 'pmod(CAST('1' AS BINARY), CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 473
+SELECT pmod(cast('2017-12-11 09:30:00.0' as timestamp), cast(1 as decimal(3, 0))) FROM t
+-- !query 473 schema
+struct<>
+-- !query 473 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP), CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in 'pmod(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP), CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 474
+SELECT pmod(cast('2017-12-11 09:30:00.0' as timestamp), cast(1 as decimal(5, 0))) FROM t
+-- !query 474 schema
+struct<>
+-- !query 474 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP), CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in 'pmod(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP), CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 475
+SELECT pmod(cast('2017-12-11 09:30:00.0' as timestamp), cast(1 as decimal(10, 0))) FROM t
+-- !query 475 schema
+struct<>
+-- !query 475 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP), CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in 'pmod(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP), CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 476
+SELECT pmod(cast('2017-12-11 09:30:00.0' as timestamp), cast(1 as decimal(20, 0))) FROM t
+-- !query 476 schema
+struct<>
+-- !query 476 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP), CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in 'pmod(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP), CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 477
+SELECT pmod(cast('2017-12-11 09:30:00' as date), cast(1 as decimal(3, 0))) FROM t
+-- !query 477 schema
+struct<>
+-- !query 477 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST('2017-12-11 09:30:00' AS DATE), CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in 'pmod(CAST('2017-12-11 09:30:00' AS DATE), CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 478
+SELECT pmod(cast('2017-12-11 09:30:00' as date), cast(1 as decimal(5, 0))) FROM t
+-- !query 478 schema
+struct<>
+-- !query 478 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST('2017-12-11 09:30:00' AS DATE), CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in 'pmod(CAST('2017-12-11 09:30:00' AS DATE), CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 479
+SELECT pmod(cast('2017-12-11 09:30:00' as date), cast(1 as decimal(10, 0))) FROM t
+-- !query 479 schema
+struct<>
+-- !query 479 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST('2017-12-11 09:30:00' AS DATE), CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in 'pmod(CAST('2017-12-11 09:30:00' AS DATE), CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 480
+SELECT pmod(cast('2017-12-11 09:30:00' as date), cast(1 as decimal(20, 0))) FROM t
+-- !query 480 schema
+struct<>
+-- !query 480 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST('2017-12-11 09:30:00' AS DATE), CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in 'pmod(CAST('2017-12-11 09:30:00' AS DATE), CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 481
+SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as tinyint)) FROM t
+-- !query 481 schema
+struct
+-- !query 481 output
+0
+
+
+-- !query 482
+SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as tinyint)) FROM t
+-- !query 482 schema
+struct
+-- !query 482 output
+0
+
+
+-- !query 483
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as tinyint)) FROM t
+-- !query 483 schema
+struct
+-- !query 483 output
+0
+
+
+-- !query 484
+SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as tinyint)) FROM t
+-- !query 484 schema
+struct
+-- !query 484 output
+0
+
+
+-- !query 485
+SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as smallint)) FROM t
+-- !query 485 schema
+struct
+-- !query 485 output
+0
+
+
+-- !query 486
+SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as smallint)) FROM t
+-- !query 486 schema
+struct
+-- !query 486 output
+0
+
+
+-- !query 487
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as smallint)) FROM t
+-- !query 487 schema
+struct
+-- !query 487 output
+0
+
+
+-- !query 488
+SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as smallint)) FROM t
+-- !query 488 schema
+struct
+-- !query 488 output
+0
+
+
+-- !query 489
+SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as int)) FROM t
+-- !query 489 schema
+struct
+-- !query 489 output
+0
+
+
+-- !query 490
+SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as int)) FROM t
+-- !query 490 schema
+struct
+-- !query 490 output
+0
+
+
+-- !query 491
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as int)) FROM t
+-- !query 491 schema
+struct
+-- !query 491 output
+0
+
+
+-- !query 492
+SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as int)) FROM t
+-- !query 492 schema
+struct
+-- !query 492 output
+0
+
+
+-- !query 493
+SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as bigint)) FROM t
+-- !query 493 schema
+struct
+-- !query 493 output
+0
+
+
+-- !query 494
+SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as bigint)) FROM t
+-- !query 494 schema
+struct
+-- !query 494 output
+0
+
+
+-- !query 495
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as bigint)) FROM t
+-- !query 495 schema
+struct
+-- !query 495 output
+0
+
+
+-- !query 496
+SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as bigint)) FROM t
+-- !query 496 schema
+struct
+-- !query 496 output
+0
+
+
+-- !query 497
+SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as float)) FROM t
+-- !query 497 schema
+struct
+-- !query 497 output
+0.0
+
+
+-- !query 498
+SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as float)) FROM t
+-- !query 498 schema
+struct
+-- !query 498 output
+0.0
+
+
+-- !query 499
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as float)) FROM t
+-- !query 499 schema
+struct
+-- !query 499 output
+0.0
+
+
+-- !query 500
+SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as float)) FROM t
+-- !query 500 schema
+struct
+-- !query 500 output
+0.0
+
+
+-- !query 501
+SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as double)) FROM t
+-- !query 501 schema
+struct
+-- !query 501 output
+0.0
+
+
+-- !query 502
+SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as double)) FROM t
+-- !query 502 schema
+struct
+-- !query 502 output
+0.0
+
+
+-- !query 503
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as double)) FROM t
+-- !query 503 schema
+struct
+-- !query 503 output
+0.0
+
+
+-- !query 504
+SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as double)) FROM t
+-- !query 504 schema
+struct
+-- !query 504 output
+0.0
+
+
+-- !query 505
+SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as decimal(10, 0))) FROM t
+-- !query 505 schema
+struct
+-- !query 505 output
+0
+
+
+-- !query 506
+SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as decimal(10, 0))) FROM t
+-- !query 506 schema
+struct
+-- !query 506 output
+0
+
+
+-- !query 507
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as decimal(10, 0))) FROM t
+-- !query 507 schema
+struct
+-- !query 507 output
+0
+
+
+-- !query 508
+SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as decimal(10, 0))) FROM t
+-- !query 508 schema
+struct
+-- !query 508 output
+0
+
+
+-- !query 509
+SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as string)) FROM t
+-- !query 509 schema
+struct
+-- !query 509 output
+0.0
+
+
+-- !query 510
+SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as string)) FROM t
+-- !query 510 schema
+struct
+-- !query 510 output
+0.0
+
+
+-- !query 511
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as string)) FROM t
+-- !query 511 schema
+struct
+-- !query 511 output
+0.0
+
+
+-- !query 512
+SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as string)) FROM t
+-- !query 512 schema
+struct
+-- !query 512 output
+0.0
+
+
+-- !query 513
+SELECT pmod(cast(1 as decimal(3, 0)) , cast('1' as binary)) FROM t
+-- !query 513 schema
+struct<>
+-- !query 513 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST(1 AS DECIMAL(3,0)), CAST('1' AS BINARY))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(3,0)), CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7
+
+
+-- !query 514
+SELECT pmod(cast(1 as decimal(5, 0)) , cast('1' as binary)) FROM t
+-- !query 514 schema
+struct<>
+-- !query 514 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST(1 AS DECIMAL(5,0)), CAST('1' AS BINARY))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(5,0)), CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7
+
+
+-- !query 515
+SELECT pmod(cast(1 as decimal(10, 0)), cast('1' as binary)) FROM t
+-- !query 515 schema
+struct<>
+-- !query 515 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST(1 AS DECIMAL(10,0)), CAST('1' AS BINARY))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(10,0)), CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7
+
+
+-- !query 516
+SELECT pmod(cast(1 as decimal(20, 0)), cast('1' as binary)) FROM t
+-- !query 516 schema
+struct<>
+-- !query 516 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST(1 AS DECIMAL(20,0)), CAST('1' AS BINARY))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(20,0)), CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7
+
+
+-- !query 517
+SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as boolean)) FROM t
+-- !query 517 schema
+struct<>
+-- !query 517 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST(1 AS DECIMAL(3,0)), CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(3,0)), CAST(1 AS BOOLEAN))' (decimal(3,0) and boolean).; line 1 pos 7
+
+
+-- !query 518
+SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as boolean)) FROM t
+-- !query 518 schema
+struct<>
+-- !query 518 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST(1 AS DECIMAL(5,0)), CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(5,0)), CAST(1 AS BOOLEAN))' (decimal(5,0) and boolean).; line 1 pos 7
+
+
+-- !query 519
+SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as boolean)) FROM t
+-- !query 519 schema
+struct<>
+-- !query 519 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST(1 AS DECIMAL(10,0)), CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(10,0)), CAST(1 AS BOOLEAN))' (decimal(10,0) and boolean).; line 1 pos 7
+
+
+-- !query 520
+SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as boolean)) FROM t
+-- !query 520 schema
+struct<>
+-- !query 520 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST(1 AS DECIMAL(20,0)), CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(20,0)), CAST(1 AS BOOLEAN))' (decimal(20,0) and boolean).; line 1 pos 7
+
+
+-- !query 521
+SELECT pmod(cast(1 as decimal(3, 0)) , cast('2017-12-11 09:30:00.0' as timestamp)) FROM t
+-- !query 521 schema
+struct<>
+-- !query 521 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST(1 AS DECIMAL(3,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(3,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7
+
+
+-- !query 522
+SELECT pmod(cast(1 as decimal(5, 0)) , cast('2017-12-11 09:30:00.0' as timestamp)) FROM t
+-- !query 522 schema
+struct<>
+-- !query 522 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST(1 AS DECIMAL(5,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(5,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7
+
+
+-- !query 523
+SELECT pmod(cast(1 as decimal(10, 0)), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t
+-- !query 523 schema
+struct<>
+-- !query 523 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7
+
+
+-- !query 524
+SELECT pmod(cast(1 as decimal(20, 0)), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t
+-- !query 524 schema
+struct<>
+-- !query 524 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST(1 AS DECIMAL(20,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(20,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7
+
+
+-- !query 525
+SELECT pmod(cast(1 as decimal(3, 0)) , cast('2017-12-11 09:30:00' as date)) FROM t
+-- !query 525 schema
+struct<>
+-- !query 525 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST(1 AS DECIMAL(3,0)), CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(3,0)), CAST('2017-12-11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7
+
+
+-- !query 526
+SELECT pmod(cast(1 as decimal(5, 0)) , cast('2017-12-11 09:30:00' as date)) FROM t
+-- !query 526 schema
+struct<>
+-- !query 526 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST(1 AS DECIMAL(5,0)), CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(5,0)), CAST('2017-12-11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7
+
+
+-- !query 527
+SELECT pmod(cast(1 as decimal(10, 0)), cast('2017-12-11 09:30:00' as date)) FROM t
+-- !query 527 schema
+struct<>
+-- !query 527 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7
+
+
+-- !query 528
+SELECT pmod(cast(1 as decimal(20, 0)), cast('2017-12-11 09:30:00' as date)) FROM t
+-- !query 528 schema
+struct<>
+-- !query 528 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'pmod(CAST(1 AS DECIMAL(20,0)), CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(20,0)), CAST('2017-12-11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7
+
+
+-- !query 529
+SELECT cast(1 as tinyint) = cast(1 as decimal(3, 0)) FROM t
+-- !query 529 schema
+struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) = CAST(1 AS DECIMAL(3,0))):boolean>
+-- !query 529 output
+true
+
+
+-- !query 530
+SELECT cast(1 as tinyint) = cast(1 as decimal(5, 0)) FROM t
+-- !query 530 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0)) = CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean>
+-- !query 530 output
+true
+
+
+-- !query 531
+SELECT cast(1 as tinyint) = cast(1 as decimal(10, 0)) FROM t
+-- !query 531 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 531 output
+true
+
+
+-- !query 532
+SELECT cast(1 as tinyint) = cast(1 as decimal(20, 0)) FROM t
+-- !query 532 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 532 output
+true
+
+
+-- !query 533
+SELECT cast(1 as smallint) = cast(1 as decimal(3, 0)) FROM t
+-- !query 533 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0)) = CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean>
+-- !query 533 output
+true
+
+
+-- !query 534
+SELECT cast(1 as smallint) = cast(1 as decimal(5, 0)) FROM t
+-- !query 534 schema
+struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) = CAST(1 AS DECIMAL(5,0))):boolean>
+-- !query 534 output
+true
+
+
+-- !query 535
+SELECT cast(1 as smallint) = cast(1 as decimal(10, 0)) FROM t
+-- !query 535 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 535 output
+true
+
+
+-- !query 536
+SELECT cast(1 as smallint) = cast(1 as decimal(20, 0)) FROM t
+-- !query 536 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 536 output
+true
+
+
+-- !query 537
+SELECT cast(1 as int) = cast(1 as decimal(3, 0)) FROM t
+-- !query 537 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean>
+-- !query 537 output
+true
+
+
+-- !query 538
+SELECT cast(1 as int) = cast(1 as decimal(5, 0)) FROM t
+-- !query 538 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean>
+-- !query 538 output
+true
+
+
+-- !query 539
+SELECT cast(1 as int) = cast(1 as decimal(10, 0)) FROM t
+-- !query 539 schema
+struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) = CAST(1 AS DECIMAL(10,0))):boolean>
+-- !query 539 output
+true
+
+
+-- !query 540
+SELECT cast(1 as int) = cast(1 as decimal(20, 0)) FROM t
+-- !query 540 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 540 output
+true
+
+
+-- !query 541
+SELECT cast(1 as bigint) = cast(1 as decimal(3, 0)) FROM t
+-- !query 541 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean>
+-- !query 541 output
+true
+
+
+-- !query 542
+SELECT cast(1 as bigint) = cast(1 as decimal(5, 0)) FROM t
+-- !query 542 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean>
+-- !query 542 output
+true
+
+
+-- !query 543
+SELECT cast(1 as bigint) = cast(1 as decimal(10, 0)) FROM t
+-- !query 543 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean>
+-- !query 543 output
+true
+
+
+-- !query 544
+SELECT cast(1 as bigint) = cast(1 as decimal(20, 0)) FROM t
+-- !query 544 schema
+struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) = CAST(1 AS DECIMAL(20,0))):boolean>
+-- !query 544 output
+true
+
+
+-- !query 545
+SELECT cast(1 as float) = cast(1 as decimal(3, 0)) FROM t
+-- !query 545 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) = CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean>
+-- !query 545 output
+true
+
+
+-- !query 546
+SELECT cast(1 as float) = cast(1 as decimal(5, 0)) FROM t
+-- !query 546 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) = CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean>
+-- !query 546 output
+true
+
+
+-- !query 547
+SELECT cast(1 as float) = cast(1 as decimal(10, 0)) FROM t
+-- !query 547 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) = CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean>
+-- !query 547 output
+true
+
+
+-- !query 548
+SELECT cast(1 as float) = cast(1 as decimal(20, 0)) FROM t
+-- !query 548 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) = CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean>
+-- !query 548 output
+true
+
+
+-- !query 549
+SELECT cast(1 as double) = cast(1 as decimal(3, 0)) FROM t
+-- !query 549 schema
+struct<(CAST(1 AS DOUBLE) = CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean>
+-- !query 549 output
+true
+
+
+-- !query 550
+SELECT cast(1 as double) = cast(1 as decimal(5, 0)) FROM t
+-- !query 550 schema
+struct<(CAST(1 AS DOUBLE) = CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean>
+-- !query 550 output
+true
+
+
+-- !query 551
+SELECT cast(1 as double) = cast(1 as decimal(10, 0)) FROM t
+-- !query 551 schema
+struct<(CAST(1 AS DOUBLE) = CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean>
+-- !query 551 output
+true
+
+
+-- !query 552
+SELECT cast(1 as double) = cast(1 as decimal(20, 0)) FROM t
+-- !query 552 schema
+struct<(CAST(1 AS DOUBLE) = CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean>
+-- !query 552 output
+true
+
+
+-- !query 553
+SELECT cast(1 as decimal(10, 0)) = cast(1 as decimal(3, 0)) FROM t
+-- !query 553 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean>
+-- !query 553 output
+true
+
+
+-- !query 554
+SELECT cast(1 as decimal(10, 0)) = cast(1 as decimal(5, 0)) FROM t
+-- !query 554 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean>
+-- !query 554 output
+true
+
+
+-- !query 555
+SELECT cast(1 as decimal(10, 0)) = cast(1 as decimal(10, 0)) FROM t
+-- !query 555 schema
+struct<(CAST(1 AS DECIMAL(10,0)) = CAST(1 AS DECIMAL(10,0))):boolean>
+-- !query 555 output
+true
+
+
+-- !query 556
+SELECT cast(1 as decimal(10, 0)) = cast(1 as decimal(20, 0)) FROM t
+-- !query 556 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 556 output
+true
+
+
+-- !query 557
+SELECT cast('1' as binary) = cast(1 as decimal(3, 0)) FROM t
+-- !query 557 schema
+struct<>
+-- !query 557 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 558
+SELECT cast('1' as binary) = cast(1 as decimal(5, 0)) FROM t
+-- !query 558 schema
+struct<>
+-- !query 558 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 559
+SELECT cast('1' as binary) = cast(1 as decimal(10, 0)) FROM t
+-- !query 559 schema
+struct<>
+-- !query 559 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 560
+SELECT cast('1' as binary) = cast(1 as decimal(20, 0)) FROM t
+-- !query 560 schema
+struct<>
+-- !query 560 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 561
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) = cast(1 as decimal(3, 0)) FROM t
+-- !query 561 schema
+struct<>
+-- !query 561 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 562
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) = cast(1 as decimal(5, 0)) FROM t
+-- !query 562 schema
+struct<>
+-- !query 562 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 563
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) = cast(1 as decimal(10, 0)) FROM t
+-- !query 563 schema
+struct<>
+-- !query 563 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 564
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) = cast(1 as decimal(20, 0)) FROM t
+-- !query 564 schema
+struct<>
+-- !query 564 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 565
+SELECT cast('2017-12-11 09:30:00' as date) = cast(1 as decimal(3, 0)) FROM t
+-- !query 565 schema
+struct<>
+-- !query 565 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 566
+SELECT cast('2017-12-11 09:30:00' as date) = cast(1 as decimal(5, 0)) FROM t
+-- !query 566 schema
+struct<>
+-- !query 566 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 567
+SELECT cast('2017-12-11 09:30:00' as date) = cast(1 as decimal(10, 0)) FROM t
+-- !query 567 schema
+struct<>
+-- !query 567 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 568
+SELECT cast('2017-12-11 09:30:00' as date) = cast(1 as decimal(20, 0)) FROM t
+-- !query 568 schema
+struct<>
+-- !query 568 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 569
+SELECT cast(1 as decimal(3, 0)) = cast(1 as tinyint) FROM t
+-- !query 569 schema
+struct<(CAST(1 AS DECIMAL(3,0)) = CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0))):boolean>
+-- !query 569 output
+true
+
+
+-- !query 570
+SELECT cast(1 as decimal(5, 0)) = cast(1 as tinyint) FROM t
+-- !query 570 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0)) = CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean>
+-- !query 570 output
+true
+
+
+-- !query 571
+SELECT cast(1 as decimal(10, 0)) = cast(1 as tinyint) FROM t
+-- !query 571 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean>
+-- !query 571 output
+true
+
+
+-- !query 572
+SELECT cast(1 as decimal(20, 0)) = cast(1 as tinyint) FROM t
+-- !query 572 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean>
+-- !query 572 output
+true
+
+
+-- !query 573
+SELECT cast(1 as decimal(3, 0)) = cast(1 as smallint) FROM t
+-- !query 573 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0)) = CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean>
+-- !query 573 output
+true
+
+
+-- !query 574
+SELECT cast(1 as decimal(5, 0)) = cast(1 as smallint) FROM t
+-- !query 574 schema
+struct<(CAST(1 AS DECIMAL(5,0)) = CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0))):boolean>
+-- !query 574 output
+true
+
+
+-- !query 575
+SELECT cast(1 as decimal(10, 0)) = cast(1 as smallint) FROM t
+-- !query 575 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean>
+-- !query 575 output
+true
+
+
+-- !query 576
+SELECT cast(1 as decimal(20, 0)) = cast(1 as smallint) FROM t
+-- !query 576 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean>
+-- !query 576 output
+true
+
+
+-- !query 577
+SELECT cast(1 as decimal(3, 0)) = cast(1 as int) FROM t
+-- !query 577 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) = CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 577 output
+true
+
+
+-- !query 578
+SELECT cast(1 as decimal(5, 0)) = cast(1 as int) FROM t
+-- !query 578 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) = CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 578 output
+true
+
+
+-- !query 579
+SELECT cast(1 as decimal(10, 0)) = cast(1 as int) FROM t
+-- !query 579 schema
+struct<(CAST(1 AS DECIMAL(10,0)) = CAST(CAST(1 AS INT) AS DECIMAL(10,0))):boolean>
+-- !query 579 output
+true
+
+
+-- !query 580
+SELECT cast(1 as decimal(20, 0)) = cast(1 as int) FROM t
+-- !query 580 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean>
+-- !query 580 output
+true
+
+
+-- !query 581
+SELECT cast(1 as decimal(3, 0)) = cast(1 as bigint) FROM t
+-- !query 581 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 581 output
+true
+
+
+-- !query 582
+SELECT cast(1 as decimal(5, 0)) = cast(1 as bigint) FROM t
+-- !query 582 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 582 output
+true
+
+
+-- !query 583
+SELECT cast(1 as decimal(10, 0)) = cast(1 as bigint) FROM t
+-- !query 583 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 583 output
+true
+
+
+-- !query 584
+SELECT cast(1 as decimal(20, 0)) = cast(1 as bigint) FROM t
+-- !query 584 schema
+struct<(CAST(1 AS DECIMAL(20,0)) = CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):boolean>
+-- !query 584 output
+true
+
+
+-- !query 585
+SELECT cast(1 as decimal(3, 0)) = cast(1 as float) FROM t
+-- !query 585 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) = CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 585 output
+true
+
+
+-- !query 586
+SELECT cast(1 as decimal(5, 0)) = cast(1 as float) FROM t
+-- !query 586 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) = CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 586 output
+true
+
+
+-- !query 587
+SELECT cast(1 as decimal(10, 0)) = cast(1 as float) FROM t
+-- !query 587 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) = CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 587 output
+true
+
+
+-- !query 588
+SELECT cast(1 as decimal(20, 0)) = cast(1 as float) FROM t
+-- !query 588 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) = CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 588 output
+true
+
+
+-- !query 589
+SELECT cast(1 as decimal(3, 0)) = cast(1 as double) FROM t
+-- !query 589 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) = CAST(1 AS DOUBLE)):boolean>
+-- !query 589 output
+true
+
+
+-- !query 590
+SELECT cast(1 as decimal(5, 0)) = cast(1 as double) FROM t
+-- !query 590 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) = CAST(1 AS DOUBLE)):boolean>
+-- !query 590 output
+true
+
+
+-- !query 591
+SELECT cast(1 as decimal(10, 0)) = cast(1 as double) FROM t
+-- !query 591 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) = CAST(1 AS DOUBLE)):boolean>
+-- !query 591 output
+true
+
+
+-- !query 592
+SELECT cast(1 as decimal(20, 0)) = cast(1 as double) FROM t
+-- !query 592 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) = CAST(1 AS DOUBLE)):boolean>
+-- !query 592 output
+true
+
+
+-- !query 593
+SELECT cast(1 as decimal(3, 0)) = cast(1 as decimal(10, 0)) FROM t
+-- !query 593 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 593 output
+true
+
+
+-- !query 594
+SELECT cast(1 as decimal(5, 0)) = cast(1 as decimal(10, 0)) FROM t
+-- !query 594 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 594 output
+true
+
+
+-- !query 595
+SELECT cast(1 as decimal(10, 0)) = cast(1 as decimal(10, 0)) FROM t
+-- !query 595 schema
+struct<(CAST(1 AS DECIMAL(10,0)) = CAST(1 AS DECIMAL(10,0))):boolean>
+-- !query 595 output
+true
+
+
+-- !query 596
+SELECT cast(1 as decimal(20, 0)) = cast(1 as decimal(10, 0)) FROM t
+-- !query 596 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean>
+-- !query 596 output
+true
+
+
+-- !query 597
+SELECT cast(1 as decimal(3, 0)) = cast(1 as string) FROM t
+-- !query 597 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) = CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 597 output
+true
+
+
+-- !query 598
+SELECT cast(1 as decimal(5, 0)) = cast(1 as string) FROM t
+-- !query 598 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) = CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 598 output
+true
+
+
+-- !query 599
+SELECT cast(1 as decimal(10, 0)) = cast(1 as string) FROM t
+-- !query 599 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) = CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 599 output
+true
+
+
+-- !query 600
+SELECT cast(1 as decimal(20, 0)) = cast(1 as string) FROM t
+-- !query 600 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) = CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 600 output
+true
+
+
+-- !query 601
+SELECT cast(1 as decimal(3, 0)) = cast('1' as binary) FROM t
+-- !query 601 schema
+struct<>
+-- !query 601 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) = CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) = CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7
+
+
+-- !query 602
+SELECT cast(1 as decimal(5, 0)) = cast('1' as binary) FROM t
+-- !query 602 schema
+struct<>
+-- !query 602 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) = CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) = CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7
+
+
+-- !query 603
+SELECT cast(1 as decimal(10, 0)) = cast('1' as binary) FROM t
+-- !query 603 schema
+struct<>
+-- !query 603 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) = CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) = CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7
+
+
+-- !query 604
+SELECT cast(1 as decimal(20, 0)) = cast('1' as binary) FROM t
+-- !query 604 schema
+struct<>
+-- !query 604 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) = CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) = CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7
+
+
+-- !query 605
+SELECT cast(1 as decimal(3, 0)) = cast(1 as boolean) FROM t
+-- !query 605 schema
+struct<(CAST(1 AS DECIMAL(3,0)) = CAST(CAST(1 AS BOOLEAN) AS DECIMAL(3,0))):boolean>
+-- !query 605 output
+true
+
+
+-- !query 606
+SELECT cast(1 as decimal(5, 0)) = cast(1 as boolean) FROM t
+-- !query 606 schema
+struct<(CAST(1 AS DECIMAL(5,0)) = CAST(CAST(1 AS BOOLEAN) AS DECIMAL(5,0))):boolean>
+-- !query 606 output
+true
+
+
+-- !query 607
+SELECT cast(1 as decimal(10, 0)) = cast(1 as boolean) FROM t
+-- !query 607 schema
+struct<(CAST(1 AS DECIMAL(10,0)) = CAST(CAST(1 AS BOOLEAN) AS DECIMAL(10,0))):boolean>
+-- !query 607 output
+true
+
+
+-- !query 608
+SELECT cast(1 as decimal(20, 0)) = cast(1 as boolean) FROM t
+-- !query 608 schema
+struct<(CAST(1 AS DECIMAL(20,0)) = CAST(CAST(1 AS BOOLEAN) AS DECIMAL(20,0))):boolean>
+-- !query 608 output
+true
+
+
+-- !query 609
+SELECT cast(1 as decimal(3, 0)) = cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 609 schema
+struct<>
+-- !query 609 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7
+
+
+-- !query 610
+SELECT cast(1 as decimal(5, 0)) = cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 610 schema
+struct<>
+-- !query 610 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7
+
+
+-- !query 611
+SELECT cast(1 as decimal(10, 0)) = cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 611 schema
+struct<>
+-- !query 611 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7
+
+
+-- !query 612
+SELECT cast(1 as decimal(20, 0)) = cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 612 schema
+struct<>
+-- !query 612 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7
+
+
+-- !query 613
+SELECT cast(1 as decimal(3, 0)) = cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 613 schema
+struct<>
+-- !query 613 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) = CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) = CAST('2017-12-11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7
+
+
+-- !query 614
+SELECT cast(1 as decimal(5, 0)) = cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 614 schema
+struct<>
+-- !query 614 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) = CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) = CAST('2017-12-11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7
+
+
+-- !query 615
+SELECT cast(1 as decimal(10, 0)) = cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 615 schema
+struct<>
+-- !query 615 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) = CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) = CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7
+
+
+-- !query 616
+SELECT cast(1 as decimal(20, 0)) = cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 616 schema
+struct<>
+-- !query 616 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) = CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) = CAST('2017-12-11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7
+
+
+-- !query 617
+SELECT cast(1 as tinyint) <=> cast(1 as decimal(3, 0)) FROM t
+-- !query 617 schema
+struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) <=> CAST(1 AS DECIMAL(3,0))):boolean>
+-- !query 617 output
+true
+
+
+-- !query 618
+SELECT cast(1 as tinyint) <=> cast(1 as decimal(5, 0)) FROM t
+-- !query 618 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0)) <=> CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean>
+-- !query 618 output
+true
+
+
+-- !query 619
+SELECT cast(1 as tinyint) <=> cast(1 as decimal(10, 0)) FROM t
+-- !query 619 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)) <=> CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 619 output
+true
+
+
+-- !query 620
+SELECT cast(1 as tinyint) <=> cast(1 as decimal(20, 0)) FROM t
+-- !query 620 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0)) <=> CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 620 output
+true
+
+
+-- !query 621
+SELECT cast(1 as smallint) <=> cast(1 as decimal(3, 0)) FROM t
+-- !query 621 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0)) <=> CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean>
+-- !query 621 output
+true
+
+
+-- !query 622
+SELECT cast(1 as smallint) <=> cast(1 as decimal(5, 0)) FROM t
+-- !query 622 schema
+struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) <=> CAST(1 AS DECIMAL(5,0))):boolean>
+-- !query 622 output
+true
+
+
+-- !query 623
+SELECT cast(1 as smallint) <=> cast(1 as decimal(10, 0)) FROM t
+-- !query 623 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)) <=> CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 623 output
+true
+
+
+-- !query 624
+SELECT cast(1 as smallint) <=> cast(1 as decimal(20, 0)) FROM t
+-- !query 624 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0)) <=> CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 624 output
+true
+
+
+-- !query 625
+SELECT cast(1 as int) <=> cast(1 as decimal(3, 0)) FROM t
+-- !query 625 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) <=> CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean>
+-- !query 625 output
+true
+
+
+-- !query 626
+SELECT cast(1 as int) <=> cast(1 as decimal(5, 0)) FROM t
+-- !query 626 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) <=> CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean>
+-- !query 626 output
+true
+
+
+-- !query 627
+SELECT cast(1 as int) <=> cast(1 as decimal(10, 0)) FROM t
+-- !query 627 schema
+struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) <=> CAST(1 AS DECIMAL(10,0))):boolean>
+-- !query 627 output
+true
+
+
+-- !query 628
+SELECT cast(1 as int) <=> cast(1 as decimal(20, 0)) FROM t
+-- !query 628 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0)) <=> CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 628 output
+true
+
+
+-- !query 629
+SELECT cast(1 as bigint) <=> cast(1 as decimal(3, 0)) FROM t
+-- !query 629 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) <=> CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean>
+-- !query 629 output
+true
+
+
+-- !query 630
+SELECT cast(1 as bigint) <=> cast(1 as decimal(5, 0)) FROM t
+-- !query 630 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) <=> CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean>
+-- !query 630 output
+true
+
+
+-- !query 631
+SELECT cast(1 as bigint) <=> cast(1 as decimal(10, 0)) FROM t
+-- !query 631 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) <=> CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean>
+-- !query 631 output
+true
+
+
+-- !query 632
+SELECT cast(1 as bigint) <=> cast(1 as decimal(20, 0)) FROM t
+-- !query 632 schema
+struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) <=> CAST(1 AS DECIMAL(20,0))):boolean>
+-- !query 632 output
+true
+
+
+-- !query 633
+SELECT cast(1 as float) <=> cast(1 as decimal(3, 0)) FROM t
+-- !query 633 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) <=> CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean>
+-- !query 633 output
+true
+
+
+-- !query 634
+SELECT cast(1 as float) <=> cast(1 as decimal(5, 0)) FROM t
+-- !query 634 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) <=> CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean>
+-- !query 634 output
+true
+
+
+-- !query 635
+SELECT cast(1 as float) <=> cast(1 as decimal(10, 0)) FROM t
+-- !query 635 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) <=> CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean>
+-- !query 635 output
+true
+
+
+-- !query 636
+SELECT cast(1 as float) <=> cast(1 as decimal(20, 0)) FROM t
+-- !query 636 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) <=> CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean>
+-- !query 636 output
+true
+
+
+-- !query 637
+SELECT cast(1 as double) <=> cast(1 as decimal(3, 0)) FROM t
+-- !query 637 schema
+struct<(CAST(1 AS DOUBLE) <=> CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean>
+-- !query 637 output
+true
+
+
+-- !query 638
+SELECT cast(1 as double) <=> cast(1 as decimal(5, 0)) FROM t
+-- !query 638 schema
+struct<(CAST(1 AS DOUBLE) <=> CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean>
+-- !query 638 output
+true
+
+
+-- !query 639
+SELECT cast(1 as double) <=> cast(1 as decimal(10, 0)) FROM t
+-- !query 639 schema
+struct<(CAST(1 AS DOUBLE) <=> CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean>
+-- !query 639 output
+true
+
+
+-- !query 640
+SELECT cast(1 as double) <=> cast(1 as decimal(20, 0)) FROM t
+-- !query 640 schema
+struct<(CAST(1 AS DOUBLE) <=> CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean>
+-- !query 640 output
+true
+
+
+-- !query 641
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as decimal(3, 0)) FROM t
+-- !query 641 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) <=> CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean>
+-- !query 641 output
+true
+
+
+-- !query 642
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as decimal(5, 0)) FROM t
+-- !query 642 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) <=> CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean>
+-- !query 642 output
+true
+
+
+-- !query 643
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as decimal(10, 0)) FROM t
+-- !query 643 schema
+struct<(CAST(1 AS DECIMAL(10,0)) <=> CAST(1 AS DECIMAL(10,0))):boolean>
+-- !query 643 output
+true
+
+
+-- !query 644
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as decimal(20, 0)) FROM t
+-- !query 644 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) <=> CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 644 output
+true
+
+
+-- !query 645
+SELECT cast('1' as binary) <=> cast(1 as decimal(3, 0)) FROM t
+-- !query 645 schema
+struct<>
+-- !query 645 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) <=> CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) <=> CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 646
+SELECT cast('1' as binary) <=> cast(1 as decimal(5, 0)) FROM t
+-- !query 646 schema
+struct<>
+-- !query 646 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) <=> CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) <=> CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 647
+SELECT cast('1' as binary) <=> cast(1 as decimal(10, 0)) FROM t
+-- !query 647 schema
+struct<>
+-- !query 647 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) <=> CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) <=> CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 648
+SELECT cast('1' as binary) <=> cast(1 as decimal(20, 0)) FROM t
+-- !query 648 schema
+struct<>
+-- !query 648 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) <=> CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) <=> CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 649
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> cast(1 as decimal(3, 0)) FROM t
+-- !query 649 schema
+struct<>
+-- !query 649 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 650
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> cast(1 as decimal(5, 0)) FROM t
+-- !query 650 schema
+struct<>
+-- !query 650 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 651
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> cast(1 as decimal(10, 0)) FROM t
+-- !query 651 schema
+struct<>
+-- !query 651 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 652
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> cast(1 as decimal(20, 0)) FROM t
+-- !query 652 schema
+struct<>
+-- !query 652 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 653
+SELECT cast('2017-12-11 09:30:00' as date) <=> cast(1 as decimal(3, 0)) FROM t
+-- !query 653 schema
+struct<>
+-- !query 653 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) <=> CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) <=> CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 654
+SELECT cast('2017-12-11 09:30:00' as date) <=> cast(1 as decimal(5, 0)) FROM t
+-- !query 654 schema
+struct<>
+-- !query 654 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) <=> CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) <=> CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 655
+SELECT cast('2017-12-11 09:30:00' as date) <=> cast(1 as decimal(10, 0)) FROM t
+-- !query 655 schema
+struct<>
+-- !query 655 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) <=> CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) <=> CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 656
+SELECT cast('2017-12-11 09:30:00' as date) <=> cast(1 as decimal(20, 0)) FROM t
+-- !query 656 schema
+struct<>
+-- !query 656 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) <=> CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) <=> CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 657
+SELECT cast(1 as decimal(3, 0)) <=> cast(1 as tinyint) FROM t
+-- !query 657 schema
+struct<(CAST(1 AS DECIMAL(3,0)) <=> CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0))):boolean>
+-- !query 657 output
+true
+
+
+-- !query 658
+SELECT cast(1 as decimal(5, 0)) <=> cast(1 as tinyint) FROM t
+-- !query 658 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0)) <=> CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean>
+-- !query 658 output
+true
+
+
+-- !query 659
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as tinyint) FROM t
+-- !query 659 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) <=> CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean>
+-- !query 659 output
+true
+
+
+-- !query 660
+SELECT cast(1 as decimal(20, 0)) <=> cast(1 as tinyint) FROM t
+-- !query 660 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) <=> CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean>
+-- !query 660 output
+true
+
+
+-- !query 661
+SELECT cast(1 as decimal(3, 0)) <=> cast(1 as smallint) FROM t
+-- !query 661 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0)) <=> CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean>
+-- !query 661 output
+true
+
+
+-- !query 662
+SELECT cast(1 as decimal(5, 0)) <=> cast(1 as smallint) FROM t
+-- !query 662 schema
+struct<(CAST(1 AS DECIMAL(5,0)) <=> CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0))):boolean>
+-- !query 662 output
+true
+
+
+-- !query 663
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as smallint) FROM t
+-- !query 663 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) <=> CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean>
+-- !query 663 output
+true
+
+
+-- !query 664
+SELECT cast(1 as decimal(20, 0)) <=> cast(1 as smallint) FROM t
+-- !query 664 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) <=> CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean>
+-- !query 664 output
+true
+
+
+-- !query 665
+SELECT cast(1 as decimal(3, 0)) <=> cast(1 as int) FROM t
+-- !query 665 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) <=> CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 665 output
+true
+
+
+-- !query 666
+SELECT cast(1 as decimal(5, 0)) <=> cast(1 as int) FROM t
+-- !query 666 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) <=> CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 666 output
+true
+
+
+-- !query 667
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as int) FROM t
+-- !query 667 schema
+struct<(CAST(1 AS DECIMAL(10,0)) <=> CAST(CAST(1 AS INT) AS DECIMAL(10,0))):boolean>
+-- !query 667 output
+true
+
+
+-- !query 668
+SELECT cast(1 as decimal(20, 0)) <=> cast(1 as int) FROM t
+-- !query 668 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) <=> CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean>
+-- !query 668 output
+true
+
+
+-- !query 669
+SELECT cast(1 as decimal(3, 0)) <=> cast(1 as bigint) FROM t
+-- !query 669 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0)) <=> CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 669 output
+true
+
+
+-- !query 670
+SELECT cast(1 as decimal(5, 0)) <=> cast(1 as bigint) FROM t
+-- !query 670 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0)) <=> CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 670 output
+true
+
+
+-- !query 671
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as bigint) FROM t
+-- !query 671 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) <=> CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 671 output
+true
+
+
+-- !query 672
+SELECT cast(1 as decimal(20, 0)) <=> cast(1 as bigint) FROM t
+-- !query 672 schema
+struct<(CAST(1 AS DECIMAL(20,0)) <=> CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):boolean>
+-- !query 672 output
+true
+
+
+-- !query 673
+SELECT cast(1 as decimal(3, 0)) <=> cast(1 as float) FROM t
+-- !query 673 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) <=> CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 673 output
+true
+
+
+-- !query 674
+SELECT cast(1 as decimal(5, 0)) <=> cast(1 as float) FROM t
+-- !query 674 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) <=> CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 674 output
+true
+
+
+-- !query 675
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as float) FROM t
+-- !query 675 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) <=> CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 675 output
+true
+
+
+-- !query 676
+SELECT cast(1 as decimal(20, 0)) <=> cast(1 as float) FROM t
+-- !query 676 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) <=> CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 676 output
+true
+
+
+-- !query 677
+SELECT cast(1 as decimal(3, 0)) <=> cast(1 as double) FROM t
+-- !query 677 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) <=> CAST(1 AS DOUBLE)):boolean>
+-- !query 677 output
+true
+
+
+-- !query 678
+SELECT cast(1 as decimal(5, 0)) <=> cast(1 as double) FROM t
+-- !query 678 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) <=> CAST(1 AS DOUBLE)):boolean>
+-- !query 678 output
+true
+
+
+-- !query 679
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as double) FROM t
+-- !query 679 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) <=> CAST(1 AS DOUBLE)):boolean>
+-- !query 679 output
+true
+
+
+-- !query 680
+SELECT cast(1 as decimal(20, 0)) <=> cast(1 as double) FROM t
+-- !query 680 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) <=> CAST(1 AS DOUBLE)):boolean>
+-- !query 680 output
+true
+
+
+-- !query 681
+SELECT cast(1 as decimal(3, 0)) <=> cast(1 as decimal(10, 0)) FROM t
+-- !query 681 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) <=> CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 681 output
+true
+
+
+-- !query 682
+SELECT cast(1 as decimal(5, 0)) <=> cast(1 as decimal(10, 0)) FROM t
+-- !query 682 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) <=> CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 682 output
+true
+
+
+-- !query 683
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as decimal(10, 0)) FROM t
+-- !query 683 schema
+struct<(CAST(1 AS DECIMAL(10,0)) <=> CAST(1 AS DECIMAL(10,0))):boolean>
+-- !query 683 output
+true
+
+
+-- !query 684
+SELECT cast(1 as decimal(20, 0)) <=> cast(1 as decimal(10, 0)) FROM t
+-- !query 684 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) <=> CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean>
+-- !query 684 output
+true
+
+
+-- !query 685
+SELECT cast(1 as decimal(3, 0)) <=> cast(1 as string) FROM t
+-- !query 685 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) <=> CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 685 output
+true
+
+
+-- !query 686
+SELECT cast(1 as decimal(5, 0)) <=> cast(1 as string) FROM t
+-- !query 686 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) <=> CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 686 output
+true
+
+
+-- !query 687
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as string) FROM t
+-- !query 687 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) <=> CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 687 output
+true
+
+
+-- !query 688
+SELECT cast(1 as decimal(20, 0)) <=> cast(1 as string) FROM t
+-- !query 688 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) <=> CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 688 output
+true
+
+
+-- !query 689
+SELECT cast(1 as decimal(3, 0)) <=> cast('1' as binary) FROM t
+-- !query 689 schema
+struct<>
+-- !query 689 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) <=> CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) <=> CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7
+
+
+-- !query 690
+SELECT cast(1 as decimal(5, 0)) <=> cast('1' as binary) FROM t
+-- !query 690 schema
+struct<>
+-- !query 690 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) <=> CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) <=> CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7
+
+
+-- !query 691
+SELECT cast(1 as decimal(10, 0)) <=> cast('1' as binary) FROM t
+-- !query 691 schema
+struct<>
+-- !query 691 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) <=> CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) <=> CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7
+
+
+-- !query 692
+SELECT cast(1 as decimal(20, 0)) <=> cast('1' as binary) FROM t
+-- !query 692 schema
+struct<>
+-- !query 692 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) <=> CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) <=> CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7
+
+
+-- !query 693
+SELECT cast(1 as decimal(3, 0)) <=> cast(1 as boolean) FROM t
+-- !query 693 schema
+struct<(CAST(1 AS DECIMAL(3,0)) <=> CAST(CAST(1 AS BOOLEAN) AS DECIMAL(3,0))):boolean>
+-- !query 693 output
+true
+
+
+-- !query 694
+SELECT cast(1 as decimal(5, 0)) <=> cast(1 as boolean) FROM t
+-- !query 694 schema
+struct<(CAST(1 AS DECIMAL(5,0)) <=> CAST(CAST(1 AS BOOLEAN) AS DECIMAL(5,0))):boolean>
+-- !query 694 output
+true
+
+
+-- !query 695
+SELECT cast(1 as decimal(10, 0)) <=> cast(1 as boolean) FROM t
+-- !query 695 schema
+struct<(CAST(1 AS DECIMAL(10,0)) <=> CAST(CAST(1 AS BOOLEAN) AS DECIMAL(10,0))):boolean>
+-- !query 695 output
+true
+
+
+-- !query 696
+SELECT cast(1 as decimal(20, 0)) <=> cast(1 as boolean) FROM t
+-- !query 696 schema
+struct<(CAST(1 AS DECIMAL(20,0)) <=> CAST(CAST(1 AS BOOLEAN) AS DECIMAL(20,0))):boolean>
+-- !query 696 output
+true
+
+
+-- !query 697
+SELECT cast(1 as decimal(3, 0)) <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 697 schema
+struct<>
+-- !query 697 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7
+
+
+-- !query 698
+SELECT cast(1 as decimal(5, 0)) <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 698 schema
+struct<>
+-- !query 698 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7
+
+
+-- !query 699
+SELECT cast(1 as decimal(10, 0)) <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 699 schema
+struct<>
+-- !query 699 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7
+
+
+-- !query 700
+SELECT cast(1 as decimal(20, 0)) <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 700 schema
+struct<>
+-- !query 700 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7
+
+
+-- !query 701
+SELECT cast(1 as decimal(3, 0)) <=> cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 701 schema
+struct<>
+-- !query 701 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) <=> CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) <=> CAST('2017-12-11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7
+
+
+-- !query 702
+SELECT cast(1 as decimal(5, 0)) <=> cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 702 schema
+struct<>
+-- !query 702 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) <=> CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) <=> CAST('2017-12-11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7
+
+
+-- !query 703
+SELECT cast(1 as decimal(10, 0)) <=> cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 703 schema
+struct<>
+-- !query 703 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) <=> CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) <=> CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7
+
+
+-- !query 704
+SELECT cast(1 as decimal(20, 0)) <=> cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 704 schema
+struct<>
+-- !query 704 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) <=> CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) <=> CAST('2017-12-11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7
+
+
+-- !query 705
+SELECT cast(1 as tinyint) < cast(1 as decimal(3, 0)) FROM t
+-- !query 705 schema
+struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) < CAST(1 AS DECIMAL(3,0))):boolean>
+-- !query 705 output
+false
+
+
+-- !query 706
+SELECT cast(1 as tinyint) < cast(1 as decimal(5, 0)) FROM t
+-- !query 706 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0)) < CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean>
+-- !query 706 output
+false
+
+
+-- !query 707
+SELECT cast(1 as tinyint) < cast(1 as decimal(10, 0)) FROM t
+-- !query 707 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)) < CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 707 output
+false
+
+
+-- !query 708
+SELECT cast(1 as tinyint) < cast(1 as decimal(20, 0)) FROM t
+-- !query 708 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0)) < CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 708 output
+false
+
+
+-- !query 709
+SELECT cast(1 as smallint) < cast(1 as decimal(3, 0)) FROM t
+-- !query 709 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0)) < CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean>
+-- !query 709 output
+false
+
+
+-- !query 710
+SELECT cast(1 as smallint) < cast(1 as decimal(5, 0)) FROM t
+-- !query 710 schema
+struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) < CAST(1 AS DECIMAL(5,0))):boolean>
+-- !query 710 output
+false
+
+
+-- !query 711
+SELECT cast(1 as smallint) < cast(1 as decimal(10, 0)) FROM t
+-- !query 711 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)) < CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 711 output
+false
+
+
+-- !query 712
+SELECT cast(1 as smallint) < cast(1 as decimal(20, 0)) FROM t
+-- !query 712 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0)) < CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 712 output
+false
+
+
+-- !query 713
+SELECT cast(1 as int) < cast(1 as decimal(3, 0)) FROM t
+-- !query 713 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) < CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean>
+-- !query 713 output
+false
+
+
+-- !query 714
+SELECT cast(1 as int) < cast(1 as decimal(5, 0)) FROM t
+-- !query 714 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) < CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean>
+-- !query 714 output
+false
+
+
+-- !query 715
+SELECT cast(1 as int) < cast(1 as decimal(10, 0)) FROM t
+-- !query 715 schema
+struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) < CAST(1 AS DECIMAL(10,0))):boolean>
+-- !query 715 output
+false
+
+
+-- !query 716
+SELECT cast(1 as int) < cast(1 as decimal(20, 0)) FROM t
+-- !query 716 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0)) < CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 716 output
+false
+
+
+-- !query 717
+SELECT cast(1 as bigint) < cast(1 as decimal(3, 0)) FROM t
+-- !query 717 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) < CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean>
+-- !query 717 output
+false
+
+
+-- !query 718
+SELECT cast(1 as bigint) < cast(1 as decimal(5, 0)) FROM t
+-- !query 718 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) < CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean>
+-- !query 718 output
+false
+
+
+-- !query 719
+SELECT cast(1 as bigint) < cast(1 as decimal(10, 0)) FROM t
+-- !query 719 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) < CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean>
+-- !query 719 output
+false
+
+
+-- !query 720
+SELECT cast(1 as bigint) < cast(1 as decimal(20, 0)) FROM t
+-- !query 720 schema
+struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) < CAST(1 AS DECIMAL(20,0))):boolean>
+-- !query 720 output
+false
+
+
+-- !query 721
+SELECT cast(1 as float) < cast(1 as decimal(3, 0)) FROM t
+-- !query 721 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) < CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean>
+-- !query 721 output
+false
+
+
+-- !query 722
+SELECT cast(1 as float) < cast(1 as decimal(5, 0)) FROM t
+-- !query 722 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) < CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean>
+-- !query 722 output
+false
+
+
+-- !query 723
+SELECT cast(1 as float) < cast(1 as decimal(10, 0)) FROM t
+-- !query 723 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) < CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean>
+-- !query 723 output
+false
+
+
+-- !query 724
+SELECT cast(1 as float) < cast(1 as decimal(20, 0)) FROM t
+-- !query 724 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) < CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean>
+-- !query 724 output
+false
+
+
+-- !query 725
+SELECT cast(1 as double) < cast(1 as decimal(3, 0)) FROM t
+-- !query 725 schema
+struct<(CAST(1 AS DOUBLE) < CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean>
+-- !query 725 output
+false
+
+
+-- !query 726
+SELECT cast(1 as double) < cast(1 as decimal(5, 0)) FROM t
+-- !query 726 schema
+struct<(CAST(1 AS DOUBLE) < CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean>
+-- !query 726 output
+false
+
+
+-- !query 727
+SELECT cast(1 as double) < cast(1 as decimal(10, 0)) FROM t
+-- !query 727 schema
+struct<(CAST(1 AS DOUBLE) < CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean>
+-- !query 727 output
+false
+
+
+-- !query 728
+SELECT cast(1 as double) < cast(1 as decimal(20, 0)) FROM t
+-- !query 728 schema
+struct<(CAST(1 AS DOUBLE) < CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean>
+-- !query 728 output
+false
+
+
+-- !query 729
+SELECT cast(1 as decimal(10, 0)) < cast(1 as decimal(3, 0)) FROM t
+-- !query 729 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) < CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean>
+-- !query 729 output
+false
+
+
+-- !query 730
+SELECT cast(1 as decimal(10, 0)) < cast(1 as decimal(5, 0)) FROM t
+-- !query 730 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) < CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean>
+-- !query 730 output
+false
+
+
+-- !query 731
+SELECT cast(1 as decimal(10, 0)) < cast(1 as decimal(10, 0)) FROM t
+-- !query 731 schema
+struct<(CAST(1 AS DECIMAL(10,0)) < CAST(1 AS DECIMAL(10,0))):boolean>
+-- !query 731 output
+false
+
+
+-- !query 732
+SELECT cast(1 as decimal(10, 0)) < cast(1 as decimal(20, 0)) FROM t
+-- !query 732 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) < CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 732 output
+false
+
+
+-- !query 733
+SELECT cast('1' as binary) < cast(1 as decimal(3, 0)) FROM t
+-- !query 733 schema
+struct<>
+-- !query 733 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) < CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) < CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 734
+SELECT cast('1' as binary) < cast(1 as decimal(5, 0)) FROM t
+-- !query 734 schema
+struct<>
+-- !query 734 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) < CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) < CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 735
+SELECT cast('1' as binary) < cast(1 as decimal(10, 0)) FROM t
+-- !query 735 schema
+struct<>
+-- !query 735 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) < CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) < CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 736
+SELECT cast('1' as binary) < cast(1 as decimal(20, 0)) FROM t
+-- !query 736 schema
+struct<>
+-- !query 736 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) < CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) < CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 737
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) < cast(1 as decimal(3, 0)) FROM t
+-- !query 737 schema
+struct<>
+-- !query 737 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) < CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) < CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 738
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) < cast(1 as decimal(5, 0)) FROM t
+-- !query 738 schema
+struct<>
+-- !query 738 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) < CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) < CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 739
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) < cast(1 as decimal(10, 0)) FROM t
+-- !query 739 schema
+struct<>
+-- !query 739 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) < CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) < CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 740
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) < cast(1 as decimal(20, 0)) FROM t
+-- !query 740 schema
+struct<>
+-- !query 740 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) < CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) < CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 741
+SELECT cast('2017-12-11 09:30:00' as date) < cast(1 as decimal(3, 0)) FROM t
+-- !query 741 schema
+struct<>
+-- !query 741 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) < CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) < CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 742
+SELECT cast('2017-12-11 09:30:00' as date) < cast(1 as decimal(5, 0)) FROM t
+-- !query 742 schema
+struct<>
+-- !query 742 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) < CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) < CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 743
+SELECT cast('2017-12-11 09:30:00' as date) < cast(1 as decimal(10, 0)) FROM t
+-- !query 743 schema
+struct<>
+-- !query 743 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) < CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) < CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 744
+SELECT cast('2017-12-11 09:30:00' as date) < cast(1 as decimal(20, 0)) FROM t
+-- !query 744 schema
+struct<>
+-- !query 744 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) < CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) < CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 745
+SELECT cast(1 as decimal(3, 0)) < cast(1 as tinyint) FROM t
+-- !query 745 schema
+struct<(CAST(1 AS DECIMAL(3,0)) < CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0))):boolean>
+-- !query 745 output
+false
+
+
+-- !query 746
+SELECT cast(1 as decimal(5, 0)) < cast(1 as tinyint) FROM t
+-- !query 746 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0)) < CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean>
+-- !query 746 output
+false
+
+
+-- !query 747
+SELECT cast(1 as decimal(10, 0)) < cast(1 as tinyint) FROM t
+-- !query 747 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) < CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean>
+-- !query 747 output
+false
+
+
+-- !query 748
+SELECT cast(1 as decimal(20, 0)) < cast(1 as tinyint) FROM t
+-- !query 748 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) < CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean>
+-- !query 748 output
+false
+
+
+-- !query 749
+SELECT cast(1 as decimal(3, 0)) < cast(1 as smallint) FROM t
+-- !query 749 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0)) < CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean>
+-- !query 749 output
+false
+
+
+-- !query 750
+SELECT cast(1 as decimal(5, 0)) < cast(1 as smallint) FROM t
+-- !query 750 schema
+struct<(CAST(1 AS DECIMAL(5,0)) < CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0))):boolean>
+-- !query 750 output
+false
+
+
+-- !query 751
+SELECT cast(1 as decimal(10, 0)) < cast(1 as smallint) FROM t
+-- !query 751 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) < CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean>
+-- !query 751 output
+false
+
+
+-- !query 752
+SELECT cast(1 as decimal(20, 0)) < cast(1 as smallint) FROM t
+-- !query 752 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) < CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean>
+-- !query 752 output
+false
+
+
+-- !query 753
+SELECT cast(1 as decimal(3, 0)) < cast(1 as int) FROM t
+-- !query 753 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) < CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 753 output
+false
+
+
+-- !query 754
+SELECT cast(1 as decimal(5, 0)) < cast(1 as int) FROM t
+-- !query 754 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) < CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 754 output
+false
+
+
+-- !query 755
+SELECT cast(1 as decimal(10, 0)) < cast(1 as int) FROM t
+-- !query 755 schema
+struct<(CAST(1 AS DECIMAL(10,0)) < CAST(CAST(1 AS INT) AS DECIMAL(10,0))):boolean>
+-- !query 755 output
+false
+
+
+-- !query 756
+SELECT cast(1 as decimal(20, 0)) < cast(1 as int) FROM t
+-- !query 756 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) < CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean>
+-- !query 756 output
+false
+
+
+-- !query 757
+SELECT cast(1 as decimal(3, 0)) < cast(1 as bigint) FROM t
+-- !query 757 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0)) < CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 757 output
+false
+
+
+-- !query 758
+SELECT cast(1 as decimal(5, 0)) < cast(1 as bigint) FROM t
+-- !query 758 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0)) < CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 758 output
+false
+
+
+-- !query 759
+SELECT cast(1 as decimal(10, 0)) < cast(1 as bigint) FROM t
+-- !query 759 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) < CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 759 output
+false
+
+
+-- !query 760
+SELECT cast(1 as decimal(20, 0)) < cast(1 as bigint) FROM t
+-- !query 760 schema
+struct<(CAST(1 AS DECIMAL(20,0)) < CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):boolean>
+-- !query 760 output
+false
+
+
+-- !query 761
+SELECT cast(1 as decimal(3, 0)) < cast(1 as float) FROM t
+-- !query 761 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) < CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 761 output
+false
+
+
+-- !query 762
+SELECT cast(1 as decimal(5, 0)) < cast(1 as float) FROM t
+-- !query 762 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) < CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 762 output
+false
+
+
+-- !query 763
+SELECT cast(1 as decimal(10, 0)) < cast(1 as float) FROM t
+-- !query 763 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) < CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 763 output
+false
+
+
+-- !query 764
+SELECT cast(1 as decimal(20, 0)) < cast(1 as float) FROM t
+-- !query 764 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) < CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 764 output
+false
+
+
+-- !query 765
+SELECT cast(1 as decimal(3, 0)) < cast(1 as double) FROM t
+-- !query 765 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) < CAST(1 AS DOUBLE)):boolean>
+-- !query 765 output
+false
+
+
+-- !query 766
+SELECT cast(1 as decimal(5, 0)) < cast(1 as double) FROM t
+-- !query 766 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) < CAST(1 AS DOUBLE)):boolean>
+-- !query 766 output
+false
+
+
+-- !query 767
+SELECT cast(1 as decimal(10, 0)) < cast(1 as double) FROM t
+-- !query 767 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) < CAST(1 AS DOUBLE)):boolean>
+-- !query 767 output
+false
+
+
+-- !query 768
+SELECT cast(1 as decimal(20, 0)) < cast(1 as double) FROM t
+-- !query 768 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) < CAST(1 AS DOUBLE)):boolean>
+-- !query 768 output
+false
+
+
+-- !query 769
+SELECT cast(1 as decimal(3, 0)) < cast(1 as decimal(10, 0)) FROM t
+-- !query 769 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) < CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 769 output
+false
+
+
+-- !query 770
+SELECT cast(1 as decimal(5, 0)) < cast(1 as decimal(10, 0)) FROM t
+-- !query 770 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) < CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 770 output
+false
+
+
+-- !query 771
+SELECT cast(1 as decimal(10, 0)) < cast(1 as decimal(10, 0)) FROM t
+-- !query 771 schema
+struct<(CAST(1 AS DECIMAL(10,0)) < CAST(1 AS DECIMAL(10,0))):boolean>
+-- !query 771 output
+false
+
+
+-- !query 772
+SELECT cast(1 as decimal(20, 0)) < cast(1 as decimal(10, 0)) FROM t
+-- !query 772 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) < CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean>
+-- !query 772 output
+false
+
+
+-- !query 773
+SELECT cast(1 as decimal(3, 0)) < cast(1 as string) FROM t
+-- !query 773 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) < CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 773 output
+false
+
+
+-- !query 774
+SELECT cast(1 as decimal(5, 0)) < cast(1 as string) FROM t
+-- !query 774 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) < CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 774 output
+false
+
+
+-- !query 775
+SELECT cast(1 as decimal(10, 0)) < cast(1 as string) FROM t
+-- !query 775 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) < CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 775 output
+false
+
+
+-- !query 776
+SELECT cast(1 as decimal(20, 0)) < cast(1 as string) FROM t
+-- !query 776 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) < CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 776 output
+false
+
+
+-- !query 777
+SELECT cast(1 as decimal(3, 0)) < cast('1' as binary) FROM t
+-- !query 777 schema
+struct<>
+-- !query 777 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) < CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) < CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7
+
+
+-- !query 778
+SELECT cast(1 as decimal(5, 0)) < cast('1' as binary) FROM t
+-- !query 778 schema
+struct<>
+-- !query 778 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) < CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) < CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7
+
+
+-- !query 779
+SELECT cast(1 as decimal(10, 0)) < cast('1' as binary) FROM t
+-- !query 779 schema
+struct<>
+-- !query 779 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) < CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) < CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7
+
+
+-- !query 780
+SELECT cast(1 as decimal(20, 0)) < cast('1' as binary) FROM t
+-- !query 780 schema
+struct<>
+-- !query 780 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) < CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) < CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7
+
+
+-- !query 781
+SELECT cast(1 as decimal(3, 0)) < cast(1 as boolean) FROM t
+-- !query 781 schema
+struct<>
+-- !query 781 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) < CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) < CAST(1 AS BOOLEAN))' (decimal(3,0) and boolean).; line 1 pos 7
+
+
+-- !query 782
+SELECT cast(1 as decimal(5, 0)) < cast(1 as boolean) FROM t
+-- !query 782 schema
+struct<>
+-- !query 782 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) < CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) < CAST(1 AS BOOLEAN))' (decimal(5,0) and boolean).; line 1 pos 7
+
+
+-- !query 783
+SELECT cast(1 as decimal(10, 0)) < cast(1 as boolean) FROM t
+-- !query 783 schema
+struct<>
+-- !query 783 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) < CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) < CAST(1 AS BOOLEAN))' (decimal(10,0) and boolean).; line 1 pos 7
+
+
+-- !query 784
+SELECT cast(1 as decimal(20, 0)) < cast(1 as boolean) FROM t
+-- !query 784 schema
+struct<>
+-- !query 784 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) < CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) < CAST(1 AS BOOLEAN))' (decimal(20,0) and boolean).; line 1 pos 7
+
+
+-- !query 785
+SELECT cast(1 as decimal(3, 0)) < cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 785 schema
+struct<>
+-- !query 785 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) < CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) < CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7
+
+
+-- !query 786
+SELECT cast(1 as decimal(5, 0)) < cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 786 schema
+struct<>
+-- !query 786 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) < CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) < CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7
+
+
+-- !query 787
+SELECT cast(1 as decimal(10, 0)) < cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 787 schema
+struct<>
+-- !query 787 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) < CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) < CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7
+
+
+-- !query 788
+SELECT cast(1 as decimal(20, 0)) < cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 788 schema
+struct<>
+-- !query 788 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) < CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) < CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7
+
+
+-- !query 789
+SELECT cast(1 as decimal(3, 0)) < cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 789 schema
+struct<>
+-- !query 789 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) < CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) < CAST('2017-12-11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7
+
+
+-- !query 790
+SELECT cast(1 as decimal(5, 0)) < cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 790 schema
+struct<>
+-- !query 790 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) < CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) < CAST('2017-12-11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7
+
+
+-- !query 791
+SELECT cast(1 as decimal(10, 0)) < cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 791 schema
+struct<>
+-- !query 791 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) < CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) < CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7
+
+
+-- !query 792
+SELECT cast(1 as decimal(20, 0)) < cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 792 schema
+struct<>
+-- !query 792 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) < CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) < CAST('2017-12-11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7
+
+
+-- !query 793
+SELECT cast(1 as tinyint) <= cast(1 as decimal(3, 0)) FROM t
+-- !query 793 schema
+struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) <= CAST(1 AS DECIMAL(3,0))):boolean>
+-- !query 793 output
+true
+
+
+-- !query 794
+SELECT cast(1 as tinyint) <= cast(1 as decimal(5, 0)) FROM t
+-- !query 794 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0)) <= CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean>
+-- !query 794 output
+true
+
+
+-- !query 795
+SELECT cast(1 as tinyint) <= cast(1 as decimal(10, 0)) FROM t
+-- !query 795 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)) <= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 795 output
+true
+
+
+-- !query 796
+SELECT cast(1 as tinyint) <= cast(1 as decimal(20, 0)) FROM t
+-- !query 796 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0)) <= CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 796 output
+true
+
+
+-- !query 797
+SELECT cast(1 as smallint) <= cast(1 as decimal(3, 0)) FROM t
+-- !query 797 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0)) <= CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean>
+-- !query 797 output
+true
+
+
+-- !query 798
+SELECT cast(1 as smallint) <= cast(1 as decimal(5, 0)) FROM t
+-- !query 798 schema
+struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) <= CAST(1 AS DECIMAL(5,0))):boolean>
+-- !query 798 output
+true
+
+
+-- !query 799
+SELECT cast(1 as smallint) <= cast(1 as decimal(10, 0)) FROM t
+-- !query 799 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)) <= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 799 output
+true
+
+
+-- !query 800
+SELECT cast(1 as smallint) <= cast(1 as decimal(20, 0)) FROM t
+-- !query 800 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0)) <= CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 800 output
+true
+
+
+-- !query 801
+SELECT cast(1 as int) <= cast(1 as decimal(3, 0)) FROM t
+-- !query 801 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) <= CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean>
+-- !query 801 output
+true
+
+
+-- !query 802
+SELECT cast(1 as int) <= cast(1 as decimal(5, 0)) FROM t
+-- !query 802 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) <= CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean>
+-- !query 802 output
+true
+
+
+-- !query 803
+SELECT cast(1 as int) <= cast(1 as decimal(10, 0)) FROM t
+-- !query 803 schema
+struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) <= CAST(1 AS DECIMAL(10,0))):boolean>
+-- !query 803 output
+true
+
+
+-- !query 804
+SELECT cast(1 as int) <= cast(1 as decimal(20, 0)) FROM t
+-- !query 804 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0)) <= CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 804 output
+true
+
+
+-- !query 805
+SELECT cast(1 as bigint) <= cast(1 as decimal(3, 0)) FROM t
+-- !query 805 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) <= CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean>
+-- !query 805 output
+true
+
+
+-- !query 806
+SELECT cast(1 as bigint) <= cast(1 as decimal(5, 0)) FROM t
+-- !query 806 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) <= CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean>
+-- !query 806 output
+true
+
+
+-- !query 807
+SELECT cast(1 as bigint) <= cast(1 as decimal(10, 0)) FROM t
+-- !query 807 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) <= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean>
+-- !query 807 output
+true
+
+
+-- !query 808
+SELECT cast(1 as bigint) <= cast(1 as decimal(20, 0)) FROM t
+-- !query 808 schema
+struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) <= CAST(1 AS DECIMAL(20,0))):boolean>
+-- !query 808 output
+true
+
+
+-- !query 809
+SELECT cast(1 as float) <= cast(1 as decimal(3, 0)) FROM t
+-- !query 809 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) <= CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean>
+-- !query 809 output
+true
+
+
+-- !query 810
+SELECT cast(1 as float) <= cast(1 as decimal(5, 0)) FROM t
+-- !query 810 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) <= CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean>
+-- !query 810 output
+true
+
+
+-- !query 811
+SELECT cast(1 as float) <= cast(1 as decimal(10, 0)) FROM t
+-- !query 811 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) <= CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean>
+-- !query 811 output
+true
+
+
+-- !query 812
+SELECT cast(1 as float) <= cast(1 as decimal(20, 0)) FROM t
+-- !query 812 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) <= CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean>
+-- !query 812 output
+true
+
+
+-- !query 813
+SELECT cast(1 as double) <= cast(1 as decimal(3, 0)) FROM t
+-- !query 813 schema
+struct<(CAST(1 AS DOUBLE) <= CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean>
+-- !query 813 output
+true
+
+
+-- !query 814
+SELECT cast(1 as double) <= cast(1 as decimal(5, 0)) FROM t
+-- !query 814 schema
+struct<(CAST(1 AS DOUBLE) <= CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean>
+-- !query 814 output
+true
+
+
+-- !query 815
+SELECT cast(1 as double) <= cast(1 as decimal(10, 0)) FROM t
+-- !query 815 schema
+struct<(CAST(1 AS DOUBLE) <= CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean>
+-- !query 815 output
+true
+
+
+-- !query 816
+SELECT cast(1 as double) <= cast(1 as decimal(20, 0)) FROM t
+-- !query 816 schema
+struct<(CAST(1 AS DOUBLE) <= CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean>
+-- !query 816 output
+true
+
+
+-- !query 817
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as decimal(3, 0)) FROM t
+-- !query 817 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) <= CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean>
+-- !query 817 output
+true
+
+
+-- !query 818
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as decimal(5, 0)) FROM t
+-- !query 818 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) <= CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean>
+-- !query 818 output
+true
+
+
+-- !query 819
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as decimal(10, 0)) FROM t
+-- !query 819 schema
+struct<(CAST(1 AS DECIMAL(10,0)) <= CAST(1 AS DECIMAL(10,0))):boolean>
+-- !query 819 output
+true
+
+
+-- !query 820
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as decimal(20, 0)) FROM t
+-- !query 820 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) <= CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 820 output
+true
+
+
+-- !query 821
+SELECT cast('1' as binary) <= cast(1 as decimal(3, 0)) FROM t
+-- !query 821 schema
+struct<>
+-- !query 821 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) <= CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) <= CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 822
+SELECT cast('1' as binary) <= cast(1 as decimal(5, 0)) FROM t
+-- !query 822 schema
+struct<>
+-- !query 822 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) <= CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) <= CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 823
+SELECT cast('1' as binary) <= cast(1 as decimal(10, 0)) FROM t
+-- !query 823 schema
+struct<>
+-- !query 823 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) <= CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) <= CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 824
+SELECT cast('1' as binary) <= cast(1 as decimal(20, 0)) FROM t
+-- !query 824 schema
+struct<>
+-- !query 824 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) <= CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) <= CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 825
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <= cast(1 as decimal(3, 0)) FROM t
+-- !query 825 schema
+struct<>
+-- !query 825 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <= CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <= CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 826
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <= cast(1 as decimal(5, 0)) FROM t
+-- !query 826 schema
+struct<>
+-- !query 826 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <= CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <= CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 827
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <= cast(1 as decimal(10, 0)) FROM t
+-- !query 827 schema
+struct<>
+-- !query 827 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <= CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <= CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 828
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <= cast(1 as decimal(20, 0)) FROM t
+-- !query 828 schema
+struct<>
+-- !query 828 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <= CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <= CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 829
+SELECT cast('2017-12-11 09:30:00' as date) <= cast(1 as decimal(3, 0)) FROM t
+-- !query 829 schema
+struct<>
+-- !query 829 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) <= CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) <= CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 830
+SELECT cast('2017-12-11 09:30:00' as date) <= cast(1 as decimal(5, 0)) FROM t
+-- !query 830 schema
+struct<>
+-- !query 830 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) <= CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) <= CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 831
+SELECT cast('2017-12-11 09:30:00' as date) <= cast(1 as decimal(10, 0)) FROM t
+-- !query 831 schema
+struct<>
+-- !query 831 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) <= CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) <= CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 832
+SELECT cast('2017-12-11 09:30:00' as date) <= cast(1 as decimal(20, 0)) FROM t
+-- !query 832 schema
+struct<>
+-- !query 832 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) <= CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) <= CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 833
+SELECT cast(1 as decimal(3, 0)) <= cast(1 as tinyint) FROM t
+-- !query 833 schema
+struct<(CAST(1 AS DECIMAL(3,0)) <= CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0))):boolean>
+-- !query 833 output
+true
+
+
+-- !query 834
+SELECT cast(1 as decimal(5, 0)) <= cast(1 as tinyint) FROM t
+-- !query 834 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0)) <= CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean>
+-- !query 834 output
+true
+
+
+-- !query 835
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as tinyint) FROM t
+-- !query 835 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) <= CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean>
+-- !query 835 output
+true
+
+
+-- !query 836
+SELECT cast(1 as decimal(20, 0)) <= cast(1 as tinyint) FROM t
+-- !query 836 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) <= CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean>
+-- !query 836 output
+true
+
+
+-- !query 837
+SELECT cast(1 as decimal(3, 0)) <= cast(1 as smallint) FROM t
+-- !query 837 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0)) <= CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean>
+-- !query 837 output
+true
+
+
+-- !query 838
+SELECT cast(1 as decimal(5, 0)) <= cast(1 as smallint) FROM t
+-- !query 838 schema
+struct<(CAST(1 AS DECIMAL(5,0)) <= CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0))):boolean>
+-- !query 838 output
+true
+
+
+-- !query 839
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as smallint) FROM t
+-- !query 839 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) <= CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean>
+-- !query 839 output
+true
+
+
+-- !query 840
+SELECT cast(1 as decimal(20, 0)) <= cast(1 as smallint) FROM t
+-- !query 840 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) <= CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean>
+-- !query 840 output
+true
+
+
+-- !query 841
+SELECT cast(1 as decimal(3, 0)) <= cast(1 as int) FROM t
+-- !query 841 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) <= CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 841 output
+true
+
+
+-- !query 842
+SELECT cast(1 as decimal(5, 0)) <= cast(1 as int) FROM t
+-- !query 842 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) <= CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 842 output
+true
+
+
+-- !query 843
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as int) FROM t
+-- !query 843 schema
+struct<(CAST(1 AS DECIMAL(10,0)) <= CAST(CAST(1 AS INT) AS DECIMAL(10,0))):boolean>
+-- !query 843 output
+true
+
+
+-- !query 844
+SELECT cast(1 as decimal(20, 0)) <= cast(1 as int) FROM t
+-- !query 844 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) <= CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean>
+-- !query 844 output
+true
+
+
+-- !query 845
+SELECT cast(1 as decimal(3, 0)) <= cast(1 as bigint) FROM t
+-- !query 845 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0)) <= CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 845 output
+true
+
+
+-- !query 846
+SELECT cast(1 as decimal(5, 0)) <= cast(1 as bigint) FROM t
+-- !query 846 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0)) <= CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 846 output
+true
+
+
+-- !query 847
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as bigint) FROM t
+-- !query 847 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) <= CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 847 output
+true
+
+
+-- !query 848
+SELECT cast(1 as decimal(20, 0)) <= cast(1 as bigint) FROM t
+-- !query 848 schema
+struct<(CAST(1 AS DECIMAL(20,0)) <= CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):boolean>
+-- !query 848 output
+true
+
+
+-- !query 849
+SELECT cast(1 as decimal(3, 0)) <= cast(1 as float) FROM t
+-- !query 849 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) <= CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 849 output
+true
+
+
+-- !query 850
+SELECT cast(1 as decimal(5, 0)) <= cast(1 as float) FROM t
+-- !query 850 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) <= CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 850 output
+true
+
+
+-- !query 851
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as float) FROM t
+-- !query 851 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) <= CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 851 output
+true
+
+
+-- !query 852
+SELECT cast(1 as decimal(20, 0)) <= cast(1 as float) FROM t
+-- !query 852 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) <= CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 852 output
+true
+
+
+-- !query 853
+SELECT cast(1 as decimal(3, 0)) <= cast(1 as double) FROM t
+-- !query 853 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) <= CAST(1 AS DOUBLE)):boolean>
+-- !query 853 output
+true
+
+
+-- !query 854
+SELECT cast(1 as decimal(5, 0)) <= cast(1 as double) FROM t
+-- !query 854 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) <= CAST(1 AS DOUBLE)):boolean>
+-- !query 854 output
+true
+
+
+-- !query 855
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as double) FROM t
+-- !query 855 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) <= CAST(1 AS DOUBLE)):boolean>
+-- !query 855 output
+true
+
+
+-- !query 856
+SELECT cast(1 as decimal(20, 0)) <= cast(1 as double) FROM t
+-- !query 856 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) <= CAST(1 AS DOUBLE)):boolean>
+-- !query 856 output
+true
+
+
+-- !query 857
+SELECT cast(1 as decimal(3, 0)) <= cast(1 as decimal(10, 0)) FROM t
+-- !query 857 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) <= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 857 output
+true
+
+
+-- !query 858
+SELECT cast(1 as decimal(5, 0)) <= cast(1 as decimal(10, 0)) FROM t
+-- !query 858 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) <= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 858 output
+true
+
+
+-- !query 859
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as decimal(10, 0)) FROM t
+-- !query 859 schema
+struct<(CAST(1 AS DECIMAL(10,0)) <= CAST(1 AS DECIMAL(10,0))):boolean>
+-- !query 859 output
+true
+
+
+-- !query 860
+SELECT cast(1 as decimal(20, 0)) <= cast(1 as decimal(10, 0)) FROM t
+-- !query 860 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) <= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean>
+-- !query 860 output
+true
+
+
+-- !query 861
+SELECT cast(1 as decimal(3, 0)) <= cast(1 as string) FROM t
+-- !query 861 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) <= CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 861 output
+true
+
+
+-- !query 862
+SELECT cast(1 as decimal(5, 0)) <= cast(1 as string) FROM t
+-- !query 862 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) <= CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 862 output
+true
+
+
+-- !query 863
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as string) FROM t
+-- !query 863 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) <= CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 863 output
+true
+
+
+-- !query 864
+SELECT cast(1 as decimal(20, 0)) <= cast(1 as string) FROM t
+-- !query 864 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) <= CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 864 output
+true
+
+
+-- !query 865
+SELECT cast(1 as decimal(3, 0)) <= cast('1' as binary) FROM t
+-- !query 865 schema
+struct<>
+-- !query 865 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) <= CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) <= CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7
+
+
+-- !query 866
+SELECT cast(1 as decimal(5, 0)) <= cast('1' as binary) FROM t
+-- !query 866 schema
+struct<>
+-- !query 866 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) <= CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) <= CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7
+
+
+-- !query 867
+SELECT cast(1 as decimal(10, 0)) <= cast('1' as binary) FROM t
+-- !query 867 schema
+struct<>
+-- !query 867 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) <= CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) <= CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7
+
+
+-- !query 868
+SELECT cast(1 as decimal(20, 0)) <= cast('1' as binary) FROM t
+-- !query 868 schema
+struct<>
+-- !query 868 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) <= CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) <= CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7
+
+
+-- !query 869
+SELECT cast(1 as decimal(3, 0)) <= cast(1 as boolean) FROM t
+-- !query 869 schema
+struct<>
+-- !query 869 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) <= CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) <= CAST(1 AS BOOLEAN))' (decimal(3,0) and boolean).; line 1 pos 7
+
+
+-- !query 870
+SELECT cast(1 as decimal(5, 0)) <= cast(1 as boolean) FROM t
+-- !query 870 schema
+struct<>
+-- !query 870 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) <= CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) <= CAST(1 AS BOOLEAN))' (decimal(5,0) and boolean).; line 1 pos 7
+
+
+-- !query 871
+SELECT cast(1 as decimal(10, 0)) <= cast(1 as boolean) FROM t
+-- !query 871 schema
+struct<>
+-- !query 871 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) <= CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) <= CAST(1 AS BOOLEAN))' (decimal(10,0) and boolean).; line 1 pos 7
+
+
+-- !query 872
+SELECT cast(1 as decimal(20, 0)) <= cast(1 as boolean) FROM t
+-- !query 872 schema
+struct<>
+-- !query 872 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) <= CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) <= CAST(1 AS BOOLEAN))' (decimal(20,0) and boolean).; line 1 pos 7
+
+
+-- !query 873
+SELECT cast(1 as decimal(3, 0)) <= cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 873 schema
+struct<>
+-- !query 873 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) <= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) <= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7
+
+
+-- !query 874
+SELECT cast(1 as decimal(5, 0)) <= cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 874 schema
+struct<>
+-- !query 874 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) <= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) <= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7
+
+
+-- !query 875
+SELECT cast(1 as decimal(10, 0)) <= cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 875 schema
+struct<>
+-- !query 875 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) <= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) <= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7
+
+
+-- !query 876
+SELECT cast(1 as decimal(20, 0)) <= cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 876 schema
+struct<>
+-- !query 876 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) <= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) <= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7
+
+
+-- !query 877
+SELECT cast(1 as decimal(3, 0)) <= cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 877 schema
+struct<>
+-- !query 877 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) <= CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) <= CAST('2017-12-11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7
+
+
+-- !query 878
+SELECT cast(1 as decimal(5, 0)) <= cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 878 schema
+struct<>
+-- !query 878 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) <= CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) <= CAST('2017-12-11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7
+
+
+-- !query 879
+SELECT cast(1 as decimal(10, 0)) <= cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 879 schema
+struct<>
+-- !query 879 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) <= CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) <= CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7
+
+
+-- !query 880
+SELECT cast(1 as decimal(20, 0)) <= cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 880 schema
+struct<>
+-- !query 880 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) <= CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) <= CAST('2017-12-11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7
+
+
+-- !query 881
+SELECT cast(1 as tinyint) > cast(1 as decimal(3, 0)) FROM t
+-- !query 881 schema
+struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) > CAST(1 AS DECIMAL(3,0))):boolean>
+-- !query 881 output
+false
+
+
+-- !query 882
+SELECT cast(1 as tinyint) > cast(1 as decimal(5, 0)) FROM t
+-- !query 882 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0)) > CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean>
+-- !query 882 output
+false
+
+
+-- !query 883
+SELECT cast(1 as tinyint) > cast(1 as decimal(10, 0)) FROM t
+-- !query 883 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)) > CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 883 output
+false
+
+
+-- !query 884
+SELECT cast(1 as tinyint) > cast(1 as decimal(20, 0)) FROM t
+-- !query 884 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0)) > CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 884 output
+false
+
+
+-- !query 885
+SELECT cast(1 as smallint) > cast(1 as decimal(3, 0)) FROM t
+-- !query 885 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0)) > CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean>
+-- !query 885 output
+false
+
+
+-- !query 886
+SELECT cast(1 as smallint) > cast(1 as decimal(5, 0)) FROM t
+-- !query 886 schema
+struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) > CAST(1 AS DECIMAL(5,0))):boolean>
+-- !query 886 output
+false
+
+
+-- !query 887
+SELECT cast(1 as smallint) > cast(1 as decimal(10, 0)) FROM t
+-- !query 887 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)) > CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 887 output
+false
+
+
+-- !query 888
+SELECT cast(1 as smallint) > cast(1 as decimal(20, 0)) FROM t
+-- !query 888 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0)) > CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 888 output
+false
+
+
+-- !query 889
+SELECT cast(1 as int) > cast(1 as decimal(3, 0)) FROM t
+-- !query 889 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) > CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean>
+-- !query 889 output
+false
+
+
+-- !query 890
+SELECT cast(1 as int) > cast(1 as decimal(5, 0)) FROM t
+-- !query 890 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) > CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean>
+-- !query 890 output
+false
+
+
+-- !query 891
+SELECT cast(1 as int) > cast(1 as decimal(10, 0)) FROM t
+-- !query 891 schema
+struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) > CAST(1 AS DECIMAL(10,0))):boolean>
+-- !query 891 output
+false
+
+
+-- !query 892
+SELECT cast(1 as int) > cast(1 as decimal(20, 0)) FROM t
+-- !query 892 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0)) > CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 892 output
+false
+
+
+-- !query 893
+SELECT cast(1 as bigint) > cast(1 as decimal(3, 0)) FROM t
+-- !query 893 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) > CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean>
+-- !query 893 output
+false
+
+
+-- !query 894
+SELECT cast(1 as bigint) > cast(1 as decimal(5, 0)) FROM t
+-- !query 894 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) > CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean>
+-- !query 894 output
+false
+
+
+-- !query 895
+SELECT cast(1 as bigint) > cast(1 as decimal(10, 0)) FROM t
+-- !query 895 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) > CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean>
+-- !query 895 output
+false
+
+
+-- !query 896
+SELECT cast(1 as bigint) > cast(1 as decimal(20, 0)) FROM t
+-- !query 896 schema
+struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) > CAST(1 AS DECIMAL(20,0))):boolean>
+-- !query 896 output
+false
+
+
+-- !query 897
+SELECT cast(1 as float) > cast(1 as decimal(3, 0)) FROM t
+-- !query 897 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) > CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean>
+-- !query 897 output
+false
+
+
+-- !query 898
+SELECT cast(1 as float) > cast(1 as decimal(5, 0)) FROM t
+-- !query 898 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) > CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean>
+-- !query 898 output
+false
+
+
+-- !query 899
+SELECT cast(1 as float) > cast(1 as decimal(10, 0)) FROM t
+-- !query 899 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) > CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean>
+-- !query 899 output
+false
+
+
+-- !query 900
+SELECT cast(1 as float) > cast(1 as decimal(20, 0)) FROM t
+-- !query 900 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) > CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean>
+-- !query 900 output
+false
+
+
+-- !query 901
+SELECT cast(1 as double) > cast(1 as decimal(3, 0)) FROM t
+-- !query 901 schema
+struct<(CAST(1 AS DOUBLE) > CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean>
+-- !query 901 output
+false
+
+
+-- !query 902
+SELECT cast(1 as double) > cast(1 as decimal(5, 0)) FROM t
+-- !query 902 schema
+struct<(CAST(1 AS DOUBLE) > CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean>
+-- !query 902 output
+false
+
+
+-- !query 903
+SELECT cast(1 as double) > cast(1 as decimal(10, 0)) FROM t
+-- !query 903 schema
+struct<(CAST(1 AS DOUBLE) > CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean>
+-- !query 903 output
+false
+
+
+-- !query 904
+SELECT cast(1 as double) > cast(1 as decimal(20, 0)) FROM t
+-- !query 904 schema
+struct<(CAST(1 AS DOUBLE) > CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean>
+-- !query 904 output
+false
+
+
+-- !query 905
+SELECT cast(1 as decimal(10, 0)) > cast(1 as decimal(3, 0)) FROM t
+-- !query 905 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) > CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean>
+-- !query 905 output
+false
+
+
+-- !query 906
+SELECT cast(1 as decimal(10, 0)) > cast(1 as decimal(5, 0)) FROM t
+-- !query 906 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) > CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean>
+-- !query 906 output
+false
+
+
+-- !query 907
+SELECT cast(1 as decimal(10, 0)) > cast(1 as decimal(10, 0)) FROM t
+-- !query 907 schema
+struct<(CAST(1 AS DECIMAL(10,0)) > CAST(1 AS DECIMAL(10,0))):boolean>
+-- !query 907 output
+false
+
+
+-- !query 908
+SELECT cast(1 as decimal(10, 0)) > cast(1 as decimal(20, 0)) FROM t
+-- !query 908 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) > CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 908 output
+false
+
+
+-- !query 909
+SELECT cast('1' as binary) > cast(1 as decimal(3, 0)) FROM t
+-- !query 909 schema
+struct<>
+-- !query 909 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) > CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) > CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 910
+SELECT cast('1' as binary) > cast(1 as decimal(5, 0)) FROM t
+-- !query 910 schema
+struct<>
+-- !query 910 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) > CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) > CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 911
+SELECT cast('1' as binary) > cast(1 as decimal(10, 0)) FROM t
+-- !query 911 schema
+struct<>
+-- !query 911 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) > CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) > CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 912
+SELECT cast('1' as binary) > cast(1 as decimal(20, 0)) FROM t
+-- !query 912 schema
+struct<>
+-- !query 912 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) > CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) > CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 913
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) > cast(1 as decimal(3, 0)) FROM t
+-- !query 913 schema
+struct<>
+-- !query 913 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) > CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) > CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 914
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) > cast(1 as decimal(5, 0)) FROM t
+-- !query 914 schema
+struct<>
+-- !query 914 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) > CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) > CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 915
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) > cast(1 as decimal(10, 0)) FROM t
+-- !query 915 schema
+struct<>
+-- !query 915 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) > CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) > CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 916
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) > cast(1 as decimal(20, 0)) FROM t
+-- !query 916 schema
+struct<>
+-- !query 916 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) > CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) > CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 917
+SELECT cast('2017-12-11 09:30:00' as date) > cast(1 as decimal(3, 0)) FROM t
+-- !query 917 schema
+struct<>
+-- !query 917 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) > CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) > CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 918
+SELECT cast('2017-12-11 09:30:00' as date) > cast(1 as decimal(5, 0)) FROM t
+-- !query 918 schema
+struct<>
+-- !query 918 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) > CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) > CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 919
+SELECT cast('2017-12-11 09:30:00' as date) > cast(1 as decimal(10, 0)) FROM t
+-- !query 919 schema
+struct<>
+-- !query 919 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) > CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) > CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 920
+SELECT cast('2017-12-11 09:30:00' as date) > cast(1 as decimal(20, 0)) FROM t
+-- !query 920 schema
+struct<>
+-- !query 920 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) > CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) > CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 921
+SELECT cast(1 as decimal(3, 0)) > cast(1 as tinyint) FROM t
+-- !query 921 schema
+struct<(CAST(1 AS DECIMAL(3,0)) > CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0))):boolean>
+-- !query 921 output
+false
+
+
+-- !query 922
+SELECT cast(1 as decimal(5, 0)) > cast(1 as tinyint) FROM t
+-- !query 922 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0)) > CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean>
+-- !query 922 output
+false
+
+
+-- !query 923
+SELECT cast(1 as decimal(10, 0)) > cast(1 as tinyint) FROM t
+-- !query 923 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) > CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean>
+-- !query 923 output
+false
+
+
+-- !query 924
+SELECT cast(1 as decimal(20, 0)) > cast(1 as tinyint) FROM t
+-- !query 924 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) > CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean>
+-- !query 924 output
+false
+
+
+-- !query 925
+SELECT cast(1 as decimal(3, 0)) > cast(1 as smallint) FROM t
+-- !query 925 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0)) > CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean>
+-- !query 925 output
+false
+
+
+-- !query 926
+SELECT cast(1 as decimal(5, 0)) > cast(1 as smallint) FROM t
+-- !query 926 schema
+struct<(CAST(1 AS DECIMAL(5,0)) > CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0))):boolean>
+-- !query 926 output
+false
+
+
+-- !query 927
+SELECT cast(1 as decimal(10, 0)) > cast(1 as smallint) FROM t
+-- !query 927 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) > CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean>
+-- !query 927 output
+false
+
+
+-- !query 928
+SELECT cast(1 as decimal(20, 0)) > cast(1 as smallint) FROM t
+-- !query 928 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) > CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean>
+-- !query 928 output
+false
+
+
+-- !query 929
+SELECT cast(1 as decimal(3, 0)) > cast(1 as int) FROM t
+-- !query 929 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) > CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 929 output
+false
+
+
+-- !query 930
+SELECT cast(1 as decimal(5, 0)) > cast(1 as int) FROM t
+-- !query 930 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) > CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 930 output
+false
+
+
+-- !query 931
+SELECT cast(1 as decimal(10, 0)) > cast(1 as int) FROM t
+-- !query 931 schema
+struct<(CAST(1 AS DECIMAL(10,0)) > CAST(CAST(1 AS INT) AS DECIMAL(10,0))):boolean>
+-- !query 931 output
+false
+
+
+-- !query 932
+SELECT cast(1 as decimal(20, 0)) > cast(1 as int) FROM t
+-- !query 932 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) > CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean>
+-- !query 932 output
+false
+
+
+-- !query 933
+SELECT cast(1 as decimal(3, 0)) > cast(1 as bigint) FROM t
+-- !query 933 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0)) > CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 933 output
+false
+
+
+-- !query 934
+SELECT cast(1 as decimal(5, 0)) > cast(1 as bigint) FROM t
+-- !query 934 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0)) > CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 934 output
+false
+
+
+-- !query 935
+SELECT cast(1 as decimal(10, 0)) > cast(1 as bigint) FROM t
+-- !query 935 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) > CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 935 output
+false
+
+
+-- !query 936
+SELECT cast(1 as decimal(20, 0)) > cast(1 as bigint) FROM t
+-- !query 936 schema
+struct<(CAST(1 AS DECIMAL(20,0)) > CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):boolean>
+-- !query 936 output
+false
+
+
+-- !query 937
+SELECT cast(1 as decimal(3, 0)) > cast(1 as float) FROM t
+-- !query 937 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) > CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 937 output
+false
+
+
+-- !query 938
+SELECT cast(1 as decimal(5, 0)) > cast(1 as float) FROM t
+-- !query 938 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) > CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 938 output
+false
+
+
+-- !query 939
+SELECT cast(1 as decimal(10, 0)) > cast(1 as float) FROM t
+-- !query 939 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) > CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 939 output
+false
+
+
+-- !query 940
+SELECT cast(1 as decimal(20, 0)) > cast(1 as float) FROM t
+-- !query 940 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) > CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 940 output
+false
+
+
+-- !query 941
+SELECT cast(1 as decimal(3, 0)) > cast(1 as double) FROM t
+-- !query 941 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) > CAST(1 AS DOUBLE)):boolean>
+-- !query 941 output
+false
+
+
+-- !query 942
+SELECT cast(1 as decimal(5, 0)) > cast(1 as double) FROM t
+-- !query 942 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) > CAST(1 AS DOUBLE)):boolean>
+-- !query 942 output
+false
+
+
+-- !query 943
+SELECT cast(1 as decimal(10, 0)) > cast(1 as double) FROM t
+-- !query 943 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) > CAST(1 AS DOUBLE)):boolean>
+-- !query 943 output
+false
+
+
+-- !query 944
+SELECT cast(1 as decimal(20, 0)) > cast(1 as double) FROM t
+-- !query 944 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) > CAST(1 AS DOUBLE)):boolean>
+-- !query 944 output
+false
+
+
+-- !query 945
+SELECT cast(1 as decimal(3, 0)) > cast(1 as decimal(10, 0)) FROM t
+-- !query 945 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) > CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 945 output
+false
+
+
+-- !query 946
+SELECT cast(1 as decimal(5, 0)) > cast(1 as decimal(10, 0)) FROM t
+-- !query 946 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) > CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 946 output
+false
+
+
+-- !query 947
+SELECT cast(1 as decimal(10, 0)) > cast(1 as decimal(10, 0)) FROM t
+-- !query 947 schema
+struct<(CAST(1 AS DECIMAL(10,0)) > CAST(1 AS DECIMAL(10,0))):boolean>
+-- !query 947 output
+false
+
+
+-- !query 948
+SELECT cast(1 as decimal(20, 0)) > cast(1 as decimal(10, 0)) FROM t
+-- !query 948 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) > CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean>
+-- !query 948 output
+false
+
+
+-- !query 949
+SELECT cast(1 as decimal(3, 0)) > cast(1 as string) FROM t
+-- !query 949 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) > CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 949 output
+false
+
+
+-- !query 950
+SELECT cast(1 as decimal(5, 0)) > cast(1 as string) FROM t
+-- !query 950 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) > CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 950 output
+false
+
+
+-- !query 951
+SELECT cast(1 as decimal(10, 0)) > cast(1 as string) FROM t
+-- !query 951 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) > CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 951 output
+false
+
+
+-- !query 952
+SELECT cast(1 as decimal(20, 0)) > cast(1 as string) FROM t
+-- !query 952 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) > CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 952 output
+false
+
+
+-- !query 953
+SELECT cast(1 as decimal(3, 0)) > cast('1' as binary) FROM t
+-- !query 953 schema
+struct<>
+-- !query 953 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) > CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) > CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7
+
+
+-- !query 954
+SELECT cast(1 as decimal(5, 0)) > cast('1' as binary) FROM t
+-- !query 954 schema
+struct<>
+-- !query 954 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) > CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) > CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7
+
+
+-- !query 955
+SELECT cast(1 as decimal(10, 0)) > cast('1' as binary) FROM t
+-- !query 955 schema
+struct<>
+-- !query 955 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) > CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) > CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7
+
+
+-- !query 956
+SELECT cast(1 as decimal(20, 0)) > cast('1' as binary) FROM t
+-- !query 956 schema
+struct<>
+-- !query 956 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) > CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) > CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7
+
+
+-- !query 957
+SELECT cast(1 as decimal(3, 0)) > cast(1 as boolean) FROM t
+-- !query 957 schema
+struct<>
+-- !query 957 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) > CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) > CAST(1 AS BOOLEAN))' (decimal(3,0) and boolean).; line 1 pos 7
+
+
+-- !query 958
+SELECT cast(1 as decimal(5, 0)) > cast(1 as boolean) FROM t
+-- !query 958 schema
+struct<>
+-- !query 958 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) > CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) > CAST(1 AS BOOLEAN))' (decimal(5,0) and boolean).; line 1 pos 7
+
+
+-- !query 959
+SELECT cast(1 as decimal(10, 0)) > cast(1 as boolean) FROM t
+-- !query 959 schema
+struct<>
+-- !query 959 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) > CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) > CAST(1 AS BOOLEAN))' (decimal(10,0) and boolean).; line 1 pos 7
+
+
+-- !query 960
+SELECT cast(1 as decimal(20, 0)) > cast(1 as boolean) FROM t
+-- !query 960 schema
+struct<>
+-- !query 960 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) > CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) > CAST(1 AS BOOLEAN))' (decimal(20,0) and boolean).; line 1 pos 7
+
+
+-- !query 961
+SELECT cast(1 as decimal(3, 0)) > cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 961 schema
+struct<>
+-- !query 961 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) > CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) > CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7
+
+
+-- !query 962
+SELECT cast(1 as decimal(5, 0)) > cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 962 schema
+struct<>
+-- !query 962 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) > CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) > CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7
+
+
+-- !query 963
+SELECT cast(1 as decimal(10, 0)) > cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 963 schema
+struct<>
+-- !query 963 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) > CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) > CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7
+
+
+-- !query 964
+SELECT cast(1 as decimal(20, 0)) > cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 964 schema
+struct<>
+-- !query 964 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) > CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) > CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7
+
+
+-- !query 965
+SELECT cast(1 as decimal(3, 0)) > cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 965 schema
+struct<>
+-- !query 965 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) > CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) > CAST('2017-12-11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7
+
+
+-- !query 966
+SELECT cast(1 as decimal(5, 0)) > cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 966 schema
+struct<>
+-- !query 966 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) > CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) > CAST('2017-12-11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7
+
+
+-- !query 967
+SELECT cast(1 as decimal(10, 0)) > cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 967 schema
+struct<>
+-- !query 967 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) > CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) > CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7
+
+
+-- !query 968
+SELECT cast(1 as decimal(20, 0)) > cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 968 schema
+struct<>
+-- !query 968 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) > CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) > CAST('2017-12-11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7
+
+
+-- !query 969
+SELECT cast(1 as tinyint) >= cast(1 as decimal(3, 0)) FROM t
+-- !query 969 schema
+struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) >= CAST(1 AS DECIMAL(3,0))):boolean>
+-- !query 969 output
+true
+
+
+-- !query 970
+SELECT cast(1 as tinyint) >= cast(1 as decimal(5, 0)) FROM t
+-- !query 970 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0)) >= CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean>
+-- !query 970 output
+true
+
+
+-- !query 971
+SELECT cast(1 as tinyint) >= cast(1 as decimal(10, 0)) FROM t
+-- !query 971 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)) >= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 971 output
+true
+
+
+-- !query 972
+SELECT cast(1 as tinyint) >= cast(1 as decimal(20, 0)) FROM t
+-- !query 972 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0)) >= CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 972 output
+true
+
+
+-- !query 973
+SELECT cast(1 as smallint) >= cast(1 as decimal(3, 0)) FROM t
+-- !query 973 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0)) >= CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean>
+-- !query 973 output
+true
+
+
+-- !query 974
+SELECT cast(1 as smallint) >= cast(1 as decimal(5, 0)) FROM t
+-- !query 974 schema
+struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) >= CAST(1 AS DECIMAL(5,0))):boolean>
+-- !query 974 output
+true
+
+
+-- !query 975
+SELECT cast(1 as smallint) >= cast(1 as decimal(10, 0)) FROM t
+-- !query 975 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)) >= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 975 output
+true
+
+
+-- !query 976
+SELECT cast(1 as smallint) >= cast(1 as decimal(20, 0)) FROM t
+-- !query 976 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0)) >= CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 976 output
+true
+
+
+-- !query 977
+SELECT cast(1 as int) >= cast(1 as decimal(3, 0)) FROM t
+-- !query 977 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) >= CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean>
+-- !query 977 output
+true
+
+
+-- !query 978
+SELECT cast(1 as int) >= cast(1 as decimal(5, 0)) FROM t
+-- !query 978 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) >= CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean>
+-- !query 978 output
+true
+
+
+-- !query 979
+SELECT cast(1 as int) >= cast(1 as decimal(10, 0)) FROM t
+-- !query 979 schema
+struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) >= CAST(1 AS DECIMAL(10,0))):boolean>
+-- !query 979 output
+true
+
+
+-- !query 980
+SELECT cast(1 as int) >= cast(1 as decimal(20, 0)) FROM t
+-- !query 980 schema
+struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0)) >= CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 980 output
+true
+
+
+-- !query 981
+SELECT cast(1 as bigint) >= cast(1 as decimal(3, 0)) FROM t
+-- !query 981 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) >= CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean>
+-- !query 981 output
+true
+
+
+-- !query 982
+SELECT cast(1 as bigint) >= cast(1 as decimal(5, 0)) FROM t
+-- !query 982 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) >= CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean>
+-- !query 982 output
+true
+
+
+-- !query 983
+SELECT cast(1 as bigint) >= cast(1 as decimal(10, 0)) FROM t
+-- !query 983 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) >= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean>
+-- !query 983 output
+true
+
+
+-- !query 984
+SELECT cast(1 as bigint) >= cast(1 as decimal(20, 0)) FROM t
+-- !query 984 schema
+struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) >= CAST(1 AS DECIMAL(20,0))):boolean>
+-- !query 984 output
+true
+
+
+-- !query 985
+SELECT cast(1 as float) >= cast(1 as decimal(3, 0)) FROM t
+-- !query 985 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) >= CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean>
+-- !query 985 output
+true
+
+
+-- !query 986
+SELECT cast(1 as float) >= cast(1 as decimal(5, 0)) FROM t
+-- !query 986 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) >= CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean>
+-- !query 986 output
+true
+
+
+-- !query 987
+SELECT cast(1 as float) >= cast(1 as decimal(10, 0)) FROM t
+-- !query 987 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) >= CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean>
+-- !query 987 output
+true
+
+
+-- !query 988
+SELECT cast(1 as float) >= cast(1 as decimal(20, 0)) FROM t
+-- !query 988 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) >= CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean>
+-- !query 988 output
+true
+
+
+-- !query 989
+SELECT cast(1 as double) >= cast(1 as decimal(3, 0)) FROM t
+-- !query 989 schema
+struct<(CAST(1 AS DOUBLE) >= CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean>
+-- !query 989 output
+true
+
+
+-- !query 990
+SELECT cast(1 as double) >= cast(1 as decimal(5, 0)) FROM t
+-- !query 990 schema
+struct<(CAST(1 AS DOUBLE) >= CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean>
+-- !query 990 output
+true
+
+
+-- !query 991
+SELECT cast(1 as double) >= cast(1 as decimal(10, 0)) FROM t
+-- !query 991 schema
+struct<(CAST(1 AS DOUBLE) >= CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean>
+-- !query 991 output
+true
+
+
+-- !query 992
+SELECT cast(1 as double) >= cast(1 as decimal(20, 0)) FROM t
+-- !query 992 schema
+struct<(CAST(1 AS DOUBLE) >= CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean>
+-- !query 992 output
+true
+
+
+-- !query 993
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as decimal(3, 0)) FROM t
+-- !query 993 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) >= CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean>
+-- !query 993 output
+true
+
+
+-- !query 994
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as decimal(5, 0)) FROM t
+-- !query 994 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) >= CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean>
+-- !query 994 output
+true
+
+
+-- !query 995
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as decimal(10, 0)) FROM t
+-- !query 995 schema
+struct<(CAST(1 AS DECIMAL(10,0)) >= CAST(1 AS DECIMAL(10,0))):boolean>
+-- !query 995 output
+true
+
+
+-- !query 996
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as decimal(20, 0)) FROM t
+-- !query 996 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) >= CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 996 output
+true
+
+
+-- !query 997
+SELECT cast('1' as binary) >= cast(1 as decimal(3, 0)) FROM t
+-- !query 997 schema
+struct<>
+-- !query 997 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) >= CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) >= CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 998
+SELECT cast('1' as binary) >= cast(1 as decimal(5, 0)) FROM t
+-- !query 998 schema
+struct<>
+-- !query 998 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) >= CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) >= CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 999
+SELECT cast('1' as binary) >= cast(1 as decimal(10, 0)) FROM t
+-- !query 999 schema
+struct<>
+-- !query 999 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) >= CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) >= CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 1000
+SELECT cast('1' as binary) >= cast(1 as decimal(20, 0)) FROM t
+-- !query 1000 schema
+struct<>
+-- !query 1000 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) >= CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) >= CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 1001
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) >= cast(1 as decimal(3, 0)) FROM t
+-- !query 1001 schema
+struct<>
+-- !query 1001 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) >= CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) >= CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 1002
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) >= cast(1 as decimal(5, 0)) FROM t
+-- !query 1002 schema
+struct<>
+-- !query 1002 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) >= CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) >= CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 1003
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) >= cast(1 as decimal(10, 0)) FROM t
+-- !query 1003 schema
+struct<>
+-- !query 1003 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) >= CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) >= CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 1004
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) >= cast(1 as decimal(20, 0)) FROM t
+-- !query 1004 schema
+struct<>
+-- !query 1004 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) >= CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) >= CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 1005
+SELECT cast('2017-12-11 09:30:00' as date) >= cast(1 as decimal(3, 0)) FROM t
+-- !query 1005 schema
+struct<>
+-- !query 1005 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) >= CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) >= CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 1006
+SELECT cast('2017-12-11 09:30:00' as date) >= cast(1 as decimal(5, 0)) FROM t
+-- !query 1006 schema
+struct<>
+-- !query 1006 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) >= CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) >= CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 1007
+SELECT cast('2017-12-11 09:30:00' as date) >= cast(1 as decimal(10, 0)) FROM t
+-- !query 1007 schema
+struct<>
+-- !query 1007 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) >= CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) >= CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 1008
+SELECT cast('2017-12-11 09:30:00' as date) >= cast(1 as decimal(20, 0)) FROM t
+-- !query 1008 schema
+struct<>
+-- !query 1008 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) >= CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) >= CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 1009
+SELECT cast(1 as decimal(3, 0)) >= cast(1 as tinyint) FROM t
+-- !query 1009 schema
+struct<(CAST(1 AS DECIMAL(3,0)) >= CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0))):boolean>
+-- !query 1009 output
+true
+
+
+-- !query 1010
+SELECT cast(1 as decimal(5, 0)) >= cast(1 as tinyint) FROM t
+-- !query 1010 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0)) >= CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean>
+-- !query 1010 output
+true
+
+
+-- !query 1011
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as tinyint) FROM t
+-- !query 1011 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) >= CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean>
+-- !query 1011 output
+true
+
+
+-- !query 1012
+SELECT cast(1 as decimal(20, 0)) >= cast(1 as tinyint) FROM t
+-- !query 1012 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) >= CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean>
+-- !query 1012 output
+true
+
+
+-- !query 1013
+SELECT cast(1 as decimal(3, 0)) >= cast(1 as smallint) FROM t
+-- !query 1013 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0)) >= CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean>
+-- !query 1013 output
+true
+
+
+-- !query 1014
+SELECT cast(1 as decimal(5, 0)) >= cast(1 as smallint) FROM t
+-- !query 1014 schema
+struct<(CAST(1 AS DECIMAL(5,0)) >= CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0))):boolean>
+-- !query 1014 output
+true
+
+
+-- !query 1015
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as smallint) FROM t
+-- !query 1015 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) >= CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean>
+-- !query 1015 output
+true
+
+
+-- !query 1016
+SELECT cast(1 as decimal(20, 0)) >= cast(1 as smallint) FROM t
+-- !query 1016 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) >= CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean>
+-- !query 1016 output
+true
+
+
+-- !query 1017
+SELECT cast(1 as decimal(3, 0)) >= cast(1 as int) FROM t
+-- !query 1017 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) >= CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 1017 output
+true
+
+
+-- !query 1018
+SELECT cast(1 as decimal(5, 0)) >= cast(1 as int) FROM t
+-- !query 1018 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) >= CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 1018 output
+true
+
+
+-- !query 1019
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as int) FROM t
+-- !query 1019 schema
+struct<(CAST(1 AS DECIMAL(10,0)) >= CAST(CAST(1 AS INT) AS DECIMAL(10,0))):boolean>
+-- !query 1019 output
+true
+
+
+-- !query 1020
+SELECT cast(1 as decimal(20, 0)) >= cast(1 as int) FROM t
+-- !query 1020 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) >= CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean>
+-- !query 1020 output
+true
+
+
+-- !query 1021
+SELECT cast(1 as decimal(3, 0)) >= cast(1 as bigint) FROM t
+-- !query 1021 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0)) >= CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 1021 output
+true
+
+
+-- !query 1022
+SELECT cast(1 as decimal(5, 0)) >= cast(1 as bigint) FROM t
+-- !query 1022 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0)) >= CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 1022 output
+true
+
+
+-- !query 1023
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as bigint) FROM t
+-- !query 1023 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) >= CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean>
+-- !query 1023 output
+true
+
+
+-- !query 1024
+SELECT cast(1 as decimal(20, 0)) >= cast(1 as bigint) FROM t
+-- !query 1024 schema
+struct<(CAST(1 AS DECIMAL(20,0)) >= CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):boolean>
+-- !query 1024 output
+true
+
+
+-- !query 1025
+SELECT cast(1 as decimal(3, 0)) >= cast(1 as float) FROM t
+-- !query 1025 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) >= CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 1025 output
+true
+
+
+-- !query 1026
+SELECT cast(1 as decimal(5, 0)) >= cast(1 as float) FROM t
+-- !query 1026 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) >= CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 1026 output
+true
+
+
+-- !query 1027
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as float) FROM t
+-- !query 1027 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) >= CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 1027 output
+true
+
+
+-- !query 1028
+SELECT cast(1 as decimal(20, 0)) >= cast(1 as float) FROM t
+-- !query 1028 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) >= CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean>
+-- !query 1028 output
+true
+
+
+-- !query 1029
+SELECT cast(1 as decimal(3, 0)) >= cast(1 as double) FROM t
+-- !query 1029 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) >= CAST(1 AS DOUBLE)):boolean>
+-- !query 1029 output
+true
+
+
+-- !query 1030
+SELECT cast(1 as decimal(5, 0)) >= cast(1 as double) FROM t
+-- !query 1030 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) >= CAST(1 AS DOUBLE)):boolean>
+-- !query 1030 output
+true
+
+
+-- !query 1031
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as double) FROM t
+-- !query 1031 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) >= CAST(1 AS DOUBLE)):boolean>
+-- !query 1031 output
+true
+
+
+-- !query 1032
+SELECT cast(1 as decimal(20, 0)) >= cast(1 as double) FROM t
+-- !query 1032 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) >= CAST(1 AS DOUBLE)):boolean>
+-- !query 1032 output
+true
+
+
+-- !query 1033
+SELECT cast(1 as decimal(3, 0)) >= cast(1 as decimal(10, 0)) FROM t
+-- !query 1033 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) >= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 1033 output
+true
+
+
+-- !query 1034
+SELECT cast(1 as decimal(5, 0)) >= cast(1 as decimal(10, 0)) FROM t
+-- !query 1034 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) >= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean>
+-- !query 1034 output
+true
+
+
+-- !query 1035
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as decimal(10, 0)) FROM t
+-- !query 1035 schema
+struct<(CAST(1 AS DECIMAL(10,0)) >= CAST(1 AS DECIMAL(10,0))):boolean>
+-- !query 1035 output
+true
+
+
+-- !query 1036
+SELECT cast(1 as decimal(20, 0)) >= cast(1 as decimal(10, 0)) FROM t
+-- !query 1036 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) >= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean>
+-- !query 1036 output
+true
+
+
+-- !query 1037
+SELECT cast(1 as decimal(3, 0)) >= cast(1 as string) FROM t
+-- !query 1037 schema
+struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) >= CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 1037 output
+true
+
+
+-- !query 1038
+SELECT cast(1 as decimal(5, 0)) >= cast(1 as string) FROM t
+-- !query 1038 schema
+struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) >= CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 1038 output
+true
+
+
+-- !query 1039
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as string) FROM t
+-- !query 1039 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) >= CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 1039 output
+true
+
+
+-- !query 1040
+SELECT cast(1 as decimal(20, 0)) >= cast(1 as string) FROM t
+-- !query 1040 schema
+struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) >= CAST(CAST(1 AS STRING) AS DOUBLE)):boolean>
+-- !query 1040 output
+true
+
+
+-- !query 1041
+SELECT cast(1 as decimal(3, 0)) >= cast('1' as binary) FROM t
+-- !query 1041 schema
+struct<>
+-- !query 1041 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) >= CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) >= CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7
+
+
+-- !query 1042
+SELECT cast(1 as decimal(5, 0)) >= cast('1' as binary) FROM t
+-- !query 1042 schema
+struct<>
+-- !query 1042 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) >= CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) >= CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7
+
+
+-- !query 1043
+SELECT cast(1 as decimal(10, 0)) >= cast('1' as binary) FROM t
+-- !query 1043 schema
+struct<>
+-- !query 1043 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) >= CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) >= CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7
+
+
+-- !query 1044
+SELECT cast(1 as decimal(20, 0)) >= cast('1' as binary) FROM t
+-- !query 1044 schema
+struct<>
+-- !query 1044 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) >= CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) >= CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7
+
+
+-- !query 1045
+SELECT cast(1 as decimal(3, 0)) >= cast(1 as boolean) FROM t
+-- !query 1045 schema
+struct<>
+-- !query 1045 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) >= CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) >= CAST(1 AS BOOLEAN))' (decimal(3,0) and boolean).; line 1 pos 7
+
+
+-- !query 1046
+SELECT cast(1 as decimal(5, 0)) >= cast(1 as boolean) FROM t
+-- !query 1046 schema
+struct<>
+-- !query 1046 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) >= CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) >= CAST(1 AS BOOLEAN))' (decimal(5,0) and boolean).; line 1 pos 7
+
+
+-- !query 1047
+SELECT cast(1 as decimal(10, 0)) >= cast(1 as boolean) FROM t
+-- !query 1047 schema
+struct<>
+-- !query 1047 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) >= CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) >= CAST(1 AS BOOLEAN))' (decimal(10,0) and boolean).; line 1 pos 7
+
+
+-- !query 1048
+SELECT cast(1 as decimal(20, 0)) >= cast(1 as boolean) FROM t
+-- !query 1048 schema
+struct<>
+-- !query 1048 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) >= CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) >= CAST(1 AS BOOLEAN))' (decimal(20,0) and boolean).; line 1 pos 7
+
+
+-- !query 1049
+SELECT cast(1 as decimal(3, 0)) >= cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 1049 schema
+struct<>
+-- !query 1049 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) >= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) >= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7
+
+
+-- !query 1050
+SELECT cast(1 as decimal(5, 0)) >= cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 1050 schema
+struct<>
+-- !query 1050 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) >= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) >= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7
+
+
+-- !query 1051
+SELECT cast(1 as decimal(10, 0)) >= cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 1051 schema
+struct<>
+-- !query 1051 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) >= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) >= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7
+
+
+-- !query 1052
+SELECT cast(1 as decimal(20, 0)) >= cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 1052 schema
+struct<>
+-- !query 1052 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) >= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) >= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7
+
+
+-- !query 1053
+SELECT cast(1 as decimal(3, 0)) >= cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 1053 schema
+struct<>
+-- !query 1053 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) >= CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) >= CAST('2017-12-11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7
+
+
+-- !query 1054
+SELECT cast(1 as decimal(5, 0)) >= cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 1054 schema
+struct<>
+-- !query 1054 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) >= CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) >= CAST('2017-12-11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7
+
+
+-- !query 1055
+SELECT cast(1 as decimal(10, 0)) >= cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 1055 schema
+struct<>
+-- !query 1055 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) >= CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) >= CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7
+
+
+-- !query 1056
+SELECT cast(1 as decimal(20, 0)) >= cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 1056 schema
+struct<>
+-- !query 1056 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) >= CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) >= CAST('2017-12-11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7
+
+
+-- !query 1057
+SELECT cast(1 as tinyint) <> cast(1 as decimal(3, 0)) FROM t
+-- !query 1057 schema
+struct<(NOT (CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) = CAST(1 AS DECIMAL(3,0)))):boolean>
+-- !query 1057 output
+false
+
+
+-- !query 1058
+SELECT cast(1 as tinyint) <> cast(1 as decimal(5, 0)) FROM t
+-- !query 1058 schema
+struct<(NOT (CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0)) = CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0)))):boolean>
+-- !query 1058 output
+false
+
+
+-- !query 1059
+SELECT cast(1 as tinyint) <> cast(1 as decimal(10, 0)) FROM t
+-- !query 1059 schema
+struct<(NOT (CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)))):boolean>
+-- !query 1059 output
+false
+
+
+-- !query 1060
+SELECT cast(1 as tinyint) <> cast(1 as decimal(20, 0)) FROM t
+-- !query 1060 schema
+struct<(NOT (CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)))):boolean>
+-- !query 1060 output
+false
+
+
+-- !query 1061
+SELECT cast(1 as smallint) <> cast(1 as decimal(3, 0)) FROM t
+-- !query 1061 schema
+struct<(NOT (CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0)) = CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0)))):boolean>
+-- !query 1061 output
+false
+
+
+-- !query 1062
+SELECT cast(1 as smallint) <> cast(1 as decimal(5, 0)) FROM t
+-- !query 1062 schema
+struct<(NOT (CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) = CAST(1 AS DECIMAL(5,0)))):boolean>
+-- !query 1062 output
+false
+
+
+-- !query 1063
+SELECT cast(1 as smallint) <> cast(1 as decimal(10, 0)) FROM t
+-- !query 1063 schema
+struct<(NOT (CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)))):boolean>
+-- !query 1063 output
+false
+
+
+-- !query 1064
+SELECT cast(1 as smallint) <> cast(1 as decimal(20, 0)) FROM t
+-- !query 1064 schema
+struct<(NOT (CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)))):boolean>
+-- !query 1064 output
+false
+
+
+-- !query 1065
+SELECT cast(1 as int) <> cast(1 as decimal(3, 0)) FROM t
+-- !query 1065 schema
+struct<(NOT (CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)))):boolean>
+-- !query 1065 output
+false
+
+
+-- !query 1066
+SELECT cast(1 as int) <> cast(1 as decimal(5, 0)) FROM t
+-- !query 1066 schema
+struct<(NOT (CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)))):boolean>
+-- !query 1066 output
+false
+
+
+-- !query 1067
+SELECT cast(1 as int) <> cast(1 as decimal(10, 0)) FROM t
+-- !query 1067 schema
+struct<(NOT (CAST(CAST(1 AS INT) AS DECIMAL(10,0)) = CAST(1 AS DECIMAL(10,0)))):boolean>
+-- !query 1067 output
+false
+
+
+-- !query 1068
+SELECT cast(1 as int) <> cast(1 as decimal(20, 0)) FROM t
+-- !query 1068 schema
+struct<(NOT (CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)))):boolean>
+-- !query 1068 output
+false
+
+
+-- !query 1069
+SELECT cast(1 as bigint) <> cast(1 as decimal(3, 0)) FROM t
+-- !query 1069 schema
+struct<(NOT (CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0)))):boolean>
+-- !query 1069 output
+false
+
+
+-- !query 1070
+SELECT cast(1 as bigint) <> cast(1 as decimal(5, 0)) FROM t
+-- !query 1070 schema
+struct<(NOT (CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0)))):boolean>
+-- !query 1070 output
+false
+
+
+-- !query 1071
+SELECT cast(1 as bigint) <> cast(1 as decimal(10, 0)) FROM t
+-- !query 1071 schema
+struct<(NOT (CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)))):boolean>
+-- !query 1071 output
+false
+
+
+-- !query 1072
+SELECT cast(1 as bigint) <> cast(1 as decimal(20, 0)) FROM t
+-- !query 1072 schema
+struct<(NOT (CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) = CAST(1 AS DECIMAL(20,0)))):boolean>
+-- !query 1072 output
+false
+
+
+-- !query 1073
+SELECT cast(1 as float) <> cast(1 as decimal(3, 0)) FROM t
+-- !query 1073 schema
+struct<(NOT (CAST(CAST(1 AS FLOAT) AS DOUBLE) = CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE))):boolean>
+-- !query 1073 output
+false
+
+
+-- !query 1074
+SELECT cast(1 as float) <> cast(1 as decimal(5, 0)) FROM t
+-- !query 1074 schema
+struct<(NOT (CAST(CAST(1 AS FLOAT) AS DOUBLE) = CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE))):boolean>
+-- !query 1074 output
+false
+
+
+-- !query 1075
+SELECT cast(1 as float) <> cast(1 as decimal(10, 0)) FROM t
+-- !query 1075 schema
+struct<(NOT (CAST(CAST(1 AS FLOAT) AS DOUBLE) = CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE))):boolean>
+-- !query 1075 output
+false
+
+
+-- !query 1076
+SELECT cast(1 as float) <> cast(1 as decimal(20, 0)) FROM t
+-- !query 1076 schema
+struct<(NOT (CAST(CAST(1 AS FLOAT) AS DOUBLE) = CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE))):boolean>
+-- !query 1076 output
+false
+
+
+-- !query 1077
+SELECT cast(1 as double) <> cast(1 as decimal(3, 0)) FROM t
+-- !query 1077 schema
+struct<(NOT (CAST(1 AS DOUBLE) = CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE))):boolean>
+-- !query 1077 output
+false
+
+
+-- !query 1078
+SELECT cast(1 as double) <> cast(1 as decimal(5, 0)) FROM t
+-- !query 1078 schema
+struct<(NOT (CAST(1 AS DOUBLE) = CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE))):boolean>
+-- !query 1078 output
+false
+
+
+-- !query 1079
+SELECT cast(1 as double) <> cast(1 as decimal(10, 0)) FROM t
+-- !query 1079 schema
+struct<(NOT (CAST(1 AS DOUBLE) = CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE))):boolean>
+-- !query 1079 output
+false
+
+
+-- !query 1080
+SELECT cast(1 as double) <> cast(1 as decimal(20, 0)) FROM t
+-- !query 1080 schema
+struct<(NOT (CAST(1 AS DOUBLE) = CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE))):boolean>
+-- !query 1080 output
+false
+
+
+-- !query 1081
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as decimal(3, 0)) FROM t
+-- !query 1081 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)))):boolean>
+-- !query 1081 output
+false
+
+
+-- !query 1082
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as decimal(5, 0)) FROM t
+-- !query 1082 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)))):boolean>
+-- !query 1082 output
+false
+
+
+-- !query 1083
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as decimal(10, 0)) FROM t
+-- !query 1083 schema
+struct<(NOT (CAST(1 AS DECIMAL(10,0)) = CAST(1 AS DECIMAL(10,0)))):boolean>
+-- !query 1083 output
+false
+
+
+-- !query 1084
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as decimal(20, 0)) FROM t
+-- !query 1084 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)))):boolean>
+-- !query 1084 output
+false
+
+
+-- !query 1085
+SELECT cast('1' as binary) <> cast(1 as decimal(3, 0)) FROM t
+-- !query 1085 schema
+struct<>
+-- !query 1085 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 1086
+SELECT cast('1' as binary) <> cast(1 as decimal(5, 0)) FROM t
+-- !query 1086 schema
+struct<>
+-- !query 1086 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 1087
+SELECT cast('1' as binary) <> cast(1 as decimal(10, 0)) FROM t
+-- !query 1087 schema
+struct<>
+-- !query 1087 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 1088
+SELECT cast('1' as binary) <> cast(1 as decimal(20, 0)) FROM t
+-- !query 1088 schema
+struct<>
+-- !query 1088 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 1089
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <> cast(1 as decimal(3, 0)) FROM t
+-- !query 1089 schema
+struct<>
+-- !query 1089 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 1090
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <> cast(1 as decimal(5, 0)) FROM t
+-- !query 1090 schema
+struct<>
+-- !query 1090 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 1091
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <> cast(1 as decimal(10, 0)) FROM t
+-- !query 1091 schema
+struct<>
+-- !query 1091 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 1092
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) <> cast(1 as decimal(20, 0)) FROM t
+-- !query 1092 schema
+struct<>
+-- !query 1092 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 1093
+SELECT cast('2017-12-11 09:30:00' as date) <> cast(1 as decimal(3, 0)) FROM t
+-- !query 1093 schema
+struct<>
+-- !query 1093 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7
+
+
+-- !query 1094
+SELECT cast('2017-12-11 09:30:00' as date) <> cast(1 as decimal(5, 0)) FROM t
+-- !query 1094 schema
+struct<>
+-- !query 1094 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7
+
+
+-- !query 1095
+SELECT cast('2017-12-11 09:30:00' as date) <> cast(1 as decimal(10, 0)) FROM t
+-- !query 1095 schema
+struct<>
+-- !query 1095 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 1096
+SELECT cast('2017-12-11 09:30:00' as date) <> cast(1 as decimal(20, 0)) FROM t
+-- !query 1096 schema
+struct<>
+-- !query 1096 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7
+
+
+-- !query 1097
+SELECT cast(1 as decimal(3, 0)) <> cast(1 as tinyint) FROM t
+-- !query 1097 schema
+struct<(NOT (CAST(1 AS DECIMAL(3,0)) = CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)))):boolean>
+-- !query 1097 output
+false
+
+
+-- !query 1098
+SELECT cast(1 as decimal(5, 0)) <> cast(1 as tinyint) FROM t
+-- !query 1098 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0)) = CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0)))):boolean>
+-- !query 1098 output
+false
+
+
+-- !query 1099
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as tinyint) FROM t
+-- !query 1099 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)))):boolean>
+-- !query 1099 output
+false
+
+
+-- !query 1100
+SELECT cast(1 as decimal(20, 0)) <> cast(1 as tinyint) FROM t
+-- !query 1100 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0)))):boolean>
+-- !query 1100 output
+false
+
+
+-- !query 1101
+SELECT cast(1 as decimal(3, 0)) <> cast(1 as smallint) FROM t
+-- !query 1101 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0)) = CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0)))):boolean>
+-- !query 1101 output
+false
+
+
+-- !query 1102
+SELECT cast(1 as decimal(5, 0)) <> cast(1 as smallint) FROM t
+-- !query 1102 schema
+struct<(NOT (CAST(1 AS DECIMAL(5,0)) = CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)))):boolean>
+-- !query 1102 output
+false
+
+
+-- !query 1103
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as smallint) FROM t
+-- !query 1103 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)))):boolean>
+-- !query 1103 output
+false
+
+
+-- !query 1104
+SELECT cast(1 as decimal(20, 0)) <> cast(1 as smallint) FROM t
+-- !query 1104 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0)))):boolean>
+-- !query 1104 output
+false
+
+
+-- !query 1105
+SELECT cast(1 as decimal(3, 0)) <> cast(1 as int) FROM t
+-- !query 1105 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) = CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)))):boolean>
+-- !query 1105 output
+false
+
+
+-- !query 1106
+SELECT cast(1 as decimal(5, 0)) <> cast(1 as int) FROM t
+-- !query 1106 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) = CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)))):boolean>
+-- !query 1106 output
+false
+
+
+-- !query 1107
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as int) FROM t
+-- !query 1107 schema
+struct<(NOT (CAST(1 AS DECIMAL(10,0)) = CAST(CAST(1 AS INT) AS DECIMAL(10,0)))):boolean>
+-- !query 1107 output
+false
+
+
+-- !query 1108
+SELECT cast(1 as decimal(20, 0)) <> cast(1 as int) FROM t
+-- !query 1108 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0)))):boolean>
+-- !query 1108 output
+false
+
+
+-- !query 1109
+SELECT cast(1 as decimal(3, 0)) <> cast(1 as bigint) FROM t
+-- !query 1109 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)))):boolean>
+-- !query 1109 output
+false
+
+
+-- !query 1110
+SELECT cast(1 as decimal(5, 0)) <> cast(1 as bigint) FROM t
+-- !query 1110 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)))):boolean>
+-- !query 1110 output
+false
+
+
+-- !query 1111
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as bigint) FROM t
+-- !query 1111 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)))):boolean>
+-- !query 1111 output
+false
+
+
+-- !query 1112
+SELECT cast(1 as decimal(20, 0)) <> cast(1 as bigint) FROM t
+-- !query 1112 schema
+struct<(NOT (CAST(1 AS DECIMAL(20,0)) = CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)))):boolean>
+-- !query 1112 output
+false
+
+
+-- !query 1113
+SELECT cast(1 as decimal(3, 0)) <> cast(1 as float) FROM t
+-- !query 1113 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) = CAST(CAST(1 AS FLOAT) AS DOUBLE))):boolean>
+-- !query 1113 output
+false
+
+
+-- !query 1114
+SELECT cast(1 as decimal(5, 0)) <> cast(1 as float) FROM t
+-- !query 1114 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) = CAST(CAST(1 AS FLOAT) AS DOUBLE))):boolean>
+-- !query 1114 output
+false
+
+
+-- !query 1115
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as float) FROM t
+-- !query 1115 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) = CAST(CAST(1 AS FLOAT) AS DOUBLE))):boolean>
+-- !query 1115 output
+false
+
+
+-- !query 1116
+SELECT cast(1 as decimal(20, 0)) <> cast(1 as float) FROM t
+-- !query 1116 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) = CAST(CAST(1 AS FLOAT) AS DOUBLE))):boolean>
+-- !query 1116 output
+false
+
+
+-- !query 1117
+SELECT cast(1 as decimal(3, 0)) <> cast(1 as double) FROM t
+-- !query 1117 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) = CAST(1 AS DOUBLE))):boolean>
+-- !query 1117 output
+false
+
+
+-- !query 1118
+SELECT cast(1 as decimal(5, 0)) <> cast(1 as double) FROM t
+-- !query 1118 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) = CAST(1 AS DOUBLE))):boolean>
+-- !query 1118 output
+false
+
+
+-- !query 1119
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as double) FROM t
+-- !query 1119 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) = CAST(1 AS DOUBLE))):boolean>
+-- !query 1119 output
+false
+
+
+-- !query 1120
+SELECT cast(1 as decimal(20, 0)) <> cast(1 as double) FROM t
+-- !query 1120 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) = CAST(1 AS DOUBLE))):boolean>
+-- !query 1120 output
+false
+
+
+-- !query 1121
+SELECT cast(1 as decimal(3, 0)) <> cast(1 as decimal(10, 0)) FROM t
+-- !query 1121 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)))):boolean>
+-- !query 1121 output
+false
+
+
+-- !query 1122
+SELECT cast(1 as decimal(5, 0)) <> cast(1 as decimal(10, 0)) FROM t
+-- !query 1122 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)))):boolean>
+-- !query 1122 output
+false
+
+
+-- !query 1123
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as decimal(10, 0)) FROM t
+-- !query 1123 schema
+struct<(NOT (CAST(1 AS DECIMAL(10,0)) = CAST(1 AS DECIMAL(10,0)))):boolean>
+-- !query 1123 output
+false
+
+
+-- !query 1124
+SELECT cast(1 as decimal(20, 0)) <> cast(1 as decimal(10, 0)) FROM t
+-- !query 1124 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)))):boolean>
+-- !query 1124 output
+false
+
+
+-- !query 1125
+SELECT cast(1 as decimal(3, 0)) <> cast(1 as string) FROM t
+-- !query 1125 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) = CAST(CAST(1 AS STRING) AS DOUBLE))):boolean>
+-- !query 1125 output
+false
+
+
+-- !query 1126
+SELECT cast(1 as decimal(5, 0)) <> cast(1 as string) FROM t
+-- !query 1126 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) = CAST(CAST(1 AS STRING) AS DOUBLE))):boolean>
+-- !query 1126 output
+false
+
+
+-- !query 1127
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as string) FROM t
+-- !query 1127 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) = CAST(CAST(1 AS STRING) AS DOUBLE))):boolean>
+-- !query 1127 output
+false
+
+
+-- !query 1128
+SELECT cast(1 as decimal(20, 0)) <> cast(1 as string) FROM t
+-- !query 1128 schema
+struct<(NOT (CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) = CAST(CAST(1 AS STRING) AS DOUBLE))):boolean>
+-- !query 1128 output
+false
+
+
+-- !query 1129
+SELECT cast(1 as decimal(3, 0)) <> cast('1' as binary) FROM t
+-- !query 1129 schema
+struct<>
+-- !query 1129 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) = CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) = CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7
+
+
+-- !query 1130
+SELECT cast(1 as decimal(5, 0)) <> cast('1' as binary) FROM t
+-- !query 1130 schema
+struct<>
+-- !query 1130 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) = CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) = CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7
+
+
+-- !query 1131
+SELECT cast(1 as decimal(10, 0)) <> cast('1' as binary) FROM t
+-- !query 1131 schema
+struct<>
+-- !query 1131 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) = CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) = CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7
+
+
+-- !query 1132
+SELECT cast(1 as decimal(20, 0)) <> cast('1' as binary) FROM t
+-- !query 1132 schema
+struct<>
+-- !query 1132 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) = CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) = CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7
+
+
+-- !query 1133
+SELECT cast(1 as decimal(3, 0)) <> cast(1 as boolean) FROM t
+-- !query 1133 schema
+struct<(NOT (CAST(1 AS DECIMAL(3,0)) = CAST(CAST(1 AS BOOLEAN) AS DECIMAL(3,0)))):boolean>
+-- !query 1133 output
+false
+
+
+-- !query 1134
+SELECT cast(1 as decimal(5, 0)) <> cast(1 as boolean) FROM t
+-- !query 1134 schema
+struct<(NOT (CAST(1 AS DECIMAL(5,0)) = CAST(CAST(1 AS BOOLEAN) AS DECIMAL(5,0)))):boolean>
+-- !query 1134 output
+false
+
+
+-- !query 1135
+SELECT cast(1 as decimal(10, 0)) <> cast(1 as boolean) FROM t
+-- !query 1135 schema
+struct<(NOT (CAST(1 AS DECIMAL(10,0)) = CAST(CAST(1 AS BOOLEAN) AS DECIMAL(10,0)))):boolean>
+-- !query 1135 output
+false
+
+
+-- !query 1136
+SELECT cast(1 as decimal(20, 0)) <> cast(1 as boolean) FROM t
+-- !query 1136 schema
+struct<(NOT (CAST(1 AS DECIMAL(20,0)) = CAST(CAST(1 AS BOOLEAN) AS DECIMAL(20,0)))):boolean>
+-- !query 1136 output
+false
+
+
+-- !query 1137
+SELECT cast(1 as decimal(3, 0)) <> cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 1137 schema
+struct<>
+-- !query 1137 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7
+
+
+-- !query 1138
+SELECT cast(1 as decimal(5, 0)) <> cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 1138 schema
+struct<>
+-- !query 1138 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7
+
+
+-- !query 1139
+SELECT cast(1 as decimal(10, 0)) <> cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 1139 schema
+struct<>
+-- !query 1139 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7
+
+
+-- !query 1140
+SELECT cast(1 as decimal(20, 0)) <> cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 1140 schema
+struct<>
+-- !query 1140 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7
+
+
+-- !query 1141
+SELECT cast(1 as decimal(3, 0)) <> cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 1141 schema
+struct<>
+-- !query 1141 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(3,0)) = CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) = CAST('2017-12-11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7
+
+
+-- !query 1142
+SELECT cast(1 as decimal(5, 0)) <> cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 1142 schema
+struct<>
+-- !query 1142 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(5,0)) = CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) = CAST('2017-12-11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7
+
+
+-- !query 1143
+SELECT cast(1 as decimal(10, 0)) <> cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 1143 schema
+struct<>
+-- !query 1143 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) = CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) = CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7
+
+
+-- !query 1144
+SELECT cast(1 as decimal(20, 0)) <> cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 1144 schema
+struct<>
+-- !query 1144 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(20,0)) = CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) = CAST('2017-12-11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7
diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/division.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/division.sql.out
new file mode 100644
index 0000000000000..017e0fea30e90
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/division.sql.out
@@ -0,0 +1,1242 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 145
+
+
+-- !query 0
+CREATE TEMPORARY VIEW t AS SELECT 1
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+SELECT cast(1 as tinyint) / cast(1 as tinyint) FROM t
+-- !query 1 schema
+struct<(CAST(CAST(1 AS TINYINT) AS DOUBLE) / CAST(CAST(1 AS TINYINT) AS DOUBLE)):double>
+-- !query 1 output
+1.0
+
+
+-- !query 2
+SELECT cast(1 as tinyint) / cast(1 as smallint) FROM t
+-- !query 2 schema
+struct<(CAST(CAST(1 AS TINYINT) AS DOUBLE) / CAST(CAST(1 AS SMALLINT) AS DOUBLE)):double>
+-- !query 2 output
+1.0
+
+
+-- !query 3
+SELECT cast(1 as tinyint) / cast(1 as int) FROM t
+-- !query 3 schema
+struct<(CAST(CAST(1 AS TINYINT) AS DOUBLE) / CAST(CAST(1 AS INT) AS DOUBLE)):double>
+-- !query 3 output
+1.0
+
+
+-- !query 4
+SELECT cast(1 as tinyint) / cast(1 as bigint) FROM t
+-- !query 4 schema
+struct<(CAST(CAST(1 AS TINYINT) AS DOUBLE) / CAST(CAST(1 AS BIGINT) AS DOUBLE)):double>
+-- !query 4 output
+1.0
+
+
+-- !query 5
+SELECT cast(1 as tinyint) / cast(1 as float) FROM t
+-- !query 5 schema
+struct<(CAST(CAST(1 AS TINYINT) AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 5 output
+1.0
+
+
+-- !query 6
+SELECT cast(1 as tinyint) / cast(1 as double) FROM t
+-- !query 6 schema
+struct<(CAST(CAST(1 AS TINYINT) AS DOUBLE) / CAST(CAST(1 AS DOUBLE) AS DOUBLE)):double>
+-- !query 6 output
+1.0
+
+
+-- !query 7
+SELECT cast(1 as tinyint) / cast(1 as decimal(10, 0)) FROM t
+-- !query 7 schema
+struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)) / CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(14,11)>
+-- !query 7 output
+1
+
+
+-- !query 8
+SELECT cast(1 as tinyint) / cast(1 as string) FROM t
+-- !query 8 schema
+struct<(CAST(CAST(1 AS TINYINT) AS DOUBLE) / CAST(CAST(CAST(1 AS STRING) AS DOUBLE) AS DOUBLE)):double>
+-- !query 8 output
+1.0
+
+
+-- !query 9
+SELECT cast(1 as tinyint) / cast('1' as binary) FROM t
+-- !query 9 schema
+struct<>
+-- !query 9 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS TINYINT) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS TINYINT) / CAST('1' AS BINARY))' (tinyint and binary).; line 1 pos 7
+
+
+-- !query 10
+SELECT cast(1 as tinyint) / cast(1 as boolean) FROM t
+-- !query 10 schema
+struct<>
+-- !query 10 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS TINYINT) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS TINYINT) / CAST(1 AS BOOLEAN))' (tinyint and boolean).; line 1 pos 7
+
+
+-- !query 11
+SELECT cast(1 as tinyint) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 11 schema
+struct<>
+-- !query 11 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS TINYINT) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS TINYINT) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (tinyint and timestamp).; line 1 pos 7
+
+
+-- !query 12
+SELECT cast(1 as tinyint) / cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 12 schema
+struct<>
+-- !query 12 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS TINYINT) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS TINYINT) / CAST('2017-12-11 09:30:00' AS DATE))' (tinyint and date).; line 1 pos 7
+
+
+-- !query 13
+SELECT cast(1 as smallint) / cast(1 as tinyint) FROM t
+-- !query 13 schema
+struct<(CAST(CAST(1 AS SMALLINT) AS DOUBLE) / CAST(CAST(1 AS TINYINT) AS DOUBLE)):double>
+-- !query 13 output
+1.0
+
+
+-- !query 14
+SELECT cast(1 as smallint) / cast(1 as smallint) FROM t
+-- !query 14 schema
+struct<(CAST(CAST(1 AS SMALLINT) AS DOUBLE) / CAST(CAST(1 AS SMALLINT) AS DOUBLE)):double>
+-- !query 14 output
+1.0
+
+
+-- !query 15
+SELECT cast(1 as smallint) / cast(1 as int) FROM t
+-- !query 15 schema
+struct<(CAST(CAST(1 AS SMALLINT) AS DOUBLE) / CAST(CAST(1 AS INT) AS DOUBLE)):double>
+-- !query 15 output
+1.0
+
+
+-- !query 16
+SELECT cast(1 as smallint) / cast(1 as bigint) FROM t
+-- !query 16 schema
+struct<(CAST(CAST(1 AS SMALLINT) AS DOUBLE) / CAST(CAST(1 AS BIGINT) AS DOUBLE)):double>
+-- !query 16 output
+1.0
+
+
+-- !query 17
+SELECT cast(1 as smallint) / cast(1 as float) FROM t
+-- !query 17 schema
+struct<(CAST(CAST(1 AS SMALLINT) AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 17 output
+1.0
+
+
+-- !query 18
+SELECT cast(1 as smallint) / cast(1 as double) FROM t
+-- !query 18 schema
+struct<(CAST(CAST(1 AS SMALLINT) AS DOUBLE) / CAST(CAST(1 AS DOUBLE) AS DOUBLE)):double>
+-- !query 18 output
+1.0
+
+
+-- !query 19
+SELECT cast(1 as smallint) / cast(1 as decimal(10, 0)) FROM t
+-- !query 19 schema
+struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)) / CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(16,11)>
+-- !query 19 output
+1
+
+
+-- !query 20
+SELECT cast(1 as smallint) / cast(1 as string) FROM t
+-- !query 20 schema
+struct<(CAST(CAST(1 AS SMALLINT) AS DOUBLE) / CAST(CAST(CAST(1 AS STRING) AS DOUBLE) AS DOUBLE)):double>
+-- !query 20 output
+1.0
+
+
+-- !query 21
+SELECT cast(1 as smallint) / cast('1' as binary) FROM t
+-- !query 21 schema
+struct<>
+-- !query 21 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS SMALLINT) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS SMALLINT) / CAST('1' AS BINARY))' (smallint and binary).; line 1 pos 7
+
+
+-- !query 22
+SELECT cast(1 as smallint) / cast(1 as boolean) FROM t
+-- !query 22 schema
+struct<>
+-- !query 22 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS SMALLINT) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS SMALLINT) / CAST(1 AS BOOLEAN))' (smallint and boolean).; line 1 pos 7
+
+
+-- !query 23
+SELECT cast(1 as smallint) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 23 schema
+struct<>
+-- !query 23 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS SMALLINT) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS SMALLINT) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (smallint and timestamp).; line 1 pos 7
+
+
+-- !query 24
+SELECT cast(1 as smallint) / cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 24 schema
+struct<>
+-- !query 24 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS SMALLINT) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS SMALLINT) / CAST('2017-12-11 09:30:00' AS DATE))' (smallint and date).; line 1 pos 7
+
+
+-- !query 25
+SELECT cast(1 as int) / cast(1 as tinyint) FROM t
+-- !query 25 schema
+struct<(CAST(CAST(1 AS INT) AS DOUBLE) / CAST(CAST(1 AS TINYINT) AS DOUBLE)):double>
+-- !query 25 output
+1.0
+
+
+-- !query 26
+SELECT cast(1 as int) / cast(1 as smallint) FROM t
+-- !query 26 schema
+struct<(CAST(CAST(1 AS INT) AS DOUBLE) / CAST(CAST(1 AS SMALLINT) AS DOUBLE)):double>
+-- !query 26 output
+1.0
+
+
+-- !query 27
+SELECT cast(1 as int) / cast(1 as int) FROM t
+-- !query 27 schema
+struct<(CAST(CAST(1 AS INT) AS DOUBLE) / CAST(CAST(1 AS INT) AS DOUBLE)):double>
+-- !query 27 output
+1.0
+
+
+-- !query 28
+SELECT cast(1 as int) / cast(1 as bigint) FROM t
+-- !query 28 schema
+struct<(CAST(CAST(1 AS INT) AS DOUBLE) / CAST(CAST(1 AS BIGINT) AS DOUBLE)):double>
+-- !query 28 output
+1.0
+
+
+-- !query 29
+SELECT cast(1 as int) / cast(1 as float) FROM t
+-- !query 29 schema
+struct<(CAST(CAST(1 AS INT) AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 29 output
+1.0
+
+
+-- !query 30
+SELECT cast(1 as int) / cast(1 as double) FROM t
+-- !query 30 schema
+struct<(CAST(CAST(1 AS INT) AS DOUBLE) / CAST(CAST(1 AS DOUBLE) AS DOUBLE)):double>
+-- !query 30 output
+1.0
+
+
+-- !query 31
+SELECT cast(1 as int) / cast(1 as decimal(10, 0)) FROM t
+-- !query 31 schema
+struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) / CAST(1 AS DECIMAL(10,0))):decimal(21,11)>
+-- !query 31 output
+1
+
+
+-- !query 32
+SELECT cast(1 as int) / cast(1 as string) FROM t
+-- !query 32 schema
+struct<(CAST(CAST(1 AS INT) AS DOUBLE) / CAST(CAST(CAST(1 AS STRING) AS DOUBLE) AS DOUBLE)):double>
+-- !query 32 output
+1.0
+
+
+-- !query 33
+SELECT cast(1 as int) / cast('1' as binary) FROM t
+-- !query 33 schema
+struct<>
+-- !query 33 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS INT) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS INT) / CAST('1' AS BINARY))' (int and binary).; line 1 pos 7
+
+
+-- !query 34
+SELECT cast(1 as int) / cast(1 as boolean) FROM t
+-- !query 34 schema
+struct<>
+-- !query 34 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS INT) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS INT) / CAST(1 AS BOOLEAN))' (int and boolean).; line 1 pos 7
+
+
+-- !query 35
+SELECT cast(1 as int) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 35 schema
+struct<>
+-- !query 35 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS INT) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS INT) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (int and timestamp).; line 1 pos 7
+
+
+-- !query 36
+SELECT cast(1 as int) / cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 36 schema
+struct<>
+-- !query 36 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS INT) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS INT) / CAST('2017-12-11 09:30:00' AS DATE))' (int and date).; line 1 pos 7
+
+
+-- !query 37
+SELECT cast(1 as bigint) / cast(1 as tinyint) FROM t
+-- !query 37 schema
+struct<(CAST(CAST(1 AS BIGINT) AS DOUBLE) / CAST(CAST(1 AS TINYINT) AS DOUBLE)):double>
+-- !query 37 output
+1.0
+
+
+-- !query 38
+SELECT cast(1 as bigint) / cast(1 as smallint) FROM t
+-- !query 38 schema
+struct<(CAST(CAST(1 AS BIGINT) AS DOUBLE) / CAST(CAST(1 AS SMALLINT) AS DOUBLE)):double>
+-- !query 38 output
+1.0
+
+
+-- !query 39
+SELECT cast(1 as bigint) / cast(1 as int) FROM t
+-- !query 39 schema
+struct<(CAST(CAST(1 AS BIGINT) AS DOUBLE) / CAST(CAST(1 AS INT) AS DOUBLE)):double>
+-- !query 39 output
+1.0
+
+
+-- !query 40
+SELECT cast(1 as bigint) / cast(1 as bigint) FROM t
+-- !query 40 schema
+struct<(CAST(CAST(1 AS BIGINT) AS DOUBLE) / CAST(CAST(1 AS BIGINT) AS DOUBLE)):double>
+-- !query 40 output
+1.0
+
+
+-- !query 41
+SELECT cast(1 as bigint) / cast(1 as float) FROM t
+-- !query 41 schema
+struct<(CAST(CAST(1 AS BIGINT) AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 41 output
+1.0
+
+
+-- !query 42
+SELECT cast(1 as bigint) / cast(1 as double) FROM t
+-- !query 42 schema
+struct<(CAST(CAST(1 AS BIGINT) AS DOUBLE) / CAST(CAST(1 AS DOUBLE) AS DOUBLE)):double>
+-- !query 42 output
+1.0
+
+
+-- !query 43
+SELECT cast(1 as bigint) / cast(1 as decimal(10, 0)) FROM t
+-- !query 43 schema
+struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) / CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):decimal(31,11)>
+-- !query 43 output
+1
+
+
+-- !query 44
+SELECT cast(1 as bigint) / cast(1 as string) FROM t
+-- !query 44 schema
+struct<(CAST(CAST(1 AS BIGINT) AS DOUBLE) / CAST(CAST(CAST(1 AS STRING) AS DOUBLE) AS DOUBLE)):double>
+-- !query 44 output
+1.0
+
+
+-- !query 45
+SELECT cast(1 as bigint) / cast('1' as binary) FROM t
+-- !query 45 schema
+struct<>
+-- !query 45 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS BIGINT) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS BIGINT) / CAST('1' AS BINARY))' (bigint and binary).; line 1 pos 7
+
+
+-- !query 46
+SELECT cast(1 as bigint) / cast(1 as boolean) FROM t
+-- !query 46 schema
+struct<>
+-- !query 46 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS BIGINT) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS BIGINT) / CAST(1 AS BOOLEAN))' (bigint and boolean).; line 1 pos 7
+
+
+-- !query 47
+SELECT cast(1 as bigint) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 47 schema
+struct<>
+-- !query 47 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS BIGINT) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS BIGINT) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (bigint and timestamp).; line 1 pos 7
+
+
+-- !query 48
+SELECT cast(1 as bigint) / cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 48 schema
+struct<>
+-- !query 48 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS BIGINT) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS BIGINT) / CAST('2017-12-11 09:30:00' AS DATE))' (bigint and date).; line 1 pos 7
+
+
+-- !query 49
+SELECT cast(1 as float) / cast(1 as tinyint) FROM t
+-- !query 49 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(1 AS TINYINT) AS DOUBLE)):double>
+-- !query 49 output
+1.0
+
+
+-- !query 50
+SELECT cast(1 as float) / cast(1 as smallint) FROM t
+-- !query 50 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(1 AS SMALLINT) AS DOUBLE)):double>
+-- !query 50 output
+1.0
+
+
+-- !query 51
+SELECT cast(1 as float) / cast(1 as int) FROM t
+-- !query 51 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(1 AS INT) AS DOUBLE)):double>
+-- !query 51 output
+1.0
+
+
+-- !query 52
+SELECT cast(1 as float) / cast(1 as bigint) FROM t
+-- !query 52 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(1 AS BIGINT) AS DOUBLE)):double>
+-- !query 52 output
+1.0
+
+
+-- !query 53
+SELECT cast(1 as float) / cast(1 as float) FROM t
+-- !query 53 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 53 output
+1.0
+
+
+-- !query 54
+SELECT cast(1 as float) / cast(1 as double) FROM t
+-- !query 54 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(1 AS DOUBLE) AS DOUBLE)):double>
+-- !query 54 output
+1.0
+
+
+-- !query 55
+SELECT cast(1 as float) / cast(1 as decimal(10, 0)) FROM t
+-- !query 55 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) AS DOUBLE)):double>
+-- !query 55 output
+1.0
+
+
+-- !query 56
+SELECT cast(1 as float) / cast(1 as string) FROM t
+-- !query 56 schema
+struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(CAST(1 AS STRING) AS DOUBLE) AS DOUBLE)):double>
+-- !query 56 output
+1.0
+
+
+-- !query 57
+SELECT cast(1 as float) / cast('1' as binary) FROM t
+-- !query 57 schema
+struct<>
+-- !query 57 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS FLOAT) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS FLOAT) / CAST('1' AS BINARY))' (float and binary).; line 1 pos 7
+
+
+-- !query 58
+SELECT cast(1 as float) / cast(1 as boolean) FROM t
+-- !query 58 schema
+struct<>
+-- !query 58 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS FLOAT) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS FLOAT) / CAST(1 AS BOOLEAN))' (float and boolean).; line 1 pos 7
+
+
+-- !query 59
+SELECT cast(1 as float) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 59 schema
+struct<>
+-- !query 59 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS FLOAT) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS FLOAT) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (float and timestamp).; line 1 pos 7
+
+
+-- !query 60
+SELECT cast(1 as float) / cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 60 schema
+struct<>
+-- !query 60 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS FLOAT) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS FLOAT) / CAST('2017-12-11 09:30:00' AS DATE))' (float and date).; line 1 pos 7
+
+
+-- !query 61
+SELECT cast(1 as double) / cast(1 as tinyint) FROM t
+-- !query 61 schema
+struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS TINYINT) AS DOUBLE)):double>
+-- !query 61 output
+1.0
+
+
+-- !query 62
+SELECT cast(1 as double) / cast(1 as smallint) FROM t
+-- !query 62 schema
+struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS SMALLINT) AS DOUBLE)):double>
+-- !query 62 output
+1.0
+
+
+-- !query 63
+SELECT cast(1 as double) / cast(1 as int) FROM t
+-- !query 63 schema
+struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS INT) AS DOUBLE)):double>
+-- !query 63 output
+1.0
+
+
+-- !query 64
+SELECT cast(1 as double) / cast(1 as bigint) FROM t
+-- !query 64 schema
+struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS BIGINT) AS DOUBLE)):double>
+-- !query 64 output
+1.0
+
+
+-- !query 65
+SELECT cast(1 as double) / cast(1 as float) FROM t
+-- !query 65 schema
+struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 65 output
+1.0
+
+
+-- !query 66
+SELECT cast(1 as double) / cast(1 as double) FROM t
+-- !query 66 schema
+struct<(CAST(1 AS DOUBLE) / CAST(1 AS DOUBLE)):double>
+-- !query 66 output
+1.0
+
+
+-- !query 67
+SELECT cast(1 as double) / cast(1 as decimal(10, 0)) FROM t
+-- !query 67 schema
+struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double>
+-- !query 67 output
+1.0
+
+
+-- !query 68
+SELECT cast(1 as double) / cast(1 as string) FROM t
+-- !query 68 schema
+struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS STRING) AS DOUBLE)):double>
+-- !query 68 output
+1.0
+
+
+-- !query 69
+SELECT cast(1 as double) / cast('1' as binary) FROM t
+-- !query 69 schema
+struct<>
+-- !query 69 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DOUBLE) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DOUBLE) / CAST('1' AS BINARY))' (double and binary).; line 1 pos 7
+
+
+-- !query 70
+SELECT cast(1 as double) / cast(1 as boolean) FROM t
+-- !query 70 schema
+struct<>
+-- !query 70 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DOUBLE) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DOUBLE) / CAST(1 AS BOOLEAN))' (double and boolean).; line 1 pos 7
+
+
+-- !query 71
+SELECT cast(1 as double) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 71 schema
+struct<>
+-- !query 71 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DOUBLE) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DOUBLE) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (double and timestamp).; line 1 pos 7
+
+
+-- !query 72
+SELECT cast(1 as double) / cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 72 schema
+struct<>
+-- !query 72 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DOUBLE) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DOUBLE) / CAST('2017-12-11 09:30:00' AS DATE))' (double and date).; line 1 pos 7
+
+
+-- !query 73
+SELECT cast(1 as decimal(10, 0)) / cast(1 as tinyint) FROM t
+-- !query 73 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) / CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0))):decimal(16,6)>
+-- !query 73 output
+1
+
+
+-- !query 74
+SELECT cast(1 as decimal(10, 0)) / cast(1 as smallint) FROM t
+-- !query 74 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) / CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0))):decimal(16,6)>
+-- !query 74 output
+1
+
+
+-- !query 75
+SELECT cast(1 as decimal(10, 0)) / cast(1 as int) FROM t
+-- !query 75 schema
+struct<(CAST(1 AS DECIMAL(10,0)) / CAST(CAST(1 AS INT) AS DECIMAL(10,0))):decimal(21,11)>
+-- !query 75 output
+1
+
+
+-- !query 76
+SELECT cast(1 as decimal(10, 0)) / cast(1 as bigint) FROM t
+-- !query 76 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) / CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(31,21)>
+-- !query 76 output
+1
+
+
+-- !query 77
+SELECT cast(1 as decimal(10, 0)) / cast(1 as float) FROM t
+-- !query 77 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 77 output
+1.0
+
+
+-- !query 78
+SELECT cast(1 as decimal(10, 0)) / cast(1 as double) FROM t
+-- !query 78 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) / CAST(1 AS DOUBLE)):double>
+-- !query 78 output
+1.0
+
+
+-- !query 79
+SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(10, 0)) FROM t
+-- !query 79 schema
+struct<(CAST(1 AS DECIMAL(10,0)) / CAST(1 AS DECIMAL(10,0))):decimal(21,11)>
+-- !query 79 output
+1
+
+
+-- !query 80
+SELECT cast(1 as decimal(10, 0)) / cast(1 as string) FROM t
+-- !query 80 schema
+struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) / CAST(CAST(1 AS STRING) AS DOUBLE)):double>
+-- !query 80 output
+1.0
+
+
+-- !query 81
+SELECT cast(1 as decimal(10, 0)) / cast('1' as binary) FROM t
+-- !query 81 schema
+struct<>
+-- !query 81 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) / CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7
+
+
+-- !query 82
+SELECT cast(1 as decimal(10, 0)) / cast(1 as boolean) FROM t
+-- !query 82 schema
+struct<>
+-- !query 82 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) / CAST(1 AS BOOLEAN))' (decimal(10,0) and boolean).; line 1 pos 7
+
+
+-- !query 83
+SELECT cast(1 as decimal(10, 0)) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 83 schema
+struct<>
+-- !query 83 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7
+
+
+-- !query 84
+SELECT cast(1 as decimal(10, 0)) / cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 84 schema
+struct<>
+-- !query 84 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS DECIMAL(10,0)) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) / CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7
+
+
+-- !query 85
+SELECT cast(1 as string) / cast(1 as tinyint) FROM t
+-- !query 85 schema
+struct<(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST(CAST(1 AS TINYINT) AS DOUBLE)):double>
+-- !query 85 output
+1.0
+
+
+-- !query 86
+SELECT cast(1 as string) / cast(1 as smallint) FROM t
+-- !query 86 schema
+struct<(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST(CAST(1 AS SMALLINT) AS DOUBLE)):double>
+-- !query 86 output
+1.0
+
+
+-- !query 87
+SELECT cast(1 as string) / cast(1 as int) FROM t
+-- !query 87 schema
+struct<(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST(CAST(1 AS INT) AS DOUBLE)):double>
+-- !query 87 output
+1.0
+
+
+-- !query 88
+SELECT cast(1 as string) / cast(1 as bigint) FROM t
+-- !query 88 schema
+struct<(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST(CAST(1 AS BIGINT) AS DOUBLE)):double>
+-- !query 88 output
+1.0
+
+
+-- !query 89
+SELECT cast(1 as string) / cast(1 as float) FROM t
+-- !query 89 schema
+struct<(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double>
+-- !query 89 output
+1.0
+
+
+-- !query 90
+SELECT cast(1 as string) / cast(1 as double) FROM t
+-- !query 90 schema
+struct<(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST(1 AS DOUBLE)):double>
+-- !query 90 output
+1.0
+
+
+-- !query 91
+SELECT cast(1 as string) / cast(1 as decimal(10, 0)) FROM t
+-- !query 91 schema
+struct<(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double>
+-- !query 91 output
+1.0
+
+
+-- !query 92
+SELECT cast(1 as string) / cast(1 as string) FROM t
+-- !query 92 schema
+struct<(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST(CAST(1 AS STRING) AS DOUBLE)):double>
+-- !query 92 output
+1.0
+
+
+-- !query 93
+SELECT cast(1 as string) / cast('1' as binary) FROM t
+-- !query 93 schema
+struct<>
+-- !query 93 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST('1' AS BINARY))' (double and binary).; line 1 pos 7
+
+
+-- !query 94
+SELECT cast(1 as string) / cast(1 as boolean) FROM t
+-- !query 94 schema
+struct<>
+-- !query 94 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST(1 AS BOOLEAN))' (double and boolean).; line 1 pos 7
+
+
+-- !query 95
+SELECT cast(1 as string) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 95 schema
+struct<>
+-- !query 95 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (double and timestamp).; line 1 pos 7
+
+
+-- !query 96
+SELECT cast(1 as string) / cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 96 schema
+struct<>
+-- !query 96 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST('2017-12-11 09:30:00' AS DATE))' (double and date).; line 1 pos 7
+
+
+-- !query 97
+SELECT cast('1' as binary) / cast(1 as tinyint) FROM t
+-- !query 97 schema
+struct<>
+-- !query 97 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS TINYINT))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS TINYINT))' (binary and tinyint).; line 1 pos 7
+
+
+-- !query 98
+SELECT cast('1' as binary) / cast(1 as smallint) FROM t
+-- !query 98 schema
+struct<>
+-- !query 98 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS SMALLINT))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS SMALLINT))' (binary and smallint).; line 1 pos 7
+
+
+-- !query 99
+SELECT cast('1' as binary) / cast(1 as int) FROM t
+-- !query 99 schema
+struct<>
+-- !query 99 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS INT))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS INT))' (binary and int).; line 1 pos 7
+
+
+-- !query 100
+SELECT cast('1' as binary) / cast(1 as bigint) FROM t
+-- !query 100 schema
+struct<>
+-- !query 100 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS BIGINT))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS BIGINT))' (binary and bigint).; line 1 pos 7
+
+
+-- !query 101
+SELECT cast('1' as binary) / cast(1 as float) FROM t
+-- !query 101 schema
+struct<>
+-- !query 101 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS FLOAT))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS FLOAT))' (binary and float).; line 1 pos 7
+
+
+-- !query 102
+SELECT cast('1' as binary) / cast(1 as double) FROM t
+-- !query 102 schema
+struct<>
+-- !query 102 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS DOUBLE))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS DOUBLE))' (binary and double).; line 1 pos 7
+
+
+-- !query 103
+SELECT cast('1' as binary) / cast(1 as decimal(10, 0)) FROM t
+-- !query 103 schema
+struct<>
+-- !query 103 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 104
+SELECT cast('1' as binary) / cast(1 as string) FROM t
+-- !query 104 schema
+struct<>
+-- !query 104 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) / CAST(CAST(1 AS STRING) AS DOUBLE))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(CAST(1 AS STRING) AS DOUBLE))' (binary and double).; line 1 pos 7
+
+
+-- !query 105
+SELECT cast('1' as binary) / cast('1' as binary) FROM t
+-- !query 105 schema
+struct<>
+-- !query 105 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) / CAST('1' AS BINARY))' due to data type mismatch: '(CAST('1' AS BINARY) / CAST('1' AS BINARY))' requires (double or decimal) type, not binary; line 1 pos 7
+
+
+-- !query 106
+SELECT cast('1' as binary) / cast(1 as boolean) FROM t
+-- !query 106 schema
+struct<>
+-- !query 106 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS BOOLEAN))' (binary and boolean).; line 1 pos 7
+
+
+-- !query 107
+SELECT cast('1' as binary) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 107 schema
+struct<>
+-- !query 107 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (binary and timestamp).; line 1 pos 7
+
+
+-- !query 108
+SELECT cast('1' as binary) / cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 108 schema
+struct<>
+-- !query 108 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('1' AS BINARY) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST('2017-12-11 09:30:00' AS DATE))' (binary and date).; line 1 pos 7
+
+
+-- !query 109
+SELECT cast(1 as boolean) / cast(1 as tinyint) FROM t
+-- !query 109 schema
+struct<>
+-- !query 109 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS BOOLEAN) / CAST(1 AS TINYINT))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) / CAST(1 AS TINYINT))' (boolean and tinyint).; line 1 pos 7
+
+
+-- !query 110
+SELECT cast(1 as boolean) / cast(1 as smallint) FROM t
+-- !query 110 schema
+struct<>
+-- !query 110 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS BOOLEAN) / CAST(1 AS SMALLINT))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) / CAST(1 AS SMALLINT))' (boolean and smallint).; line 1 pos 7
+
+
+-- !query 111
+SELECT cast(1 as boolean) / cast(1 as int) FROM t
+-- !query 111 schema
+struct<>
+-- !query 111 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS BOOLEAN) / CAST(1 AS INT))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) / CAST(1 AS INT))' (boolean and int).; line 1 pos 7
+
+
+-- !query 112
+SELECT cast(1 as boolean) / cast(1 as bigint) FROM t
+-- !query 112 schema
+struct<>
+-- !query 112 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS BOOLEAN) / CAST(1 AS BIGINT))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) / CAST(1 AS BIGINT))' (boolean and bigint).; line 1 pos 7
+
+
+-- !query 113
+SELECT cast(1 as boolean) / cast(1 as float) FROM t
+-- !query 113 schema
+struct<>
+-- !query 113 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS BOOLEAN) / CAST(1 AS FLOAT))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) / CAST(1 AS FLOAT))' (boolean and float).; line 1 pos 7
+
+
+-- !query 114
+SELECT cast(1 as boolean) / cast(1 as double) FROM t
+-- !query 114 schema
+struct<>
+-- !query 114 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS BOOLEAN) / CAST(1 AS DOUBLE))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) / CAST(1 AS DOUBLE))' (boolean and double).; line 1 pos 7
+
+
+-- !query 115
+SELECT cast(1 as boolean) / cast(1 as decimal(10, 0)) FROM t
+-- !query 115 schema
+struct<>
+-- !query 115 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS BOOLEAN) / CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) / CAST(1 AS DECIMAL(10,0)))' (boolean and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 116
+SELECT cast(1 as boolean) / cast(1 as string) FROM t
+-- !query 116 schema
+struct<>
+-- !query 116 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS BOOLEAN) / CAST(CAST(1 AS STRING) AS DOUBLE))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) / CAST(CAST(1 AS STRING) AS DOUBLE))' (boolean and double).; line 1 pos 7
+
+
+-- !query 117
+SELECT cast(1 as boolean) / cast('1' as binary) FROM t
+-- !query 117 schema
+struct<>
+-- !query 117 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS BOOLEAN) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) / CAST('1' AS BINARY))' (boolean and binary).; line 1 pos 7
+
+
+-- !query 118
+SELECT cast(1 as boolean) / cast(1 as boolean) FROM t
+-- !query 118 schema
+struct<>
+-- !query 118 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS BOOLEAN) / CAST(1 AS BOOLEAN))' due to data type mismatch: '(CAST(1 AS BOOLEAN) / CAST(1 AS BOOLEAN))' requires (double or decimal) type, not boolean; line 1 pos 7
+
+
+-- !query 119
+SELECT cast(1 as boolean) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 119 schema
+struct<>
+-- !query 119 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS BOOLEAN) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (boolean and timestamp).; line 1 pos 7
+
+
+-- !query 120
+SELECT cast(1 as boolean) / cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 120 schema
+struct<>
+-- !query 120 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST(1 AS BOOLEAN) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) / CAST('2017-12-11 09:30:00' AS DATE))' (boolean and date).; line 1 pos 7
+
+
+-- !query 121
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as tinyint) FROM t
+-- !query 121 schema
+struct<>
+-- !query 121 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS TINYINT))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS TINYINT))' (timestamp and tinyint).; line 1 pos 7
+
+
+-- !query 122
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as smallint) FROM t
+-- !query 122 schema
+struct<>
+-- !query 122 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS SMALLINT))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS SMALLINT))' (timestamp and smallint).; line 1 pos 7
+
+
+-- !query 123
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as int) FROM t
+-- !query 123 schema
+struct<>
+-- !query 123 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS INT))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS INT))' (timestamp and int).; line 1 pos 7
+
+
+-- !query 124
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as bigint) FROM t
+-- !query 124 schema
+struct<>
+-- !query 124 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS BIGINT))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS BIGINT))' (timestamp and bigint).; line 1 pos 7
+
+
+-- !query 125
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as float) FROM t
+-- !query 125 schema
+struct<>
+-- !query 125 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS FLOAT))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS FLOAT))' (timestamp and float).; line 1 pos 7
+
+
+-- !query 126
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as double) FROM t
+-- !query 126 schema
+struct<>
+-- !query 126 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DOUBLE))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DOUBLE))' (timestamp and double).; line 1 pos 7
+
+
+-- !query 127
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as decimal(10, 0)) FROM t
+-- !query 127 schema
+struct<>
+-- !query 127 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 128
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as string) FROM t
+-- !query 128 schema
+struct<>
+-- !query 128 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(CAST(1 AS STRING) AS DOUBLE))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(CAST(1 AS STRING) AS DOUBLE))' (timestamp and double).; line 1 pos 7
+
+
+-- !query 129
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast('1' as binary) FROM t
+-- !query 129 schema
+struct<>
+-- !query 129 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST('1' AS BINARY))' (timestamp and binary).; line 1 pos 7
+
+
+-- !query 130
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as boolean) FROM t
+-- !query 130 schema
+struct<>
+-- !query 130 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS BOOLEAN))' (timestamp and boolean).; line 1 pos 7
+
+
+-- !query 131
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 131 schema
+struct<>
+-- !query 131 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' requires (double or decimal) type, not timestamp; line 1 pos 7
+
+
+-- !query 132
+SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 132 schema
+struct<>
+-- !query 132 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST('2017-12-11 09:30:00' AS DATE))' (timestamp and date).; line 1 pos 7
+
+
+-- !query 133
+SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as tinyint) FROM t
+-- !query 133 schema
+struct<>
+-- !query 133 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS TINYINT))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS TINYINT))' (date and tinyint).; line 1 pos 7
+
+
+-- !query 134
+SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as smallint) FROM t
+-- !query 134 schema
+struct<>
+-- !query 134 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS SMALLINT))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS SMALLINT))' (date and smallint).; line 1 pos 7
+
+
+-- !query 135
+SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as int) FROM t
+-- !query 135 schema
+struct<>
+-- !query 135 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS INT))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS INT))' (date and int).; line 1 pos 7
+
+
+-- !query 136
+SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as bigint) FROM t
+-- !query 136 schema
+struct<>
+-- !query 136 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS BIGINT))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS BIGINT))' (date and bigint).; line 1 pos 7
+
+
+-- !query 137
+SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as float) FROM t
+-- !query 137 schema
+struct<>
+-- !query 137 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS FLOAT))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS FLOAT))' (date and float).; line 1 pos 7
+
+
+-- !query 138
+SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as double) FROM t
+-- !query 138 schema
+struct<>
+-- !query 138 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS DOUBLE))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS DOUBLE))' (date and double).; line 1 pos 7
+
+
+-- !query 139
+SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as decimal(10, 0)) FROM t
+-- !query 139 schema
+struct<>
+-- !query 139 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 140
+SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as string) FROM t
+-- !query 140 schema
+struct<>
+-- !query 140 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(CAST(1 AS STRING) AS DOUBLE))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(CAST(1 AS STRING) AS DOUBLE))' (date and double).; line 1 pos 7
+
+
+-- !query 141
+SELECT cast('2017-12-11 09:30:00' as date) / cast('1' as binary) FROM t
+-- !query 141 schema
+struct<>
+-- !query 141 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) / CAST('1' AS BINARY))' (date and binary).; line 1 pos 7
+
+
+-- !query 142
+SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as boolean) FROM t
+-- !query 142 schema
+struct<>
+-- !query 142 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS BOOLEAN))' (date and boolean).; line 1 pos 7
+
+
+-- !query 143
+SELECT cast('2017-12-11 09:30:00' as date) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t
+-- !query 143 schema
+struct<>
+-- !query 143 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (date and timestamp).; line 1 pos 7
+
+
+-- !query 144
+SELECT cast('2017-12-11 09:30:00' as date) / cast('2017-12-11 09:30:00' as date) FROM t
+-- !query 144 schema
+struct<>
+-- !query 144 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: '(CAST('2017-12-11 09:30:00' AS DATE) / CAST('2017-12-11 09:30:00' AS DATE))' requires (double or decimal) type, not date; line 1 pos 7
diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out
new file mode 100644
index 0000000000000..b62e1b6826045
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out
@@ -0,0 +1,115 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 6
+
+
+-- !query 0
+SELECT elt(2, col1, col2, col3, col4, col5) col
+FROM (
+ SELECT
+ 'prefix_' col1,
+ id col2,
+ string(id + 1) col3,
+ encode(string(id + 2), 'utf-8') col4,
+ CAST(id AS DOUBLE) col5
+ FROM range(10)
+)
+-- !query 0 schema
+struct
+-- !query 0 output
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+
+
+-- !query 1
+SELECT elt(3, col1, col2, col3, col4) col
+FROM (
+ SELECT
+ string(id) col1,
+ string(id + 1) col2,
+ encode(string(id + 2), 'utf-8') col3,
+ encode(string(id + 3), 'utf-8') col4
+ FROM range(10)
+)
+-- !query 1 schema
+struct
+-- !query 1 output
+10
+11
+2
+3
+4
+5
+6
+7
+8
+9
+
+
+-- !query 2
+set spark.sql.function.eltOutputAsString=true
+-- !query 2 schema
+struct
+-- !query 2 output
+spark.sql.function.eltOutputAsString true
+
+
+-- !query 3
+SELECT elt(1, col1, col2) col
+FROM (
+ SELECT
+ encode(string(id), 'utf-8') col1,
+ encode(string(id + 1), 'utf-8') col2
+ FROM range(10)
+)
+-- !query 3 schema
+struct
+-- !query 3 output
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+
+
+-- !query 4
+set spark.sql.function.eltOutputAsString=false
+-- !query 4 schema
+struct
+-- !query 4 output
+spark.sql.function.eltOutputAsString false
+
+
+-- !query 5
+SELECT elt(2, col1, col2) col
+FROM (
+ SELECT
+ encode(string(id), 'utf-8') col1,
+ encode(string(id + 1), 'utf-8') col2
+ FROM range(10)
+)
+-- !query 5 schema
+struct
+-- !query 5 output
+1
+10
+2
+3
+4
+5
+6
+7
+8
+9
diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/ifCoercion.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/ifCoercion.sql.out
new file mode 100644
index 0000000000000..7097027872707
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/ifCoercion.sql.out
@@ -0,0 +1,1232 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 145
+
+
+-- !query 0
+CREATE TEMPORARY VIEW t AS SELECT 1
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+SELECT IF(true, cast(1 as tinyint), cast(2 as tinyint)) FROM t
+-- !query 1 schema
+struct<(IF(true, CAST(1 AS TINYINT), CAST(2 AS TINYINT))):tinyint>
+-- !query 1 output
+1
+
+
+-- !query 2
+SELECT IF(true, cast(1 as tinyint), cast(2 as smallint)) FROM t
+-- !query 2 schema
+struct<(IF(true, CAST(CAST(1 AS TINYINT) AS SMALLINT), CAST(2 AS SMALLINT))):smallint>
+-- !query 2 output
+1
+
+
+-- !query 3
+SELECT IF(true, cast(1 as tinyint), cast(2 as int)) FROM t
+-- !query 3 schema
+struct<(IF(true, CAST(CAST(1 AS TINYINT) AS INT), CAST(2 AS INT))):int>
+-- !query 3 output
+1
+
+
+-- !query 4
+SELECT IF(true, cast(1 as tinyint), cast(2 as bigint)) FROM t
+-- !query 4 schema
+struct<(IF(true, CAST(CAST(1 AS TINYINT) AS BIGINT), CAST(2 AS BIGINT))):bigint>
+-- !query 4 output
+1
+
+
+-- !query 5
+SELECT IF(true, cast(1 as tinyint), cast(2 as float)) FROM t
+-- !query 5 schema
+struct<(IF(true, CAST(CAST(1 AS TINYINT) AS FLOAT), CAST(2 AS FLOAT))):float>
+-- !query 5 output
+1.0
+
+
+-- !query 6
+SELECT IF(true, cast(1 as tinyint), cast(2 as double)) FROM t
+-- !query 6 schema
+struct<(IF(true, CAST(CAST(1 AS TINYINT) AS DOUBLE), CAST(2 AS DOUBLE))):double>
+-- !query 6 output
+1.0
+
+
+-- !query 7
+SELECT IF(true, cast(1 as tinyint), cast(2 as decimal(10, 0))) FROM t
+-- !query 7 schema
+struct<(IF(true, CAST(CAST(1 AS TINYINT) AS DECIMAL(10,0)), CAST(2 AS DECIMAL(10,0)))):decimal(10,0)>
+-- !query 7 output
+1
+
+
+-- !query 8
+SELECT IF(true, cast(1 as tinyint), cast(2 as string)) FROM t
+-- !query 8 schema
+struct<(IF(true, CAST(CAST(1 AS TINYINT) AS STRING), CAST(2 AS STRING))):string>
+-- !query 8 output
+1
+
+
+-- !query 9
+SELECT IF(true, cast(1 as tinyint), cast('2' as binary)) FROM t
+-- !query 9 schema
+struct<>
+-- !query 9 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS TINYINT), CAST('2' AS BINARY)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS TINYINT), CAST('2' AS BINARY)))' (tinyint and binary).; line 1 pos 7
+
+
+-- !query 10
+SELECT IF(true, cast(1 as tinyint), cast(2 as boolean)) FROM t
+-- !query 10 schema
+struct<>
+-- !query 10 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS TINYINT), CAST(2 AS BOOLEAN)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS TINYINT), CAST(2 AS BOOLEAN)))' (tinyint and boolean).; line 1 pos 7
+
+
+-- !query 11
+SELECT IF(true, cast(1 as tinyint), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t
+-- !query 11 schema
+struct<>
+-- !query 11 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS TINYINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS TINYINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' (tinyint and timestamp).; line 1 pos 7
+
+
+-- !query 12
+SELECT IF(true, cast(1 as tinyint), cast('2017-12-11 09:30:00' as date)) FROM t
+-- !query 12 schema
+struct<>
+-- !query 12 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS TINYINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS TINYINT), CAST('2017-12-11 09:30:00' AS DATE)))' (tinyint and date).; line 1 pos 7
+
+
+-- !query 13
+SELECT IF(true, cast(1 as smallint), cast(2 as tinyint)) FROM t
+-- !query 13 schema
+struct<(IF(true, CAST(1 AS SMALLINT), CAST(CAST(2 AS TINYINT) AS SMALLINT))):smallint>
+-- !query 13 output
+1
+
+
+-- !query 14
+SELECT IF(true, cast(1 as smallint), cast(2 as smallint)) FROM t
+-- !query 14 schema
+struct<(IF(true, CAST(1 AS SMALLINT), CAST(2 AS SMALLINT))):smallint>
+-- !query 14 output
+1
+
+
+-- !query 15
+SELECT IF(true, cast(1 as smallint), cast(2 as int)) FROM t
+-- !query 15 schema
+struct<(IF(true, CAST(CAST(1 AS SMALLINT) AS INT), CAST(2 AS INT))):int>
+-- !query 15 output
+1
+
+
+-- !query 16
+SELECT IF(true, cast(1 as smallint), cast(2 as bigint)) FROM t
+-- !query 16 schema
+struct<(IF(true, CAST(CAST(1 AS SMALLINT) AS BIGINT), CAST(2 AS BIGINT))):bigint>
+-- !query 16 output
+1
+
+
+-- !query 17
+SELECT IF(true, cast(1 as smallint), cast(2 as float)) FROM t
+-- !query 17 schema
+struct<(IF(true, CAST(CAST(1 AS SMALLINT) AS FLOAT), CAST(2 AS FLOAT))):float>
+-- !query 17 output
+1.0
+
+
+-- !query 18
+SELECT IF(true, cast(1 as smallint), cast(2 as double)) FROM t
+-- !query 18 schema
+struct<(IF(true, CAST(CAST(1 AS SMALLINT) AS DOUBLE), CAST(2 AS DOUBLE))):double>
+-- !query 18 output
+1.0
+
+
+-- !query 19
+SELECT IF(true, cast(1 as smallint), cast(2 as decimal(10, 0))) FROM t
+-- !query 19 schema
+struct<(IF(true, CAST(CAST(1 AS SMALLINT) AS DECIMAL(10,0)), CAST(2 AS DECIMAL(10,0)))):decimal(10,0)>
+-- !query 19 output
+1
+
+
+-- !query 20
+SELECT IF(true, cast(1 as smallint), cast(2 as string)) FROM t
+-- !query 20 schema
+struct<(IF(true, CAST(CAST(1 AS SMALLINT) AS STRING), CAST(2 AS STRING))):string>
+-- !query 20 output
+1
+
+
+-- !query 21
+SELECT IF(true, cast(1 as smallint), cast('2' as binary)) FROM t
+-- !query 21 schema
+struct<>
+-- !query 21 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS SMALLINT), CAST('2' AS BINARY)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS SMALLINT), CAST('2' AS BINARY)))' (smallint and binary).; line 1 pos 7
+
+
+-- !query 22
+SELECT IF(true, cast(1 as smallint), cast(2 as boolean)) FROM t
+-- !query 22 schema
+struct<>
+-- !query 22 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS SMALLINT), CAST(2 AS BOOLEAN)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS SMALLINT), CAST(2 AS BOOLEAN)))' (smallint and boolean).; line 1 pos 7
+
+
+-- !query 23
+SELECT IF(true, cast(1 as smallint), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t
+-- !query 23 schema
+struct<>
+-- !query 23 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS SMALLINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS SMALLINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' (smallint and timestamp).; line 1 pos 7
+
+
+-- !query 24
+SELECT IF(true, cast(1 as smallint), cast('2017-12-11 09:30:00' as date)) FROM t
+-- !query 24 schema
+struct<>
+-- !query 24 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS SMALLINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS SMALLINT), CAST('2017-12-11 09:30:00' AS DATE)))' (smallint and date).; line 1 pos 7
+
+
+-- !query 25
+SELECT IF(true, cast(1 as int), cast(2 as tinyint)) FROM t
+-- !query 25 schema
+struct<(IF(true, CAST(1 AS INT), CAST(CAST(2 AS TINYINT) AS INT))):int>
+-- !query 25 output
+1
+
+
+-- !query 26
+SELECT IF(true, cast(1 as int), cast(2 as smallint)) FROM t
+-- !query 26 schema
+struct<(IF(true, CAST(1 AS INT), CAST(CAST(2 AS SMALLINT) AS INT))):int>
+-- !query 26 output
+1
+
+
+-- !query 27
+SELECT IF(true, cast(1 as int), cast(2 as int)) FROM t
+-- !query 27 schema
+struct<(IF(true, CAST(1 AS INT), CAST(2 AS INT))):int>
+-- !query 27 output
+1
+
+
+-- !query 28
+SELECT IF(true, cast(1 as int), cast(2 as bigint)) FROM t
+-- !query 28 schema
+struct<(IF(true, CAST(CAST(1 AS INT) AS BIGINT), CAST(2 AS BIGINT))):bigint>
+-- !query 28 output
+1
+
+
+-- !query 29
+SELECT IF(true, cast(1 as int), cast(2 as float)) FROM t
+-- !query 29 schema
+struct<(IF(true, CAST(CAST(1 AS INT) AS FLOAT), CAST(2 AS FLOAT))):float>
+-- !query 29 output
+1.0
+
+
+-- !query 30
+SELECT IF(true, cast(1 as int), cast(2 as double)) FROM t
+-- !query 30 schema
+struct<(IF(true, CAST(CAST(1 AS INT) AS DOUBLE), CAST(2 AS DOUBLE))):double>
+-- !query 30 output
+1.0
+
+
+-- !query 31
+SELECT IF(true, cast(1 as int), cast(2 as decimal(10, 0))) FROM t
+-- !query 31 schema
+struct<(IF(true, CAST(CAST(1 AS INT) AS DECIMAL(10,0)), CAST(2 AS DECIMAL(10,0)))):decimal(10,0)>
+-- !query 31 output
+1
+
+
+-- !query 32
+SELECT IF(true, cast(1 as int), cast(2 as string)) FROM t
+-- !query 32 schema
+struct<(IF(true, CAST(CAST(1 AS INT) AS STRING), CAST(2 AS STRING))):string>
+-- !query 32 output
+1
+
+
+-- !query 33
+SELECT IF(true, cast(1 as int), cast('2' as binary)) FROM t
+-- !query 33 schema
+struct<>
+-- !query 33 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS INT), CAST('2' AS BINARY)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS INT), CAST('2' AS BINARY)))' (int and binary).; line 1 pos 7
+
+
+-- !query 34
+SELECT IF(true, cast(1 as int), cast(2 as boolean)) FROM t
+-- !query 34 schema
+struct<>
+-- !query 34 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS INT), CAST(2 AS BOOLEAN)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS INT), CAST(2 AS BOOLEAN)))' (int and boolean).; line 1 pos 7
+
+
+-- !query 35
+SELECT IF(true, cast(1 as int), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t
+-- !query 35 schema
+struct<>
+-- !query 35 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS INT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS INT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' (int and timestamp).; line 1 pos 7
+
+
+-- !query 36
+SELECT IF(true, cast(1 as int), cast('2017-12-11 09:30:00' as date)) FROM t
+-- !query 36 schema
+struct<>
+-- !query 36 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS INT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS INT), CAST('2017-12-11 09:30:00' AS DATE)))' (int and date).; line 1 pos 7
+
+
+-- !query 37
+SELECT IF(true, cast(1 as bigint), cast(2 as tinyint)) FROM t
+-- !query 37 schema
+struct<(IF(true, CAST(1 AS BIGINT), CAST(CAST(2 AS TINYINT) AS BIGINT))):bigint>
+-- !query 37 output
+1
+
+
+-- !query 38
+SELECT IF(true, cast(1 as bigint), cast(2 as smallint)) FROM t
+-- !query 38 schema
+struct<(IF(true, CAST(1 AS BIGINT), CAST(CAST(2 AS SMALLINT) AS BIGINT))):bigint>
+-- !query 38 output
+1
+
+
+-- !query 39
+SELECT IF(true, cast(1 as bigint), cast(2 as int)) FROM t
+-- !query 39 schema
+struct<(IF(true, CAST(1 AS BIGINT), CAST(CAST(2 AS INT) AS BIGINT))):bigint>
+-- !query 39 output
+1
+
+
+-- !query 40
+SELECT IF(true, cast(1 as bigint), cast(2 as bigint)) FROM t
+-- !query 40 schema
+struct<(IF(true, CAST(1 AS BIGINT), CAST(2 AS BIGINT))):bigint>
+-- !query 40 output
+1
+
+
+-- !query 41
+SELECT IF(true, cast(1 as bigint), cast(2 as float)) FROM t
+-- !query 41 schema
+struct<(IF(true, CAST(CAST(1 AS BIGINT) AS FLOAT), CAST(2 AS FLOAT))):float>
+-- !query 41 output
+1.0
+
+
+-- !query 42
+SELECT IF(true, cast(1 as bigint), cast(2 as double)) FROM t
+-- !query 42 schema
+struct<(IF(true, CAST(CAST(1 AS BIGINT) AS DOUBLE), CAST(2 AS DOUBLE))):double>
+-- !query 42 output
+1.0
+
+
+-- !query 43
+SELECT IF(true, cast(1 as bigint), cast(2 as decimal(10, 0))) FROM t
+-- !query 43 schema
+struct<(IF(true, CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)), CAST(CAST(2 AS DECIMAL(10,0)) AS DECIMAL(20,0)))):decimal(20,0)>
+-- !query 43 output
+1
+
+
+-- !query 44
+SELECT IF(true, cast(1 as bigint), cast(2 as string)) FROM t
+-- !query 44 schema
+struct<(IF(true, CAST(CAST(1 AS BIGINT) AS STRING), CAST(2 AS STRING))):string>
+-- !query 44 output
+1
+
+
+-- !query 45
+SELECT IF(true, cast(1 as bigint), cast('2' as binary)) FROM t
+-- !query 45 schema
+struct<>
+-- !query 45 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS BIGINT), CAST('2' AS BINARY)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BIGINT), CAST('2' AS BINARY)))' (bigint and binary).; line 1 pos 7
+
+
+-- !query 46
+SELECT IF(true, cast(1 as bigint), cast(2 as boolean)) FROM t
+-- !query 46 schema
+struct<>
+-- !query 46 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS BIGINT), CAST(2 AS BOOLEAN)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BIGINT), CAST(2 AS BOOLEAN)))' (bigint and boolean).; line 1 pos 7
+
+
+-- !query 47
+SELECT IF(true, cast(1 as bigint), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t
+-- !query 47 schema
+struct<>
+-- !query 47 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS BIGINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BIGINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' (bigint and timestamp).; line 1 pos 7
+
+
+-- !query 48
+SELECT IF(true, cast(1 as bigint), cast('2017-12-11 09:30:00' as date)) FROM t
+-- !query 48 schema
+struct<>
+-- !query 48 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS BIGINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BIGINT), CAST('2017-12-11 09:30:00' AS DATE)))' (bigint and date).; line 1 pos 7
+
+
+-- !query 49
+SELECT IF(true, cast(1 as float), cast(2 as tinyint)) FROM t
+-- !query 49 schema
+struct<(IF(true, CAST(1 AS FLOAT), CAST(CAST(2 AS TINYINT) AS FLOAT))):float>
+-- !query 49 output
+1.0
+
+
+-- !query 50
+SELECT IF(true, cast(1 as float), cast(2 as smallint)) FROM t
+-- !query 50 schema
+struct<(IF(true, CAST(1 AS FLOAT), CAST(CAST(2 AS SMALLINT) AS FLOAT))):float>
+-- !query 50 output
+1.0
+
+
+-- !query 51
+SELECT IF(true, cast(1 as float), cast(2 as int)) FROM t
+-- !query 51 schema
+struct<(IF(true, CAST(1 AS FLOAT), CAST(CAST(2 AS INT) AS FLOAT))):float>
+-- !query 51 output
+1.0
+
+
+-- !query 52
+SELECT IF(true, cast(1 as float), cast(2 as bigint)) FROM t
+-- !query 52 schema
+struct<(IF(true, CAST(1 AS FLOAT), CAST(CAST(2 AS BIGINT) AS FLOAT))):float>
+-- !query 52 output
+1.0
+
+
+-- !query 53
+SELECT IF(true, cast(1 as float), cast(2 as float)) FROM t
+-- !query 53 schema
+struct<(IF(true, CAST(1 AS FLOAT), CAST(2 AS FLOAT))):float>
+-- !query 53 output
+1.0
+
+
+-- !query 54
+SELECT IF(true, cast(1 as float), cast(2 as double)) FROM t
+-- !query 54 schema
+struct<(IF(true, CAST(CAST(1 AS FLOAT) AS DOUBLE), CAST(2 AS DOUBLE))):double>
+-- !query 54 output
+1.0
+
+
+-- !query 55
+SELECT IF(true, cast(1 as float), cast(2 as decimal(10, 0))) FROM t
+-- !query 55 schema
+struct<(IF(true, CAST(CAST(1 AS FLOAT) AS DOUBLE), CAST(CAST(2 AS DECIMAL(10,0)) AS DOUBLE))):double>
+-- !query 55 output
+1.0
+
+
+-- !query 56
+SELECT IF(true, cast(1 as float), cast(2 as string)) FROM t
+-- !query 56 schema
+struct<(IF(true, CAST(CAST(1 AS FLOAT) AS STRING), CAST(2 AS STRING))):string>
+-- !query 56 output
+1.0
+
+
+-- !query 57
+SELECT IF(true, cast(1 as float), cast('2' as binary)) FROM t
+-- !query 57 schema
+struct<>
+-- !query 57 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS FLOAT), CAST('2' AS BINARY)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS FLOAT), CAST('2' AS BINARY)))' (float and binary).; line 1 pos 7
+
+
+-- !query 58
+SELECT IF(true, cast(1 as float), cast(2 as boolean)) FROM t
+-- !query 58 schema
+struct<>
+-- !query 58 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS FLOAT), CAST(2 AS BOOLEAN)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS FLOAT), CAST(2 AS BOOLEAN)))' (float and boolean).; line 1 pos 7
+
+
+-- !query 59
+SELECT IF(true, cast(1 as float), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t
+-- !query 59 schema
+struct<>
+-- !query 59 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS FLOAT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS FLOAT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' (float and timestamp).; line 1 pos 7
+
+
+-- !query 60
+SELECT IF(true, cast(1 as float), cast('2017-12-11 09:30:00' as date)) FROM t
+-- !query 60 schema
+struct<>
+-- !query 60 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS FLOAT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS FLOAT), CAST('2017-12-11 09:30:00' AS DATE)))' (float and date).; line 1 pos 7
+
+
+-- !query 61
+SELECT IF(true, cast(1 as double), cast(2 as tinyint)) FROM t
+-- !query 61 schema
+struct<(IF(true, CAST(1 AS DOUBLE), CAST(CAST(2 AS TINYINT) AS DOUBLE))):double>
+-- !query 61 output
+1.0
+
+
+-- !query 62
+SELECT IF(true, cast(1 as double), cast(2 as smallint)) FROM t
+-- !query 62 schema
+struct<(IF(true, CAST(1 AS DOUBLE), CAST(CAST(2 AS SMALLINT) AS DOUBLE))):double>
+-- !query 62 output
+1.0
+
+
+-- !query 63
+SELECT IF(true, cast(1 as double), cast(2 as int)) FROM t
+-- !query 63 schema
+struct<(IF(true, CAST(1 AS DOUBLE), CAST(CAST(2 AS INT) AS DOUBLE))):double>
+-- !query 63 output
+1.0
+
+
+-- !query 64
+SELECT IF(true, cast(1 as double), cast(2 as bigint)) FROM t
+-- !query 64 schema
+struct<(IF(true, CAST(1 AS DOUBLE), CAST(CAST(2 AS BIGINT) AS DOUBLE))):double>
+-- !query 64 output
+1.0
+
+
+-- !query 65
+SELECT IF(true, cast(1 as double), cast(2 as float)) FROM t
+-- !query 65 schema
+struct<(IF(true, CAST(1 AS DOUBLE), CAST(CAST(2 AS FLOAT) AS DOUBLE))):double>
+-- !query 65 output
+1.0
+
+
+-- !query 66
+SELECT IF(true, cast(1 as double), cast(2 as double)) FROM t
+-- !query 66 schema
+struct<(IF(true, CAST(1 AS DOUBLE), CAST(2 AS DOUBLE))):double>
+-- !query 66 output
+1.0
+
+
+-- !query 67
+SELECT IF(true, cast(1 as double), cast(2 as decimal(10, 0))) FROM t
+-- !query 67 schema
+struct<(IF(true, CAST(1 AS DOUBLE), CAST(CAST(2 AS DECIMAL(10,0)) AS DOUBLE))):double>
+-- !query 67 output
+1.0
+
+
+-- !query 68
+SELECT IF(true, cast(1 as double), cast(2 as string)) FROM t
+-- !query 68 schema
+struct<(IF(true, CAST(CAST(1 AS DOUBLE) AS STRING), CAST(2 AS STRING))):string>
+-- !query 68 output
+1.0
+
+
+-- !query 69
+SELECT IF(true, cast(1 as double), cast('2' as binary)) FROM t
+-- !query 69 schema
+struct<>
+-- !query 69 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS DOUBLE), CAST('2' AS BINARY)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS DOUBLE), CAST('2' AS BINARY)))' (double and binary).; line 1 pos 7
+
+
+-- !query 70
+SELECT IF(true, cast(1 as double), cast(2 as boolean)) FROM t
+-- !query 70 schema
+struct<>
+-- !query 70 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS DOUBLE), CAST(2 AS BOOLEAN)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS DOUBLE), CAST(2 AS BOOLEAN)))' (double and boolean).; line 1 pos 7
+
+
+-- !query 71
+SELECT IF(true, cast(1 as double), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t
+-- !query 71 schema
+struct<>
+-- !query 71 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS DOUBLE), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS DOUBLE), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' (double and timestamp).; line 1 pos 7
+
+
+-- !query 72
+SELECT IF(true, cast(1 as double), cast('2017-12-11 09:30:00' as date)) FROM t
+-- !query 72 schema
+struct<>
+-- !query 72 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS DOUBLE), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS DOUBLE), CAST('2017-12-11 09:30:00' AS DATE)))' (double and date).; line 1 pos 7
+
+
+-- !query 73
+SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as tinyint)) FROM t
+-- !query 73 schema
+struct<(IF(true, CAST(1 AS DECIMAL(10,0)), CAST(CAST(2 AS TINYINT) AS DECIMAL(10,0)))):decimal(10,0)>
+-- !query 73 output
+1
+
+
+-- !query 74
+SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as smallint)) FROM t
+-- !query 74 schema
+struct<(IF(true, CAST(1 AS DECIMAL(10,0)), CAST(CAST(2 AS SMALLINT) AS DECIMAL(10,0)))):decimal(10,0)>
+-- !query 74 output
+1
+
+
+-- !query 75
+SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as int)) FROM t
+-- !query 75 schema
+struct<(IF(true, CAST(1 AS DECIMAL(10,0)), CAST(CAST(2 AS INT) AS DECIMAL(10,0)))):decimal(10,0)>
+-- !query 75 output
+1
+
+
+-- !query 76
+SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as bigint)) FROM t
+-- !query 76 schema
+struct<(IF(true, CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)), CAST(CAST(2 AS BIGINT) AS DECIMAL(20,0)))):decimal(20,0)>
+-- !query 76 output
+1
+
+
+-- !query 77
+SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as float)) FROM t
+-- !query 77 schema
+struct<(IF(true, CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE), CAST(CAST(2 AS FLOAT) AS DOUBLE))):double>
+-- !query 77 output
+1.0
+
+
+-- !query 78
+SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as double)) FROM t
+-- !query 78 schema
+struct<(IF(true, CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE), CAST(2 AS DOUBLE))):double>
+-- !query 78 output
+1.0
+
+
+-- !query 79
+SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as decimal(10, 0))) FROM t
+-- !query 79 schema
+struct<(IF(true, CAST(1 AS DECIMAL(10,0)), CAST(2 AS DECIMAL(10,0)))):decimal(10,0)>
+-- !query 79 output
+1
+
+
+-- !query 80
+SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as string)) FROM t
+-- !query 80 schema
+struct<(IF(true, CAST(CAST(1 AS DECIMAL(10,0)) AS STRING), CAST(2 AS STRING))):string>
+-- !query 80 output
+1
+
+
+-- !query 81
+SELECT IF(true, cast(1 as decimal(10, 0)), cast('2' as binary)) FROM t
+-- !query 81 schema
+struct<>
+-- !query 81 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS DECIMAL(10,0)), CAST('2' AS BINARY)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS DECIMAL(10,0)), CAST('2' AS BINARY)))' (decimal(10,0) and binary).; line 1 pos 7
+
+
+-- !query 82
+SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as boolean)) FROM t
+-- !query 82 schema
+struct<>
+-- !query 82 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS DECIMAL(10,0)), CAST(2 AS BOOLEAN)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS DECIMAL(10,0)), CAST(2 AS BOOLEAN)))' (decimal(10,0) and boolean).; line 1 pos 7
+
+
+-- !query 83
+SELECT IF(true, cast(1 as decimal(10, 0)), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t
+-- !query 83 schema
+struct<>
+-- !query 83 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' (decimal(10,0) and timestamp).; line 1 pos 7
+
+
+-- !query 84
+SELECT IF(true, cast(1 as decimal(10, 0)), cast('2017-12-11 09:30:00' as date)) FROM t
+-- !query 84 schema
+struct<>
+-- !query 84 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00' AS DATE)))' (decimal(10,0) and date).; line 1 pos 7
+
+
+-- !query 85
+SELECT IF(true, cast(1 as string), cast(2 as tinyint)) FROM t
+-- !query 85 schema
+struct<(IF(true, CAST(1 AS STRING), CAST(CAST(2 AS TINYINT) AS STRING))):string>
+-- !query 85 output
+1
+
+
+-- !query 86
+SELECT IF(true, cast(1 as string), cast(2 as smallint)) FROM t
+-- !query 86 schema
+struct<(IF(true, CAST(1 AS STRING), CAST(CAST(2 AS SMALLINT) AS STRING))):string>
+-- !query 86 output
+1
+
+
+-- !query 87
+SELECT IF(true, cast(1 as string), cast(2 as int)) FROM t
+-- !query 87 schema
+struct<(IF(true, CAST(1 AS STRING), CAST(CAST(2 AS INT) AS STRING))):string>
+-- !query 87 output
+1
+
+
+-- !query 88
+SELECT IF(true, cast(1 as string), cast(2 as bigint)) FROM t
+-- !query 88 schema
+struct<(IF(true, CAST(1 AS STRING), CAST(CAST(2 AS BIGINT) AS STRING))):string>
+-- !query 88 output
+1
+
+
+-- !query 89
+SELECT IF(true, cast(1 as string), cast(2 as float)) FROM t
+-- !query 89 schema
+struct<(IF(true, CAST(1 AS STRING), CAST(CAST(2 AS FLOAT) AS STRING))):string>
+-- !query 89 output
+1
+
+
+-- !query 90
+SELECT IF(true, cast(1 as string), cast(2 as double)) FROM t
+-- !query 90 schema
+struct<(IF(true, CAST(1 AS STRING), CAST(CAST(2 AS DOUBLE) AS STRING))):string>
+-- !query 90 output
+1
+
+
+-- !query 91
+SELECT IF(true, cast(1 as string), cast(2 as decimal(10, 0))) FROM t
+-- !query 91 schema
+struct<(IF(true, CAST(1 AS STRING), CAST(CAST(2 AS DECIMAL(10,0)) AS STRING))):string>
+-- !query 91 output
+1
+
+
+-- !query 92
+SELECT IF(true, cast(1 as string), cast(2 as string)) FROM t
+-- !query 92 schema
+struct<(IF(true, CAST(1 AS STRING), CAST(2 AS STRING))):string>
+-- !query 92 output
+1
+
+
+-- !query 93
+SELECT IF(true, cast(1 as string), cast('2' as binary)) FROM t
+-- !query 93 schema
+struct<>
+-- !query 93 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS STRING), CAST('2' AS BINARY)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS STRING), CAST('2' AS BINARY)))' (string and binary).; line 1 pos 7
+
+
+-- !query 94
+SELECT IF(true, cast(1 as string), cast(2 as boolean)) FROM t
+-- !query 94 schema
+struct<>
+-- !query 94 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS STRING), CAST(2 AS BOOLEAN)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS STRING), CAST(2 AS BOOLEAN)))' (string and boolean).; line 1 pos 7
+
+
+-- !query 95
+SELECT IF(true, cast(1 as string), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t
+-- !query 95 schema
+struct<(IF(true, CAST(1 AS STRING), CAST(CAST(2017-12-11 09:30:00.0 AS TIMESTAMP) AS STRING))):string>
+-- !query 95 output
+1
+
+
+-- !query 96
+SELECT IF(true, cast(1 as string), cast('2017-12-11 09:30:00' as date)) FROM t
+-- !query 96 schema
+struct<(IF(true, CAST(1 AS STRING), CAST(CAST(2017-12-11 09:30:00 AS DATE) AS STRING))):string>
+-- !query 96 output
+1
+
+
+-- !query 97
+SELECT IF(true, cast('1' as binary), cast(2 as tinyint)) FROM t
+-- !query 97 schema
+struct<>
+-- !query 97 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('1' AS BINARY), CAST(2 AS TINYINT)))' due to data type mismatch: differing types in '(IF(true, CAST('1' AS BINARY), CAST(2 AS TINYINT)))' (binary and tinyint).; line 1 pos 7
+
+
+-- !query 98
+SELECT IF(true, cast('1' as binary), cast(2 as smallint)) FROM t
+-- !query 98 schema
+struct<>
+-- !query 98 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('1' AS BINARY), CAST(2 AS SMALLINT)))' due to data type mismatch: differing types in '(IF(true, CAST('1' AS BINARY), CAST(2 AS SMALLINT)))' (binary and smallint).; line 1 pos 7
+
+
+-- !query 99
+SELECT IF(true, cast('1' as binary), cast(2 as int)) FROM t
+-- !query 99 schema
+struct<>
+-- !query 99 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('1' AS BINARY), CAST(2 AS INT)))' due to data type mismatch: differing types in '(IF(true, CAST('1' AS BINARY), CAST(2 AS INT)))' (binary and int).; line 1 pos 7
+
+
+-- !query 100
+SELECT IF(true, cast('1' as binary), cast(2 as bigint)) FROM t
+-- !query 100 schema
+struct<>
+-- !query 100 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('1' AS BINARY), CAST(2 AS BIGINT)))' due to data type mismatch: differing types in '(IF(true, CAST('1' AS BINARY), CAST(2 AS BIGINT)))' (binary and bigint).; line 1 pos 7
+
+
+-- !query 101
+SELECT IF(true, cast('1' as binary), cast(2 as float)) FROM t
+-- !query 101 schema
+struct<>
+-- !query 101 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('1' AS BINARY), CAST(2 AS FLOAT)))' due to data type mismatch: differing types in '(IF(true, CAST('1' AS BINARY), CAST(2 AS FLOAT)))' (binary and float).; line 1 pos 7
+
+
+-- !query 102
+SELECT IF(true, cast('1' as binary), cast(2 as double)) FROM t
+-- !query 102 schema
+struct<>
+-- !query 102 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('1' AS BINARY), CAST(2 AS DOUBLE)))' due to data type mismatch: differing types in '(IF(true, CAST('1' AS BINARY), CAST(2 AS DOUBLE)))' (binary and double).; line 1 pos 7
+
+
+-- !query 103
+SELECT IF(true, cast('1' as binary), cast(2 as decimal(10, 0))) FROM t
+-- !query 103 schema
+struct<>
+-- !query 103 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('1' AS BINARY), CAST(2 AS DECIMAL(10,0))))' due to data type mismatch: differing types in '(IF(true, CAST('1' AS BINARY), CAST(2 AS DECIMAL(10,0))))' (binary and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 104
+SELECT IF(true, cast('1' as binary), cast(2 as string)) FROM t
+-- !query 104 schema
+struct<>
+-- !query 104 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('1' AS BINARY), CAST(2 AS STRING)))' due to data type mismatch: differing types in '(IF(true, CAST('1' AS BINARY), CAST(2 AS STRING)))' (binary and string).; line 1 pos 7
+
+
+-- !query 105
+SELECT IF(true, cast('1' as binary), cast('2' as binary)) FROM t
+-- !query 105 schema
+struct<(IF(true, CAST(1 AS BINARY), CAST(2 AS BINARY))):binary>
+-- !query 105 output
+1
+
+
+-- !query 106
+SELECT IF(true, cast('1' as binary), cast(2 as boolean)) FROM t
+-- !query 106 schema
+struct<>
+-- !query 106 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('1' AS BINARY), CAST(2 AS BOOLEAN)))' due to data type mismatch: differing types in '(IF(true, CAST('1' AS BINARY), CAST(2 AS BOOLEAN)))' (binary and boolean).; line 1 pos 7
+
+
+-- !query 107
+SELECT IF(true, cast('1' as binary), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t
+-- !query 107 schema
+struct<>
+-- !query 107 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('1' AS BINARY), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: differing types in '(IF(true, CAST('1' AS BINARY), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' (binary and timestamp).; line 1 pos 7
+
+
+-- !query 108
+SELECT IF(true, cast('1' as binary), cast('2017-12-11 09:30:00' as date)) FROM t
+-- !query 108 schema
+struct<>
+-- !query 108 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('1' AS BINARY), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: differing types in '(IF(true, CAST('1' AS BINARY), CAST('2017-12-11 09:30:00' AS DATE)))' (binary and date).; line 1 pos 7
+
+
+-- !query 109
+SELECT IF(true, cast(1 as boolean), cast(2 as tinyint)) FROM t
+-- !query 109 schema
+struct<>
+-- !query 109 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS TINYINT)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS TINYINT)))' (boolean and tinyint).; line 1 pos 7
+
+
+-- !query 110
+SELECT IF(true, cast(1 as boolean), cast(2 as smallint)) FROM t
+-- !query 110 schema
+struct<>
+-- !query 110 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS SMALLINT)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS SMALLINT)))' (boolean and smallint).; line 1 pos 7
+
+
+-- !query 111
+SELECT IF(true, cast(1 as boolean), cast(2 as int)) FROM t
+-- !query 111 schema
+struct<>
+-- !query 111 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS INT)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS INT)))' (boolean and int).; line 1 pos 7
+
+
+-- !query 112
+SELECT IF(true, cast(1 as boolean), cast(2 as bigint)) FROM t
+-- !query 112 schema
+struct<>
+-- !query 112 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS BIGINT)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS BIGINT)))' (boolean and bigint).; line 1 pos 7
+
+
+-- !query 113
+SELECT IF(true, cast(1 as boolean), cast(2 as float)) FROM t
+-- !query 113 schema
+struct<>
+-- !query 113 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS FLOAT)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS FLOAT)))' (boolean and float).; line 1 pos 7
+
+
+-- !query 114
+SELECT IF(true, cast(1 as boolean), cast(2 as double)) FROM t
+-- !query 114 schema
+struct<>
+-- !query 114 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS DOUBLE)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS DOUBLE)))' (boolean and double).; line 1 pos 7
+
+
+-- !query 115
+SELECT IF(true, cast(1 as boolean), cast(2 as decimal(10, 0))) FROM t
+-- !query 115 schema
+struct<>
+-- !query 115 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS DECIMAL(10,0))))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS DECIMAL(10,0))))' (boolean and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 116
+SELECT IF(true, cast(1 as boolean), cast(2 as string)) FROM t
+-- !query 116 schema
+struct<>
+-- !query 116 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS STRING)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS STRING)))' (boolean and string).; line 1 pos 7
+
+
+-- !query 117
+SELECT IF(true, cast(1 as boolean), cast('2' as binary)) FROM t
+-- !query 117 schema
+struct<>
+-- !query 117 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS BOOLEAN), CAST('2' AS BINARY)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BOOLEAN), CAST('2' AS BINARY)))' (boolean and binary).; line 1 pos 7
+
+
+-- !query 118
+SELECT IF(true, cast(1 as boolean), cast(2 as boolean)) FROM t
+-- !query 118 schema
+struct<(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS BOOLEAN))):boolean>
+-- !query 118 output
+true
+
+
+-- !query 119
+SELECT IF(true, cast(1 as boolean), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t
+-- !query 119 schema
+struct<>
+-- !query 119 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS BOOLEAN), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BOOLEAN), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' (boolean and timestamp).; line 1 pos 7
+
+
+-- !query 120
+SELECT IF(true, cast(1 as boolean), cast('2017-12-11 09:30:00' as date)) FROM t
+-- !query 120 schema
+struct<>
+-- !query 120 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST(1 AS BOOLEAN), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BOOLEAN), CAST('2017-12-11 09:30:00' AS DATE)))' (boolean and date).; line 1 pos 7
+
+
+-- !query 121
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as tinyint)) FROM t
+-- !query 121 schema
+struct<>
+-- !query 121 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS TINYINT)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS TINYINT)))' (timestamp and tinyint).; line 1 pos 7
+
+
+-- !query 122
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as smallint)) FROM t
+-- !query 122 schema
+struct<>
+-- !query 122 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS SMALLINT)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS SMALLINT)))' (timestamp and smallint).; line 1 pos 7
+
+
+-- !query 123
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as int)) FROM t
+-- !query 123 schema
+struct<>
+-- !query 123 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS INT)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS INT)))' (timestamp and int).; line 1 pos 7
+
+
+-- !query 124
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as bigint)) FROM t
+-- !query 124 schema
+struct<>
+-- !query 124 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS BIGINT)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS BIGINT)))' (timestamp and bigint).; line 1 pos 7
+
+
+-- !query 125
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as float)) FROM t
+-- !query 125 schema
+struct<>
+-- !query 125 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS FLOAT)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS FLOAT)))' (timestamp and float).; line 1 pos 7
+
+
+-- !query 126
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as double)) FROM t
+-- !query 126 schema
+struct<>
+-- !query 126 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS DOUBLE)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS DOUBLE)))' (timestamp and double).; line 1 pos 7
+
+
+-- !query 127
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as decimal(10, 0))) FROM t
+-- !query 127 schema
+struct<>
+-- !query 127 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS DECIMAL(10,0))))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS DECIMAL(10,0))))' (timestamp and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 128
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as string)) FROM t
+-- !query 128 schema
+struct<(IF(true, CAST(CAST(2017-12-12 09:30:00.0 AS TIMESTAMP) AS STRING), CAST(2 AS STRING))):string>
+-- !query 128 output
+2017-12-12 09:30:00
+
+
+-- !query 129
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast('2' as binary)) FROM t
+-- !query 129 schema
+struct<>
+-- !query 129 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST('2' AS BINARY)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST('2' AS BINARY)))' (timestamp and binary).; line 1 pos 7
+
+
+-- !query 130
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as boolean)) FROM t
+-- !query 130 schema
+struct<>
+-- !query 130 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS BOOLEAN)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS BOOLEAN)))' (timestamp and boolean).; line 1 pos 7
+
+
+-- !query 131
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t
+-- !query 131 schema
+struct<(IF(true, CAST(2017-12-12 09:30:00.0 AS TIMESTAMP), CAST(2017-12-11 09:30:00.0 AS TIMESTAMP))):timestamp>
+-- !query 131 output
+2017-12-12 09:30:00
+
+
+-- !query 132
+SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast('2017-12-11 09:30:00' as date)) FROM t
+-- !query 132 schema
+struct<(IF(true, CAST(2017-12-12 09:30:00.0 AS TIMESTAMP), CAST(CAST(2017-12-11 09:30:00 AS DATE) AS TIMESTAMP))):timestamp>
+-- !query 132 output
+2017-12-12 09:30:00
+
+
+-- !query 133
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as tinyint)) FROM t
+-- !query 133 schema
+struct<>
+-- !query 133 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS TINYINT)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS TINYINT)))' (date and tinyint).; line 1 pos 7
+
+
+-- !query 134
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as smallint)) FROM t
+-- !query 134 schema
+struct<>
+-- !query 134 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS SMALLINT)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS SMALLINT)))' (date and smallint).; line 1 pos 7
+
+
+-- !query 135
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as int)) FROM t
+-- !query 135 schema
+struct<>
+-- !query 135 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS INT)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS INT)))' (date and int).; line 1 pos 7
+
+
+-- !query 136
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as bigint)) FROM t
+-- !query 136 schema
+struct<>
+-- !query 136 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS BIGINT)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS BIGINT)))' (date and bigint).; line 1 pos 7
+
+
+-- !query 137
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as float)) FROM t
+-- !query 137 schema
+struct<>
+-- !query 137 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS FLOAT)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS FLOAT)))' (date and float).; line 1 pos 7
+
+
+-- !query 138
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as double)) FROM t
+-- !query 138 schema
+struct<>
+-- !query 138 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS DOUBLE)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS DOUBLE)))' (date and double).; line 1 pos 7
+
+
+-- !query 139
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as decimal(10, 0))) FROM t
+-- !query 139 schema
+struct<>
+-- !query 139 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS DECIMAL(10,0))))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS DECIMAL(10,0))))' (date and decimal(10,0)).; line 1 pos 7
+
+
+-- !query 140
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as string)) FROM t
+-- !query 140 schema
+struct<(IF(true, CAST(CAST(2017-12-12 09:30:00 AS DATE) AS STRING), CAST(2 AS STRING))):string>
+-- !query 140 output
+2017-12-12
+
+
+-- !query 141
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast('2' as binary)) FROM t
+-- !query 141 schema
+struct<>
+-- !query 141 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST('2' AS BINARY)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST('2' AS BINARY)))' (date and binary).; line 1 pos 7
+
+
+-- !query 142
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as boolean)) FROM t
+-- !query 142 schema
+struct<>
+-- !query 142 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS BOOLEAN)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS BOOLEAN)))' (date and boolean).; line 1 pos 7
+
+
+-- !query 143
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t
+-- !query 143 schema
+struct<(IF(true, CAST(CAST(2017-12-12 09:30:00 AS DATE) AS TIMESTAMP), CAST(2017-12-11 09:30:00.0 AS TIMESTAMP))):timestamp>
+-- !query 143 output
+2017-12-12 00:00:00
+
+
+-- !query 144
+SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast('2017-12-11 09:30:00' as date)) FROM t
+-- !query 144 schema
+struct<(IF(true, CAST(2017-12-12 09:30:00 AS DATE), CAST(2017-12-11 09:30:00 AS DATE))):date>
+-- !query 144 output
+2017-12-12
diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/implicitTypeCasts.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/implicitTypeCasts.sql.out
new file mode 100644
index 0000000000000..44fa48e2697b3
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/implicitTypeCasts.sql.out
@@ -0,0 +1,354 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 44
+
+
+-- !query 0
+CREATE TEMPORARY VIEW t AS SELECT 1
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+SELECT 1 + '2' FROM t
+-- !query 1 schema
+struct<(CAST(1 AS DOUBLE) + CAST(2 AS DOUBLE)):double>
+-- !query 1 output
+3.0
+
+
+-- !query 2
+SELECT 1 - '2' FROM t
+-- !query 2 schema
+struct<(CAST(1 AS DOUBLE) - CAST(2 AS DOUBLE)):double>
+-- !query 2 output
+-1.0
+
+
+-- !query 3
+SELECT 1 * '2' FROM t
+-- !query 3 schema
+struct<(CAST(1 AS DOUBLE) * CAST(2 AS DOUBLE)):double>
+-- !query 3 output
+2.0
+
+
+-- !query 4
+SELECT 4 / '2' FROM t
+-- !query 4 schema
+struct<(CAST(4 AS DOUBLE) / CAST(CAST(2 AS DOUBLE) AS DOUBLE)):double>
+-- !query 4 output
+2.0
+
+
+-- !query 5
+SELECT 1.1 + '2' FROM t
+-- !query 5 schema
+struct<(CAST(1.1 AS DOUBLE) + CAST(2 AS DOUBLE)):double>
+-- !query 5 output
+3.1
+
+
+-- !query 6
+SELECT 1.1 - '2' FROM t
+-- !query 6 schema
+struct<(CAST(1.1 AS DOUBLE) - CAST(2 AS DOUBLE)):double>
+-- !query 6 output
+-0.8999999999999999
+
+
+-- !query 7
+SELECT 1.1 * '2' FROM t
+-- !query 7 schema
+struct<(CAST(1.1 AS DOUBLE) * CAST(2 AS DOUBLE)):double>
+-- !query 7 output
+2.2
+
+
+-- !query 8
+SELECT 4.4 / '2' FROM t
+-- !query 8 schema
+struct<(CAST(4.4 AS DOUBLE) / CAST(2 AS DOUBLE)):double>
+-- !query 8 output
+2.2
+
+
+-- !query 9
+SELECT 1.1 + '2.2' FROM t
+-- !query 9 schema
+struct<(CAST(1.1 AS DOUBLE) + CAST(2.2 AS DOUBLE)):double>
+-- !query 9 output
+3.3000000000000003
+
+
+-- !query 10
+SELECT 1.1 - '2.2' FROM t
+-- !query 10 schema
+struct<(CAST(1.1 AS DOUBLE) - CAST(2.2 AS DOUBLE)):double>
+-- !query 10 output
+-1.1
+
+
+-- !query 11
+SELECT 1.1 * '2.2' FROM t
+-- !query 11 schema
+struct<(CAST(1.1 AS DOUBLE) * CAST(2.2 AS DOUBLE)):double>
+-- !query 11 output
+2.4200000000000004
+
+
+-- !query 12
+SELECT 4.4 / '2.2' FROM t
+-- !query 12 schema
+struct<(CAST(4.4 AS DOUBLE) / CAST(2.2 AS DOUBLE)):double>
+-- !query 12 output
+2.0
+
+
+-- !query 13
+SELECT '$' || cast(1 as smallint) || '$' FROM t
+-- !query 13 schema
+struct
+-- !query 13 output
+$1$
+
+
+-- !query 14
+SELECT '$' || 1 || '$' FROM t
+-- !query 14 schema
+struct
+-- !query 14 output
+$1$
+
+
+-- !query 15
+SELECT '$' || cast(1 as bigint) || '$' FROM t
+-- !query 15 schema
+struct
+-- !query 15 output
+$1$
+
+
+-- !query 16
+SELECT '$' || cast(1.1 as float) || '$' FROM t
+-- !query 16 schema
+struct
+-- !query 16 output
+$1.1$
+
+
+-- !query 17
+SELECT '$' || cast(1.1 as double) || '$' FROM t
+-- !query 17 schema
+struct
+-- !query 17 output
+$1.1$
+
+
+-- !query 18
+SELECT '$' || 1.1 || '$' FROM t
+-- !query 18 schema
+struct