Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions capi/include/yara_x.h
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,9 @@ enum YRX_RESULT yrx_compiler_ban_module(struct YRX_COMPILER *compiler,
enum YRX_RESULT yrx_compiler_new_namespace(struct YRX_COMPILER *compiler,
const char *namespace_);

// Collects a string of a hashmap of all current loaded global vars
const char *yrx_compiler_get_globals(struct YRX_COMPILER *compiler);

// Defines a global variable of string type and sets its initial value.
enum YRX_RESULT yrx_compiler_define_global_str(struct YRX_COMPILER *compiler,
const char *ident,
Expand Down
19 changes: 18 additions & 1 deletion capi/src/compiler.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ffi::{c_char, CStr};
use std::ffi::{c_char, CStr, CString};
use std::mem;
use std::mem::ManuallyDrop;

Expand Down Expand Up @@ -387,6 +387,23 @@ unsafe fn yrx_compiler_define_global<
}
}

/// Collects a string of a hashmap of all current loaded global vars
#[no_mangle]
pub unsafe extern "C" fn yrx_compiler_get_globals(
compiler: *mut YRX_COMPILER,
) -> *const c_char {
let compiler = if let Some(compiler) = compiler.as_mut() {
compiler
} else {
return CString::new("Could not access the compiler").unwrap().into_raw();
};

let globals = compiler.inner.show_globals();
let json = serde_json::to_string(&globals).unwrap();

CString::new(json).unwrap().into_raw()
}

