Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package com.hierynomus.sshj.transport.verification;

import net.schmizz.sshj.common.Base64DecodingException;
import net.schmizz.sshj.common.Base64Decoder;
import net.schmizz.sshj.common.IOUtils;
import net.schmizz.sshj.common.SSHException;
import net.schmizz.sshj.transport.mac.MAC;
Expand All @@ -26,9 +28,13 @@
import java.util.regex.Pattern;

import com.hierynomus.sshj.transport.mac.Macs;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KnownHostMatchers {

private static final Logger log = LoggerFactory.getLogger(KnownHostMatchers.class);

public static HostMatcher createMatcher(String hostEntry) throws SSHException {
if (hostEntry.contains(",")) {
return new AnyHostMatcher(hostEntry);
Expand Down Expand Up @@ -80,17 +86,22 @@ private static class HashedHostMatcher implements HostMatcher {

@Override
public boolean match(String hostname) throws IOException {
return hash.equals(hashHost(hostname));
try {
return hash.equals(hashHost(hostname));
} catch (Base64DecodingException err) {
log.warn("Hostname [{}] not matched: salt decoding failed", hostname, err);
return false;
}
}

private String hashHost(String host) throws IOException {
private String hashHost(String host) throws IOException, Base64DecodingException {
sha1.init(getSaltyBytes());
return "|1|" + salt + "|" + Base64.getEncoder().encodeToString(sha1.doFinal(host.getBytes(IOUtils.UTF8)));
}

private byte[] getSaltyBytes() {
private byte[] getSaltyBytes() throws IOException, Base64DecodingException {
if (saltyBytes == null) {
saltyBytes = Base64.getDecoder().decode(salt);
saltyBytes = Base64Decoder.decode(salt);
}
return saltyBytes;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package com.hierynomus.sshj.userauth.keyprovider;

import net.schmizz.sshj.common.Base64DecodingException;
import net.schmizz.sshj.common.Base64Decoder;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.KeyType;

Expand All @@ -23,7 +25,6 @@
import java.io.IOException;
import java.io.Reader;
import java.security.PublicKey;
import java.util.Base64;

public class OpenSSHKeyFileUtil {
private OpenSSHKeyFileUtil() {
Expand Down Expand Up @@ -54,16 +55,19 @@ public static ParsedPubKey initPubKey(Reader publicKey) throws IOException {
if (!keydata.isEmpty()) {
String[] parts = keydata.trim().split("\\s+");
if (parts.length >= 2) {
byte[] decodedPublicKey = Base64Decoder.decode(parts[1]);
return new ParsedPubKey(
KeyType.fromString(parts[0]),
new Buffer.PlainBuffer(Base64.getDecoder().decode(parts[1])).readPublicKey()
new Buffer.PlainBuffer(decodedPublicKey).readPublicKey()
);
} else {
throw new IOException("Got line with only one column");
}
}
}
throw new IOException("Public key file is blank");
} catch (Base64DecodingException err) {
throw new IOException("Public key decoding failed", err);
} finally {
br.close();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,8 @@
import net.i2p.crypto.eddsa.EdDSAPrivateKey;
import net.i2p.crypto.eddsa.spec.EdDSANamedCurveTable;
import net.i2p.crypto.eddsa.spec.EdDSAPrivateKeySpec;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.*;
import net.schmizz.sshj.common.Buffer.PlainBuffer;
import net.schmizz.sshj.common.ByteArrayUtils;
import net.schmizz.sshj.common.IOUtils;
import net.schmizz.sshj.common.KeyType;
import net.schmizz.sshj.common.SSHRuntimeException;
import net.schmizz.sshj.common.SecurityUtils;
import net.schmizz.sshj.transport.cipher.Cipher;
import net.schmizz.sshj.userauth.keyprovider.BaseFileKeyProvider;
import net.schmizz.sshj.userauth.keyprovider.FileKeyProvider;
Expand All @@ -55,7 +50,6 @@
import java.security.spec.ECPrivateKeySpec;
import java.security.spec.RSAPrivateCrtKeySpec;
import java.util.Arrays;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;

Expand Down Expand Up @@ -124,7 +118,7 @@ protected KeyPair readKeyPair() throws IOException {
try {
if (checkHeader(reader)) {
final String encodedPrivateKey = readEncodedKey(reader);
byte[] decodedPrivateKey = Base64.getDecoder().decode(encodedPrivateKey);
byte[] decodedPrivateKey = Base64Decoder.decode(encodedPrivateKey);
final PlainBuffer bufferedPrivateKey = new PlainBuffer(decodedPrivateKey);
return readDecodedKeyPair(bufferedPrivateKey);
} else {
Expand All @@ -133,6 +127,8 @@ protected KeyPair readKeyPair() throws IOException {
}
} catch (final GeneralSecurityException e) {
throw new SSHRuntimeException("Read OpenSSH Version 1 Key failed", e);
} catch (Base64DecodingException e) {
throw new SSHRuntimeException("Private Key decoding failed", e);
} finally {
IOUtils.closeQuietly(reader);
}
Expand Down
47 changes: 47 additions & 0 deletions src/main/java/net/schmizz/sshj/common/Base64Decoder.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package net.schmizz.sshj.common;

import java.io.IOException;
import java.util.Base64;

/**
* <p>Wraps {@link java.util.Base64.Decoder} in order to wrap unchecked {@code IllegalArgumentException} thrown by
* the default Java Base64 decoder here and there.</p>
*
* <p>Please use this class instead of {@link java.util.Base64.Decoder}.</p>
*/
public class Base64Decoder {
private Base64Decoder() {
}

public static byte[] decode(byte[] source) throws Base64DecodingException {
try {
return Base64.getDecoder().decode(source);
} catch (IllegalArgumentException err) {
throw new Base64DecodingException(err);
}
}

public static byte[] decode(String src) throws Base64DecodingException {
try {
return Base64.getDecoder().decode(src);
} catch (IllegalArgumentException err) {
throw new Base64DecodingException(err);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package net.schmizz.sshj.common;

/**
* A checked wrapper for all {@link IllegalArgumentException}, thrown by {@link java.util.Base64.Decoder}.
*
* @see Base64Decoder
*/
public class Base64DecodingException extends Exception {
public Base64DecodingException(IllegalArgumentException cause) {
super("Failed to decode base64: " + cause.getMessage(), cause);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,7 @@
import com.hierynomus.sshj.common.KeyAlgorithm;
import com.hierynomus.sshj.transport.verification.KnownHostMatchers;
import com.hierynomus.sshj.userauth.certificate.Certificate;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.IOUtils;
import net.schmizz.sshj.common.KeyType;
import net.schmizz.sshj.common.LoggerFactory;
import net.schmizz.sshj.common.SSHException;
import net.schmizz.sshj.common.SSHRuntimeException;
import net.schmizz.sshj.common.SecurityUtils;
import net.schmizz.sshj.common.*;
import org.slf4j.Logger;

import java.io.BufferedOutputStream;
Expand Down Expand Up @@ -290,9 +284,9 @@ public KnownHostEntry parseEntry(String line)
if (type != KeyType.UNKNOWN) {
final String sKey = split[i++];
try {
byte[] keyBytes = Base64.getDecoder().decode(sKey);
byte[] keyBytes = Base64Decoder.decode(sKey);
key = new Buffer.PlainBuffer(keyBytes).readPublicKey();
} catch (IOException | IllegalArgumentException exception) {
} catch (IOException | Base64DecodingException exception) {
log.warn("Error decoding Base64 key bytes", exception);
return new BadHostEntry(line);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
import net.i2p.crypto.eddsa.spec.EdDSANamedCurveTable;
import net.i2p.crypto.eddsa.spec.EdDSAPrivateKeySpec;
import net.i2p.crypto.eddsa.spec.EdDSAPublicKeySpec;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.KeyType;
import net.schmizz.sshj.common.SecurityUtils;
import net.schmizz.sshj.common.*;
import net.schmizz.sshj.userauth.password.PasswordUtils;
import org.bouncycastle.asn1.nist.NISTNamedCurves;
import org.bouncycastle.asn1.x9.X9ECParameters;
Expand All @@ -42,7 +40,6 @@
import java.security.*;
import java.security.spec.*;
import java.util.Arrays;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;

Expand Down Expand Up @@ -240,29 +237,34 @@ protected void parseKeyPair() throws IOException {
if (this.keyFileVersion == null) {
throw new IOException("Invalid key file format: missing \"PuTTY-User-Key-File-?\" entry");
}
// Retrieve keys from payload
publicKey = Base64.getDecoder().decode(payload.get("Public-Lines"));
if (this.isEncrypted()) {
final char[] passphrase;
if (pwdf != null) {
passphrase = pwdf.reqPassword(resource);
} else {
passphrase = "".toCharArray();
}
try {
privateKey = this.decrypt(Base64.getDecoder().decode(payload.get("Private-Lines")), passphrase);
Mac mac;
if (this.keyFileVersion <= 2) {
mac = this.prepareVerifyMacV2(passphrase);
try {
// Retrieve keys from payload
publicKey = Base64Decoder.decode(payload.get("Public-Lines"));
if (this.isEncrypted()) {
final char[] passphrase;
if (pwdf != null) {
passphrase = pwdf.reqPassword(resource);
} else {
mac = this.prepareVerifyMacV3();
passphrase = "".toCharArray();
}
try {
privateKey = this.decrypt(Base64Decoder.decode(payload.get("Private-Lines")), passphrase);
Mac mac;
if (this.keyFileVersion <= 2) {
mac = this.prepareVerifyMacV2(passphrase);
} else {
mac = this.prepareVerifyMacV3();
}
this.verify(mac);
} finally {
PasswordUtils.blankOut(passphrase);
}
this.verify(mac);
} finally {
PasswordUtils.blankOut(passphrase);
} else {
privateKey = Base64Decoder.decode(payload.get("Private-Lines"));
}
} else {
privateKey = Base64.getDecoder().decode(payload.get("Private-Lines"));
}
catch (Base64DecodingException e) {
throw new IOException("PuTTY key decoding failed", e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,16 @@
*/
package com.hierynomus.sshj.transport.verification;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.assertj.core.api.Assertions.*;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.SecurityUtils;
import net.schmizz.sshj.transport.verification.OpenSSHKnownHosts;
import net.schmizz.sshj.util.KeyUtil;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import java.io.File;
import java.io.IOException;
Expand All @@ -29,17 +34,8 @@
import java.util.Base64;
import java.util.stream.Stream;

import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.SecurityUtils;
import net.schmizz.sshj.transport.verification.OpenSSHKnownHosts;
import net.schmizz.sshj.util.KeyUtil;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.*;

public class OpenSSHKnownHostsTest {
@TempDir
Expand Down Expand Up @@ -118,6 +114,24 @@ public void shouldNotFailOnMalformedBase64String() throws IOException {
assertThat(ohk.entries().get(0)).isInstanceOf(OpenSSHKnownHosts.BadHostEntry.class);
}

@Test
public void shouldNotFailOnMalformeSaltBase64String() throws IOException {
// A record with broken base64 inside the salt part of the hash.
// No matter how it could be generated, such broken strings must not cause unexpected errors.
String hostName = "example.com";
File knownHosts = knownHosts(
"|1|2gujgGa6gJnK7wGPCX8zuGttvCMXX|Oqkbjtxd9RFxKQv6y3l3GIxLNiU= ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBGVVnyoAD5/uWiiuTSM3RuW8dEWRrqOXYobAMKHhAA6kuOBoPK+LoAYyUcN26bdMiCxg+VOaLHxPNWv5SlhbMWw=\n"
);
OpenSSHKnownHosts ohk = new OpenSSHKnownHosts(knownHosts);
assertEquals(1, ohk.entries().size());

// Some random valid public key. It doesn't matter for the test if it matches the broken host key record or not.
PublicKey k = new Buffer.PlainBuffer(Base64.getDecoder().decode(
"AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBLTjA7hduYGmvV9smEEsIdGLdghSPD7kL8QarIIOkeXmBh+LTtT/T1K+Ot/rmXCZsP8hoUXxbvN+Tks440Ci0ck="))
.readPublicKey();
assertFalse(ohk.verify(hostName, 22, k));
}

@Test
public void shouldMarkBadLineAndNotFail() throws Exception {
File knownHosts = knownHosts(
Expand Down
Loading