Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit d9a12ff

Browse files
authored
Merge pull request #650 from juanjux/base64-functions
Added from_base64 and to_base64
2 parents 89c0734 + 689788f commit d9a12ff

File tree

6 files changed

+218
-0
lines changed

6 files changed

+218
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ We support and actively test against certain third-party clients to ensure compa
105105
|`SUBSTR(str, pos, [len])`|Return a substring from the provided string starting at `pos` with a length of `len` characters. If no `len` is provided, all characters from `pos` until the end will be taken.|
106106
|`SUBSTRING(str, pos, [len])`|Return a substring from the provided string starting at `pos` with a length of `len` characters. If no `len` is provided, all characters from `pos` until the end will be taken.|
107107
|`SUM(expr)`|Returns the sum of expr in all rows.|
108+
|`TO_BASE64(str)`|Encodes the string str in base64 format.|
109+
|`FROM_BASE64(str)`|Decodes the base64-encoded string str.|
108110
|`TRIM(str)`|Returns the string str with all spaces removed.|
109111
|`UPPER(str)`|Returns the string str with all characters in upper case.|
110112
|`WEEKDAY(date)`|Returns the weekday of the given date.|

SUPPORTED.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@
110110
- LOG2
111111
- LOG10
112112
- SLEEP
113+
- TO_BASE64
114+
- FROM_BASE64
113115

114116
## Time functions
115117
- DAY

engine_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,14 @@ var queries = []struct {
901901
"SELECT SLEEP(0.5)",
902902
[]sql.Row{{int(0)}},
903903
},
904+
{
905+
"SELECT TO_BASE64('foo')",
906+
[]sql.Row{{string("Zm9v")}},
907+
},
908+
{
909+
"SELECT FROM_BASE64('YmFy')",
910+
[]sql.Row{{string("bar")}},
911+
},
904912
}
905913

