From 848fd6225ffafcf017e2190ff83bf49bf2df0efc Mon Sep 17 00:00:00 2001 From: Aurora Gaffney Date: Thu, 17 Apr 2025 13:55:36 -0400 Subject: [PATCH] feat: generate Handshake TX hash Signed-off-by: Aurora Gaffney --- internal/handshake/block.go | 20 +++++++- internal/handshake/block_test.go | 3 +- internal/handshake/transaction.go | 84 ++++++++++++++++++++++++++----- 3 files changed, 90 insertions(+), 17 deletions(-) diff --git a/internal/handshake/block.go b/internal/handshake/block.go index 825d5a4..2bbbe9f 100644 --- a/internal/handshake/block.go +++ b/internal/handshake/block.go @@ -7,6 +7,7 @@ package handshake import ( + "bytes" "encoding/binary" "io" ) @@ -16,13 +17,28 @@ type Block struct { Transactions []Transaction } -func (b *Block) Decode(r io.Reader) error { +func NewBlockFromReader(r io.Reader) (*Block, error) { + // Read entire input into a bytes.Buffer + tmpData, err := io.ReadAll(r) + if err != nil { + return nil, err + } + buf := bytes.NewBuffer(tmpData) + // Decode block + var tmpBlock Block + if err := tmpBlock.Decode(buf); err != nil { + return nil, err + } + return &tmpBlock, err +} + +func (b *Block) Decode(r *bytes.Buffer) error { // Decode header if err := b.Header.Decode(r); err != nil { return err } // Transactions - txCount, err := binary.ReadUvarint(r.(io.ByteReader)) + txCount, err := binary.ReadUvarint(r) if err != nil { return err } diff --git a/internal/handshake/block_test.go b/internal/handshake/block_test.go index 10a9b07..6830582 100644 --- a/internal/handshake/block_test.go +++ b/internal/handshake/block_test.go @@ -120,8 +120,7 @@ func TestDecodeHandshakeBlock(t *testing.T) { t.Fatalf("unexpected error: %s", err) } br := bytes.NewReader(testBlockBytes) - var block handshake.Block - err = block.Decode(br) + block, err := handshake.NewBlockFromReader(br) if err != nil { t.Fatalf("unexpected error deserializing block: %s", err) } diff --git a/internal/handshake/transaction.go b/internal/handshake/transaction.go index 534dc11..924a909 100644 --- a/internal/handshake/transaction.go +++ b/internal/handshake/transaction.go @@ -7,25 +7,49 @@ package handshake import ( + "bytes" "encoding/binary" "errors" "io" + + "golang.org/x/crypto/blake2b" ) type Transaction struct { - Version uint32 - Inputs []TransactionInput - Outputs []TransactionOutput - LockTime uint32 + Version uint32 + Inputs []TransactionInput + Outputs []TransactionOutput + LockTime uint32 + hash []byte + witnessHash []byte } -func (t *Transaction) Decode(r io.Reader) error { - var err error - if err = binary.Read(r, binary.LittleEndian, &t.Version); err != nil { +func NewTransactionFromReader(r io.Reader) (*Transaction, error) { + // Read entire input into a bytes.Buffer + tmpData, err := io.ReadAll(r) + if err != nil { + return nil, err + } + buf := bytes.NewBuffer(tmpData) + // Decode TX + var tmpTransaction Transaction + if err := tmpTransaction.Decode(buf); err != nil { + return nil, err + } + return &tmpTransaction, err +} + +func (t *Transaction) Decode(r *bytes.Buffer) error { + // Save original buffer + // This is needed to capture TX bytes + origData := make([]byte, r.Len()) + copy(origData, r.Bytes()) + // Version + if err := binary.Read(r, binary.LittleEndian, &t.Version); err != nil { return err } // Inputs - inCount, err := binary.ReadUvarint(r.(io.ByteReader)) + inCount, err := binary.ReadUvarint(r) if err != nil { return err } @@ -37,7 +61,7 @@ func (t *Transaction) Decode(r io.Reader) error { t.Inputs = append(t.Inputs, tmpInput) } // Outputs - outCount, err := binary.ReadUvarint(r.(io.ByteReader)) + outCount, err := binary.ReadUvarint(r) if err != nil { return err } @@ -52,22 +76,56 @@ func (t *Transaction) Decode(r io.Reader) error { if err := binary.Read(r, binary.LittleEndian, &t.LockTime); err != nil { return err } + // Capture original TX bytes + txBytes := origData[:len(origData)-r.Len()] + // Generate TX hash + tmpHash := blake2b.Sum256(txBytes) + t.hash = make([]byte, len(tmpHash)) + copy(t.hash, tmpHash[:]) + // Save remaining data + // This is needed for capturing the witness data bytes + origData = make([]byte, r.Len()) + copy(origData, r.Bytes()) // Witnesses for i := uint64(0); i < inCount; i++ { if err := t.Inputs[i].DecodeWitness(r); err != nil { return err } } + // Capture original bytes for witness data + witnessDataBytes := origData[:len(origData)-r.Len()] + // Generate witness data hash + witnessDataHash := blake2b.Sum256(witnessDataBytes) + // Generate TX hash with witness data + h, err := blake2b.New256(nil) + if err != nil { + return err + } + h.Write(t.hash) + h.Write(witnessDataHash[:]) + t.witnessHash = h.Sum(nil) return nil } +func (t *Transaction) Hash() []byte { + ret := make([]byte, len(t.hash)) + copy(ret, t.hash) + return ret +} + +func (t *Transaction) WitnessHash() []byte { + ret := make([]byte, len(t.witnessHash)) + copy(ret, t.witnessHash) + return ret +} + type TransactionInput struct { PrevOutpoint Outpoint Sequence uint32 Witness [][]byte } -func (i *TransactionInput) Decode(r io.Reader) error { +func (i *TransactionInput) Decode(r *bytes.Buffer) error { if err := i.PrevOutpoint.Decode(r); err != nil { return err } @@ -102,7 +160,7 @@ type TransactionOutput struct { Covenant GenericCovenant } -func (o *TransactionOutput) Decode(r io.Reader) error { +func (o *TransactionOutput) Decode(r *bytes.Buffer) error { if err := binary.Read(r, binary.LittleEndian, &o.Value); err != nil { return err } @@ -120,7 +178,7 @@ type Outpoint struct { Index uint32 } -func (o *Outpoint) Decode(r io.Reader) error { +func (o *Outpoint) Decode(r *bytes.Buffer) error { return binary.Read(r, binary.LittleEndian, o) } @@ -129,7 +187,7 @@ type Address struct { Hash []byte } -func (a *Address) Decode(r io.Reader) error { +func (a *Address) Decode(r *bytes.Buffer) error { if err := binary.Read(r, binary.LittleEndian, &a.Version); err != nil { return err }