Skip to content

Commit a579021

Browse files
authored
Add COSE_Key support (#146)
Signed-off-by: Sergei Trofimov <[email protected]>
1 parent 354ac99 commit a579021

File tree

10 files changed

+1774
-53
lines changed

10 files changed

+1774
-53
lines changed

.github/.codecov.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
coverage:
22
status:
3+
patch: off
34
project:
45
default:
5-
target: 89%
6+
target: 89%

algorithm.go

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package cose
22

33
import (
44
"crypto"
5+
"fmt"
56
"strconv"
67
)
78

@@ -36,10 +37,12 @@ const (
3637

3738
// PureEdDSA by RFC 8152.
3839
AlgorithmEd25519 Algorithm = -8
40+
41+
// An invalid/unrecognised algorithm.
42+
AlgorithmInvalid Algorithm = 0
3943
)
4044

4145
// Algorithm represents an IANA algorithm entry in the COSE Algorithms registry.
42-
// Algorithms with string values are not supported.
4346
//
4447
// # See Also
4548
//
@@ -72,6 +75,35 @@ func (a Algorithm) String() string {
7275
}
7376
}
7477

78+
// MarshalCBOR marshals the Algorithm as a CBOR int.
79+
func (a Algorithm) MarshalCBOR() ([]byte, error) {
80+
return encMode.Marshal(int64(a))
81+
}
82+
83+
// UnmarshalCBOR populates the Algorithm from the provided CBOR value (must be
84+
// int or tstr).
85+
func (a *Algorithm) UnmarshalCBOR(data []byte) error {
86+
var raw intOrStr
87+
88+
if err := raw.UnmarshalCBOR(data); err != nil {
89+
return fmt.Errorf("invalid algorithm value: %w", err)
90+
}
91+
92+
if raw.IsString() {
93+
v := algorithmFromString(raw.String())
94+
if v == AlgorithmInvalid {
95+
return fmt.Errorf("unknown algorithm value %q", raw.String())
96+
}
97+
98+
*a = v
99+
} else {
100+
v := raw.Int()
101+
*a = Algorithm(v)
102+
}
103+
104+
return nil
105+
}
106+
75107
// hashFunc returns the hash associated with the algorithm supported by this
76108
// library.
77109
func (a Algorithm) hashFunc() crypto.Hash {
@@ -103,3 +135,8 @@ func computeHash(h crypto.Hash, data []byte) ([]byte, error) {
103135
}
104136
return hh.Sum(nil), nil
105137
}
138+
139+
// NOTE: there are currently no registered string values for an algorithm.
140+
func algorithmFromString(v string) Algorithm {
141+
return AlgorithmInvalid
142+
}

algorithm_test.go

