diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java index ade11574c3974..99fc10ed0a4c8 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java @@ -52,16 +52,19 @@ import com.facebook.presto.sql.tree.SubqueryExpression; import com.facebook.presto.sql.tree.Table; import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.HashMultiset; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ListMultimap; import com.google.common.collect.Multimap; +import com.google.common.collect.Multiset; import javax.annotation.Nullable; import javax.annotation.concurrent.Immutable; import java.util.ArrayDeque; +import java.util.ArrayList; import java.util.Collection; import java.util.Deque; import java.util.HashMap; @@ -70,6 +73,7 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; @@ -157,6 +161,12 @@ public class Analysis private final Map, List> groupingOperations = new LinkedHashMap<>(); + private final Multiset rowFilterScopes = HashMultiset.create(); + private final Map, List> rowFilters = new LinkedHashMap<>(); + + private final Multiset columnMaskScopes = HashMultiset.create(); + private final Map, Map> columnMasks = new LinkedHashMap<>(); + // for create table private Optional createTableDestination = Optional.empty(); private Map createTableProperties = ImmutableMap.of(); @@ -994,6 +1004,59 @@ public Map> getInvokedFunctions() return functionMap.entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> ImmutableSet.copyOf(entry.getValue()))); } + public boolean hasRowFilter(QualifiedObjectName table, String identity) + { + return rowFilterScopes.contains(new RowFilterScopeEntry(table, identity)); + } + + public void registerTableForRowFiltering(QualifiedObjectName table, String identity) + { + rowFilterScopes.add(new RowFilterScopeEntry(table, identity)); + } + + public void unregisterTableForRowFiltering(QualifiedObjectName table, String identity) + { + rowFilterScopes.remove(new RowFilterScopeEntry(table, identity)); + } + + public void addRowFilter(Table table, Expression filter) + { + rowFilters.computeIfAbsent(NodeRef.of(table), node -> new ArrayList<>()) + .add(filter); + } + + public List getRowFilters(Table node) + { + return rowFilters.getOrDefault(NodeRef.of(node), ImmutableList.of()); + } + + public boolean hasColumnMask(QualifiedObjectName table, String column, String identity) + { + return columnMaskScopes.contains(new ColumnMaskScopeEntry(table, column, identity)); + } + + public void registerTableForColumnMasking(QualifiedObjectName table, String column, String identity) + { + columnMaskScopes.add(new ColumnMaskScopeEntry(table, column, identity)); + } + + public void unregisterTableForColumnMasking(QualifiedObjectName table, String column, String identity) + { + columnMaskScopes.remove(new ColumnMaskScopeEntry(table, column, identity)); + } + + public void addColumnMask(Table table, String column, Expression mask) + { + Map masks = columnMasks.computeIfAbsent(NodeRef.of(table), node -> new LinkedHashMap<>()); + checkArgument(!masks.containsKey(column), "Mask already exists for column %s", column); + masks.put(column, mask); + } + + public Map getColumnMasks(Table table) + { + return columnMasks.getOrDefault(NodeRef.of(table), ImmutableMap.of()); + } + @Immutable public static final class Insert { @@ -1177,4 +1240,71 @@ public boolean isFromView() return isFromView; } } + + private static class RowFilterScopeEntry + { + private final QualifiedObjectName table; + private final String identity; + + public RowFilterScopeEntry(QualifiedObjectName table, String identity) + { + this.table = requireNonNull(table, "table is null"); + this.identity = requireNonNull(identity, "identity is null"); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + RowFilterScopeEntry that = (RowFilterScopeEntry) o; + return table.equals(that.table) && + identity.equals(that.identity); + } + + @Override + public int hashCode() + { + return Objects.hash(table, identity); + } + } + + private static class ColumnMaskScopeEntry + { + private final QualifiedObjectName table; + private final String column; + private final String identity; + + public ColumnMaskScopeEntry(QualifiedObjectName table, String column, String identity) + { + this.table = requireNonNull(table, "table is null"); + this.column = requireNonNull(column, "column is null"); + this.identity = requireNonNull(identity, "identity is null"); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ColumnMaskScopeEntry that = (ColumnMaskScopeEntry) o; + return table.equals(that.table) && + column.equals(that.column) && + identity.equals(that.identity); + } + + @Override + public int hashCode() + { + return Objects.hash(table, column, identity); + } + } } diff --git a/presto-docs/src/main/sphinx/develop/system-access-control.rst b/presto-docs/src/main/sphinx/develop/system-access-control.rst index 767cfa8e6c636..774cb782b966f 100644 --- a/presto-docs/src/main/sphinx/develop/system-access-control.rst +++ b/presto-docs/src/main/sphinx/develop/system-access-control.rst @@ -29,6 +29,16 @@ name which is used by the administrator in a Presto configuration. The implementation of ``SystemAccessControl`` and ``SystemAccessControlFactory`` must be wrapped as a plugin and installed on the Presto cluster. +Row Filters and Column Masks +---------------------------- + +The access control implementation can optionally provide row filters and column masks to +control viewing of specific rows or mask sensitive values in columns. The filters +and masks are retrieved per table from the given ``Identity``, schema, table, and +column names. The returned filters and masks will be in the form of a ``ViewExpression`` +that is then applied to the query plan before execution. Filters and masks can also be +supplied at the connector level from a ``ConnectorAccessControl`` implementation. + Configuration ------------- diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/security/LegacyAccessControl.java b/presto-hive/src/main/java/com/facebook/presto/hive/security/LegacyAccessControl.java index d1b77eb6d685d..59c2fe40a9a3a 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/security/LegacyAccessControl.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/security/LegacyAccessControl.java @@ -19,6 +19,7 @@ import com.facebook.presto.hive.TransactionalMetadata; import com.facebook.presto.hive.metastore.MetastoreContext; import com.facebook.presto.hive.metastore.Table; +import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorAccessControl; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; @@ -26,9 +27,13 @@ import com.facebook.presto.spi.security.ConnectorIdentity; import com.facebook.presto.spi.security.PrestoPrincipal; import com.facebook.presto.spi.security.Privilege; +import com.facebook.presto.spi.security.ViewExpression; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import javax.inject.Inject; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -286,4 +291,16 @@ public void checkCanAddConstraint(ConnectorTransactionHandle transaction, Connec denyAddConstraint(tableName.toString()); } } + + @Override + public List getRowFilters(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) + { + return ImmutableList.of(); + } + + @Override + public Map getColumnMasks(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName, List columns) + { + return ImmutableMap.of(); + } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/security/SqlStandardAccessControl.java b/presto-hive/src/main/java/com/facebook/presto/hive/security/SqlStandardAccessControl.java index 682fa482c1ae6..5f225ba0a094c 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/security/SqlStandardAccessControl.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/security/SqlStandardAccessControl.java @@ -21,6 +21,7 @@ import com.facebook.presto.hive.metastore.Database; import com.facebook.presto.hive.metastore.MetastoreContext; import com.facebook.presto.hive.metastore.SemiTransactionalHiveMetastore; +import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorAccessControl; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; @@ -30,9 +31,13 @@ import com.facebook.presto.spi.security.PrestoPrincipal; import com.facebook.presto.spi.security.Privilege; import com.facebook.presto.spi.security.RoleGrant; +import com.facebook.presto.spi.security.ViewExpression; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import javax.inject.Inject; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -702,6 +707,18 @@ public void checkCanShowRoleGrants(ConnectorTransactionHandle transactionHandle, { } + @Override + public List getRowFilters(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) + { + return ImmutableList.of(); + } + + @Override + public Map getColumnMasks(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName, List columns) + { + return ImmutableMap.of(); + } + private boolean isAdmin(ConnectorTransactionHandle transaction, ConnectorIdentity identity, MetastoreContext metastoreContext) { return getMetastore(transaction) diff --git a/presto-main-base/src/main/java/com/facebook/presto/security/AccessControlManager.java b/presto-main-base/src/main/java/com/facebook/presto/security/AccessControlManager.java index 468d7d0ed2f3b..f618ff1fbabe7 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/security/AccessControlManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/security/AccessControlManager.java @@ -20,6 +20,7 @@ import com.facebook.presto.common.Subfield; import com.facebook.presto.common.transaction.TransactionId; import com.facebook.presto.spi.CatalogSchemaTableName; +import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaTableName; @@ -33,8 +34,10 @@ import com.facebook.presto.spi.security.Privilege; import com.facebook.presto.spi.security.SystemAccessControl; import com.facebook.presto.spi.security.SystemAccessControlFactory; +import com.facebook.presto.spi.security.ViewExpression; import com.facebook.presto.transaction.TransactionManager; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import org.weakref.jmx.Managed; @@ -55,6 +58,7 @@ import java.util.concurrent.atomic.AtomicReference; import static com.facebook.presto.metadata.MetadataUtil.toSchemaTableName; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_COLUMN_MASK; import static com.facebook.presto.spi.StandardErrorCode.SERVER_STARTING_UP; import static com.facebook.presto.util.PropertiesUtil.loadProperties; import static com.google.common.base.Preconditions.checkArgument; @@ -798,6 +802,54 @@ public void checkCanAddConstraints(TransactionId transactionId, Identity identit } } + @Override + public List getRowFilters(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName) + { + requireNonNull(transactionId, "transactionId is null"); + requireNonNull(identity, "identity is null"); + requireNonNull(tableName, "catalogName is null"); + + ImmutableList.Builder filters = ImmutableList.builder(); + CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); + if (entry != null) { + entry.getAccessControl().getRowFilters(entry.getTransactionHandle(transactionId), identity.toConnectorIdentity(tableName.getCatalogName()), context, toSchemaTableName(tableName)) + .forEach(filters::add); + } + + systemAccessControl.get().getRowFilters(identity, context, toCatalogSchemaTableName(tableName)) + .forEach(filters::add); + + return filters.build(); + } + + @Override + public Map getColumnMasks(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName, List columns) + { + requireNonNull(transactionId, "transactionId is null"); + requireNonNull(identity, "identity is null"); + requireNonNull(tableName, "catalogName is null"); + requireNonNull(columns, "columns is null"); + + ImmutableMap.Builder columnMasksBuilder = ImmutableMap.builder(); + + // connector-provided masks take precedence over global masks + CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); + if (entry != null) { + Map connectorMasks = entry.getAccessControl().getColumnMasks(entry.getTransactionHandle(transactionId), identity.toConnectorIdentity(tableName.getCatalogName()), context, toSchemaTableName(tableName), columns); + columnMasksBuilder.putAll(connectorMasks); + } + + Map systemMasks = systemAccessControl.get().getColumnMasks(identity, context, toCatalogSchemaTableName(tableName), columns); + columnMasksBuilder.putAll(systemMasks); + + try { + return columnMasksBuilder.buildOrThrow(); + } + catch (IllegalArgumentException exception) { + throw new PrestoException(INVALID_COLUMN_MASK, "Multiple masks for the same column found", exception); + } + } + private CatalogAccessControlEntry getConnectorAccessControl(TransactionId transactionId, String catalogName) { return transactionManager.getOptionalCatalogMetadata(transactionId, catalogName) diff --git a/presto-main-base/src/main/java/com/facebook/presto/security/AllowAllSystemAccessControl.java b/presto-main-base/src/main/java/com/facebook/presto/security/AllowAllSystemAccessControl.java index a0e1ea6c315f2..65ae6cda52e7f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/security/AllowAllSystemAccessControl.java +++ b/presto-main-base/src/main/java/com/facebook/presto/security/AllowAllSystemAccessControl.java @@ -15,6 +15,7 @@ import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.spi.CatalogSchemaTableName; +import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.security.AccessControlContext; import com.facebook.presto.spi.security.AuthorizedIdentity; @@ -23,6 +24,9 @@ import com.facebook.presto.spi.security.Privilege; import com.facebook.presto.spi.security.SystemAccessControl; import com.facebook.presto.spi.security.SystemAccessControlFactory; +import com.facebook.presto.spi.security.ViewExpression; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import java.security.Principal; import java.security.cert.X509Certificate; @@ -231,4 +235,16 @@ public void checkCanDropConstraint(Identity identity, AccessControlContext conte public void checkCanAddConstraint(Identity identity, AccessControlContext context, CatalogSchemaTableName table) { } + + @Override + public List getRowFilters(Identity identity, AccessControlContext context, CatalogSchemaTableName tableName) + { + return ImmutableList.of(); + } + + @Override + public Map getColumnMasks(Identity identity, AccessControlContext context, CatalogSchemaTableName tableName, List columns) + { + return ImmutableMap.of(); + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/security/FileBasedSystemAccessControl.java b/presto-main-base/src/main/java/com/facebook/presto/security/FileBasedSystemAccessControl.java index 199490f096f29..3e18f55fc289b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/security/FileBasedSystemAccessControl.java +++ b/presto-main-base/src/main/java/com/facebook/presto/security/FileBasedSystemAccessControl.java @@ -19,6 +19,7 @@ import com.facebook.presto.plugin.base.security.SchemaAccessControlRule; import com.facebook.presto.security.CatalogAccessControlRule.AccessMode; import com.facebook.presto.spi.CatalogSchemaTableName; +import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.security.AccessControlContext; @@ -28,7 +29,9 @@ import com.facebook.presto.spi.security.Privilege; import com.facebook.presto.spi.security.SystemAccessControl; import com.facebook.presto.spi.security.SystemAccessControlFactory; +import com.facebook.presto.spi.security.ViewExpression; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.airlift.units.Duration; @@ -450,6 +453,18 @@ public void checkCanAddConstraint(Identity identity, AccessControlContext contex } } + @Override + public List getRowFilters(Identity identity, AccessControlContext context, CatalogSchemaTableName tableName) + { + return ImmutableList.of(); + } + + @Override + public Map getColumnMasks(Identity identity, AccessControlContext context, CatalogSchemaTableName tableName, List columns) + { + return ImmutableMap.of(); + } + private boolean isSchemaOwner(Identity identity, CatalogSchemaName schema) { if (!canAccessCatalog(identity, schema.getCatalogName(), ALL)) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java index d52540a1660f0..7199ae33a3581 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java @@ -55,6 +55,7 @@ import com.facebook.presto.spi.security.AllowAllAccessControl; import com.facebook.presto.spi.security.Identity; import com.facebook.presto.spi.security.ViewAccessControl; +import com.facebook.presto.spi.security.ViewExpression; import com.facebook.presto.sql.ExpressionUtils; import com.facebook.presto.sql.MaterializedViewUtils; import com.facebook.presto.sql.SqlFormatterUtil; @@ -205,8 +206,11 @@ import static com.facebook.presto.metadata.MetadataUtil.createQualifiedObjectName; import static com.facebook.presto.metadata.MetadataUtil.getConnectorIdOrThrow; import static com.facebook.presto.metadata.MetadataUtil.toSchemaTableName; +import static com.facebook.presto.spi.StandardErrorCode.DATATYPE_MISMATCH; import static com.facebook.presto.spi.StandardErrorCode.INVALID_ARGUMENTS; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_COLUMN_MASK; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_ROW_FILTER; import static com.facebook.presto.spi.StandardWarningCode.PERFORMANCE_WARNING; import static com.facebook.presto.spi.StandardWarningCode.REDUNDANT_ORDER_BY; import static com.facebook.presto.spi.analyzer.AccessControlRole.TABLE_CREATE; @@ -412,7 +416,16 @@ protected Scope visitInsert(Insert insert, Optional scope) analysis.addAccessControlCheckForTable(TABLE_INSERT, new AccessControlInfoForTable(accessControl, session.getIdentity(), session.getTransactionId(), session.getAccessControlContext(), targetTable)); + if (!accessControl.getRowFilters(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), targetTable).isEmpty()) { + throw new SemanticException(NOT_SUPPORTED, insert, "Insert into table with row filter is not supported"); + } + List columnsMetadata = tableColumnsMetadata.getColumnsMetadata(); + + if (!accessControl.getColumnMasks(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), targetTable, columnsMetadata).isEmpty()) { + throw new SemanticException(NOT_SUPPORTED, insert, "Insert into table with column masks is not supported"); + } + List tableColumns = columnsMetadata.stream() .filter(column -> !column.isHidden()) .map(ColumnMetadata::getName) @@ -600,6 +613,16 @@ protected Scope visitDelete(Delete node, Optional scope) analysis.addAccessControlCheckForTable(TABLE_DELETE, new AccessControlInfoForTable(accessControl, session.getIdentity(), session.getTransactionId(), session.getAccessControlContext(), tableName)); + if (!accessControl.getRowFilters(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), tableName).isEmpty()) { + throw new SemanticException(NOT_SUPPORTED, node, "Delete from table with row filter is not supported"); + } + + TableColumnMetadata tableColumnsMetadata = getTableColumnsMetadata(session, metadataResolver, analysis.getMetadataHandle(), tableName); + List columnsMetadata = tableColumnsMetadata.getColumnsMetadata(); + if (!accessControl.getColumnMasks(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), tableName, columnsMetadata).isEmpty()) { + throw new SemanticException(NOT_SUPPORTED, node, "Delete from table with column mask is not supported"); + } + return createAndAssignScope(node, scope, Field.newUnqualified(node.getLocation(), "rows", BIGINT)); } @@ -1341,6 +1364,7 @@ protected Scope visitTable(Table table, Optional scope) } TableColumnMetadata tableColumnsMetadata = getTableColumnsMetadata(session, metadataResolver, analysis.getMetadataHandle(), name); + List columnsMetadata = tableColumnsMetadata.getColumnsMetadata(); Optional tableHandle = getTableHandle(tableColumnsMetadata, table, name, scope); Map columnHandles = tableColumnsMetadata.getColumnHandles(); @@ -1348,7 +1372,7 @@ protected Scope visitTable(Table table, Optional scope) // TODO: discover columns lazily based on where they are needed (to support connectors that can't enumerate all tables) ImmutableList.Builder fields = ImmutableList.builder(); - for (ColumnMetadata column : tableColumnsMetadata.getColumnsMetadata()) { + for (ColumnMetadata column : columnsMetadata) { Field field = Field.newQualified( Optional.empty(), table.getName(), @@ -1366,6 +1390,23 @@ protected Scope visitTable(Table table, Optional scope) analysis.registerTable(table, tableHandle.get()); + List outputFields = fields.build(); + + Scope accessControlScope = Scope.builder() + .withRelationType(RelationId.anonymous(), new RelationType(outputFields)) + .build(); + + Map masks = accessControl.getColumnMasks(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), name, columnsMetadata); + + for (Map.Entry maskEntry : masks.entrySet()) { + analyzeColumnMask(session.getIdentity().getUser(), table, name, maskEntry.getKey(), accessControlScope, maskEntry.getValue()); + } + + accessControl.getRowFilters(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), name) + .forEach(filter -> analyzeRowFilter(session.getIdentity().getUser(), table, name, accessControlScope, filter)); + + analysis.registerTable(table, tableHandle.get()); + if (statement instanceof RefreshMaterializedView) { Table view = ((RefreshMaterializedView) statement).getTarget(); if (!table.equals(view) && !analysis.hasTableRegisteredForMaterializedView(view, table)) { @@ -1379,7 +1420,7 @@ protected Scope visitTable(Table table, Optional scope) } } - return createAndAssignScope(table, scope, fields.build()); + return createAndAssignScope(table, scope, outputFields); } private Optional getTableHandle(TableColumnMetadata tableColumnsMetadata, Table table, QualifiedObjectName name, Optional scope) @@ -2874,21 +2915,8 @@ private RelationType analyzeView(Query query, QualifiedObjectName name, Optional viewAccessControl = accessControl; } - Session.SessionBuilder viewSessionBuilder = Session.builder(metadata.getSessionPropertyManager()) - .setQueryId(session.getQueryId()) - .setTransactionId(session.getTransactionId().orElse(null)) - .setIdentity(identity) - .setSource(session.getSource().orElse(null)) - .setCatalog(catalog.orElse(null)) - .setSchema(schema.orElse(null)) - .setTimeZoneKey(session.getTimeZoneKey()) - .setLocale(session.getLocale()) - .setRemoteUserAddress(session.getRemoteUserAddress().orElse(null)) - .setUserAgent(session.getUserAgent().orElse(null)) - .setClientInfo(session.getClientInfo().orElse(null)) - .setStartTime(session.getStartTime()); - session.getConnectorProperties().forEach((connectorId, properties) -> properties.forEach((k, v) -> viewSessionBuilder.setConnectionProperty(connectorId, k, v))); - Session viewSession = viewSessionBuilder.build(); + Session viewSession = createViewSession(catalog, schema, identity); + StatementAnalyzer analyzer = new StatementAnalyzer(analysis, metadata, sqlParser, viewAccessControl, viewSession, warningCollector); Scope queryScope = analyzer.analyze(query, Scope.create()); return queryScope.getRelationType().withAlias(name.getObjectName(), null); @@ -2899,6 +2927,25 @@ private RelationType analyzeView(Query query, QualifiedObjectName name, Optional } } + private Session createViewSession(Optional catalog, Optional schema, Identity identity) + { + Session.SessionBuilder viewSessionBuilder = Session.builder(metadata.getSessionPropertyManager()) + .setQueryId(session.getQueryId()) + .setTransactionId(session.getTransactionId().orElse(null)) + .setIdentity(identity) + .setSource(session.getSource().orElse(null)) + .setCatalog(catalog.orElse(null)) + .setSchema(schema.orElse(null)) + .setTimeZoneKey(session.getTimeZoneKey()) + .setLocale(session.getLocale()) + .setRemoteUserAddress(session.getRemoteUserAddress().orElse(null)) + .setUserAgent(session.getUserAgent().orElse(null)) + .setClientInfo(session.getClientInfo().orElse(null)) + .setStartTime(session.getStartTime()); + session.getConnectorProperties().forEach((connectorId, properties) -> properties.forEach((k, v) -> viewSessionBuilder.setConnectionProperty(connectorId, k, v))); + return viewSessionBuilder.build(); + } + private Query parseView(String view, QualifiedObjectName name, Node node) { try { @@ -2941,6 +2988,112 @@ private ExpressionAnalysis analyzeExpression(Expression expression, Scope scope) warningCollector); } + private void analyzeRowFilter(String currentIdentity, Table table, QualifiedObjectName name, Scope scope, ViewExpression filter) + { + if (analysis.hasRowFilter(name, currentIdentity)) { + throw new PrestoException(INVALID_ROW_FILTER, format("Row filter for '%s' is recursive", name), null); + } + + Expression expression; + try { + expression = sqlParser.createExpression(filter.getExpression(), createParsingOptions(session)); + } + catch (ParsingException e) { + throw new PrestoException(INVALID_ROW_FILTER, format("Invalid row filter for '%s': %s", name, e.getErrorMessage()), e); + } + + analysis.registerTableForRowFiltering(name, currentIdentity); + ExpressionAnalysis expressionAnalysis; + try { + expressionAnalysis = ExpressionAnalyzer.analyzeExpression( + createViewSession(filter.getCatalog(), filter.getSchema(), new Identity(filter.getIdentity(), Optional.empty())), // TODO: path should be included in row filter + metadata, + accessControl, + sqlParser, + scope, + analysis, + expression, + warningCollector); + } + catch (PrestoException e) { + throw new PrestoException(e::getErrorCode, format("Invalid row filter for '%s: %s'", name, e.getMessage()), e); + } + finally { + analysis.unregisterTableForRowFiltering(name, currentIdentity); + } + + verifyNoAggregateWindowOrGroupingFunctions(analysis.getFunctionHandles(), functionAndTypeResolver, expression, format("Row filter for '%s'", name)); + + analysis.recordSubqueries(expression, expressionAnalysis); + + Type actualType = expressionAnalysis.getType(expression); + if (!actualType.equals(BOOLEAN)) { + if (!metadata.getFunctionAndTypeManager().canCoerce(actualType, BOOLEAN)) { + throw new PrestoException(DATATYPE_MISMATCH, format("Expected row filter for '%s' to be of type BOOLEAN, but was %s", name, actualType), null); + } + + analysis.addCoercion(expression, BOOLEAN, false); + } + + analysis.addRowFilter(table, expression); + } + + private void analyzeColumnMask(String currentIdentity, Table table, QualifiedObjectName tableName, ColumnMetadata columnMetadata, Scope scope, ViewExpression mask) + { + String column = columnMetadata.getName(); + if (analysis.hasColumnMask(tableName, column, currentIdentity)) { + throw new PrestoException(INVALID_COLUMN_MASK, format("Column mask for '%s.%s' is recursive", tableName, column), null); + } + + Expression expression; + try { + expression = sqlParser.createExpression(mask.getExpression(), createParsingOptions(session)); + } + catch (ParsingException e) { + throw new PrestoException(INVALID_COLUMN_MASK, format("Invalid column mask for '%s.%s': %s", tableName, column, e.getErrorMessage()), e); + } + + ExpressionAnalysis expressionAnalysis; + analysis.registerTableForColumnMasking(tableName, column, currentIdentity); + try { + expressionAnalysis = ExpressionAnalyzer.analyzeExpression( + createViewSession(mask.getCatalog(), mask.getSchema(), new Identity(mask.getIdentity(), Optional.empty())), // TODO: path should be included in row filter + metadata, + accessControl, + sqlParser, + scope, + analysis, + expression, + warningCollector); + } + catch (PrestoException e) { + throw new PrestoException(e::getErrorCode, format("Invalid column mask for '%s.%s: %s'", tableName, column, e.getMessage()), e); + } + finally { + analysis.unregisterTableForColumnMasking(tableName, column, currentIdentity); + } + + verifyNoAggregateWindowOrGroupingFunctions(analysis.getFunctionHandles(), functionAndTypeResolver, expression, format("Column mask for '%s.%s'", table.getName(), column)); + + analysis.recordSubqueries(expression, expressionAnalysis); + + Type expectedType = columnMetadata.getType(); + Type actualType = expressionAnalysis.getType(expression); + if (!actualType.equals(expectedType)) { + if (!metadata.getFunctionAndTypeManager().canCoerce(actualType, columnMetadata.getType())) { + throw new PrestoException(DATATYPE_MISMATCH, format("Expected column mask for '%s.%s' to be of type %s, but was %s", tableName, column, columnMetadata.getType(), actualType), null); + } + + // TODO: this should be "coercion.isTypeOnlyCoercion(actualType, expectedType)", but type-only coercions are broken + // due to the line "changeType(value, returnType)" in SqlToRowExpressionTranslator.visitCast. If there's an expression + // like CAST(CAST(x AS VARCHAR(1)) AS VARCHAR(2)), it determines that the outer cast is type-only and converts the expression + // to CAST(x AS VARCHAR(2)) by changing the type of the inner cast. + analysis.addCoercion(expression, expectedType, false); + } + + analysis.addColumnMask(table, column, expression); + } + private List descriptorToFields(Scope scope) { ImmutableList.Builder builder = ImmutableList.builder(); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index e260ea6b98d82..7e22aeb2ad5ff 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -222,7 +222,68 @@ protected RelationPlan visitTable(Table node, SqlPlannerContext context) PlanNode root = new TableScanNode(getSourceLocation(node.getLocation()), idAllocator.getNextId(), handle, outputVariables, columns.build(), tableConstraints, TupleDomain.all(), TupleDomain.all(), Optional.empty()); - return new RelationPlan(root, scope, outputVariables); + RelationPlan tableScan = new RelationPlan(root, scope, outputVariables); + tableScan = addRowFilters(node, tableScan, context); + tableScan = addColumnMasks(node, tableScan, context); + return tableScan; + } + + private RelationPlan addRowFilters(Table node, RelationPlan plan, SqlPlannerContext context) + { + PlanBuilder planBuilder = initializePlanBuilder(plan); + + for (Expression filter : analysis.getRowFilters(node)) { + planBuilder = subqueryPlanner.handleSubqueries(planBuilder, filter, filter, context); + + planBuilder = planBuilder.withNewRoot(new FilterNode( + getSourceLocation(node.getLocation()), + idAllocator.getNextId(), + planBuilder.getRoot(), + rowExpression(planBuilder.rewrite(filter), context))); + } + + return new RelationPlan(planBuilder.getRoot(), plan.getScope(), plan.getFieldMappings()); + } + + private RelationPlan addColumnMasks(Table table, RelationPlan plan, SqlPlannerContext context) + { + Map columnMasks = analysis.getColumnMasks(table); + + PlanBuilder planBuilder = initializePlanBuilder(plan); + List mappings = plan.getFieldMappings(); + ImmutableList.Builder newMappings = ImmutableList.builder(); + + Assignments.Builder assignments = new Assignments.Builder(); + for (VariableReferenceExpression variableReferenceExpression : planBuilder.getRoot().getOutputVariables()) { + assignments.put(variableReferenceExpression, rowExpression(new SymbolReference(variableReferenceExpression.getName()), context)); + } + + for (int i = 0; i < plan.getDescriptor().getAllFieldCount(); i++) { + Field field = plan.getDescriptor().getFieldByIndex(i); + + VariableReferenceExpression fieldMapping; + RowExpression rowExpression; + if (field.getName().isPresent() && columnMasks.containsKey(field.getName().get())) { + Expression mask = columnMasks.get(field.getName().get()); + planBuilder = subqueryPlanner.handleSubqueries(planBuilder, mask, mask, context); + fieldMapping = newVariable(variableAllocator, field); + rowExpression = rowExpression(planBuilder.rewrite(mask), context); + } + else { + fieldMapping = mappings.get(i); + rowExpression = rowExpression(createSymbolReference(fieldMapping), context); + } + + assignments.put(fieldMapping, rowExpression); + newMappings.add(fieldMapping); + } + + planBuilder = planBuilder.withNewRoot(new ProjectNode( + idAllocator.getNextId(), + planBuilder.getRoot(), + assignments.build())); + + return new RelationPlan(planBuilder.getRoot(), plan.getScope(), newMappings.build()); } @Override diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/TestingAccessControlManager.java b/presto-main-base/src/main/java/com/facebook/presto/testing/TestingAccessControlManager.java index fa2e681c1878d..9a7be061f767b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/testing/TestingAccessControlManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/TestingAccessControlManager.java @@ -19,16 +19,22 @@ import com.facebook.presto.common.transaction.TransactionId; import com.facebook.presto.security.AccessControlManager; import com.facebook.presto.security.AllowAllSystemAccessControl; +import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.security.AccessControlContext; import com.facebook.presto.spi.security.Identity; +import com.facebook.presto.spi.security.ViewExpression; import com.facebook.presto.transaction.TransactionManager; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import javax.inject.Inject; import java.security.Principal; +import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; @@ -89,6 +95,8 @@ public class TestingAccessControlManager extends AccessControlManager { private final Set denyPrivileges = new HashSet<>(); + private final Map> rowFilters = new HashMap<>(); + private final Map columnMasks = new HashMap<>(); @Inject public TestingAccessControlManager(TransactionManager transactionManager) @@ -115,6 +123,19 @@ public void deny(TestingPrivilege... deniedPrivileges) public void reset() { denyPrivileges.clear(); + rowFilters.clear(); + columnMasks.clear(); + } + + public void rowFilter(QualifiedObjectName table, String identity, ViewExpression filter) + { + rowFilters.computeIfAbsent(new RowFilterKey(identity, table), key -> new ArrayList<>()) + .add(filter); + } + + public void columnMask(QualifiedObjectName table, String column, String identity, ViewExpression mask) + { + columnMasks.put(new ColumnMaskKey(identity, table, column), mask); } @Override @@ -378,6 +399,29 @@ public void checkCanAddConstraints(TransactionId transactionId, Identity identit super.checkCanAddConstraints(transactionId, identity, context, tableName); } + @Override + public List getRowFilters(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName) + { + return rowFilters.getOrDefault(new RowFilterKey(identity.getUser(), tableName), ImmutableList.of()); + } + + @Override + public Map getColumnMasks(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName, List columns) + { + Map superResult = super.getColumnMasks(transactionId, identity, context, tableName, columns); + ImmutableMap.Builder columnMaskBuilder = ImmutableMap.builder(); + for (ColumnMetadata column : columns) { + ColumnMaskKey columnMaskKey = new ColumnMaskKey(identity.getUser(), tableName, column.getName()); + if (columnMasks.containsKey(columnMaskKey)) { + columnMaskBuilder.put(column, columnMasks.get(columnMaskKey)); + } + else if (superResult.containsKey(column)) { + columnMaskBuilder.put(column, superResult.get(column)); + } + } + return columnMaskBuilder.buildOrThrow(); + } + private boolean shouldDenyPrivilege(String userName, String entityName, TestingPrivilegeType type) { TestingPrivilege testPrivilege = privilege(userName, entityName, type); @@ -450,4 +494,71 @@ public String toString() .toString(); } } + + private static class RowFilterKey + { + private final String identity; + private final QualifiedObjectName table; + + public RowFilterKey(String identity, QualifiedObjectName table) + { + this.identity = requireNonNull(identity, "identity is null"); + this.table = requireNonNull(table, "table is null"); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + RowFilterKey that = (RowFilterKey) o; + return identity.equals(that.identity) && + table.equals(that.table); + } + + @Override + public int hashCode() + { + return Objects.hash(identity, table); + } + } + + private static class ColumnMaskKey + { + private final String identity; + private final QualifiedObjectName table; + private final String column; + + public ColumnMaskKey(String identity, QualifiedObjectName table, String column) + { + this.identity = identity; + this.table = table; + this.column = column; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ColumnMaskKey that = (ColumnMaskKey) o; + return identity.equals(that.identity) && + table.equals(that.table) && + column.equals(that.column); + } + + @Override + public int hashCode() + { + return Objects.hash(identity, table, column); + } + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/security/TestAccessControlManager.java b/presto-main-base/src/test/java/com/facebook/presto/security/TestAccessControlManager.java index 42853893f4c6f..6a1785eaf41ed 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/security/TestAccessControlManager.java +++ b/presto-main-base/src/test/java/com/facebook/presto/security/TestAccessControlManager.java @@ -24,6 +24,7 @@ import com.facebook.presto.metadata.InMemoryNodeManager; import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.spi.CatalogSchemaTableName; +import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.QueryId; @@ -41,6 +42,7 @@ import com.facebook.presto.spi.security.Privilege; import com.facebook.presto.spi.security.SystemAccessControl; import com.facebook.presto.spi.security.SystemAccessControlFactory; +import com.facebook.presto.spi.security.ViewExpression; import com.facebook.presto.testing.TestingConnectorContext; import com.facebook.presto.tpch.TpchConnectorFactory; import com.facebook.presto.transaction.TransactionManager; @@ -51,12 +53,15 @@ import java.security.Principal; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; +import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.spi.ConnectorId.createInformationSchemaConnectorId; import static com.facebook.presto.spi.ConnectorId.createSystemTablesConnectorId; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_COLUMN_MASK; import static com.facebook.presto.spi.security.AccessDeniedException.denyQueryIntegrityCheck; import static com.facebook.presto.spi.security.AccessDeniedException.denySelectColumns; import static com.facebook.presto.spi.security.AccessDeniedException.denySelectTable; @@ -66,6 +71,7 @@ import static java.util.Objects.requireNonNull; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertThrows; +import static org.testng.Assert.expectThrows; import static org.testng.Assert.fail; public class TestAccessControlManager @@ -258,6 +264,78 @@ public void testDenySystemAccessControl() }); } + @Test + public void testColumnMaskOrdering() + { + CatalogManager catalogManager = new CatalogManager(); + TransactionManager transactionManager = createTestTransactionManager(catalogManager); + AccessControlManager accessControlManager = new AccessControlManager(transactionManager); + + accessControlManager.addSystemAccessControlFactory(new SystemAccessControlFactory() { + @Override + public String getName() + { + return "test"; + } + + @Override + public SystemAccessControl create(Map config) + { + return new SystemAccessControl() { + @Override + public Map getColumnMasks(Identity identity, AccessControlContext context, CatalogSchemaTableName tableName, List columns) + { + ImmutableMap.Builder columnMaskBuilder = ImmutableMap.builder(); + for (ColumnMetadata column : columns) { + columnMaskBuilder.put(column, new ViewExpression("user", Optional.empty(), Optional.empty(), "system mask")); + } + return columnMaskBuilder.buildOrThrow(); + } + + @Override + public void checkCanSetUser(Identity identity, AccessControlContext context, Optional principal, String userName) + { + } + + @Override + public void checkQueryIntegrity(Identity identity, AccessControlContext context, String query) + { + } + + @Override + public void checkCanSetSystemSessionProperty(Identity identity, AccessControlContext context, String propertyName) + { + } + }; + } + }); + accessControlManager.setSystemAccessControl("test", ImmutableMap.of()); + + ConnectorId connectorId = registerBogusConnector(catalogManager, transactionManager, accessControlManager, "catalog"); + accessControlManager.addCatalogAccessControl(connectorId, new ConnectorAccessControl() { + @Override + public Map getColumnMasks(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName, List columns) + { + ImmutableMap.Builder columnMaskBuilder = ImmutableMap.builder(); + for (ColumnMetadata column : columns) { + columnMaskBuilder.put(column, new ViewExpression("user", Optional.empty(), Optional.empty(), "connector mask")); + } + return columnMaskBuilder.buildOrThrow(); + } + }); + + PrestoException exception = expectThrows( + PrestoException.class, + () -> transaction(transactionManager, accessControlManager) + .execute(transactionId -> { + accessControlManager.getColumnMasks(transactionId, new Identity(USER_NAME, Optional.of(PRINCIPAL)), + new AccessControlContext(new QueryId(QUERY_ID), Optional.empty(), Collections.emptySet(), Optional.empty(), WarningCollector.NOOP, new RuntimeStats(), Optional.empty(), Optional.empty(), Optional.empty()), new QualifiedObjectName("catalog", "schema", "table"), + ImmutableList.of(ColumnMetadata.builder().setName("column").setType(BIGINT).build())); + })); + assertEquals(exception.getErrorCode(), INVALID_COLUMN_MASK.toErrorCode()); + assertEquals(exception.getMessage(), "Multiple masks for the same column found"); + } + private static ConnectorId registerBogusConnector(CatalogManager catalogManager, TransactionManager transactionManager, AccessControl accessControl, String catalogName) { ConnectorId connectorId = new ConnectorId(catalogName); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestAccessControlFiltersMasks.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestAccessControlFiltersMasks.java new file mode 100644 index 0000000000000..ac5dfb3f453f6 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestAccessControlFiltersMasks.java @@ -0,0 +1,118 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.spi.plan.SemiJoinNode; +import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.security.ViewExpression; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.testing.TestingAccessControlManager; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; + +public class TestAccessControlFiltersMasks + extends BasePlanTest +{ + private static final String CATALOG = "local"; + private static final String USER = "user"; + private static final String RUN_AS_USER = "run-as-user"; + + private LocalQueryRunner runner; + private TestingAccessControlManager accessControl; + + @BeforeClass + public void init() + { + runner = getQueryRunner(); + accessControl = getQueryRunner().getAccessControl(); + } + + @Test + public void testBasicRowFilter() + { + executeExclusively(() -> { + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey < 10")); + assertPlan("SELECT * FROM orders", + anyTree( + filter("ORDERKEY < 10", + tableScan("orders", ImmutableMap.of("ORDERKEY", "orderkey"))))); + accessControl.reset(); + }); + } + + @Test + public void testMultipleIdentityFilters() + { + executeExclusively(() -> { + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + RUN_AS_USER, + new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "orderkey = 1")); + + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + USER, + new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)")); + assertPlan("SELECT count(*) FROM orders", + anyTree( + node(SemiJoinNode.class, + anyTree(filter("O_ORDERKEY = 1", tableScan("orders", ImmutableMap.of("O_ORDERKEY", "orderkey")))), + anyTree(filter("S_ORDERKEY = 1", tableScan("orders", ImmutableMap.of("S_ORDERKEY", "orderkey"))))))); + accessControl.reset(); + }); + } + + @Test + public void testBasicColumnMask() + { + executeExclusively(() -> { + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "custkey", + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "NULL")); + assertPlan("SELECT custkey FROM orders WHERE orderkey = 1", + anyTree( + project(ImmutableMap.of("custkey_0", expression("NULL")), + anyTree(node(TableScanNode.class))))); + accessControl.reset(); + }); + } + + protected void executeExclusively(Runnable executionBlock) + { + runner.getExclusiveLock().lock(); + try { + executionBlock.run(); + } + finally { + runner.getExclusiveLock().unlock(); + } + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/query/QueryAssertions.java b/presto-main-base/src/test/java/com/facebook/presto/sql/query/QueryAssertions.java index 8f765ed6010bf..acf3d1d70482b 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/query/QueryAssertions.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/query/QueryAssertions.java @@ -142,4 +142,15 @@ public void close() { runner.close(); } + + protected void executeExclusively(Runnable executionBlock) + { + runner.getExclusiveLock().lock(); + try { + executionBlock.run(); + } + finally { + runner.getExclusiveLock().unlock(); + } + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/query/TestColumnMask.java b/presto-main-base/src/test/java/com/facebook/presto/sql/query/TestColumnMask.java new file mode 100644 index 0000000000000..f04214fe9584c --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/query/TestColumnMask.java @@ -0,0 +1,382 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.query; + +import com.facebook.presto.Session; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.spi.security.Identity; +import com.facebook.presto.spi.security.ViewExpression; +import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.testing.TestingAccessControlManager; +import com.facebook.presto.tpch.TpchConnectorFactory; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; + +public class TestColumnMask +{ + private static final String CATALOG = "local"; + private static final String USER = "user"; + private static final String RUN_AS_USER = "run-as-user"; + + private QueryAssertions assertions; + private TestingAccessControlManager accessControl; + + @BeforeClass + public void init() + { + Session session = testSessionBuilder() + .setCatalog(CATALOG) + .setSchema(TINY_SCHEMA_NAME) + .setIdentity(new Identity(USER, Optional.empty())).build(); + + LocalQueryRunner runner = new LocalQueryRunner(session); + + runner.createCatalog(CATALOG, new TpchConnectorFactory(1), ImmutableMap.of()); + + assertions = new QueryAssertions(runner); + accessControl = assertions.getQueryRunner().getAccessControl(); + } + + @AfterClass(alwaysRun = true) + public void teardown() + { + assertions.close(); + assertions = null; + } + + @Test + public void testSimpleMask() + { + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "custkey", + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "-custkey")); + assertions.assertQuery("SELECT custkey FROM orders WHERE orderkey = 1", "VALUES BIGINT '-370'"); + }); + + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "custkey", + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "NULL")); + assertions.assertQuery("SELECT custkey FROM orders WHERE orderkey = 1", "VALUES CAST(NULL AS BIGINT)"); + }); + } + + @Test + public void testMultipleMasksOnDifferentColumns() + { + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "custkey", + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "-custkey")); + + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "orderstatus", + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "'X'")); + + assertions.assertQuery("SELECT custkey, orderstatus FROM orders WHERE orderkey = 1", "VALUES (BIGINT '-370', 'X')"); + }); + } + + @Test + public void testCoercibleType() + { + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "clerk", + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "CAST(clerk AS VARCHAR(5))")); + assertions.assertQuery("SELECT clerk FROM orders WHERE orderkey = 1", "VALUES CAST('Clerk' AS VARCHAR(15))"); + }); + } + + @Test + public void testSubquery() + { + // uncorrelated + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "clerk", + USER, + new ViewExpression(USER, Optional.of(CATALOG), Optional.of("tiny"), "(SELECT cast(max(name) AS VARCHAR(15)) FROM nation)")); + assertions.assertQuery("SELECT clerk FROM orders WHERE orderkey = 1", "VALUES CAST('VIETNAM' AS VARCHAR(15))"); + }); + + // correlated + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "clerk", + USER, + new ViewExpression(USER, Optional.of(CATALOG), Optional.of("tiny"), "(SELECT cast(max(name) AS VARCHAR(15)) FROM nation WHERE nationkey = orderkey)")); + assertions.assertQuery("SELECT clerk FROM orders WHERE orderkey = 1", "VALUES CAST('ARGENTINA' AS VARCHAR(15))"); + }); + } + + @Test + public void testTableReferenceInWithClause() + { + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "custkey", + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "-custkey")); + assertions.assertQuery("WITH t AS (SELECT custkey FROM orders WHERE orderkey = 1) SELECT * FROM t", "VALUES BIGINT '-370'"); + }); + } + + @Test + public void testOtherSchema() + { + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "orderkey", + USER, + new ViewExpression(USER, Optional.of(CATALOG), Optional.of("sf1"), "(SELECT count(*) FROM customer)")); // count is 15000 only when evaluating against sf1 + assertions.assertQuery("SELECT max(orderkey) FROM orders", "VALUES BIGINT '150000'"); + }); + } + + @Test + public void testDifferentIdentity() + { + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "orderkey", + RUN_AS_USER, + new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "100")); + + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "orderkey", + USER, + new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "(SELECT sum(orderkey) FROM orders)")); + + assertions.assertQuery("SELECT max(orderkey) FROM orders", "VALUES BIGINT '1500000'"); + }); + } + + @Test + public void testRecursion() + { + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "orderkey", + USER, + new ViewExpression(USER, Optional.of(CATALOG), Optional.of("tiny"), "(SELECT orderkey FROM orders)")); + + assertions.assertFails("SELECT orderkey FROM orders", ".*\\QColumn mask for 'local.tiny.orders.orderkey' is recursive\\E.*"); + }); + + // different reference style to same table + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "orderkey", + USER, + new ViewExpression(USER, Optional.of(CATALOG), Optional.of("tiny"), "(SELECT orderkey FROM local.tiny.orders)")); + + assertions.assertFails("SELECT orderkey FROM orders", ".*\\QColumn mask for 'local.tiny.orders.orderkey' is recursive\\E.*"); + }); + + // mutual recursion + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "orderkey", + RUN_AS_USER, + new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "(SELECT orderkey FROM orders)")); + + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "orderkey", + USER, + new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "(SELECT orderkey FROM orders)")); + + assertions.assertFails("SELECT orderkey FROM orders", ".*\\QColumn mask for 'local.tiny.orders.orderkey' is recursive\\E.*"); + }); + } + + @Test + public void testLimitedScope() + { + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "customer"), + "custkey", + USER, + new ViewExpression(USER, Optional.of(CATALOG), Optional.of("tiny"), "orderkey")); + assertions.assertFails( + "SELECT (SELECT min(custkey) FROM customer WHERE customer.custkey = orders.custkey) FROM orders", + "\\Qline 1:1: Column 'orderkey' cannot be resolved\\E"); + }); + } + + @Test + public void testSqlInjection() + { + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "nation"), + "name", + USER, + new ViewExpression(USER, Optional.of(CATALOG), Optional.of("tiny"), "(SELECT name FROM region WHERE regionkey = 0)")); + assertions.assertQuery( + "WITH region(regionkey, name) AS (VALUES (0, 'ASIA'))" + + "SELECT name FROM nation ORDER BY name LIMIT 1", + "VALUES CAST('AFRICA' AS VARCHAR(25))"); // if sql-injection would work then query would return ASIA + }); + } + + @Test + public void testInvalidMasks() + { + // parse error + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "orderkey", + USER, + new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "$$$")); + + assertions.assertFails("SELECT orderkey FROM orders", "\\QInvalid column mask for 'local.tiny.orders.orderkey': mismatched input '$'. Expecting: \\E"); + }); + + // unknown column + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "orderkey", + USER, + new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "unknown_column")); + + assertions.assertFails("SELECT orderkey FROM orders", "\\Qline 1:1: Column 'unknown_column' cannot be resolved\\E"); + }); + + // invalid type + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "orderkey", + USER, + new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "'foo'")); + + assertions.assertFails("SELECT orderkey FROM orders", "\\QExpected column mask for 'local.tiny.orders.orderkey' to be of type bigint, but was varchar(3)\\E"); + }); + + // aggregation + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "orderkey", + USER, + new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "count(*) > 0")); + + assertions.assertFails("SELECT orderkey FROM orders", "\\Qline 1:10: Column mask for 'orders.orderkey' cannot contain aggregations, window functions or grouping operations: [\"count\"(*)]\\E"); + }); + + // window function + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "orderkey", + USER, + new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "row_number() OVER () > 0")); + + assertions.assertFails("SELECT orderkey FROM orders", "\\Qline 1:22: Column mask for 'orders.orderkey' cannot contain aggregations, window functions or grouping operations: [\"row_number\"() OVER ()]\\E"); + }); + + // grouping function + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "orderkey", + USER, + new ViewExpression(USER, Optional.of(CATALOG), Optional.of("tiny"), "grouping(orderkey) = 0")); + + assertions.assertFails("SELECT orderkey FROM orders", "\\Qline 1:20: Column mask for 'orders.orderkey' cannot contain aggregations, window functions or grouping operations: [GROUPING (orderkey)]\\E"); + }); + } + + @Test + public void testInsertWithColumnMasking() + { + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "clerk", + USER, + new ViewExpression(USER, Optional.of(CATALOG), Optional.of("tiny"), "clerk")); + + assertions.assertFails("INSERT INTO orders SELECT * FROM orders", "Insert into table with column masks is not supported"); + }); + } + + @Test + public void testDeleteWithColumnMasking() + { + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + "clerk", + USER, + new ViewExpression(USER, Optional.of(CATALOG), Optional.of("tiny"), "clerk")); + + assertions.assertFails("DELETE FROM orders", "\\Qline 1:1: Delete from table with column mask is not supported\\E"); + }); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/query/TestRowFilter.java b/presto-main-base/src/test/java/com/facebook/presto/sql/query/TestRowFilter.java new file mode 100644 index 0000000000000..075db8d3d71a0 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/query/TestRowFilter.java @@ -0,0 +1,334 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.query; + +import com.facebook.presto.Session; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.spi.security.Identity; +import com.facebook.presto.spi.security.ViewExpression; +import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.testing.TestingAccessControlManager; +import com.facebook.presto.tpch.TpchConnectorFactory; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; + +public class TestRowFilter +{ + private static final String CATALOG = "local"; + private static final String USER = "user"; + private static final String RUN_AS_USER = "run-as-user"; + + private QueryAssertions assertions; + private TestingAccessControlManager accessControl; + + @BeforeClass + public void init() + { + Session session = testSessionBuilder() + .setCatalog(CATALOG) + .setSchema(TINY_SCHEMA_NAME) + .setIdentity(new Identity(USER, Optional.empty())).build(); + + LocalQueryRunner runner = new LocalQueryRunner(session); + + runner.createCatalog(CATALOG, new TpchConnectorFactory(1), ImmutableMap.of()); + + assertions = new QueryAssertions(runner); + accessControl = assertions.getQueryRunner().getAccessControl(); + } + + @AfterClass(alwaysRun = true) + public void teardown() + { + assertions.close(); + assertions = null; + } + + @Test + public void testSimpleFilter() + { + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey < 10")); + assertions.assertQuery("SELECT count(*) FROM orders", "VALUES BIGINT '7'"); + }); + + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "NULL")); + assertions.assertQuery("SELECT count(*) FROM orders", "VALUES BIGINT '0'"); + }); + } + + @Test + public void testMultipleFilters() + { + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey < 10")); + + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey > 5")); + + assertions.assertQuery("SELECT count(*) FROM orders", "VALUES BIGINT '2'"); + }); + } + + @Test + public void testCorrelatedSubquery() + { + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + USER, + new ViewExpression(USER, Optional.of(CATALOG), Optional.of("tiny"), "EXISTS (SELECT 1 FROM nation WHERE nationkey = orderkey)")); + assertions.assertQuery("SELECT count(*) FROM orders", "VALUES BIGINT '7'"); + }); + } + + @Test + public void testTableReferenceInWithClause() + { + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + USER, + new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey = 1")); + assertions.assertQuery("WITH t AS (SELECT count(*) FROM orders) SELECT * FROM t", "VALUES BIGINT '1'"); + }); + } + + @Test + public void testOtherSchema() + { + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + USER, + new ViewExpression(USER, Optional.of(CATALOG), Optional.of("sf1"), "(SELECT count(*) FROM customer) = 150000")); // Filter is TRUE only if evaluating against sf1.customer + assertions.assertQuery("SELECT count(*) FROM orders", "VALUES BIGINT '15000'"); + }); + } + + @Test + public void testDifferentIdentity() + { + // This does not fail the recursive check because the initial filter is added to the subquery with RUN_AS_USER identity, + // then the second filter is added with USER identity, allowing both filters to produce 1 row in the result. + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + RUN_AS_USER, + new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "orderkey = 1")); + + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + USER, + new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)")); + + assertions.assertQuery("SELECT count(*) FROM orders", "VALUES BIGINT '1'"); + }); + } + + @Test + public void testRecursion() + { + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + USER, + new ViewExpression(USER, Optional.of(CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)")); + + assertions.assertFails("SELECT count(*) FROM orders", ".*\\QRow filter for 'local.tiny.orders' is recursive\\E.*"); + }); + + // different reference style to same table + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + USER, + new ViewExpression(USER, Optional.of(CATALOG), Optional.of("tiny"), "orderkey IN (SELECT local.tiny.orderkey FROM orders)")); + assertions.assertFails("SELECT count(*) FROM orders", ".*\\QRow filter for 'local.tiny.orders' is recursive\\E.*"); + }); + + // mutual recursion + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + RUN_AS_USER, + new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)")); + + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + USER, + new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)")); + + assertions.assertFails("SELECT count(*) FROM orders", ".*\\QRow filter for 'local.tiny.orders' is recursive\\E.*"); + }); + } + + @Test + public void testLimitedScope() + { + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "customer"), + USER, + new ViewExpression(USER, Optional.of(CATALOG), Optional.of("tiny"), "orderkey = 1")); + assertions.assertFails( + "SELECT (SELECT min(name) FROM customer WHERE customer.custkey = orders.custkey) FROM orders", + "\\Qline 1:1: Column 'orderkey' cannot be resolved\\E"); + }); + } + + @Test + public void testSqlInjection() + { + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "nation"), + USER, + new ViewExpression(USER, Optional.of(CATALOG), Optional.of("tiny"), "regionkey IN (SELECT regionkey FROM region WHERE name = 'ASIA')")); + assertions.assertQuery( + "WITH region(regionkey, name) AS (VALUES (0, 'ASIA'), (1, 'ASIA'), (2, 'ASIA'), (3, 'ASIA'), (4, 'ASIA'))" + + "SELECT name FROM nation ORDER BY name LIMIT 1", + "VALUES CAST('CHINA' AS VARCHAR(25))"); // if sql-injection would work then query would return ALGERIA + }); + } + + @Test + public void testInvalidFilter() + { + // parse error + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + USER, + new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "$$$")); + + assertions.assertFails("SELECT count(*) FROM orders", "\\QInvalid row filter for 'local.tiny.orders': mismatched input '$'. Expecting: \\E"); + }); + + // unknown column + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + USER, + new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "unknown_column")); + + assertions.assertFails("SELECT count(*) FROM orders", "line 1:1: Column 'unknown_column' cannot be resolved"); + }); + + // invalid type + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + USER, + new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "1")); + + assertions.assertFails("SELECT count(*) FROM orders", "\\QExpected row filter for 'local.tiny.orders' to be of type BOOLEAN, but was integer\\E"); + }); + + // aggregation + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + USER, + new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "count(*) > 0")); + + assertions.assertFails("SELECT count(*) FROM orders", "\\Qline 1:10: Row filter for 'local.tiny.orders' cannot contain aggregations, window functions or grouping operations: [\"count\"(*)]\\E"); + }); + + // window function + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + USER, + new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "row_number() OVER () > 0")); + + assertions.assertFails("SELECT count(*) FROM orders", "\\Qline 1:22: Row filter for 'local.tiny.orders' cannot contain aggregations, window functions or grouping operations: [\"row_number\"() OVER ()]\\E"); + }); + + // window function + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + USER, + new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "grouping(orderkey) = 0")); + + assertions.assertFails("SELECT count(*) FROM orders", "\\Qline 1:20: Row filter for 'local.tiny.orders' cannot contain aggregations, window functions or grouping operations: [GROUPING (orderkey)]\\E"); + }); + } + + @Test + public void testInsertWithRowFiltering() + { + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + USER, + new ViewExpression(USER, Optional.of(CATALOG), Optional.of("tiny"), "orderkey < 10")); + + assertions.assertFails("INSERT INTO orders SELECT * FROM orders", "Insert into table with row filter is not supported"); + }); + } + + @Test + public void testDeleteWithRowFiltering() + { + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.rowFilter( + new QualifiedObjectName(CATALOG, "tiny", "orders"), + USER, + new ViewExpression(USER, Optional.of(CATALOG), Optional.of("tiny"), "orderkey < 10")); + + assertions.assertFails("DELETE FROM orders", "\\Qline 1:1: Delete from table with row filter is not supported\\E"); + }); + } +} diff --git a/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/AllowAllAccessControl.java b/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/AllowAllAccessControl.java index 30e99555fe432..8b9fccdaf5940 100644 --- a/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/AllowAllAccessControl.java +++ b/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/AllowAllAccessControl.java @@ -14,6 +14,7 @@ package com.facebook.presto.plugin.base.security; import com.facebook.presto.common.Subfield; +import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorAccessControl; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; @@ -21,7 +22,11 @@ import com.facebook.presto.spi.security.ConnectorIdentity; import com.facebook.presto.spi.security.PrestoPrincipal; import com.facebook.presto.spi.security.Privilege; +import com.facebook.presto.spi.security.ViewExpression; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -210,4 +215,16 @@ public void checkCanDropConstraint(ConnectorTransactionHandle transactionHandle, public void checkCanAddConstraint(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) { } + + @Override + public List getRowFilters(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) + { + return ImmutableList.of(); + } + + @Override + public Map getColumnMasks(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName, List columns) + { + return ImmutableMap.of(); + } } diff --git a/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/FileBasedAccessControl.java b/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/FileBasedAccessControl.java index e1ee1b874815c..fc1f66c700c6f 100644 --- a/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/FileBasedAccessControl.java +++ b/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/FileBasedAccessControl.java @@ -15,6 +15,7 @@ import com.facebook.presto.common.Subfield; import com.facebook.presto.plugin.base.security.TableAccessControlRule.TablePrivilege; +import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorAccessControl; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; @@ -23,6 +24,9 @@ import com.facebook.presto.spi.security.ConnectorIdentity; import com.facebook.presto.spi.security.PrestoPrincipal; import com.facebook.presto.spi.security.Privilege; +import com.facebook.presto.spi.security.ViewExpression; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import javax.inject.Inject; @@ -336,6 +340,18 @@ public void checkCanAddConstraint(ConnectorTransactionHandle transactionHandle, } } + @Override + public List getRowFilters(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) + { + return ImmutableList.of(); + } + + @Override + public Map getColumnMasks(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName, List columns) + { + return ImmutableMap.of(); + } + private boolean canSetSessionProperty(ConnectorIdentity identity, String property) { for (SessionPropertyAccessControlRule rule : sessionPropertyRules) { diff --git a/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/ForwardingConnectorAccessControl.java b/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/ForwardingConnectorAccessControl.java index dd28ee9183d56..0987eaf8640bd 100644 --- a/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/ForwardingConnectorAccessControl.java +++ b/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/ForwardingConnectorAccessControl.java @@ -14,6 +14,7 @@ package com.facebook.presto.plugin.base.security; import com.facebook.presto.common.Subfield; +import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorAccessControl; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; @@ -21,7 +22,9 @@ import com.facebook.presto.spi.security.ConnectorIdentity; import com.facebook.presto.spi.security.PrestoPrincipal; import com.facebook.presto.spi.security.Privilege; +import com.facebook.presto.spi.security.ViewExpression; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -262,4 +265,16 @@ public void checkCanAddConstraint(ConnectorTransactionHandle transactionHandle, { delegate().checkCanAddConstraint(transactionHandle, identity, context, tableName); } + + @Override + public List getRowFilters(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) + { + return delegate().getRowFilters(transactionHandle, identity, context, tableName); + } + + @Override + public Map getColumnMasks(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName, List columns) + { + return delegate().getColumnMasks(transactionHandle, identity, context, tableName, columns); + } } diff --git a/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/ForwardingSystemAccessControl.java b/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/ForwardingSystemAccessControl.java index 6e9cd6a91c715..6f616ad88bd67 100644 --- a/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/ForwardingSystemAccessControl.java +++ b/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/ForwardingSystemAccessControl.java @@ -15,6 +15,7 @@ import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.spi.CatalogSchemaTableName; +import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.security.AccessControlContext; import com.facebook.presto.spi.security.AuthorizedIdentity; @@ -22,10 +23,12 @@ import com.facebook.presto.spi.security.PrestoPrincipal; import com.facebook.presto.spi.security.Privilege; import com.facebook.presto.spi.security.SystemAccessControl; +import com.facebook.presto.spi.security.ViewExpression; import java.security.Principal; import java.security.cert.X509Certificate; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.function.Supplier; @@ -253,4 +256,16 @@ public void checkCanAddConstraint(Identity identity, AccessControlContext contex { delegate().checkCanAddConstraint(identity, context, table); } + + @Override + public List getRowFilters(Identity identity, AccessControlContext context, CatalogSchemaTableName tableName) + { + return delegate().getRowFilters(identity, context, tableName); + } + + @Override + public Map getColumnMasks(Identity identity, AccessControlContext context, CatalogSchemaTableName tableName, List columns) + { + return delegate().getColumnMasks(identity, context, tableName, columns); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/StandardErrorCode.java b/presto-spi/src/main/java/com/facebook/presto/spi/StandardErrorCode.java index 4de55e4aa38d6..df884221f3af0 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/StandardErrorCode.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/StandardErrorCode.java @@ -75,6 +75,9 @@ public enum StandardErrorCode INVALID_LIMIT_CLAUSE(0x0000_0031, USER_ERROR), COLUMN_NOT_FOUND(0x0000_0032, USER_ERROR), UNKNOWN_TYPE(0x0000_0033, USER_ERROR), + INVALID_ROW_FILTER(0x0000_0034, USER_ERROR), + INVALID_COLUMN_MASK(0x0000_0035, USER_ERROR), + DATATYPE_MISMATCH(0x0000_0036, USER_ERROR), GENERIC_INTERNAL_ERROR(0x0001_0000, INTERNAL_ERROR), TOO_MANY_REQUESTS_FAILED(0x0001_0001, INTERNAL_ERROR, true), diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorAccessControl.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorAccessControl.java index e5621d458c0b0..a4d6ac064b408 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorAccessControl.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorAccessControl.java @@ -14,12 +14,16 @@ package com.facebook.presto.spi.connector; import com.facebook.presto.common.Subfield; +import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.security.AccessControlContext; import com.facebook.presto.spi.security.ConnectorIdentity; import com.facebook.presto.spi.security.PrestoPrincipal; import com.facebook.presto.spi.security.Privilege; +import com.facebook.presto.spi.security.ViewExpression; +import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -405,4 +409,29 @@ default void checkCanAddConstraint(ConnectorTransactionHandle transactionHandle, { denyAddConstraint(tableName.toString()); } + + /** + * Get row filters associated with the given table and identity. + *

