diff --git a/dd-trace-api/src/main/java/datadog/trace/api/config/GeneralConfig.java b/dd-trace-api/src/main/java/datadog/trace/api/config/GeneralConfig.java index 9a5caf5c4a2..a882a097e59 100644 --- a/dd-trace-api/src/main/java/datadog/trace/api/config/GeneralConfig.java +++ b/dd-trace-api/src/main/java/datadog/trace/api/config/GeneralConfig.java @@ -94,6 +94,7 @@ public final class GeneralConfig { public static final String AGENTLESS_LOG_SUBMISSION_LEVEL = "agentless.log.submission.level"; public static final String AGENTLESS_LOG_SUBMISSION_URL = "agentless.log.submission.url"; public static final String APM_TRACING_ENABLED = "apm.tracing.enabled"; + public static final String JDK_SOCKET_ENABLED = "jdk.socket.enabled"; private GeneralConfig() {} } diff --git a/gradle/java_no_deps.gradle b/gradle/java_no_deps.gradle index 4a05431c074..95a87f0e8ed 100644 --- a/gradle/java_no_deps.gradle +++ b/gradle/java_no_deps.gradle @@ -59,9 +59,15 @@ if (project.hasProperty('minJavaVersionForTests') && project.getProperty('minJav targetCompatibility = version } + // "socket-utils" is only set to compileOnly because the implementation dependency incorrectly adds Java17 classes to all jar prefixes. + // This causes the AgentJarIndex to search for other non-Java17 classes in the wrong prefix location and fail to resolve class names. dependencies { - compileOnly files(project.sourceSets."main_$name".compileClasspath) - implementation files(project.sourceSets."main_$name".output) + if ("${project.projectDir}".endsWith("socket-utils")) { + compileOnly files(project.sourceSets."main_$name".output) + } else { + compileOnly files(project.sourceSets."main_$name".compileClasspath) + implementation files(project.sourceSets."main_$name".output) + } } jar { diff --git a/internal-api/src/main/java/datadog/trace/api/Config.java b/internal-api/src/main/java/datadog/trace/api/Config.java index 004a609898e..6554f6517d9 100644 --- a/internal-api/src/main/java/datadog/trace/api/Config.java +++ b/internal-api/src/main/java/datadog/trace/api/Config.java @@ -557,6 +557,8 @@ public static String getHostName() { private final boolean apmTracingEnabled; + private final boolean jdkSocketEnabled; + // Read order: System Properties -> Env Variables, [-> properties file], [-> default value] private Config() { this(ConfigProvider.createDefault()); @@ -1924,6 +1926,8 @@ PROFILING_DATADOG_PROFILER_ENABLED, isDatadogProfilerSafeInCurrentEnvironment()) this.apmTracingEnabled = configProvider.getBoolean(GeneralConfig.APM_TRACING_ENABLED, true); + this.jdkSocketEnabled = configProvider.getBoolean(JDK_SOCKET_ENABLED, true); + log.debug("New instance: {}", this); } @@ -3466,6 +3470,10 @@ public boolean isApmTracingEnabled() { return apmTracingEnabled; } + public boolean isJdkSocketEnabled() { + return jdkSocketEnabled; + } + /** @return A map of tags to be applied only to the local application root span. */ public Map getLocalRootSpanTags() { final Map runtimeTags = getRuntimeTags(); @@ -4703,6 +4711,8 @@ public String toString() { + dataJobsCommandPattern + ", apmTracingEnabled=" + apmTracingEnabled + + ", jdkSocketEnabled=" + + jdkSocketEnabled + ", cloudRequestPayloadTagging=" + cloudRequestPayloadTagging + ", cloudResponsePayloadTagging=" diff --git a/internal-api/src/test/groovy/datadog/trace/api/ConfigTest.groovy b/internal-api/src/test/groovy/datadog/trace/api/ConfigTest.groovy index b3e2a3d9b09..a553488676b 100644 --- a/internal-api/src/test/groovy/datadog/trace/api/ConfigTest.groovy +++ b/internal-api/src/test/groovy/datadog/trace/api/ConfigTest.groovy @@ -50,6 +50,7 @@ import static datadog.trace.api.config.GeneralConfig.GLOBAL_TAGS import static datadog.trace.api.config.GeneralConfig.HEALTH_METRICS_ENABLED import static datadog.trace.api.config.GeneralConfig.HEALTH_METRICS_STATSD_HOST import static datadog.trace.api.config.GeneralConfig.HEALTH_METRICS_STATSD_PORT +import static datadog.trace.api.config.GeneralConfig.JDK_SOCKET_ENABLED import static datadog.trace.api.config.GeneralConfig.PERF_METRICS_ENABLED import static datadog.trace.api.config.GeneralConfig.SERVICE_NAME import static datadog.trace.api.config.GeneralConfig.SITE @@ -257,6 +258,7 @@ class ConfigTest extends DDSpecification { prop.setProperty(DYNAMIC_INSTRUMENTATION_EXCLUDE_FILES, "exclude file") prop.setProperty(EXCEPTION_REPLAY_ENABLED, "true") prop.setProperty(TRACE_X_DATADOG_TAGS_MAX_LENGTH, "128") + prop.setProperty(JDK_SOCKET_ENABLED, "false") when: Config config = Config.get(prop) @@ -348,6 +350,7 @@ class ConfigTest extends DDSpecification { config.dynamicInstrumentationInstrumentTheWorld == true config.dynamicInstrumentationExcludeFiles == "exclude file" config.debuggerExceptionEnabled == true + config.jdkSocketEnabled == false config.xDatadogTagsMaxLength == 128 } diff --git a/utils/socket-utils/build.gradle b/utils/socket-utils/build.gradle index eb09cca849f..d2e05364414 100644 --- a/utils/socket-utils/build.gradle +++ b/utils/socket-utils/build.gradle @@ -1,8 +1,31 @@ +ext { + minJavaVersionForTests = JavaVersion.VERSION_17 +} + apply from: "$rootDir/gradle/java.gradle" +apply plugin: "idea" + +[compileMain_java17Java, compileTestJava].each { + it.configure { + setJavaVersion(it, 17) + sourceCompatibility = JavaVersion.VERSION_1_8 + targetCompatibility = JavaVersion.VERSION_1_8 + } +} dependencies { implementation libs.slf4j implementation project(':internal-api') + implementation libs.jnr.unixsocket + testImplementation files(sourceSets.main_java17.output) +} + +forbiddenApisMain_java17 { + failOnMissingClasses = false +} - implementation group: 'com.github.jnr', name: 'jnr-unixsocket', version: libs.versions.jnr.unixsocket.get() +idea { + module { + jdkName = '17' + } } diff --git a/utils/socket-utils/src/main/java/datadog/common/socket/UnixDomainSocketFactory.java b/utils/socket-utils/src/main/java/datadog/common/socket/UnixDomainSocketFactory.java index bb1929f369b..8e141b939f2 100644 --- a/utils/socket-utils/src/main/java/datadog/common/socket/UnixDomainSocketFactory.java +++ b/utils/socket-utils/src/main/java/datadog/common/socket/UnixDomainSocketFactory.java @@ -3,6 +3,7 @@ import static java.util.concurrent.TimeUnit.MINUTES; import datadog.trace.api.Config; +import datadog.trace.api.Platform; import datadog.trace.relocate.api.RatelimitedLogger; import java.io.File; import java.io.IOException; @@ -24,6 +25,8 @@ public final class UnixDomainSocketFactory extends SocketFactory { private static final Logger log = LoggerFactory.getLogger(UnixDomainSocketFactory.class); + private static final boolean JDK_SUPPORTS_UDS = Platform.isJavaVersionAtLeast(16); + private final RatelimitedLogger rlLog = new RatelimitedLogger(log, 5, MINUTES); private final File path; @@ -35,8 +38,14 @@ public UnixDomainSocketFactory(final File path) { @Override public Socket createSocket() throws IOException { try { - final UnixSocketChannel channel = UnixSocketChannel.open(); - return new TunnelingUnixSocket(path, channel); + if (JDK_SUPPORTS_UDS && Config.get().isJdkSocketEnabled()) { + try { + return new TunnelingJdkSocket(path.toPath()); + } catch (Throwable ignore) { + // fall back to jnr-unixsocket library + } + } + return new TunnelingUnixSocket(path, UnixSocketChannel.open()); } catch (Throwable e) { if (Config.get().isAgentConfiguredUsingDefault()) { // fall back to port if we previously auto-discovered this socket file diff --git a/utils/socket-utils/src/main/java17/datadog/common/socket/TunnelingJdkSocket.java b/utils/socket-utils/src/main/java17/datadog/common/socket/TunnelingJdkSocket.java new file mode 100644 index 00000000000..6db94aa15ac --- /dev/null +++ b/utils/socket-utils/src/main/java17/datadog/common/socket/TunnelingJdkSocket.java @@ -0,0 +1,255 @@ +package datadog.common.socket; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.net.SocketAddress; +import java.net.SocketException; +import java.net.UnixDomainSocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.nio.file.Path; +import java.util.Iterator; +import java.util.Set; + +/** + * Subtype UNIX socket for a higher-fidelity impersonation of TCP sockets. This is named "tunneling" + * because it assumes the ultimate destination has a hostname and port. + * + *

Bsed on {@link TunnelingUnixSocket}; adapted to use the built-in UDS support added in Java 16. + */ +final class TunnelingJdkSocket extends Socket { + private final SocketAddress unixSocketAddress; + private InetSocketAddress inetSocketAddress; + + private SocketChannel unixSocketChannel; + + private int timeout; + private boolean shutIn; + private boolean shutOut; + private boolean closed; + + TunnelingJdkSocket(final Path path) { + this.unixSocketAddress = UnixDomainSocketAddress.of(path); + } + + TunnelingJdkSocket(final Path path, final InetSocketAddress address) { + this(path); + inetSocketAddress = address; + } + + @Override + public boolean isConnected() { + return null != unixSocketChannel; + } + + @Override + public boolean isInputShutdown() { + return shutIn; + } + + @Override + public boolean isOutputShutdown() { + return shutOut; + } + + @Override + public boolean isClosed() { + return closed; + } + + @Override + public synchronized void setSoTimeout(int timeout) throws SocketException { + if (isClosed()) { + throw new SocketException("Socket is closed"); + } + if (timeout < 0) { + throw new IllegalArgumentException("Socket timeout can't be negative"); + } + this.timeout = timeout; + } + + @Override + public synchronized int getSoTimeout() throws SocketException { + if (isClosed()) { + throw new SocketException("Socket is closed"); + } + return timeout; + } + + @Override + public void connect(final SocketAddress endpoint) throws IOException { + if (isClosed()) { + throw new SocketException("Socket is closed"); + } + if (isConnected()) { + throw new SocketException("Socket is already connected"); + } + inetSocketAddress = (InetSocketAddress) endpoint; + unixSocketChannel = SocketChannel.open(unixSocketAddress); + } + + // `timeout` is intentionally ignored here, like in the jnr-unixsocket implementation. + // See: + // https://github.com/jnr/jnr-unixsocket/blob/master/src/main/java/jnr/unixsocket/UnixSocket.java#L89-L97 + @Override + public void connect(final SocketAddress endpoint, final int timeout) throws IOException { + if (isClosed()) { + throw new SocketException("Socket is closed"); + } + if (isConnected()) { + throw new SocketException("Socket is already connected"); + } + inetSocketAddress = (InetSocketAddress) endpoint; + unixSocketChannel = SocketChannel.open(unixSocketAddress); + } + + @Override + public SocketChannel getChannel() { + return unixSocketChannel; + } + + @Override + public InputStream getInputStream() throws IOException { + if (isClosed()) { + throw new SocketException("Socket is closed"); + } + if (!isConnected()) { + throw new SocketException("Socket is not connected"); + } + if (isInputShutdown()) { + throw new SocketException("Socket input is shutdown"); + } + + return new InputStream() { + private final ByteBuffer buffer = ByteBuffer.allocate(8192); + private final Selector selector = Selector.open(); + + { + unixSocketChannel.configureBlocking(false); + unixSocketChannel.register(selector, SelectionKey.OP_READ); + } + + @Override + public int read() throws IOException { + byte[] nextByte = new byte[1]; + return (read(nextByte, 0, 1) == -1) ? -1 : (nextByte[0] & 0xFF); + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + buffer.clear(); + + int readyChannels = selector.select(timeout); + if (readyChannels == 0) { + System.out.println("Timeout (" + timeout + "ms) while waiting for data."); + return 0; + } + + Set selectedKeys = selector.selectedKeys(); + Iterator keyIterator = selectedKeys.iterator(); + while (keyIterator.hasNext()) { + SelectionKey key = keyIterator.next(); + keyIterator.remove(); + if (key.isReadable()) { + int r = unixSocketChannel.read(buffer); + if (r == -1) { + return -1; + } + buffer.flip(); + len = Math.min(r, len); + buffer.get(b, off, len); + return len; + } + } + return 0; + } + + @Override + public void close() throws IOException { + selector.close(); + } + }; + } + + @Override + public OutputStream getOutputStream() throws IOException { + if (isClosed()) { + throw new SocketException("Socket is closed"); + } + if (!isConnected()) { + throw new SocketException("Socket is not connected"); + } + if (isInputShutdown()) { + throw new SocketException("Socket output is shutdown"); + } + + return new OutputStream() { + @Override + public void write(int b) throws IOException { + byte[] array = ByteBuffer.allocate(4).putInt(b).array(); + write(array, 0, 4); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + ByteBuffer buffer = ByteBuffer.wrap(b, off, len); + + while (buffer.hasRemaining()) { + unixSocketChannel.write(buffer); + } + } + }; + } + + @Override + public void shutdownInput() throws IOException { + if (isClosed()) { + throw new SocketException("Socket is closed"); + } + if (!isConnected()) { + throw new SocketException("Socket is not connected"); + } + if (isInputShutdown()) { + throw new SocketException("Socket input is already shutdown"); + } + unixSocketChannel.shutdownInput(); + shutIn = true; + } + + @Override + public void shutdownOutput() throws IOException { + if (isClosed()) { + throw new SocketException("Socket is closed"); + } + if (!isConnected()) { + throw new SocketException("Socket is not connected"); + } + if (isOutputShutdown()) { + throw new SocketException("Socket output is already shutdown"); + } + unixSocketChannel.shutdownOutput(); + shutOut = true; + } + + @Override + public InetAddress getInetAddress() { + return inetSocketAddress.getAddress(); + } + + @Override + public void close() throws IOException { + if (isClosed()) { + return; + } + if (null != unixSocketChannel) { + unixSocketChannel.close(); + } + closed = true; + } +} diff --git a/utils/socket-utils/src/test/java/datadog/common/socket/TunnelingJdkSocketTest.java b/utils/socket-utils/src/test/java/datadog/common/socket/TunnelingJdkSocketTest.java new file mode 100644 index 00000000000..05cf96e94d8 --- /dev/null +++ b/utils/socket-utils/src/test/java/datadog/common/socket/TunnelingJdkSocketTest.java @@ -0,0 +1,105 @@ +package datadog.common.socket; + +import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; +import static org.junit.jupiter.api.Assertions.fail; + +import datadog.trace.api.Config; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.StandardProtocolFamily; +import java.net.UnixDomainSocketAddress; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; +import org.junit.jupiter.api.Test; + +public class TunnelingJdkSocketTest { + + private static final AtomicBoolean isServerRunning = new AtomicBoolean(false); + + @Test + public void testTimeout() throws Exception { + if (!Config.get().isJdkSocketEnabled()) { + System.out.println( + "TunnelingJdkSocket usage is disabled. Enable it by setting the property 'JDK_SOCKET_ENABLED' to 'true'."); + return; + } + + int testTimeout = 3000; + Path socketPath = getSocketPath(); + UnixDomainSocketAddress socketAddress = UnixDomainSocketAddress.of(socketPath); + startServer(socketAddress); + TunnelingJdkSocket clientSocket = createClient(socketPath); + + // Test that the socket unblocks when timeout is set to >0 + clientSocket.setSoTimeout(1000); + assertTimeoutPreemptively( + Duration.ofMillis(testTimeout), () -> clientSocket.getInputStream().read()); + + // Test that the socket blocks indefinitely when timeout is set to 0, per + // https://docs.oracle.com/en/java/javase/16/docs/api//java.base/java/net/Socket.html#setSoTimeout(int). + clientSocket.setSoTimeout(0); + boolean infiniteTimeOut = false; + try { + assertTimeoutPreemptively( + Duration.ofMillis(testTimeout), () -> clientSocket.getInputStream().read()); + } catch (AssertionError e) { + infiniteTimeOut = true; + } + if (!infiniteTimeOut) { + fail("Test failed: Expected infinite blocking when timeout is set to 0."); + } + + clientSocket.close(); + isServerRunning.set(false); + } + + private Path getSocketPath() throws IOException { + Path socketPath = Files.createTempFile("testSocket", null); + Files.delete(socketPath); + socketPath.toFile().deleteOnExit(); + return socketPath; + } + + private static void startServer(UnixDomainSocketAddress socketAddress) { + Thread serverThread = + new Thread( + () -> { + try (ServerSocketChannel serverChannel = + ServerSocketChannel.open(StandardProtocolFamily.UNIX)) { + serverChannel.bind(socketAddress); + isServerRunning.set(true); + + synchronized (isServerRunning) { + isServerRunning.notifyAll(); + } + + while (isServerRunning.get()) { + SocketChannel clientChannel = serverChannel.accept(); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + serverThread.start(); + + synchronized (isServerRunning) { + while (!isServerRunning.get()) { + try { + isServerRunning.wait(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + } + + private TunnelingJdkSocket createClient(Path socketPath) throws IOException { + TunnelingJdkSocket clientSocket = new TunnelingJdkSocket(socketPath); + clientSocket.connect(new InetSocketAddress("localhost", 0)); + return clientSocket; + } +}