Skip to content

Commit 49ee1cd

Browse files
committed
Experiment tree-sitter imperative
1 parent 310aec4 commit 49ee1cd

File tree

2 files changed

+248
-0
lines changed

2 files changed

+248
-0
lines changed
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
from typing import Any, Dict, Optional, Tuple
2+
from tree_sitter import Node, Tree
3+
from utils import parse_code, traverse_tree, rewrite, SOURCE_CODE
4+
5+
6+
relevant_builder_method_names_mapping = {
7+
"setAppName": "appName",
8+
"setMaster": "master",
9+
"set": "config",
10+
"setAll": "all",
11+
"setIfMissing": "ifMissing",
12+
"setJars": "jars",
13+
"setExecutorEnv": "executorEnv",
14+
"setSparkHome": "sparkHome",
15+
}
16+
17+
18+
def get_initializer_named(tree: Tree, name: str):
19+
for node in traverse_tree(tree):
20+
if node.type == "object_creation_expression":
21+
oce_type = node.child_by_field_name("type")
22+
if oce_type and oce_type.text.decode() == name:
23+
return node
24+
25+
26+
def get_enclosing_variable_declaration_name_type(
27+
node: Node,
28+
) -> Tuple[Node | None, str | None, str | None]:
29+
name, typ, nd = None, None, None
30+
if node.parent and node.parent.type == "variable_declarator":
31+
n = node.parent.child_by_field_name("name")
32+
if n:
33+
name = n.text.decode()
34+
if (
35+
node.parent.parent
36+
and node.parent.parent.type == "local_variable_declaration"
37+
):
38+
t = node.parent.parent.child_by_field_name("type")
39+
if t:
40+
typ = t.text.decode()
41+
nd = node.parent.parent
42+
return nd, name, typ
43+
44+
45+
def all_enclosing_method_invocations(node: Node) -> list[Node]:
46+
if node.parent and node.parent.type == "method_invocation":
47+
return [node.parent] + all_enclosing_method_invocations(node.parent)
48+
else:
49+
return []
50+
51+
52+
def build_spark_session_builder(builder_mappings: list[tuple[str, Node]]):
53+
replacement_expr = 'new SparkSession.builder().config("spark.sql.legacy.allowUntypedScalaUDF", "true")'
54+
for name, args in builder_mappings:
55+
replacement_expr += f".{name}{args.text.decode()}"
56+
return replacement_expr
57+
58+
59+
def update_spark_conf_init(
60+
tree: Tree, src_code: str, state: Dict[str, Any]
61+
) -> Tuple[Tree, str]:
62+
spark_conf_init = get_initializer_named(tree, "SparkConf")
63+
if not spark_conf_init:
64+
print("No SparkConf initializer found")
65+
return tree, src_code
66+
67+
encapsulating_method_invocations = all_enclosing_method_invocations(
68+
spark_conf_init
69+
)
70+
builder_mappings = []
71+
for n in encapsulating_method_invocations:
72+
name = n.child_by_field_name("name")
73+
if (
74+
name
75+
and name.text.decode()
76+
in relevant_builder_method_names_mapping.keys()
77+
):
78+
builder_mappings.append(
79+
(
80+
relevant_builder_method_names_mapping[name.text.decode()],
81+
n.child_by_field_name("arguments"),
82+
)
83+
)
84+
85+
builder_mapping = build_spark_session_builder(builder_mappings)
86+
87+
outermost_node_builder_pattern = (
88+
encapsulating_method_invocations[-1]
89+
if encapsulating_method_invocations
90+
else spark_conf_init
91+
)
92+
93+
node, name, typ = get_enclosing_variable_declaration_name_type(
94+
outermost_node_builder_pattern
95+
)
96+
97+
if not (node and name and typ):
98+
print("Not in a variable declaration")
99+
return tree, src_code
100+
101+
declaration_replacement = (
102+
f"SparkSession {name} = {builder_mapping}.getOrCreate();"
103+
)
104+
105+
state["spark_conf_name"] = name
106+
107+
return rewrite(node, src_code, declaration_replacement)
108+
109+
110+
def update_spark_context_init(
111+
tree: Tree, source_code: str, state: Dict[str, Any]
112+
):
113+
if "spark_conf_name" not in state:
114+
print("Needs the name of the variable holding the SparkConf")
115+
return tree, source_code
116+
spark_conf_name = state["spark_conf_name"]
117+
init = get_initializer_named(tree, "JavaSparkContext")
118+
if not init:
119+
return tree, source_code
120+
121+
node, name, typ = get_enclosing_variable_declaration_name_type(init)
122+
if node:
123+
return rewrite(
124+
node,
125+
source_code,
126+
f"SparkContext {name} = {spark_conf_name}.sparkContext()",
127+
)
128+
else:
129+
return rewrite(init, source_code, f"{spark_conf_name}.sparkContext()")
130+
131+
132+
def get_setter_call(variable_name: str, tree: Tree) -> Optional[Node]:
133+
for node in traverse_tree(tree):
134+
if node.type == "method_invocation":
135+
name = node.child_by_field_name("name")
136+
r = node.child_by_field_name("object")
137+
if name and r:
138+
name = name.text.decode()
139+
r = r.text.decode()
140+
if r == variable_name and name in relevant_builder_method_names_mapping.keys():
141+
return node
142+
143+
144+
def update_spark_conf_setters(
145+
tree: Tree, source_code: str, state: Dict[str, Any]
146+
):
147+
setter_call = get_setter_call(state["spark_conf_name"], tree)
148+
if setter_call:
149+
rcvr = state["spark_conf_name"]
150+
invc = setter_call.child_by_field_name("name")
151+
args = setter_call.child_by_field_name("arguments")
152+
if rcvr and invc and args:
153+
new_fn = relevant_builder_method_names_mapping[invc.text.decode()]
154+
replacement = f"{rcvr}.{new_fn}{args.text.decode()}"
155+
return rewrite(setter_call, source_code, replacement)
156+
return tree, source_code
157+
158+
state = {}
159+
no_change = False
160+
while not no_change:
161+
TREE: Tree = parse_code("java", SOURCE_CODE)
162+
original_code = SOURCE_CODE
163+
TREE, SOURCE_CODE = update_spark_conf_init(TREE, SOURCE_CODE, state)
164+
TREE, SOURCE_CODE = update_spark_context_init(TREE, SOURCE_CODE, state)
165+
no_change = SOURCE_CODE == original_code
166+
no_setter_found = False
167+
while not no_setter_found:
168+
b4_code = SOURCE_CODE
169+
TREE, SOURCE_CODE = update_spark_conf_setters(TREE, SOURCE_CODE, state)
170+
no_setter_found = SOURCE_CODE == b4_code
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
2+
from tree_sitter import Node, Tree
3+
from tree_sitter_languages import get_parser
4+
5+
6+
SOURCE_CODE = """package com.piranha;
7+
8+
import org.apache.spark.SparkConf;
9+
import org.apache.spark.api.java.JavaSparkContext;
10+
11+
public class Sample {
12+
public static void main(String[] args) {
13+
SparkConf conf = new SparkConf()
14+
.setAppName("Sample App");
15+
16+
JavaSparkContext sc = new JavaSparkContext(conf);
17+
18+
SparkConf conf1 = new SparkConf()
19+
.setSparkHome(sparkHome)
20+
.setExecutorEnv("spark.executor.extraClassPath", "test")
21+
.setAppName(appName)
22+
.setMaster(master)
23+
.set("spark.driver.allowMultipleContexts", "true");
24+
25+
sc1 = new JavaSparkContext(conf1);
26+
27+
28+
29+
var conf2 = new SparkConf();
30+
conf2.set("spark.driver.instances:", "100");
31+
conf2.setAppName(appName);
32+
conf2.setSparkHome(sparkHome);
33+
34+
sc2 = new JavaSparkContext(conf2);
35+
36+
37+
}
38+
}
39+
"""
40+
41+
42+
def parse_code(language: str, source_code: str) -> Tree:
43+
"Helper function to parse into tree sitter nodes"
44+
parser = get_parser(language)
45+
source_tree = parser.parse(bytes(source_code, "utf8"))
46+
return source_tree
47+
48+
def traverse_tree(tree: Tree):
49+
cursor = tree.walk()
50+
51+
reached_root = False
52+
while reached_root == False:
53+
yield cursor.node
54+
55+
if cursor.goto_first_child():
56+
continue
57+
58+
if cursor.goto_next_sibling():
59+
continue
60+
61+
retracing = True
62+
while retracing:
63+
if not cursor.goto_parent():
64+
retracing = False
65+
reached_root = True
66+
67+
if cursor.goto_next_sibling():
68+
retracing = False
69+
70+
71+
def rewrite(node: Node, source_code: str, replacement: str):
72+
new_source_code = (
73+
source_code[: node.start_byte]
74+
+ replacement
75+
+ source_code[node.end_byte :]
76+
)
77+
print(new_source_code)
78+
return parse_code("java", new_source_code), new_source_code

0 commit comments

Comments
 (0)