99
1010import static com .google .common .truth .Truth .assertThat ;
1111import static com .google .common .truth .Truth .assertWithMessage ;
12+ import static com .google .protobuf .WireFormat .FIXED64_SIZE ;
1213import static org .junit .Assert .assertArrayEquals ;
1314import static org .junit .Assert .assertThrows ;
1415
@@ -149,7 +150,7 @@ private static final class SmallBlockInputStream extends FilterInputStream {
149150 private int readCalls ;
150151
151152 public SmallBlockInputStream (byte [] data , int blockSize ) {
152- super (new ByteArrayInputStream (data ));
153+ super (new ByteArrayInputStreamMatchingZeroLengthReadSemantics (data ));
153154 this .blockSize = blockSize ;
154155 }
155156
@@ -217,7 +218,7 @@ private void assertReadVarint(byte[] data, long value) throws Exception {
217218 // array first.
218219 byte [] longerData = new byte [data .length + 1 ];
219220 System .arraycopy (data , 0 , longerData , 0 , data .length );
220- InputStream rawInput = new ByteArrayInputStream (longerData );
221+ InputStream rawInput = new ByteArrayInputStreamMatchingZeroLengthReadSemantics (longerData );
221222 assertThat (CodedInputStream .readRawVarint32 (rawInput )).isEqualTo ((int ) value );
222223 assertThat (rawInput .available ()).isEqualTo (1 );
223224 }
@@ -253,7 +254,8 @@ private void assertReadVarintFailure(InvalidProtocolBufferException expected, by
253254
254255 // Make sure we get the same error when reading direct from an InputStream.
255256 try {
256- CodedInputStream .readRawVarint32 (new ByteArrayInputStream (data ));
257+ CodedInputStream .readRawVarint32 (
258+ new ByteArrayInputStreamMatchingZeroLengthReadSemantics (data ));
257259 assertWithMessage ("Should have thrown an exception." ).fail ();
258260 } catch (InvalidProtocolBufferException e ) {
259261 assertThat (e ).hasMessageThat ().isEqualTo (expected .getMessage ());
@@ -800,13 +802,16 @@ public void testMaliciousSGroupTagsWithMapField_fromInputStream() throws Excepti
800802 InvalidProtocolBufferException .class ,
801803 () ->
802804 MapContainer .parseFrom (
803- new ByteArrayInputStream (NESTING_SGROUP_WITH_INITIAL_BYTES )));
805+ new ByteArrayInputStreamMatchingZeroLengthReadSemantics (
806+ NESTING_SGROUP_WITH_INITIAL_BYTES )));
804807 Throwable mergeFromThrown =
805808 assertThrows (
806809 InvalidProtocolBufferException .class ,
807810 () ->
808811 MapContainer .newBuilder ()
809- .mergeFrom (new ByteArrayInputStream (NESTING_SGROUP_WITH_INITIAL_BYTES )));
812+ .mergeFrom (
813+ new ByteArrayInputStreamMatchingZeroLengthReadSemantics (
814+ NESTING_SGROUP_WITH_INITIAL_BYTES )));
810815
811816 assertThat (parseFromThrown )
812817 .hasMessageThat ()
@@ -818,7 +823,8 @@ public void testMaliciousSGroupTagsWithMapField_fromInputStream() throws Excepti
818823
819824 @ Test
820825 public void testMaliciousSGroupTags_inputStream_skipMessage () throws Exception {
821- ByteArrayInputStream inputSteam = new ByteArrayInputStream (NESTING_SGROUP );
826+ InputStream inputSteam =
827+ new ByteArrayInputStreamMatchingZeroLengthReadSemantics (NESTING_SGROUP );
822828 CodedInputStream input = CodedInputStream .newInstance (inputSteam );
823829 CodedOutputStream output = CodedOutputStream .newInstance (new byte [NESTING_SGROUP .length ]);
824830
@@ -986,7 +992,9 @@ public void testRefillBufferWithCorrectSize() throws Exception {
986992 inputStreamBufferLength <= rawInput .length + 1 ;
987993 inputStreamBufferLength ++) {
988994 CodedInputStream input =
989- CodedInputStream .newInstance (new ByteArrayInputStream (rawInput ), inputStreamBufferLength );
995+ CodedInputStream .newInstance (
996+ new ByteArrayInputStreamMatchingZeroLengthReadSemantics (rawInput ),
997+ inputStreamBufferLength );
990998 input .setSizeLimit (rawInput .length - 1 );
991999 input .readString ();
9921000 input .readString ();
@@ -1001,7 +1009,9 @@ public void testRefillBufferWithCorrectSize() throws Exception {
10011009
10021010 @ Test
10031011 public void testIsAtEnd () throws Exception {
1004- CodedInputStream input = CodedInputStream .newInstance (new ByteArrayInputStream (new byte [5 ]));
1012+ CodedInputStream input =
1013+ CodedInputStream .newInstance (
1014+ new ByteArrayInputStreamMatchingZeroLengthReadSemantics (new byte [5 ]));
10051015 try {
10061016 for (int i = 0 ; i < 5 ; i ++) {
10071017 assertThat (input .isAtEnd ()).isFalse ();
@@ -1026,7 +1036,9 @@ public void testCurrentLimitExceeded() throws Exception {
10261036 output .flush ();
10271037
10281038 byte [] rawInput = rawOutput .toByteArray ();
1029- CodedInputStream input = CodedInputStream .newInstance (new ByteArrayInputStream (rawInput ));
1039+ CodedInputStream input =
1040+ CodedInputStream .newInstance (
1041+ new ByteArrayInputStreamMatchingZeroLengthReadSemantics (rawInput ));
10301042 // The length of the whole rawInput
10311043 input .setSizeLimit (11 );
10321044 // Some number that is smaller than the rawInput's length
@@ -1260,7 +1272,7 @@ public void testReadLargeByteStringFromInputStream() throws Exception {
12601272
12611273 CodedInputStream input =
12621274 CodedInputStream .newInstance (
1263- new ByteArrayInputStream (data ) {
1275+ new ByteArrayInputStreamMatchingZeroLengthReadSemantics (data ) {
12641276 @ Override
12651277 public synchronized int available () {
12661278 return 0 ;
@@ -1285,7 +1297,7 @@ public void testReadLargeByteArrayFromInputStream() throws Exception {
12851297
12861298 CodedInputStream input =
12871299 CodedInputStream .newInstance (
1288- new ByteArrayInputStream (data ) {
1300+ new ByteArrayInputStreamMatchingZeroLengthReadSemantics (data ) {
12891301 @ Override
12901302 public synchronized int available () {
12911303 return 0 ;
@@ -1569,7 +1581,9 @@ public void testSkipInvalidEndGroup_nested(@TestParameter InputType inputType) t
15691581 @ Test
15701582 public void testSkipPastEndOfByteArrayInput () throws Exception {
15711583 try {
1572- CodedInputStream .newInstance (new ByteArrayInputStream (new byte [100 ])).skipRawBytes (101 );
1584+ CodedInputStream .newInstance (
1585+ new ByteArrayInputStreamMatchingZeroLengthReadSemantics (new byte [100 ]))
1586+ .skipRawBytes (101 );
15731587 assertWithMessage ("Should have thrown an exception" ).fail ();
15741588 } catch (InvalidProtocolBufferException e ) {
15751589 // Expected
@@ -1580,11 +1594,11 @@ public void testSkipPastEndOfByteArrayInput() throws Exception {
15801594 public void testMaliciousInputStream () throws Exception {
15811595 ByteArrayOutputStream outputStream = new ByteArrayOutputStream ();
15821596 CodedOutputStream codedOutputStream = CodedOutputStream .newInstance (outputStream );
1583- codedOutputStream .writeByteArrayNoTag (new byte [] {0x0 , 0x1 , 0x2 , 0x3 , 0x4 , 0x5 });
1597+ codedOutputStream .writeByteArrayNoTag (new byte [] {0x0 , 0x1 , 0x2 , 0x3 , 0x4 , 0x5 , 0x6 , 0x7 , 0x8 });
15841598 codedOutputStream .flush ();
15851599 final List <byte []> maliciousCapture = new ArrayList <>();
15861600 InputStream inputStream =
1587- new ByteArrayInputStream (outputStream .toByteArray ()) {
1601+ new ByteArrayInputStreamMatchingZeroLengthReadSemantics (outputStream .toByteArray ()) {
15881602 @ Override
15891603 public synchronized int read (byte [] b , int off , int len ) {
15901604 maliciousCapture .add (b );
@@ -1680,4 +1694,66 @@ public void testCodedInputStreamWithEmptyBuffers_isAtEndAfterRead() throws Excep
16801694 cis .readRawBytes (4096 );
16811695 assertThat (cis .isAtEnd ()).isTrue ();
16821696 }
1697+
1698+ @ Test
1699+ public void testStreamDecoderReadFixed64_inputTooSmall (@ TestParameter boolean bufferTooSmall )
1700+ throws Exception {
1701+ byte [] data = bytes (0xde , 0xbc , 0x9a , 0x78 , 0x56 , 0x34 , 0x12 );
1702+ InputStream input = new ByteArrayInputStreamMatchingZeroLengthReadSemantics (data );
1703+ CodedInputStream cis =
1704+ CodedInputStream .newInstance (input , FIXED64_SIZE - (bufferTooSmall ? 1 : 0 ));
1705+ try {
1706+ cis .readFixed64 ();
1707+ assertWithMessage ("Should have thrown an exception" ).fail ();
1708+ } catch (InvalidProtocolBufferException expected ) {
1709+ assertThat (expected )
1710+ .hasMessageThat ()
1711+ .isEqualTo (InvalidProtocolBufferException .truncatedMessage ().getMessage ());
1712+ }
1713+ }
1714+
1715+ @ Test
1716+ public void testStreamDecoderReadFixed64_bufferBounds (@ TestParameter boolean bufferTooSmall )
1717+ throws Exception {
1718+ byte [] data = bytes (0xf0 , 0xde , 0xbc , 0x9a , 0x78 , 0x56 , 0x34 , 0x12 );
1719+ InputStream input = new ByteArrayInputStreamMatchingZeroLengthReadSemantics (data );
1720+ CodedInputStream cis =
1721+ CodedInputStream .newInstance (input , FIXED64_SIZE - (bufferTooSmall ? 1 : 0 ));
1722+ assertThat (cis .readFixed64 ()).isEqualTo (0x123456789abcdef0L );
1723+ }
1724+
1725+ /**
1726+ * A {@link ByteArrayInputStream} that matches the behavior of {@link
1727+ * InputStream#read(byte[],int,int)} when the requested length is 0.
1728+ */
1729+ private static class ByteArrayInputStreamMatchingZeroLengthReadSemantics
1730+ extends ByteArrayInputStream {
1731+ private ByteArrayInputStreamMatchingZeroLengthReadSemantics (byte [] data ) {
1732+ super (data );
1733+ }
1734+
1735+ @ Override
1736+ public synchronized int read (byte [] b , int off , int len ) {
1737+ // Inline Objects.checkFromIndexSize() which is API 30+.
1738+ if ((b .length | off | len ) < 0 || len > b .length - off ) {
1739+ throw new IndexOutOfBoundsException ();
1740+ }
1741+ // Eagerly return 0 if the requested length is 0 to match InputStream behavior.
1742+ if (len == 0 ) {
1743+ return 0 ;
1744+ }
1745+
1746+ if (pos >= count ) {
1747+ return -1 ;
1748+ }
1749+
1750+ int avail = count - pos ;
1751+ if (len > avail ) {
1752+ len = avail ;
1753+ }
1754+ System .arraycopy (buf , pos , b , off , len );
1755+ pos += len ;
1756+ return len ;
1757+ }
1758+ }
16831759}
0 commit comments