Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions internal/handshake/block.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package handshake

import (
"bytes"
"encoding/binary"
"io"
)
Expand All @@ -16,13 +17,28 @@ type Block struct {
Transactions []Transaction
}

func (b *Block) Decode(r io.Reader) error {
func NewBlockFromReader(r io.Reader) (*Block, error) {
// Read entire input into a bytes.Buffer
tmpData, err := io.ReadAll(r)
if err != nil {
return nil, err
}
buf := bytes.NewBuffer(tmpData)
// Decode block
var tmpBlock Block
if err := tmpBlock.Decode(buf); err != nil {
return nil, err
}
return &tmpBlock, err
}

func (b *Block) Decode(r *bytes.Buffer) error {
// Decode header
if err := b.Header.Decode(r); err != nil {
return err
}
// Transactions
txCount, err := binary.ReadUvarint(r.(io.ByteReader))
txCount, err := binary.ReadUvarint(r)
if err != nil {
return err
}
Expand Down
3 changes: 1 addition & 2 deletions internal/handshake/block_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ func TestDecodeHandshakeBlock(t *testing.T) {
t.Fatalf("unexpected error: %s", err)
}
br := bytes.NewReader(testBlockBytes)
var block handshake.Block
err = block.Decode(br)
block, err := handshake.NewBlockFromReader(br)
if err != nil {
t.Fatalf("unexpected error deserializing block: %s", err)
}
Expand Down
84 changes: 71 additions & 13 deletions internal/handshake/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,49 @@
package handshake

import (
"bytes"
"encoding/binary"
"errors"
"io"

"golang.org/x/crypto/blake2b"
)

type Transaction struct {
Version uint32
Inputs []TransactionInput
Outputs []TransactionOutput
LockTime uint32
Version uint32
Inputs []TransactionInput
Outputs []TransactionOutput
LockTime uint32
hash []byte
witnessHash []byte
}

func (t *Transaction) Decode(r io.Reader) error {
var err error
if err = binary.Read(r, binary.LittleEndian, &t.Version); err != nil {
func NewTransactionFromReader(r io.Reader) (*Transaction, error) {
// Read entire input into a bytes.Buffer
tmpData, err := io.ReadAll(r)
if err != nil {
return nil, err
}
buf := bytes.NewBuffer(tmpData)
// Decode TX
var tmpTransaction Transaction
if err := tmpTransaction.Decode(buf); err != nil {
return nil, err
}
return &tmpTransaction, err
}

func (t *Transaction) Decode(r *bytes.Buffer) error {
// Save original buffer
// This is needed to capture TX bytes
origData := make([]byte, r.Len())
copy(origData, r.Bytes())
// Version
if err := binary.Read(r, binary.LittleEndian, &t.Version); err != nil {
return err
}
// Inputs
inCount, err := binary.ReadUvarint(r.(io.ByteReader))
inCount, err := binary.ReadUvarint(r)
if err != nil {
return err
}
Expand All @@ -37,7 +61,7 @@ func (t *Transaction) Decode(r io.Reader) error {
t.Inputs = append(t.Inputs, tmpInput)
}
// Outputs
outCount, err := binary.ReadUvarint(r.(io.ByteReader))
outCount, err := binary.ReadUvarint(r)
if err != nil {
return err
}
Expand All @@ -52,22 +76,56 @@ func (t *Transaction) Decode(r io.Reader) error {
if err := binary.Read(r, binary.LittleEndian, &t.LockTime); err != nil {
return err
}
// Capture original TX bytes
txBytes := origData[:len(origData)-r.Len()]
// Generate TX hash
tmpHash := blake2b.Sum256(txBytes)
t.hash = make([]byte, len(tmpHash))
copy(t.hash, tmpHash[:])
// Save remaining data
// This is needed for capturing the witness data bytes
origData = make([]byte, r.Len())
copy(origData, r.Bytes())
// Witnesses
for i := uint64(0); i < inCount; i++ {
if err := t.Inputs[i].DecodeWitness(r); err != nil {
return err
}
}
// Capture original bytes for witness data
witnessDataBytes := origData[:len(origData)-r.Len()]
// Generate witness data hash
witnessDataHash := blake2b.Sum256(witnessDataBytes)
// Generate TX hash with witness data
h, err := blake2b.New256(nil)
if err != nil {
return err
}
h.Write(t.hash)
h.Write(witnessDataHash[:])
t.witnessHash = h.Sum(nil)
return nil
}

func (t *Transaction) Hash() []byte {
ret := make([]byte, len(t.hash))
copy(ret, t.hash)
return ret
}

func (t *Transaction) WitnessHash() []byte {
ret := make([]byte, len(t.witnessHash))
copy(ret, t.witnessHash)
return ret
}

type TransactionInput struct {
PrevOutpoint Outpoint
Sequence uint32
Witness [][]byte
}

func (i *TransactionInput) Decode(r io.Reader) error {
func (i *TransactionInput) Decode(r *bytes.Buffer) error {
if err := i.PrevOutpoint.Decode(r); err != nil {
return err
}
Expand Down Expand Up @@ -102,7 +160,7 @@ type TransactionOutput struct {
Covenant GenericCovenant
}

func (o *TransactionOutput) Decode(r io.Reader) error {
func (o *TransactionOutput) Decode(r *bytes.Buffer) error {
if err := binary.Read(r, binary.LittleEndian, &o.Value); err != nil {
return err
}
Expand All @@ -120,7 +178,7 @@ type Outpoint struct {
Index uint32
}

func (o *Outpoint) Decode(r io.Reader) error {
func (o *Outpoint) Decode(r *bytes.Buffer) error {
return binary.Read(r, binary.LittleEndian, o)
}

Expand All @@ -129,7 +187,7 @@ type Address struct {
Hash []byte
}

func (a *Address) Decode(r io.Reader) error {
func (a *Address) Decode(r *bytes.Buffer) error {
if err := binary.Read(r, binary.LittleEndian, &a.Version); err != nil {
return err
}
Expand Down
Loading