Skip to content

🍒 8885, 8952 - Optimize IAST Vulnerability Detection... #9241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,27 @@ public IastRequestContext() {
}

public IastRequestContext(final TaintedObjects taintedObjects) {
this(taintedObjects, false);
}

public IastRequestContext(final TaintedObjects taintedObjects, boolean isGlobal) {
this.vulnerabilityBatch = new VulnerabilityBatch();
this.overheadContext =
new OverheadContext(Config.get().getIastVulnerabilitiesPerRequest(), isGlobal);
this.taintedObjects = taintedObjects;
}

/**
* Use this constructor only when you want to create a new context with a fresh overhead context
* (e.g. for testing purposes).
*
* @param taintedObjects the tainted objects to use
* @param overheadContext the overhead context to use
*/
public IastRequestContext(
final TaintedObjects taintedObjects, final OverheadContext overheadContext) {
this.vulnerabilityBatch = new VulnerabilityBatch();
this.overheadContext = new OverheadContext(Config.get().getIastVulnerabilitiesPerRequest());
this.overheadContext = overheadContext;
this.taintedObjects = taintedObjects;
}

Expand Down Expand Up @@ -188,6 +207,7 @@ public void releaseRequestContext(@Nonnull final IastContext context) {
pool.offer(unwrapped);
iastCtx.setTaintedObjects(TaintedObjects.NoOp.INSTANCE);
}
iastCtx.overheadContext.resetMaps();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ private VulnerabilityBatch getOrCreateVulnerabilityBatch(final AgentSpan span) {
private AgentSpan startNewSpan() {
final AgentSpanContext tagContext =
new TagContext()
.withRequestContextDataIast(new IastRequestContext(TaintedObjects.NoOp.INSTANCE));
.withRequestContextDataIast(new IastRequestContext(TaintedObjects.NoOp.INSTANCE, true));
final AgentSpan span =
tracer()
.startSpan("iast", VULNERABILITY_SPAN_NAME, tagContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,60 @@
import static datadog.trace.api.iast.IastDetectionMode.UNLIMITED;

import com.datadog.iast.util.NonBlockingSemaphore;
import datadog.trace.api.iast.VulnerabilityTypes;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicIntegerArray;
import java.util.function.Function;
import javax.annotation.Nullable;
import org.jetbrains.annotations.NotNull;

public class OverheadContext {

/** Maximum number of distinct endpoints to remember in the global cache. */
private static final int GLOBAL_MAP_MAX_SIZE = 4096;

/**
* Global concurrent cache mapping each “method + path” key to its historical vulnerabilityCounts
* map. As soon as size() > GLOBAL_MAP_MAX_SIZE, we clear() the whole map.
*/
static final ConcurrentMap<String, AtomicIntegerArray> globalMap =
new ConcurrentHashMap<String, AtomicIntegerArray>() {

@Override
public AtomicIntegerArray computeIfAbsent(
String key,
@NotNull Function<? super String, ? extends AtomicIntegerArray> mappingFunction) {
if (this.size() >= GLOBAL_MAP_MAX_SIZE) {
super.clear();
}
return super.computeIfAbsent(key, mappingFunction);
}
};

// Snapshot of the globalMap for the current request
private @Nullable final Map<String, int[]> copyMap;
// Map of vulnerabilities per endpoint for the current request, needs to use AtomicIntegerArray
// because it's possible to have concurrent updates in the same request
private @Nullable final Map<String, AtomicIntegerArray> requestMap;

private final NonBlockingSemaphore availableVulnerabilities;
private final boolean isGlobal;

public OverheadContext(final int vulnerabilitiesPerRequest) {
this(vulnerabilitiesPerRequest, false);
}

public OverheadContext(final int vulnerabilitiesPerRequest, final boolean isGlobal) {
availableVulnerabilities =
vulnerabilitiesPerRequest == UNLIMITED
? NonBlockingSemaphore.unlimited()
: NonBlockingSemaphore.withPermitCount(vulnerabilitiesPerRequest);
this.isGlobal = isGlobal;
this.requestMap = isGlobal ? null : new ConcurrentHashMap<>();
this.copyMap = isGlobal ? null : new ConcurrentHashMap<>();
}

public int getAvailableQuota() {
Expand All @@ -26,4 +70,52 @@ public boolean consumeQuota(final int delta) {
public void reset() {
availableVulnerabilities.reset();
}

public void resetMaps() {
// If this is a global context, we do not reset the maps
if (isGlobal || requestMap == null || copyMap == null) {
return;
}
Set<String> endpoints = requestMap.keySet();
// If the budget is not consumed, we can reset the maps
if (getAvailableQuota() > 0) {
// clean endpoints from globalMap
endpoints.forEach(globalMap::remove);
return;
}
// If the budget is consumed, we need to merge the requestMap into the globalMap
endpoints.forEach(
endpoint -> {
AtomicIntegerArray countMap = requestMap.get(endpoint);
// should not happen, but just in case
if (countMap == null) {
globalMap.remove(endpoint);
return;
}
// Iterate over the vulnerabilities and update the globalMap
int numberOfVulnerabilities = VulnerabilityTypes.STRINGS.length;
for (int i = 0; i < numberOfVulnerabilities; i++) {
int counter = countMap.get(i);
if (counter > 0) {
AtomicIntegerArray globalCountMap =
globalMap.computeIfAbsent(
endpoint, value -> new AtomicIntegerArray(numberOfVulnerabilities));

globalCountMap.accumulateAndGet(i, counter, Math::max);
}
}
});
}

public boolean isGlobal() {
return isGlobal;
}

public @Nullable Map<String, int[]> getCopyMap() {
return copyMap;
}

public @Nullable Map<String, AtomicIntegerArray> getRequestMap() {
return requestMap;
}
}
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
package com.datadog.iast.overhead;

import static com.datadog.iast.overhead.OverheadContext.globalMap;
import static datadog.trace.api.iast.IastDetectionMode.UNLIMITED;

import com.datadog.iast.IastRequestContext;
import com.datadog.iast.IastSystem;
import com.datadog.iast.model.VulnerabilityType;
import com.datadog.iast.util.NonBlockingSemaphore;
import datadog.trace.api.Config;
import datadog.trace.api.gateway.RequestContext;
import datadog.trace.api.gateway.RequestContextSlot;
import datadog.trace.api.iast.IastContext;
import datadog.trace.api.iast.VulnerabilityTypes;
import datadog.trace.api.telemetry.LogCollector;
import datadog.trace.bootstrap.instrumentation.api.AgentSpan;
import datadog.trace.bootstrap.instrumentation.api.AgentTracer;
import datadog.trace.bootstrap.instrumentation.api.Tags;
import datadog.trace.util.AgentTaskScheduler;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicIntegerArray;
import java.util.concurrent.atomic.AtomicLong;
import javax.annotation.Nullable;
import org.slf4j.Logger;
Expand All @@ -27,9 +32,12 @@ public interface OverheadController {

int releaseRequest();

boolean hasQuota(final Operation operation, @Nullable final AgentSpan span);
boolean hasQuota(Operation operation, @Nullable AgentSpan span);

boolean consumeQuota(final Operation operation, @Nullable final AgentSpan span);
boolean consumeQuota(Operation operation, @Nullable AgentSpan span);

boolean consumeQuota(
Operation operation, @Nullable AgentSpan span, @Nullable VulnerabilityType type);

static OverheadController build(final Config config, final AgentTaskScheduler scheduler) {
return build(
Expand Down Expand Up @@ -100,14 +108,23 @@ public boolean hasQuota(final Operation operation, @Nullable final AgentSpan spa

@Override
public boolean consumeQuota(final Operation operation, @Nullable final AgentSpan span) {
final boolean result = delegate.consumeQuota(operation, span);
return consumeQuota(operation, span, null);
}

@Override
public boolean consumeQuota(
final Operation operation,
@Nullable final AgentSpan span,
@Nullable final VulnerabilityType type) {
final boolean result = delegate.consumeQuota(operation, span, type);
if (LOGGER.isDebugEnabled()) {
LOGGER.debug(
"consumeQuota: operation={}, result={}, availableQuota={}, span={}",
"consumeQuota: operation={}, result={}, availableQuota={}, span={}, type={}",
operation,
result,
getAvailableQuote(span),
span);
span,
type);
}
return result;
}
Expand Down Expand Up @@ -147,7 +164,7 @@ class OverheadControllerImpl implements OverheadController {
private volatile long lastAcquiredTimestamp = Long.MAX_VALUE;

final OverheadContext globalContext =
new OverheadContext(Config.get().getIastVulnerabilitiesPerRequest());
new OverheadContext(Config.get().getIastVulnerabilitiesPerRequest(), true);

public OverheadControllerImpl(
final float requestSampling,
Expand Down Expand Up @@ -192,7 +209,96 @@ public boolean hasQuota(final Operation operation, @Nullable final AgentSpan spa

@Override
public boolean consumeQuota(final Operation operation, @Nullable final AgentSpan span) {
return operation.consumeQuota(getContext(span));
return consumeQuota(operation, span, null);
}

@Override
public boolean consumeQuota(
final Operation operation,
@Nullable final AgentSpan span,
@Nullable final VulnerabilityType type) {

OverheadContext ctx = getContext(span);
if (ctx == null) {
return false;
}
if (ctx.isGlobal()) {
return operation.consumeQuota(ctx);
}
if (operation.hasQuota(ctx)) {
String method = null;
String path = null;
if (span != null) {
AgentSpan rootSpan = span.getLocalRootSpan();
Object methodTag = rootSpan.getTag(Tags.HTTP_METHOD);
method = (methodTag == null) ? "" : methodTag.toString();
Object routeTag = rootSpan.getTag(Tags.HTTP_ROUTE);
path = (routeTag == null) ? "" : routeTag.toString();
}
if (!maybeSkipVulnerability(ctx, type, method, path)) {
return operation.consumeQuota(ctx);
}
}
return false;
}

/**
* Method to be called when a vulnerability of a certain type is detected. Implements the
* RFC-1029 algorithm.
*
* @param ctx the overhead context for the current request
* @param type the type of vulnerability detected
* @param httpMethod the HTTP method of the request (e.g., GET, POST)
* @param httpPath the HTTP path of the request
* @return true if the vulnerability should be skipped, false otherwise
*/
private boolean maybeSkipVulnerability(
@Nullable final OverheadContext ctx,
@Nullable final VulnerabilityType type,
@Nullable final String httpMethod,
@Nullable final String httpPath) {

if (ctx == null || type == null || ctx.getRequestMap() == null || ctx.getCopyMap() == null) {
return false;
}

int numberOfVulnerabilities = VulnerabilityTypes.STRINGS.length;

String currentEndpoint = httpMethod + " " + httpPath;

AtomicIntegerArray requestArray = ctx.getRequestMap().get(currentEndpoint);
int[] copyArray;

if (requestArray == null) {
AtomicIntegerArray globalArray =
globalMap.computeIfAbsent(
currentEndpoint, k -> new AtomicIntegerArray(numberOfVulnerabilities));
copyArray = toIntArray(globalArray);
ctx.getCopyMap().put(currentEndpoint, copyArray);
requestArray =
ctx.getRequestMap()
.computeIfAbsent(
currentEndpoint, k -> new AtomicIntegerArray(numberOfVulnerabilities));
} else {
copyArray = ctx.getCopyMap().get(currentEndpoint);
}

int counter = requestArray.getAndIncrement(type.type());
int storedCounter = 0;
if (copyArray != null) {
storedCounter = copyArray[type.type()];
}

return counter < storedCounter;
}

private static int[] toIntArray(AtomicIntegerArray atomic) {
int length = atomic.length();
int[] result = new int[length];
for (int i = 0; i < length; i++) {
result[i] = atomic.get(i);
}
return result;
}

@Nullable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.datadog.iast.sink;

import static com.datadog.iast.model.VulnerabilityType.INSECURE_COOKIE;
import static com.datadog.iast.util.HttpHeader.SET_COOKIE;
import static com.datadog.iast.util.HttpHeader.SET_COOKIE2;
import static java.util.Collections.singletonList;
Expand Down Expand Up @@ -65,7 +66,9 @@ private void onCookies(final List<Cookie> cookies) {
return;
}
final AgentSpan span = AgentTracer.activeSpan();
if (!overheadController.consumeQuota(Operations.REPORT_VULNERABILITY, span)) {
if (!overheadController.consumeQuota(
Operations.REPORT_VULNERABILITY, span, INSECURE_COOKIE // we need a type to check quota
)) {
return;
}
final Location location = Location.forSpanAndStack(span, getCurrentStackTrace());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ protected void report(final Vulnerability vulnerability) {
}

protected void report(@Nullable final AgentSpan span, final Vulnerability vulnerability) {
if (!overheadController.consumeQuota(Operations.REPORT_VULNERABILITY, span)) {
if (!overheadController.consumeQuota(
Operations.REPORT_VULNERABILITY, span, vulnerability.getType())) {
return;
}
reporter.report(span, vulnerability);
Expand All @@ -70,7 +71,7 @@ protected void report(final VulnerabilityType type, final Evidence evidence) {

protected void report(
@Nullable final AgentSpan span, final VulnerabilityType type, final Evidence evidence) {
if (!overheadController.consumeQuota(Operations.REPORT_VULNERABILITY, span)) {
if (!overheadController.consumeQuota(Operations.REPORT_VULNERABILITY, span, type)) {
return;
}
final Vulnerability vulnerability =
Expand Down Expand Up @@ -170,7 +171,7 @@ protected final Evidence checkInjection(
}

final AgentSpan span = AgentTracer.activeSpan();
if (!overheadController.consumeQuota(Operations.REPORT_VULNERABILITY, span)) {
if (!overheadController.consumeQuota(Operations.REPORT_VULNERABILITY, span, type)) {
return null;
}

Expand Down Expand Up @@ -251,7 +252,7 @@ protected final Evidence checkInjection(
if (!spanFetched && valueRanges != null && valueRanges.length > 0) {
span = AgentTracer.activeSpan();
spanFetched = true;
if (!overheadController.consumeQuota(Operations.REPORT_VULNERABILITY, span)) {
if (!overheadController.consumeQuota(Operations.REPORT_VULNERABILITY, span, type)) {
return null;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class IastModuleImplTestBase extends DDSpecification {
return Stub(OverheadController) {
acquireRequest() >> true
consumeQuota(_ as Operation, _) >> true
consumeQuota(_ as Operation, _, _) >> true
}
}
}
Loading
Loading