From 5d135a2676de8038b1b3fb827ccb2a89324b28bc Mon Sep 17 00:00:00 2001 From: John Bauer Date: Tue, 2 Mar 2021 15:13:47 -0800 Subject: [PATCH 1/6] Add a module to put shorts, ints, and longs into a ByteArrayOutputStream using bit fiddling --- src/edu/stanford/nlp/io/ByteArrayUtils.java | 60 +++++++++++++ .../stanford/nlp/io/ByteArrayUtilsTest.java | 85 +++++++++++++++++++ 2 files changed, 145 insertions(+) create mode 100644 src/edu/stanford/nlp/io/ByteArrayUtils.java create mode 100644 test/src/edu/stanford/nlp/io/ByteArrayUtilsTest.java diff --git a/src/edu/stanford/nlp/io/ByteArrayUtils.java b/src/edu/stanford/nlp/io/ByteArrayUtils.java new file mode 100644 index 0000000000..212637a967 --- /dev/null +++ b/src/edu/stanford/nlp/io/ByteArrayUtils.java @@ -0,0 +1,60 @@ +package edu.stanford.nlp.io; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; + +/** + * Static methods for putting shorts, ints, and longs into a ByteArrayOutputStream using bit fiddling + * + * @author John Bauer + */ +public class ByteArrayUtils { + static public short readShort(ByteArrayInputStream bin) { + int high = ((bin.read() & 0x000000FF) << 8); + int low = (bin.read() & 0x000000FF); + return (short) ((high | low) & 0x0000FFFF); + } + + static public void writeShort(ByteArrayOutputStream bout, short val) { + bout.write((byte)((val >> 8) & 0xff)); + bout.write((byte)(val & 0xff)); + } + + static public int readInt(ByteArrayInputStream bin) { + int b24 = ((bin.read() & 0x000000FF) << 24); + int b16 = ((bin.read() & 0x000000FF) << 16); + int b8 = ((bin.read() & 0x000000FF) << 8); + int b0 = (bin.read() & 0x000000FF); + return b24 | b16 | b8 | b0; + } + + static public void writeInt(ByteArrayOutputStream bout, int val) { + bout.write((byte)((val >> 24) & 0xff)); + bout.write((byte)((val >> 16) & 0xff)); + bout.write((byte)((val >> 8) & 0xff)); + bout.write((byte)(val & 0xff)); + } + + static public long readLong(ByteArrayInputStream bin) { + long b56 = ((long) (bin.read() & 0x000000FF)) << 56; + long b48 = ((long) (bin.read() & 0x000000FF)) << 48; + long b40 = ((long) (bin.read() & 0x000000FF)) << 40; + long b32 = ((long) (bin.read() & 0x000000FF)) << 32; + long b24 = ((long) (bin.read() & 0x000000FF)) << 24; + long b16 = ((long) (bin.read() & 0x000000FF)) << 16; + long b8 = ((long) (bin.read() & 0x000000FF)) << 8; + long b0 = ((long) (bin.read() & 0x000000FF)); + return b56 | b48 | b40 | b32 | b24 | b16 | b8 | b0; + } + + static public void writeLong(ByteArrayOutputStream bout, long val) { + bout.write((byte)((val >> 56) & 0xff)); + bout.write((byte)((val >> 48) & 0xff)); + bout.write((byte)((val >> 40) & 0xff)); + bout.write((byte)((val >> 32) & 0xff)); + bout.write((byte)((val >> 24) & 0xff)); + bout.write((byte)((val >> 16) & 0xff)); + bout.write((byte)((val >> 8) & 0xff)); + bout.write((byte)(val & 0xff)); + } +} diff --git a/test/src/edu/stanford/nlp/io/ByteArrayUtilsTest.java b/test/src/edu/stanford/nlp/io/ByteArrayUtilsTest.java new file mode 100644 index 0000000000..3e443d13ec --- /dev/null +++ b/test/src/edu/stanford/nlp/io/ByteArrayUtilsTest.java @@ -0,0 +1,85 @@ +package edu.stanford.nlp.io; + +import static org.junit.Assert.*; +import org.junit.Test; + +import java.util.Random; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; + +public class ByteArrayUtilsTest { + static final int TEST_LENGTH = 1000; + + @Test + public void testShort() { + Random random = new Random(1234); + ByteArrayOutputStream bout = new ByteArrayOutputStream(); + + short[] values = new short[TEST_LENGTH]; + values[0] = 0; + for (int i = 1; i < values.length; ++i) { + values[i] = (short) random.nextInt(1 << 16); + } + + for (int i = 0; i < values.length; ++i) { + ByteArrayUtils.writeShort(bout, values[i]); + } + + ByteArrayInputStream bin = new ByteArrayInputStream(bout.toByteArray()); + short[] output = new short[values.length]; + for (int i = 0; i < values.length; ++i) { + output[i] = ByteArrayUtils.readShort(bin); + } + + assertArrayEquals(values, output); + } + + @Test + public void testInt() { + Random random = new Random(1234); + ByteArrayOutputStream bout = new ByteArrayOutputStream(); + + int[] values = new int[TEST_LENGTH]; + values[0] = 0; + for (int i = 1; i < values.length; ++i) { + values[i] = random.nextInt(); + } + + for (int i = 0; i < values.length; ++i) { + ByteArrayUtils.writeInt(bout, values[i]); + } + + ByteArrayInputStream bin = new ByteArrayInputStream(bout.toByteArray()); + int[] output = new int[values.length]; + for (int i = 0; i < values.length; ++i) { + output[i] = ByteArrayUtils.readInt(bin); + } + + assertArrayEquals(values, output); + } + + @Test + public void testLong() { + Random random = new Random(1234); + ByteArrayOutputStream bout = new ByteArrayOutputStream(); + + long[] values = new long[TEST_LENGTH]; + values[0] = 0; + for (int i = 1; i < values.length; ++i) { + values[i] = random.nextLong(); + } + + for (int i = 0; i < values.length; ++i) { + ByteArrayUtils.writeLong(bout, values[i]); + } + + ByteArrayInputStream bin = new ByteArrayInputStream(bout.toByteArray()); + long[] output = new long[values.length]; + for (int i = 0; i < values.length; ++i) { + output[i] = ByteArrayUtils.readLong(bin); + } + + assertArrayEquals(values, output); + } +} From a5623e7be2683093d37cc746e6b89e39d52394f8 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Tue, 2 Mar 2021 15:15:41 -0800 Subject: [PATCH 2/6] Write the weights using custom serialization. The version with short[] saves a bunch of space --- .../parser/shiftreduce/PerceptronModel.java | 8 +- .../nlp/parser/shiftreduce/Weight.java | 22 +++++- .../nlp/parser/shiftreduce/WeightMap.java | 74 +++++++++++++++++++ .../nlp/parser/shiftreduce/WeightTest.java | 21 ++++++ 4 files changed, 120 insertions(+), 5 deletions(-) create mode 100644 src/edu/stanford/nlp/parser/shiftreduce/WeightMap.java diff --git a/src/edu/stanford/nlp/parser/shiftreduce/PerceptronModel.java b/src/edu/stanford/nlp/parser/shiftreduce/PerceptronModel.java index a2f1a031ff..8608cf3b27 100644 --- a/src/edu/stanford/nlp/parser/shiftreduce/PerceptronModel.java +++ b/src/edu/stanford/nlp/parser/shiftreduce/PerceptronModel.java @@ -44,13 +44,13 @@ public class PerceptronModel extends BaseModel { private float learningRate = 1.0f; - Map featureWeights; + WeightMap featureWeights; final FeatureFactory featureFactory; public PerceptronModel(ShiftReduceOptions op, Index transitionIndex, Set knownStates, Set rootStates, Set rootOnlyStates) { super(op, transitionIndex, knownStates, rootStates, rootOnlyStates); - this.featureWeights = Generics.newHashMap(); + this.featureWeights = new WeightMap(); String[] classes = op.featureFactoryClass.split(";"); if (classes.length == 1) { @@ -74,7 +74,7 @@ public PerceptronModel(PerceptronModel other) { super(other); this.featureFactory = other.featureFactory; - this.featureWeights = Generics.newHashMap(); + this.featureWeights = new WeightMap(); for (String feature : other.featureWeights.keySet()) { featureWeights.put(feature, new Weight(other.featureWeights.get(feature))); } @@ -110,7 +110,7 @@ public void averageModels(Collection models) { } } - featureWeights = Generics.newHashMap(); + featureWeights = new WeightMap(); for (String feature : features) { featureWeights.put(feature, new Weight()); } diff --git a/src/edu/stanford/nlp/parser/shiftreduce/Weight.java b/src/edu/stanford/nlp/parser/shiftreduce/Weight.java index f3d6c4b5f5..46087f6e3e 100644 --- a/src/edu/stanford/nlp/parser/shiftreduce/Weight.java +++ b/src/edu/stanford/nlp/parser/shiftreduce/Weight.java @@ -1,9 +1,13 @@ package edu.stanford.nlp.parser.shiftreduce; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.Serializable; +import edu.stanford.nlp.io.ByteArrayUtils; import edu.stanford.nlp.util.ArrayUtils; + /** * Stores one row of the sparse matrix which makes up the multiclass perceptron. * @@ -240,6 +244,22 @@ public String toString() { private long[] packed; - private static final long serialVersionUID = 1; + void writeBytes(ByteArrayOutputStream bout) { + ByteArrayUtils.writeInt(bout, packed.length); + for (int i = 0; i < packed.length; ++i) { + ByteArrayUtils.writeLong(bout, packed[i]); + } + } + + static Weight readBytes(ByteArrayInputStream bin) { + int len = ByteArrayUtils.readInt(bin); + Weight weight = new Weight(); + weight.packed = new long[len]; + for (int i = 0; i < len; ++i) { + weight.packed[i] = ByteArrayUtils.readLong(bin); + } + return weight; + } + private static final long serialVersionUID = 2; } diff --git a/src/edu/stanford/nlp/parser/shiftreduce/WeightMap.java b/src/edu/stanford/nlp/parser/shiftreduce/WeightMap.java new file mode 100644 index 0000000000..2f3b965c2e --- /dev/null +++ b/src/edu/stanford/nlp/parser/shiftreduce/WeightMap.java @@ -0,0 +1,74 @@ +package edu.stanford.nlp.parser.shiftreduce; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; + + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public class WeightMap implements Serializable { + private HashMap weights = new HashMap<>(); + + private void writeObject(ObjectOutputStream out) + throws IOException + { + out.writeObject(weights.size()); + ByteArrayOutputStream bout = new ByteArrayOutputStream(); + for (String feature : weights.keySet()) { + out.writeObject(feature); + weights.get(feature).writeBytes(bout); + } + out.writeObject(bout.toByteArray()); + } + + private void readObject(ObjectInputStream in) + throws IOException, ClassNotFoundException + { + Integer size = (Integer) in.readObject(); + + List keys = new ArrayList<>(); + for (int i = 0; i < size; ++i) { + String feature = (String) in.readObject(); + keys.add(feature); + } + byte[] bytes = (byte[]) in.readObject(); + ByteArrayInputStream bin = new ByteArrayInputStream(bytes); + + weights = new HashMap<>(size); + for (int i = 0; i < size; ++i) { + weights.put(keys.get(i), Weight.readBytes(bin)); + } + } + + public Weight get(String key) { + return weights.get(key); + } + + public void put(String key, Weight weight) { + weights.put(key, weight); + } + + public int size() { + return weights.size(); + } + + public boolean containsKey(String key) { + return weights.containsKey(key); + } + + public Set keySet() { + return weights.keySet(); + } + + public Set> entrySet() { + return weights.entrySet(); + } +} diff --git a/test/src/edu/stanford/nlp/parser/shiftreduce/WeightTest.java b/test/src/edu/stanford/nlp/parser/shiftreduce/WeightTest.java index 33f67d87de..758a1f7a15 100644 --- a/test/src/edu/stanford/nlp/parser/shiftreduce/WeightTest.java +++ b/test/src/edu/stanford/nlp/parser/shiftreduce/WeightTest.java @@ -5,6 +5,9 @@ import java.util.Random; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; + public class WeightTest { @Test public void testSize() { @@ -131,4 +134,22 @@ public void testNaN() { w2.score(scores); assertEquals(-952.0, scores[232], 0.0001f); } + + @Test + public void testReadWrite() { + Weight w = new Weight(); + w.updateWeight(232, -431.0f); + w.updateWeight(200, -521.0f); + w.updateWeight(3, 50.0f); + + ByteArrayOutputStream bout = new ByteArrayOutputStream(); + w.writeBytes(bout); + + ByteArrayInputStream bin = new ByteArrayInputStream(bout.toByteArray()); + Weight w2 = Weight.readBytes(bin); + assertEquals(3, w2.size()); + assertEquals(-431.0f, w2.getScore(232), 0.0001f); + assertEquals(-521.0f, w2.getScore(200), 0.0001f); + assertEquals(50.0f, w2.getScore(3), 0.0001f); + } } From e7e08c5f954573d8ee72782d69aa0ba7a3f800b0 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Thu, 28 Jan 2021 22:08:05 -0800 Subject: [PATCH 3/6] Make the hashCode for left & right binary different. This kills all the existing models... --- src/edu/stanford/nlp/parser/shiftreduce/BinaryTransition.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/edu/stanford/nlp/parser/shiftreduce/BinaryTransition.java b/src/edu/stanford/nlp/parser/shiftreduce/BinaryTransition.java index 371368c733..f2f5405268 100644 --- a/src/edu/stanford/nlp/parser/shiftreduce/BinaryTransition.java +++ b/src/edu/stanford/nlp/parser/shiftreduce/BinaryTransition.java @@ -232,12 +232,11 @@ public boolean equals(Object o) { @Override public int hashCode() { - // TODO: fix the hashcode for the side? would require rebuilding all models switch(side) { case LEFT: return 97197711 ^ label.hashCode(); case RIGHT: - return 97197711 ^ label.hashCode(); + return 85635467 ^ label.hashCode(); default: throw new IllegalArgumentException("Unknown side " + side); } From b2c14558011d771ff0f42bbd2be6af12680806af Mon Sep 17 00:00:00 2001 From: John Bauer Date: Tue, 2 Mar 2021 08:16:36 -0800 Subject: [PATCH 4/6] Use 3 shorts instead of 1 long for keeping model weights --- .../nlp/parser/shiftreduce/Weight.java | 92 +++++++++++-------- 1 file changed, 53 insertions(+), 39 deletions(-) diff --git a/src/edu/stanford/nlp/parser/shiftreduce/Weight.java b/src/edu/stanford/nlp/parser/shiftreduce/Weight.java index 46087f6e3e..e91aadd52b 100644 --- a/src/edu/stanford/nlp/parser/shiftreduce/Weight.java +++ b/src/edu/stanford/nlp/parser/shiftreduce/Weight.java @@ -27,13 +27,15 @@ */ public class Weight implements Serializable { + static final short[] EMPTY = {}; + public Weight() { - packed = null; + packed = EMPTY; } public Weight(Weight other) { if (other.size() == 0) { - packed = null; + packed = EMPTY; return; } packed = ArrayUtils.copy(other.packed); @@ -41,34 +43,41 @@ public Weight(Weight other) { } public int size() { - if (packed == null) { - return 0; - } - return packed.length; + // TODO: find a fast way of doing this... we know it's a multiple of 3 after all + return packed.length / 3; } - private int unpackIndex(int i) { - long pack = packed[i]; - return (int) (pack >>> 32); + private short unpackIndex(int i) { + return packed[i * 3]; } private float unpackScore(int i) { - long pack = packed[i]; - return Float.intBitsToFloat((int) (pack & 0xFFFFFFFF)); + i = i * 3 + 1; + final int high = ((int) packed[i++]) << 16; + final int low = packed[i] & 0x0000FFFF; + return Float.intBitsToFloat(high | low); } - private static long packedValue(int index, float score) { - long pack = ((long) (Float.floatToIntBits(score))) & 0x00000000FFFFFFFFL; - pack = pack | (((long) index) << 32); - return pack; - } - - private static void pack(long[] packed, int i, int index, float score) { - packed[i] = packedValue(index, score); + private static void pack(short[] packed, int i, int index, float score) { + if (i > Short.MAX_VALUE) { + throw new ArithmeticException("How did you make an index with 30,000 weights??"); + } + int pos = i * 3; + packed[pos++] = (short) index; + final int bits = Float.floatToIntBits(score); + packed[pos++] = (short) ((bits & 0xFFFF0000) >> 16); + packed[pos] = (short) (bits & 0x0000FFFF); } private void pack(int i, int index, float score) { - packed[i] = packedValue(index, score); + if (i > Short.MAX_VALUE) { + throw new ArithmeticException("How did you make an index with 30,000 weights??"); + } + int pos = i * 3; + packed[pos++] = (short) index; + final int bits = Float.floatToIntBits(score); + packed[pos++] = (short) ((bits & 0xFFFF0000) >> 16); + packed[pos] = (short) (bits & 0x0000FFFF); } public void score(float[] scores) { @@ -76,14 +85,17 @@ public void score(float[] scores) { if (length > scores.length) { throw new AssertionError("Called with an array of scores too small to fit"); } - for (int i = 0; i < length; ++i) { + for (int i = 0; i < packed.length; ) { // Since this is the critical method, we optimize it even further. // We could do this: // int index = unpackIndex; float score = unpackScore; - // That results in an extra array lookup - final long pack = packed[i]; - final int index = (int) (pack >>> 32); - final float score = Float.intBitsToFloat((int) (pack & 0xFFFFFFFF)); + // That results in extra operations + final short index = packed[i++]; + final int high = ((int) packed[i++]) << 16; + final int low = packed[i++] & 0x0000FFFF; + final int bits = high | low; + // final int bits = (((int) packed[i++]) << 16) | (packed[i++] & 0x0000FFFF); + final float score = Float.intBitsToFloat(bits); scores[index] += score; } } @@ -102,7 +114,7 @@ public void addScaled(Weight other, float scale) { void condense() { // threshold is in case floating point math makes a feature we // don't care about exist - if (packed == null) { + if (packed == null || packed.length == 0) { return; } @@ -115,7 +127,7 @@ void condense() { } if (nonzero == 0) { - packed = null; + packed = EMPTY; return; } @@ -123,7 +135,7 @@ void condense() { return; } - long[] newPacked = new long[nonzero]; + short[] newPacked = new short[nonzero * 3]; int j = 0; for (int i = 0; i < length; ++i) { if (Math.abs(unpackScore(i)) <= THRESHOLD) { @@ -156,8 +168,8 @@ public void updateWeight(int index, float increment) { return; } - if (packed == null) { - packed = new long[1]; + if (packed == null || packed.length == 0) { + packed = new short[3]; pack(0, index, increment); return; } @@ -165,14 +177,14 @@ public void updateWeight(int index, float increment) { final int length = size(); for (int i = 0; i < length; ++i) { if (unpackIndex(i) == index) { - float score = unpackScore(i); + final float score = unpackScore(i); pack(i, index, score + increment); return; } } - long[] newPacked = new long[length + 1]; - for (int i = 0; i < length; ++i) { + short[] newPacked = new short[packed.length + 3]; + for (int i = 0; i < packed.length; ++i) { newPacked[i] = packed[i]; } pack(newPacked, length, index, increment); @@ -235,31 +247,33 @@ void l2Reg(float reg) { public String toString() { StringBuilder builder = new StringBuilder(); final int length = size(); + builder.append("Weight("); for (int i = 0; i < length; ++i) { - if (i > 0) builder.append(" "); + if (i > 0) builder.append(" "); builder.append(unpackIndex(i) + "=" + unpackScore(i)); } + builder.append(")"); return builder.toString(); } - private long[] packed; + private short[] packed; void writeBytes(ByteArrayOutputStream bout) { ByteArrayUtils.writeInt(bout, packed.length); for (int i = 0; i < packed.length; ++i) { - ByteArrayUtils.writeLong(bout, packed[i]); + ByteArrayUtils.writeShort(bout, packed[i]); } } static Weight readBytes(ByteArrayInputStream bin) { int len = ByteArrayUtils.readInt(bin); Weight weight = new Weight(); - weight.packed = new long[len]; + weight.packed = new short[len]; for (int i = 0; i < len; ++i) { - weight.packed[i] = ByteArrayUtils.readLong(bin); + weight.packed[i] = ByteArrayUtils.readShort(bin); } return weight; } - private static final long serialVersionUID = 2; + private static final long serialVersionUID = 3; } From fe88eb32bfd0da77bf0a0d498adb0caf7d8909da Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 3 Mar 2021 07:49:04 -0800 Subject: [PATCH 5/6] Avoid division --- src/edu/stanford/nlp/parser/shiftreduce/Weight.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/edu/stanford/nlp/parser/shiftreduce/Weight.java b/src/edu/stanford/nlp/parser/shiftreduce/Weight.java index e91aadd52b..8f32d755b5 100644 --- a/src/edu/stanford/nlp/parser/shiftreduce/Weight.java +++ b/src/edu/stanford/nlp/parser/shiftreduce/Weight.java @@ -81,8 +81,7 @@ private void pack(int i, int index, float score) { } public void score(float[] scores) { - final int length = size(); - if (length > scores.length) { + if (packed.length > scores.length * 3) { throw new AssertionError("Called with an array of scores too small to fit"); } for (int i = 0; i < packed.length; ) { From 7e9eaf1735f612395ae3410f46696b9467907d40 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 12 Mar 2021 00:20:35 -0800 Subject: [PATCH 6/6] Apparently the latest modifications use even more memory --- scripts/srparser/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/srparser/Makefile b/scripts/srparser/Makefile index f574a21ff3..3845413a38 100644 --- a/scripts/srparser/Makefile +++ b/scripts/srparser/Makefile @@ -144,7 +144,7 @@ englishSR.ser.gz: englishSR.beam.ser.gz: @echo Training $@ @echo Will test on $(ENGLISH_TEST) - java -mx50g edu.stanford.nlp.parser.shiftreduce.ShiftReduceParser -trainTreebank $(ENGLISH_TRAIN) -devTreebank $(ENGLISH_DEV) -serializedPath $@ $(DEFAULT_OPTIONS) -preTag -taggerSerializedFile $(ENGLISH_TAGGER) -tlpp $(ENGLISH_TLPP) $(TRAIN_BEAM) $(AUGMENT_LESS) > $@.out 2>&1 + java -mx80g edu.stanford.nlp.parser.shiftreduce.ShiftReduceParser -trainTreebank $(ENGLISH_TRAIN) -devTreebank $(ENGLISH_DEV) -serializedPath $@ $(DEFAULT_OPTIONS) -preTag -taggerSerializedFile $(ENGLISH_TAGGER) -tlpp $(ENGLISH_TLPP) $(TRAIN_BEAM) $(AUGMENT_LESS) > $@.out 2>&1 java -mx5g edu.stanford.nlp.parser.shiftreduce.ShiftReduceParser $(TEST_ARGS) -testTreebank $(ENGLISH_TEST) -serializedPath $@ -preTag -taggerSerializedFile $(ENGLISH_TAGGER) >> $@.out 2>&1 frenchSR.ser.gz: