Skip to content

Commit ac6cdd4

Browse files
authored
Allow to create a shared or multiplexed connection when the backend is ssh (#7) (#8)
Summary If you have submodules and you use some kind of 2FA it can get very annoying to enter X times your 2FA or press X times hardware tokens Test plan Used it with our internal repo where we have (too ?) many submodules and I'm not asked to enter my 2FA multiple times Reviewers: tudor
1 parent 05e6e3f commit ac6cdd4

File tree

1 file changed

+129
-29
lines changed

1 file changed

+129
-29
lines changed

src/stacky/stacky.py

Lines changed: 129 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import shlex
3030
import subprocess
3131
import sys
32+
import time
3233
from argparse import ArgumentParser
3334
from typing import List, Optional
3435

@@ -38,6 +39,8 @@
3839

3940
_LOGGING_FORMAT = "%(asctime)s %(module)s %(levelname)s: %(message)s"
4041

42+
# 2 minutes ought to be enough for anybody ;-)
43+
MAX_SSH_MUX_LIFETIME = 120
4144
COLOR_STDOUT = os.isatty(1)
4245
COLOR_STDERR = os.isatty(2)
4346
IS_TERMINAL = os.isatty(1) and os.isatty(2)
@@ -61,6 +64,7 @@ class StackyConfig:
6164
skip_confirm: bool = False
6265
change_to_main: bool = False
6366
change_to_adopted: bool = False
67+
share_ssh_session: bool = False
6468

6569
def read_one_config(self, config_path: str):
6670
rawconfig = configparser.ConfigParser()
@@ -75,6 +79,9 @@ def read_one_config(self, config_path: str):
7579
self.change_to_adopted = rawconfig.get(
7680
"UI", "change_to_adopted", fallback=self.change_to_adopted
7781
)
82+
self.share_ssh_session = rawconfig.get(
83+
"UI", "share_ssh_session", fallback=self.share_ssh_session
84+
)
7885

7986

8087
def read_config() -> StackyConfig:
@@ -123,7 +130,21 @@ def __init__(self, fmt, *args, **kwargs):
123130
super().__init__(fmt.format(*args, **kwargs))
124131

125132

133+
def stop_muxed_ssh(remote: str = "origin"):
134+
if CONFIG.share_ssh_session:
135+
hostish = get_remote_type(remote)
136+
if hostish is not None:
137+
cmd = gen_ssh_mux_cmd()
138+
cmd.append("-O")
139+
cmd.append("exit")
140+
cmd.append(hostish)
141+
subprocess.Popen(cmd, stderr=subprocess.DEVNULL)
142+
143+
126144
def die(*args, **kwargs):
145+
# We are taking a wild guess at what is the remote ...
146+
# TODO (mpatou) fix the assumption about the remote
147+
stop_muxed_ssh()
127148
raise ExitException(*args, **kwargs)
128149

129150

@@ -726,6 +747,7 @@ def create_gh_pr(b, prefix):
726747

727748

728749
def do_push(forest, *, force=False, pr=False, remote_name="origin"):
750+
start_muxed_ssh(remote_name)
729751
if pr:
730752
load_pr_info_for_forest(forest)
731753
print_forest(forest)
@@ -833,9 +855,16 @@ def do_push(forest, *, force=False, pr=False, remote_name="origin"):
833855
elif pr_action == PR_CREATE:
834856
create_gh_pr(b, prefix)
835857

858+
stop_muxed_ssh(remote_name)
859+
836860

837861
def cmd_stack_push(stack, args):
838-
do_push(get_current_stack_as_forest(stack), force=args.force, pr=args.pr)
862+
do_push(
863+
get_current_stack_as_forest(stack),
864+
force=args.force,
865+
pr=args.pr,
866+
remote_name=args.remote_name,
867+
)
839868

840869

841870
def do_sync(forest):
@@ -980,7 +1009,12 @@ def cmd_upstack_info(stack, args):
9801009

9811010

9821011
def cmd_upstack_push(stack, args):
983-
do_push(get_current_upstack_as_forest(stack), force=args.force, pr=args.pr)
1012+
do_push(
1013+
get_current_upstack_as_forest(stack),
1014+
force=args.force,
1015+
pr=args.pr,
1016+
remote_name=args.remote_name,
1017+
)
9841018

9851019

9861020
def cmd_upstack_sync(stack, args):
@@ -1024,7 +1058,12 @@ def cmd_downstack_info(stack, args):
10241058

10251059

10261060
def cmd_downstack_push(stack, args):
1027-
do_push(get_current_downstack_as_forest(stack), force=args.force, pr=args.pr)
1061+
do_push(
1062+
get_current_downstack_as_forest(stack),
1063+
force=args.force,
1064+
pr=args.pr,
1065+
remote_name=args.remote_name,
1066+
)
10281067

10291068

10301069
def cmd_downstack_sync(stack, args):
@@ -1038,34 +1077,57 @@ def get_bottom_level_branches_as_forest(stack):
10381077
]
10391078

10401079