/// Defines a global variable of string type and sets its initial value.
#[no_mangle]
pub unsafe extern "C" fn yrx_compiler_define_global_str(
Expand Down
7 changes: 7 additions & 0 deletions go/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,13 @@ func (c *Compiler) DefineGlobal(ident string, value interface{}) error {
return nil
}

// Returns a String of a hashmap of all of the currently loaded global variables in the compiler
func (c *Compiler) GetGlobals() string {
cStr := C.yrx_compiler_get_globals(c.cCompiler)
defer C.free(unsafe.Pointer(cStr))
return C.GoString(cStr)
}

// Errors that occurred during the compilation, across multiple calls to
// [Compiler.AddSource].
func (c *Compiler) Errors() []CompileError {
Expand Down
29 changes: 29 additions & 0 deletions go/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"testing"
"os"
"io/ioutil"
"encoding/json"

"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -106,6 +107,34 @@ func TestRelaxedReSyntax(t *testing.T) {
assert.Len(t, scanResults.MatchingRules(), 1)
}

func TestGetGlobals(t *testing.T) {
c, err := NewCompiler()
assert.NoError(t, err)

x := map[string]interface{}{"a": map[string]interface{}{"a": "a"}, "b": "d"}

c.DefineGlobal("A", "B")
c.DefineGlobal("B", 1.5)
c.DefineGlobal("C", x)
c.DefineGlobal("D", true)

var globals map[string]interface{}

// Unmarshal the JSON string into the map
err = json.Unmarshal([]byte(c.GetGlobals()), &globals)
assert.NoError(t, err)

assert.Equal(t, globals["A"], "B")
assert.Equal(t, globals["B"], 1.5)
assert.Equal(t, globals["C"], x)
assert.Equal(t, globals["D"], true)

c, err = NewCompiler()
assert.NoError(t, err)

assert.Equal(t, c.GetGlobals(), "{}")
}

func TestConditionOptimization(t *testing.T) {
_, err := Compile(`
rule test { condition: true }`,
Expand Down
6 changes: 6 additions & 0 deletions lib/src/compiler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,12 @@ impl<'a> Compiler<'a> {
Ok(self)
}


/// Shows all current gloval variables of the compiler
pub fn show_globals(&mut self) -> serde_json::Value {
self.global_symbols.borrow_mut().show_globals()
}

/// Creates a new namespace.
///
/// Further calls to [`Compiler::add_source`] will put the rules under the
Expand Down
14 changes: 14 additions & 0 deletions lib/src/symbols/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,20 @@ impl SymbolTable {
self.map.insert(ident.into(), symbol)
}

/// Shows all currently loaded global vars inside the symbol table
pub fn show_globals(&self) -> serde_json::Value {
self.map
.iter()
.filter_map(|(k, v)| {
if let Symbol::Field { type_value, .. } = v {
Some((k.clone(), type_value.value_as_json()))
} else {
None
}
})
.collect()
}

/// Returns true if the symbol table already contains a symbol with
/// the given identifier.
#[inline]
Expand Down
74 changes: 74 additions & 0 deletions lib/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::rc::Rc;
use std::{mem, ptr};
use walrus::ir::InstrSeqType;
use walrus::ValType;
use serde_json::{Map as JsonMap, Number as JsonNumber, Value as JsonValue};

use crate::modules::protos::yara::enum_value_options::Value as EnumValue;
use crate::symbols::{Symbol, SymbolLookup, SymbolTable};
Expand Down Expand Up @@ -617,6 +618,79 @@ impl TypeValue {
constraints: Some(constraints.into()),
}
}

pub fn value_as_json(&self) -> JsonValue {
match self {
Self::Unknown => JsonValue::Null,
Self::Bool { value } => value.extract().cloned().map(JsonValue::Bool).unwrap_or(JsonValue::Null),
Self::Integer { value, .. } => {
if let Some(i) = value.extract().cloned() {
JsonValue::Number(JsonNumber::from(i))
} else {
JsonValue::Null
}
}
Self::Float {value} => {
if let Some(f) = value.extract().cloned() {
JsonNumber::from_f64(f).map(JsonValue::Number).unwrap_or(JsonValue::Null)
} else {
JsonValue::Null
}
}
Self::String {value, ..} => {
if let Some(s) = value.extract().cloned() {
let s_str = String::from_utf8_lossy(s.as_slice()).into_owned();
JsonValue::String(s_str)
} else {
JsonValue::Null
}
}
Self::Regexp(r) => {
if let Some(re) = r {
JsonValue::String(re.as_str().to_string())
} else {
JsonValue::Null
}
}
Self::Struct(s) => {
let mut obj = JsonMap::new();
for (key, field) in s.fields().iter() {
obj.insert(key.clone(), field.type_value.value_as_json());
}
JsonValue::Object(obj)
}
Self::Array(a) => match a.as_ref() {
Array::Integers(items) => JsonValue::Array(items.iter().map(|i| JsonValue::Number(JsonNumber::from(*i))).collect()),
Array::Floats(items) => JsonValue::Array(items.iter().map(|f| JsonNumber::from_f64(*f).map(JsonValue::Number).unwrap_or(JsonValue::Null)).collect()),
Array::Bools(items) => JsonValue::Array(items.iter().map(|b| JsonValue::Bool(*b)).collect()),
Array::Strings(items) => JsonValue::Array(items.iter().map(|s| JsonValue::String(String::from_utf8_lossy(s.as_slice()).into_owned())).collect()),
Array::Structs(items) => JsonValue::Array(items.iter().map(|st| {
let mut obj = JsonMap::new();
for (key, field) in st.fields().iter() {
obj.insert(key.clone(), field.type_value.value_as_json());
}
JsonValue::Object(obj)
}).collect()),
}
Self::Map(m) => match m.as_ref() {
Map::IntegerKeys { map, .. } => {
let mut obj = JsonMap::new();
for (k, v) in map.iter() {
obj.insert(k.to_string(), v.value_as_json());
}
JsonValue::Object(obj)
}
Map::StringKeys { map, .. } => {
let mut obj = JsonMap::new();
for (k, v) in map.iter() {
obj.insert(String::from_utf8_lossy(k.as_slice()).into_owned(), v.value_as_json());
}
JsonValue::Object(obj)
}
}
Self::Func(_) => JsonValue::Null,
}
}
}

impl Display for TypeValue {
Expand Down
5 changes: 5 additions & 0 deletions lib/src/types/structure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ impl Struct {
self.protobuf_type_name.as_deref()
}

/// Returns the fields in this structure.
pub fn fields(&self) -> &IndexMap<String, StructField> {
&self.fields
}

/// Adds a new field to the structure.
///
/// The field name may be a dot-separated sequence of field names, like
Expand Down
25 changes: 25 additions & 0 deletions py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ matches = rules.scan(b'some dummy data')

#![deny(missing_docs)]
use std::borrow::Cow;
use std::collections::HashMap;
use std::io::{Read, Write};
use std::marker::PhantomPinned;
use std::ops::Deref;
Expand Down Expand Up @@ -54,6 +55,22 @@ fn dict_to_json(dict: Bound<PyAny>) -> PyResult<serde_json::Value> {
.map_err(|err| PyValueError::new_err(err.to_string()))
}

fn json_to_dict<'py>(
py: Python<'py>,
json: &serde_json::Value,
) -> PyResult<Py<PyAny>> {
let json_str = serde_json::to_string(json)
.map_err(|err| PyValueError::new_err(err.to_string()))?;

static JSON_LOADS: OnceLock<Py<PyAny>> = OnceLock::new();

let json_loads = JSON_LOADS.get_or_init(|| {
let json_mod = PyModule::import(py, "json").unwrap().unbind();
json_mod.getattr(py, "loads").unwrap()
});
json_loads.call1(py,(json_str,))
}

#[derive(Debug, Clone, Display, EnumString, PartialEq)]
#[strum(ascii_case_insensitive)]
enum SupportedModules {
Expand Down Expand Up @@ -521,6 +538,14 @@ impl Compiler {
Ok(())
}

fn show_globals(
&mut self,
py: Python,
) -> Py<PyAny> {
json_to_dict(py, &self.inner.show_globals()).unwrap()
}


/// Creates a new namespace.
///
/// Further calls to [`Compiler::add_source`] will put the rules under the
Expand Down
13 changes: 13 additions & 0 deletions py/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,19 @@ def test_dict_globals():
matching_rules = scanner.scan(b'').matching_rules
assert len(matching_rules) == 1

def test_show_globals():
compiler = yara_x.Compiler()
compiler.define_global('some_dict', {"foo": "bar"})
compiler.define_global("A", "B")
compiler.define_global("B", 1)
compiler.add_source('rule test {condition: some_dict.foo == "bar"}')
x = compiler.show_globals()

assert(x['some_dict'] == {'foo': 'bar'})
assert(x['A'] == "B")
assert(x['B'] == 1)


def test_namespaces():
compiler = yara_x.Compiler()
compiler.new_namespace('foo')
Expand Down
5 changes: 5 additions & 0 deletions py/yara_x.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ class Compiler:
"""
...

def show_globals(self) -> dict:
r"""
Retrives a dict where the keys are the currently loaded global variable names and the values are their values
"""

def new_namespace(self, namespace: str) -> None:
r"""
Creates a new namespace.
Expand Down