Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/.codecov.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
coverage:
status:
patch: off
project:
default:
target: 89%
target: 89%
39 changes: 38 additions & 1 deletion algorithm.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cose

import (
"crypto"
"fmt"
"strconv"
)

Expand Down Expand Up @@ -36,10 +37,12 @@ const (

// PureEdDSA by RFC 8152.
AlgorithmEd25519 Algorithm = -8

// An invalid/unrecognised algorithm.
AlgorithmInvalid Algorithm = 0
)

// Algorithm represents an IANA algorithm entry in the COSE Algorithms registry.
// Algorithms with string values are not supported.
//
// # See Also
//
Expand Down Expand Up @@ -72,6 +75,35 @@ func (a Algorithm) String() string {
}
}

// MarshalCBOR marshals the Algorithm as a CBOR int.
func (a Algorithm) MarshalCBOR() ([]byte, error) {
return encMode.Marshal(int64(a))
}
Comment on lines +78 to +81
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this method redundant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not strictly necessary, however it is ther because there is a custom UnmarshalCBOR, and it seems prudent to provide this for symmetry, even if it is just a fall-through to int64 marshalling.


// UnmarshalCBOR populates the Algorithm from the provided CBOR value (must be
// int or tstr).
func (a *Algorithm) UnmarshalCBOR(data []byte) error {
var raw intOrStr

if err := raw.UnmarshalCBOR(data); err != nil {
return fmt.Errorf("invalid algorithm value: %w", err)
}

if raw.IsString() {
v := algorithmFromString(raw.String())
if v == AlgorithmInvalid {
return fmt.Errorf("unknown algorithm value %q", raw.String())
}

*a = v
} else {
v := raw.Int()
*a = Algorithm(v)
}

return nil
}

// hashFunc returns the hash associated with the algorithm supported by this
// library.
func (a Algorithm) hashFunc() crypto.Hash {
Expand Down Expand Up @@ -103,3 +135,8 @@ func computeHash(h crypto.Hash, data []byte) ([]byte, error) {
}
return hh.Sum(nil), nil
}