906914
func TestQueries(t *testing.T) {

sql/expression/function/registry.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,6 @@ var Defaults = []sql.Function{
7676
sql.Function2{Name: "nullif", Fn: NewNullIf},
7777
sql.Function0{Name: "now", Fn: NewNow},
7878
sql.Function1{Name: "sleep", Fn: NewSleep},
79+
sql.Function1{Name: "to_base64", Fn: NewToBase64},
80+
sql.Function1{Name: "from_base64", Fn: NewFromBase64},
7981
}
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
package function
2+
3+
import (
4+
"encoding/base64"
5+
"fmt"
6+
"reflect"
7+
"strings"
8+
9+
"gopkg.in/src-d/go-mysql-server.v0/sql"
10+
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
11+
)
12+
13+
// ToBase64 is a function to encode a string to the Base64 format
14+
// using the same dialect that MySQL's TO_BASE64 uses
15+
type ToBase64 struct {
16+
expression.UnaryExpression
17+
}
18+
19+
// NewToBase64 creates a new ToBase64 expression.
20+
func NewToBase64(e sql.Expression) sql.Expression {
21+
return &ToBase64{expression.UnaryExpression{Child: e}}
22+
}
23+
24+
// Eval implements the Expression interface.
25+
func (t *ToBase64) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
26+
str, err := t.Child.Eval(ctx, row)
27+
28+
if err != nil {
29+
return nil, err
30+
}
31+
32+
if str == nil {
33+
return nil, nil
34+
}
35+
36+
str, err = sql.Text.Convert(str)
37+
if err != nil {
38+
return nil, sql.ErrInvalidType.New(reflect.TypeOf(str))
39+
}
40+
41+
encoded := base64.StdEncoding.EncodeToString([]byte(str.(string)))
42+
43+
lenEncoded := len(encoded)
44+
if lenEncoded <= 76 {
45+
return encoded, nil
46+
}
47+
48+
// Split into max 76 chars lines
49+
var out strings.Builder
50+
start := 0
51+
end := 76
52+
for {
53+
out.WriteString(encoded[start:end] + "\n")
54+
start += 76
55+
end += 76
56+
if end >= lenEncoded {
57+
out.WriteString(encoded[start:lenEncoded])
58+
break
59+
}
60+
}
61+
62+
return out.String(), nil
63+
}
64+
65+
// String implements the Stringer interface.
66+
func (t *ToBase64) String() string {
67+
return fmt.Sprintf("TO_BASE64(%s)", t.Child)
68+
}
69+
70+
// IsNullable implements the Expression interface.
71+
func (t *ToBase64) IsNullable() bool {
72+
return t.Child.IsNullable()
73+
}
74+
75+
// TransformUp implements the Expression interface.
76+
func (t *ToBase64) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) {
77+
child, err := t.Child.TransformUp(f)
78+
if err != nil {
79+
return nil, err
80+
}
81+
return f(NewToBase64(child))
82+
}
83+
84+
// Type implements the Expression interface.
85+
func (t *ToBase64) Type() sql.Type {
86+
return sql.Text
87+
}
88+
89+
90+
// FromBase64 is a function to decode a Base64-formatted string
91+
// using the same dialect that MySQL's FROM_BASE64 uses
92+
type FromBase64 struct {
93+
expression.UnaryExpression
94+
}
95+
96+
// NewFromBase64 creates a new FromBase64 expression.
97+
func NewFromBase64(e sql.Expression) sql.Expression {
98+
return &FromBase64{expression.UnaryExpression{Child: e}}
99+
}
100+
101+
// Eval implements the Expression interface.
102+
func (t *FromBase64) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
103+
str, err := t.Child.Eval(ctx, row)
104+
105+
if err != nil {
106+
return nil, err
107+
}
108+
109+
if str == nil {
110+
return nil, nil
111+
}
112+
113+
str, err = sql.Text.Convert(str)
114+
if err != nil {
115+
return nil, sql.ErrInvalidType.New(reflect.TypeOf(str))
116+
}
117+
118+
decoded, err := base64.StdEncoding.DecodeString(str.(string))
119+
if err != nil {
120+
return nil, err
121+
}
122+
123+
return string(decoded), nil
124+
}
125+
126+
// String implements the Stringer interface.
127+
func (t *FromBase64) String() string {
128+
return fmt.Sprintf("FROM_BASE64(%s)", t.Child)
129+
}
130+
131+
// IsNullable implements the Expression interface.
132+
func (t *FromBase64) IsNullable() bool {
133+
return t.Child.IsNullable()
134+
}
135+
136+
// TransformUp implements the Expression interface.
137+
func (t *FromBase64) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) {
138+
child, err := t.Child.TransformUp(f)
139+
if err != nil {
140+
return nil, err
141+
}
142+
return f(NewFromBase64(child))
143+
}
144+
145+
// Type implements the Expression interface.
146+
func (t *FromBase64) Type() sql.Type {
147+
return sql.Text
148+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package function
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
"gopkg.in/src-d/go-mysql-server.v0/sql"
8+
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
9+
)
10+
11+
func TestBase64(t *testing.T) {
12+
fTo := NewToBase64(expression.NewGetField(0, sql.Text, "", false))
13+
fFrom := NewFromBase64(expression.NewGetField(0, sql.Text, "", false))
14+
15+
testCases := []struct {
16+
name string
17+
row sql.Row
18+
expected interface{}
19+
err bool
20+
}{
21+
// Use a MySQL server to get expected values if updating/adding to this!
22+
{"null input", sql.NewRow(nil), nil, false},
23+
{"single_line", sql.NewRow("foo"), string("Zm9v"), false},
24+
{"multi_line", sql.NewRow(
25+
"Gallia est omnis divisa in partes tres, quarum unam " +
26+
"incolunt Belgae, aliam Aquitani, tertiam qui ipsorum lingua Celtae, " +
27+
"nostra Galli appellantur"),
28+
"R2FsbGlhIGVzdCBvbW5pcyBkaXZpc2EgaW4gcGFydGVzIHRyZXMsIHF1YXJ1bSB1bmFtIGluY29s\n" +
29+
"dW50IEJlbGdhZSwgYWxpYW0gQXF1aXRhbmksIHRlcnRpYW0gcXVpIGlwc29ydW0gbGluZ3VhIENl\n" +
30+
"bHRhZSwgbm9zdHJhIEdhbGxpIGFwcGVsbGFudHVy", false},
31+
{"empty_input", sql.NewRow(""), string(""), false},
32+
{"symbols", sql.NewRow("!@#$% %^&*()_+\r\n\t{};"), string("IUAjJCUgJV4mKigpXysNCgl7fTs="),
33+
false},
34+
}
35+
36+
for _, tt := range testCases {
37+
t.Run(tt.name, func(t *testing.T) {
38+
t.Helper()
39+
require := require.New(t)
40+
ctx := sql.NewEmptyContext()
41+
v, err := fTo.Eval(ctx, tt.row)
42+
43+
if tt.err {
44+
require.Error(err)
45+
} else {
46+
require.NoError(err)
47+
require.Equal(tt.expected, v)
48+
49+
ctx = sql.NewEmptyContext()
50+
v2, err := fFrom.Eval(ctx, sql.NewRow(v))
51+
require.NoError(err)
52+
require.Equal(sql.NewRow(v2), tt.row)
53+
}
54+
})
55+
}
56+
}

0 commit comments

Comments
 (0)