1041-
def cmd_update(stack, args):
1042-
remote = "origin"
1043-
info("Fetching from {}", remote)
1044-
run(["git", "fetch", remote])
1045-
1046-
# TODO(tudor): We should rebase instead of silently dropping
1047-
# everything you have on local master. Oh well.
1048-
global CURRENT_BRANCH
1049-
for b in stack.bottoms:
1050-
run(
1051-
[
1052-
"git",
1053-
"update-ref",
1054-
"refs/heads/{}".format(b.name),
1055-
"refs/remotes/{}/{}".format(remote, b.remote_branch),
1056-
]
1080+
def get_remote_type(remote: str = "origin") -> Optional[str]:
1081+
out = run(["git", "remote", "-v"])
1082+
for l in out.split("\n"):
1083+
match = re.match(
1084+
r"^{}\s+(?:ssh://)?([^/]*):(?!//).*\s+\(push\)$".format(remote), l
10571085
)
1058-
if b.name == CURRENT_BRANCH:
1059-
run(["git", "reset", "--hard", "HEAD"])
1086+
if match:
1087+
sshish_host = match.group(1)
1088+
return sshish_host
1089+
1090+
1091+
def gen_ssh_mux_cmd() -> List[str]:
1092+
args = [
1093+
"ssh",
1094+
"-o",
1095+
"ControlMaster=auto",
1096+
"-o",
1097+
f"ControlPersist={MAX_SSH_MUX_LIFETIME}",
1098+
"-o",
1099+
"ControlPath=~/.ssh/stacky-%C",
1100+
]
10601101

1061-
# We treat origin as the source of truth for bottom branches (master), and
1062-
# the local repo as the source of truth for everything else. So we can only
1063-
# track PR closure for branches that are direct descendants of master.
1102+
return args
10641103

1065-
info("Checking if any PRs have been merged and can be deleted")
1066-
forest = get_bottom_level_branches_as_forest(stack)
1067-
load_pr_info_for_forest(forest)
10681104

1105+
def start_muxed_ssh(remote: str = "origin"):
1106+
if not CONFIG.share_ssh_session:
1107+
return
1108+
hostish = get_remote_type(remote)
1109+
if hostish is not None:
1110+
info("Creating a muxed ssh connection")
1111+
cmd = gen_ssh_mux_cmd()
1112+
os.environ["GIT_SSH_COMMAND"] = " ".join(cmd)
1113+
cmd.append("-MNf")
1114+
cmd.append(hostish)
1115+
# We don't want to use the run() wrapper because
1116+
# we don't want to wait for the process to finish
1117+
1118+
p = subprocess.Popen(cmd, stderr=subprocess.PIPE)
1119+
# Wait a little bit for the connection to establish
1120+
# before carrying on
1121+
while p.poll() is None:
1122+
time.sleep(1)
1123+
if p.returncode != 0:
1124+
error = p.stderr.read()
1125+
die(
1126+
f"Failed to start ssh muxed connection, error was: {error.decode('utf-8').strip()}"
1127+
)
1128+
1129+
1130+
def get_branches_to_delete(forest):
10691131
deletes = []
10701132
for b in depth_first(forest):
10711133
if not b.parent or b.open_pr_info:
@@ -1087,10 +1149,11 @@ def cmd_update(stack, args):
10871149
b.parent.name,
10881150
)
10891151
break
1152+
return deletes
10901153

1091-
if deletes and not args.force:
1092-
confirm()
10931154

1155+
def delete_branches(stack, deletes):
1156+
global CURRENT_BRANCH
10941157
# Make sure we're not trying to delete the current branch
10951158
for b in deletes:
10961159
for c in b.children:
@@ -1106,6 +1169,43 @@ def cmd_update(stack, args):
11061169
run(["git", "branch", "-D", b.name])
11071170

11081171

1172+
def cmd_update(stack, args):
1173+
remote = "origin"
1174+
start_muxed_ssh(remote)
1175+
info("Fetching from {}", remote)
1176+
run(["git", "fetch", remote])
1177+
1178+
# TODO(tudor): We should rebase instead of silently dropping
1179+
# everything you have on local master. Oh well.
1180+
global CURRENT_BRANCH
1181+
for b in stack.bottoms:
1182+
run(
1183+
[
1184+
"git",
1185+
"update-ref",
1186+
"refs/heads/{}".format(b.name),
1187+
"refs/remotes/{}/{}".format(remote, b.remote_branch),
1188+
]
1189+
)
1190+
if b.name == CURRENT_BRANCH:
1191+
run(["git", "reset", "--hard", "HEAD"])
1192+
1193+
# We treat origin as the source of truth for bottom branches (master), and
1194+
# the local repo as the source of truth for everything else. So we can only
1195+
# track PR closure for branches that are direct descendants of master.
1196+
1197+
info("Checking if any PRs have been merged and can be deleted")
1198+
forest = get_bottom_level_branches_as_forest(stack)
1199+
load_pr_info_for_forest(forest)
1200+
1201+
deletes = get_branches_to_delete(forest)
1202+
if deletes and not args.force:
1203+
confirm()
1204+
1205+
delete_branches(stack, deletes)
1206+
stop_muxed_ssh(remote)
1207+
1208+
11091209
def cmd_import(stack, args):
11101210
# Importing has to happen based on PR info, rather than local branch
11111211
# relationships, as that's the only place Graphite populates.

0 commit comments

Comments
 (0)