+ * Each filter must be a scalar SQL expression of boolean type over the columns in the table. + * + * @return the list of filters, or empty list if not applicable + */ + default List getRowFilters(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) + { + return Collections.emptyList(); + } + + /** + * Bulk method for getting column masks for a subset of columns in a table. + *

+ * Each mask must be a scalar SQL expression of a type coercible to the type of the column being masked. The expression + * must be written in terms of columns in the table. + * + * @return a mapping from columns to masks, or an empty map if not applicable. The keys of the return Map are a subset of {@code columns}. + */ + default Map getColumnMasks(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName, List columns) + { + return Collections.emptyMap(); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/security/AccessControl.java b/presto-spi/src/main/java/com/facebook/presto/spi/security/AccessControl.java index 131c0af1577d2..8041a8ed82397 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/security/AccessControl.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/security/AccessControl.java @@ -17,10 +17,12 @@ import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.Subfield; import com.facebook.presto.common.transaction.TransactionId; +import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.SchemaTableName; import java.security.Principal; import java.security.cert.X509Certificate; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; @@ -322,4 +324,29 @@ default AuthorizedIdentity selectAuthorizedIdentity(Identity identity, AccessCon * @throws com.facebook.presto.spi.security.AccessDeniedException if not allowed */ void checkCanAddConstraints(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName constraintName); + + /** + * Get row filters associated with the given table and identity. + *

+ * Each filter must be a scalar SQL expression of boolean type over the columns in the table. + * + * @return the list of filters, or empty list if not applicable + */ + default List getRowFilters(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName) + { + return Collections.emptyList(); + } + + /** + * Bulk method for getting column masks for a subset of columns in a table. + *

+ * Each mask must be a scalar SQL expression of a type coercible to the type of the column being masked. The expression + * must be written in terms of columns in the table. + * + * @return a mapping from columns to masks, or an empty map if not applicable. The keys of the return Map are a subset of {@code columns}. + */ + default Map getColumnMasks(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName, List columns) + { + return Collections.emptyMap(); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/security/SystemAccessControl.java b/presto-spi/src/main/java/com/facebook/presto/spi/security/SystemAccessControl.java index fc5e8cbf6805f..f0601674c7257 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/security/SystemAccessControl.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/security/SystemAccessControl.java @@ -15,12 +15,14 @@ import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.spi.CatalogSchemaTableName; +import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.SchemaTableName; import java.security.Principal; import java.security.cert.X509Certificate; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; @@ -380,4 +382,29 @@ default void checkCanAddConstraint(Identity identity, AccessControlContext conte { denyAddConstraint(table.toString()); } + + /** + * Get row filters associated with the given table and identity. + *

+ * Each filter must be a scalar SQL expression of boolean type over the columns in the table. + * + * @return a list of filters, or empty list if not applicable + */ + default List getRowFilters(Identity identity, AccessControlContext context, CatalogSchemaTableName tableName) + { + return Collections.emptyList(); + } + + /** + * Bulk method for getting column masks for a subset of columns in a table. + *

+ * Each mask must be a scalar SQL expression of a type coercible to the type of the column being masked. The expression + * must be written in terms of columns in the table. + * + * @return a mapping from columns to masks, or an empty map if not applicable. The keys of the return Map are a subset of {@code columns}. + */ + default Map getColumnMasks(Identity identity, AccessControlContext context, CatalogSchemaTableName tableName, List columns) + { + return Collections.emptyMap(); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/security/ViewExpression.java b/presto-spi/src/main/java/com/facebook/presto/spi/security/ViewExpression.java new file mode 100644 index 0000000000000..53a253ebcf138 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/security/ViewExpression.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.security; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class ViewExpression +{ + private final String identity; + private final Optional catalog; + private final Optional schema; + private final String expression; + + public ViewExpression(String identity, Optional catalog, Optional schema, String expression) + { + this.identity = requireNonNull(identity, "identity is null"); + this.catalog = requireNonNull(catalog, "catalog is null"); + this.schema = requireNonNull(schema, "schema is null"); + this.expression = requireNonNull(expression, "expression is null"); + + if (!catalog.isPresent() && schema.isPresent()) { + throw new IllegalArgumentException("catalog must be present if schema is present"); + } + } + + public String getIdentity() + { + return identity; + } + + public Optional getCatalog() + { + return catalog; + } + + public Optional getSchema() + { + return schema; + } + + public String getExpression() + { + return expression; + } +}