Lines changed: 17 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,41 +16,6 @@ func TestAlgorithm_String(t *testing.T) {
1616
alg Algorithm
1717
want string
1818
}{
19-
{
20-
name: "PS256",
21-
alg: AlgorithmPS256,
22-
want: "PS256",
23-
},
24-
{
25-
name: "PS384",
26-
alg: AlgorithmPS384,
27-
want: "PS384",
28-
},
29-
{
30-
name: "PS512",
31-
alg: AlgorithmPS512,
32-
want: "PS512",
33-
},
34-
{
35-
name: "ES256",
36-
alg: AlgorithmES256,
37-
want: "ES256",
38-
},
39-
{
40-
name: "ES384",
41-
alg: AlgorithmES384,
42-
want: "ES384",
43-
},
44-
{
45-
name: "ES512",
46-
alg: AlgorithmES512,
47-
want: "ES512",
48-
},
49-
{
50-
name: "Ed25519",
51-
alg: AlgorithmEd25519,
52-
want: "EdDSA",
53-
},
5419
{
5520
name: "unknown algorithm",
5621
alg: 0,
@@ -66,6 +31,23 @@ func TestAlgorithm_String(t *testing.T) {
6631
}
6732
}
6833

34+
func TestAlgorithm_CBOR(t *testing.T) {
35+
tvs2 := []struct {
36+
Data []byte
37+
ExpectedError string
38+
}{
39+
{[]byte{0x63, 0x66, 0x6f, 0x6f}, "unknown algorithm value \"foo\""},
40+
{[]byte{0x40}, "invalid algorithm value: must be int or string, found []uint8"},
41+
}
42+
43+
for _, tv := range tvs2 {
44+
var a Algorithm
45+
46+
err := a.UnmarshalCBOR(tv.Data)
47+
assertEqualError(t, err, tv.ExpectedError)
48+
}
49+
}
50+
6951
func TestAlgorithm_computeHash(t *testing.T) {
7052
// run tests
7153
data := []byte("hello world")

common.go

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
package cose
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
)
7+
8+
// intOrStr is a value that can be either an int or a tstr when serialized to
9+
// CBOR.
10+
type intOrStr struct {
11+
intVal int64
12+
strVal string
13+
isString bool
14+
}
15+
16+
func newIntOrStr(v interface{}) *intOrStr {
17+
var ios intOrStr
18+
if err := ios.Set(v); err != nil {
19+
return nil
20+
}
21+
return &ios
22+
}
23+
24+
func (ios intOrStr) Int() int64 {
25+
return ios.intVal
26+
}
27+
28+
func (ios intOrStr) String() string {
29+
if ios.IsString() {
30+
return ios.strVal
31+
}
32+
return fmt.Sprint(ios.intVal)
33+
}
34+
35+
func (ios intOrStr) IsInt() bool {
36+
return !ios.isString
37+
}
38+
39+
func (ios intOrStr) IsString() bool {
40+
return ios.isString
41+
}
42+
43+
func (ios intOrStr) Value() interface{} {
44+
if ios.IsInt() {
45+
return ios.intVal
46+
}
47+
48+
return ios.strVal
49+
}
50+
51+
func (ios *intOrStr) Set(v interface{}) error {
52+
switch t := v.(type) {
53+
case int64:
54+
ios.intVal = t
55+
ios.strVal = ""
56+
ios.isString = false
57+
case int:
58+
ios.intVal = int64(t)
59+
ios.strVal = ""
60+
ios.isString = false
61+
case string:
62+
ios.strVal = t
63+
ios.intVal = 0
64+
ios.isString = true
65+
default:
66+
return fmt.Errorf("must be int or string, found %T", t)
67+
}
68+
69+
return nil
70+
}
71+
72+
// MarshalCBOR returns the encoded CBOR representation of the intOrString, as
73+
// either int or tstr, depending on the value. If no value has been set,
74+
// intOrStr is encoded as a zero-length tstr.
75+
func (ios intOrStr) MarshalCBOR() ([]byte, error) {
76+
if ios.IsInt() {
77+
return encMode.Marshal(ios.intVal)
78+
}
79+
80+
return encMode.Marshal(ios.strVal)
81+
}
82+
83+
// UnmarshalCBOR unmarshals the provided CBOR encoded data (must be an int,
84+
// uint, or tstr).
85+
func (ios *intOrStr) UnmarshalCBOR(data []byte) error {
86+
if len(data) == 0 {
87+
return errors.New("zero length buffer")
88+
}
89+
90+
var val interface{}
91+
if err := decMode.Unmarshal(data, &val); err != nil {
92+
return err
93+
}
94+
95+
return ios.Set(val)
96+
}

common_test.go

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
package cose
2+
3+
import (
4+
"bytes"
5+
"reflect"
6+
"testing"
7+
8+
"github.com/fxamacker/cbor/v2"
9+
)
10+
11+
func Test_intOrStr(t *testing.T) {
12+
ios := newIntOrStr(3)
13+
assertEqual(t, true, ios.IsInt())
14+
assertEqual(t, false, ios.IsString())
15+
assertEqual(t, 3, ios.Int())
16+
assertEqual(t, "3", ios.String())
17+
18+
ios = newIntOrStr("foo")
19+
assertEqual(t, false, ios.IsInt())
20+
assertEqual(t, true, ios.IsString())
21+
assertEqual(t, 0, ios.Int())
22+
assertEqual(t, "foo", ios.String())
23+
24+
ios = newIntOrStr(3.5)
25+
if ios != nil {
26+
t.Errorf("Expected nil, got %v", ios)
27+
}
28+
}
29+
30+
func Test_intOrStr_CBOR(t *testing.T) {
31+
ios := newIntOrStr(3)
32+
data, err := ios.MarshalCBOR()
33+
requireNoError(t, err)
34+
assertEqual(t, []byte{0x03}, data)
35+
36+
ios = &intOrStr{}
37+
err = ios.UnmarshalCBOR(data)
38+
requireNoError(t, err)
39+
assertEqual(t, true, ios.IsInt())
40+
assertEqual(t, 3, ios.Int())
41+
42+
ios = newIntOrStr("foo")
43+
data, err = ios.MarshalCBOR()
44+
requireNoError(t, err)
45+
assertEqual(t, []byte{0x63, 0x66, 0x6f, 0x6f}, data)
46+
47+
ios = &intOrStr{}
48+
err = ios.UnmarshalCBOR(data)
49+
requireNoError(t, err)
50+
assertEqual(t, true, ios.IsString())
51+
assertEqual(t, "foo", ios.String())
52+
53+
// empty value as field
54+
s := struct {
55+
Field1 intOrStr `cbor:"1,keyasint"`
56+
Field2 int `cbor:"2,keyasint"`
57+
}{Field1: intOrStr{}, Field2: 7}
58+
59+
data, err = cbor.Marshal(s)
60+
requireNoError(t, err)
61+
assertEqual(t, []byte{0xa2, 0x1, 0x00, 0x2, 0x7}, data)
62+
63+
ios = &intOrStr{}
64+
data = []byte{0x22}
65+
err = ios.UnmarshalCBOR(data)
66+
requireNoError(t, err)
67+
assertEqual(t, true, ios.IsInt())
68+
assertEqual(t, -3, ios.Int())
69+
70+
data = []byte{}
71+
err = ios.UnmarshalCBOR(data)
72+
assertEqualError(t, err, "zero length buffer")
73+
74+
data = []byte{0x40}
75+
err = ios.UnmarshalCBOR(data)
76+
assertEqualError(t, err, "must be int or string, found []uint8")
77+
78+
data = []byte{0xff, 0xff}
79+
err = ios.UnmarshalCBOR(data)
80+
assertEqualError(t, err, "cbor: unexpected \"break\" code")
81+
}
82+
83+
func requireNoError(t *testing.T, err error) {
84+
if err != nil {
85+
t.Errorf("Unexpected error: %q", err)
86+
t.Fail()
87+
}
88+
}
89+
90+
func assertEqualError(t *testing.T, err error, expected string) {
91+
if err == nil || err.Error() != expected {
92+
t.Errorf("Unexpected error: want %q, got %q", expected, err)
93+
}
94+
}
95+
96+
func assertEqual(t *testing.T, expected, actual interface{}) {
97+
if !objectsAreEqualValues(expected, actual) {
98+
t.Errorf("Unexpected value: want %v, got %v", expected, actual)
99+
}
100+
}
101+
102+
// taken from github.com/stretchr/testify
103+
func objectsAreEqualValues(expected, actual interface{}) bool {
104+
if objectsAreEqual(expected, actual) {
105+
return true
106+
}
107+
108+
actualType := reflect.TypeOf(actual)
109+
if actualType == nil {
110+
return false
111+
}
112+
expectedValue := reflect.ValueOf(expected)
113+
if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) {
114+
// Attempt comparison after type conversion
115+
return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual)
116+
}
117+
118+
return false
119+
}
120+
121+
// taken from github.com/stretchr/testify
122+
func objectsAreEqual(expected, actual interface{}) bool {
123+
if expected == nil || actual == nil {
124+
return expected == actual
125+
}
126+
127+
exp, ok := expected.([]byte)
128+
if !ok {
129+
return reflect.DeepEqual(expected, actual)
130+
}
131+
132+
act, ok := actual.([]byte)
133+
if !ok {
134+
return false
135+
}
136+
if exp == nil || act == nil {
137+
return exp == nil && act == nil
138+
}
139+
return bytes.Equal(exp, act)
140+
}

errors.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,8 @@ var (
1414
ErrUnavailableHashFunc = errors.New("hash function is not available")
1515
ErrVerification = errors.New("verification error")
1616
ErrInvalidPubKey = errors.New("invalid public key")
17+
ErrInvalidPrivKey = errors.New("invalid private key")
18+
ErrNotPrivKey = errors.New("not a private key")
19+
ErrSignOpNotSupported = errors.New("sign key_op not supported by key")
20+
ErrVerifyOpNotSupported = errors.New("verify key_op not supported by key")
1721
)

0 commit comments

Comments
 (0)