// NOTE: there are currently no registered string values for an algorithm.
func algorithmFromString(v string) Algorithm {
return AlgorithmInvalid
}
52 changes: 17 additions & 35 deletions algorithm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,41 +16,6 @@ func TestAlgorithm_String(t *testing.T) {
alg Algorithm
want string
}{
{
name: "PS256",
alg: AlgorithmPS256,
want: "PS256",
},
{
name: "PS384",
alg: AlgorithmPS384,
want: "PS384",
},
{
name: "PS512",
alg: AlgorithmPS512,
want: "PS512",
},
{
name: "ES256",
alg: AlgorithmES256,
want: "ES256",
},
{
name: "ES384",
alg: AlgorithmES384,
want: "ES384",
},
{
name: "ES512",
alg: AlgorithmES512,
want: "ES512",
},
{
name: "Ed25519",
alg: AlgorithmEd25519,
want: "EdDSA",
},
{
name: "unknown algorithm",
alg: 0,
Expand All @@ -66,6 +31,23 @@ func TestAlgorithm_String(t *testing.T) {
}
}

func TestAlgorithm_CBOR(t *testing.T) {
tvs2 := []struct {
Data []byte
ExpectedError string
}{
{[]byte{0x63, 0x66, 0x6f, 0x6f}, "unknown algorithm value \"foo\""},
{[]byte{0x40}, "invalid algorithm value: must be int or string, found []uint8"},
}

for _, tv := range tvs2 {
var a Algorithm

err := a.UnmarshalCBOR(tv.Data)
assertEqualError(t, err, tv.ExpectedError)
}
}

func TestAlgorithm_computeHash(t *testing.T) {
// run tests
data := []byte("hello world")
Expand Down
96 changes: 96 additions & 0 deletions common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package cose

import (
"errors"
"fmt"
)

// intOrStr is a value that can be either an int or a tstr when serialized to
// CBOR.
type intOrStr struct {
intVal int64
strVal string
isString bool
}

func newIntOrStr(v interface{}) *intOrStr {
var ios intOrStr
if err := ios.Set(v); err != nil {
return nil
}
return &ios
}

func (ios intOrStr) Int() int64 {
return ios.intVal
}

func (ios intOrStr) String() string {
if ios.IsString() {
return ios.strVal
}
return fmt.Sprint(ios.intVal)
}

func (ios intOrStr) IsInt() bool {
return !ios.isString
}

func (ios intOrStr) IsString() bool {
return ios.isString
}

func (ios intOrStr) Value() interface{} {
if ios.IsInt() {
return ios.intVal
}

return ios.strVal
}

func (ios *intOrStr) Set(v interface{}) error {
switch t := v.(type) {
case int64:
ios.intVal = t
ios.strVal = ""
ios.isString = false
case int:
ios.intVal = int64(t)
ios.strVal = ""
ios.isString = false
case string:
ios.strVal = t
ios.intVal = 0
ios.isString = true
default:
return fmt.Errorf("must be int or string, found %T", t)
}

return nil
}

// MarshalCBOR returns the encoded CBOR representation of the intOrString, as
// either int or tstr, depending on the value. If no value has been set,
// intOrStr is encoded as a zero-length tstr.
func (ios intOrStr) MarshalCBOR() ([]byte, error) {
if ios.IsInt() {
return encMode.Marshal(ios.intVal)
}

return encMode.Marshal(ios.strVal)
}

// UnmarshalCBOR unmarshals the provided CBOR encoded data (must be an int,
// uint, or tstr).
func (ios *intOrStr) UnmarshalCBOR(data []byte) error {
if len(data) == 0 {
return errors.New("zero length buffer")
}

var val interface{}
if err := decMode.Unmarshal(data, &val); err != nil {
return err
}

return ios.Set(val)
}
140 changes: 140 additions & 0 deletions common_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package cose

import (
"bytes"
"reflect"
"testing"

"github.com/fxamacker/cbor/v2"
)

func Test_intOrStr(t *testing.T) {
ios := newIntOrStr(3)
assertEqual(t, true, ios.IsInt())
assertEqual(t, false, ios.IsString())
assertEqual(t, 3, ios.Int())
assertEqual(t, "3", ios.String())

ios = newIntOrStr("foo")
assertEqual(t, false, ios.IsInt())
assertEqual(t, true, ios.IsString())
assertEqual(t, 0, ios.Int())
assertEqual(t, "foo", ios.String())

ios = newIntOrStr(3.5)
if ios != nil {
t.Errorf("Expected nil, got %v", ios)
}
}

func Test_intOrStr_CBOR(t *testing.T) {
ios := newIntOrStr(3)
data, err := ios.MarshalCBOR()
requireNoError(t, err)
assertEqual(t, []byte{0x03}, data)

ios = &intOrStr{}
err = ios.UnmarshalCBOR(data)
requireNoError(t, err)
assertEqual(t, true, ios.IsInt())
assertEqual(t, 3, ios.Int())

ios = newIntOrStr("foo")
data, err = ios.MarshalCBOR()
requireNoError(t, err)
assertEqual(t, []byte{0x63, 0x66, 0x6f, 0x6f}, data)

ios = &intOrStr{}
err = ios.UnmarshalCBOR(data)
requireNoError(t, err)
assertEqual(t, true, ios.IsString())
assertEqual(t, "foo", ios.String())

// empty value as field
s := struct {
Field1 intOrStr `cbor:"1,keyasint"`
Field2 int `cbor:"2,keyasint"`
}{Field1: intOrStr{}, Field2: 7}

data, err = cbor.Marshal(s)
requireNoError(t, err)
assertEqual(t, []byte{0xa2, 0x1, 0x00, 0x2, 0x7}, data)

ios = &intOrStr{}
data = []byte{0x22}
err = ios.UnmarshalCBOR(data)
requireNoError(t, err)
assertEqual(t, true, ios.IsInt())
assertEqual(t, -3, ios.Int())

data = []byte{}
err = ios.UnmarshalCBOR(data)
assertEqualError(t, err, "zero length buffer")

data = []byte{0x40}
err = ios.UnmarshalCBOR(data)
assertEqualError(t, err, "must be int or string, found []uint8")

data = []byte{0xff, 0xff}
err = ios.UnmarshalCBOR(data)
assertEqualError(t, err, "cbor: unexpected \"break\" code")
}

func requireNoError(t *testing.T, err error) {
if err != nil {
t.Errorf("Unexpected error: %q", err)
t.Fail()
}
}

func assertEqualError(t *testing.T, err error, expected string) {
if err == nil || err.Error() != expected {
t.Errorf("Unexpected error: want %q, got %q", expected, err)
}
}

func assertEqual(t *testing.T, expected, actual interface{}) {
if !objectsAreEqualValues(expected, actual) {
t.Errorf("Unexpected value: want %v, got %v", expected, actual)
}
}

// taken from github.com/stretchr/testify
func objectsAreEqualValues(expected, actual interface{}) bool {
if objectsAreEqual(expected, actual) {
return true
}

actualType := reflect.TypeOf(actual)
if actualType == nil {
return false
}
expectedValue := reflect.ValueOf(expected)
if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) {
// Attempt comparison after type conversion
return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual)
}

return false
}

// taken from github.com/stretchr/testify
func objectsAreEqual(expected, actual interface{}) bool {
if expected == nil || actual == nil {
return expected == actual
}

exp, ok := expected.([]byte)
if !ok {
return reflect.DeepEqual(expected, actual)
}

act, ok := actual.([]byte)
if !ok {
return false
}
if exp == nil || act == nil {
return exp == nil && act == nil
}
return bytes.Equal(exp, act)
}
4 changes: 4 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,8 @@ var (
ErrUnavailableHashFunc = errors.New("hash function is not available")
ErrVerification = errors.New("verification error")
ErrInvalidPubKey = errors.New("invalid public key")
ErrInvalidPrivKey = errors.New("invalid private key")
ErrNotPrivKey = errors.New("not a private key")
ErrSignOpNotSupported = errors.New("sign key_op not supported by key")
ErrVerifyOpNotSupported = errors.New("verify key_op not supported by key")
)
Loading