29
29
import shlex
30
30
import subprocess
31
31
import sys
32
+ import time
32
33
from argparse import ArgumentParser
33
34
from typing import List , Optional
34
35
38
39
39
40
_LOGGING_FORMAT = "%(asctime)s %(module)s %(levelname)s: %(message)s"
40
41
42
+ # 2 minutes ought to be enough for anybody ;-)
43
+ MAX_SSH_MUX_LIFETIME = 120
41
44
COLOR_STDOUT = os .isatty (1 )
42
45
COLOR_STDERR = os .isatty (2 )
43
46
IS_TERMINAL = os .isatty (1 ) and os .isatty (2 )
@@ -61,6 +64,7 @@ class StackyConfig:
61
64
skip_confirm : bool = False
62
65
change_to_main : bool = False
63
66
change_to_adopted : bool = False
67
+ share_ssh_session : bool = False
64
68
65
69
def read_one_config (self , config_path : str ):
66
70
rawconfig = configparser .ConfigParser ()
@@ -75,6 +79,9 @@ def read_one_config(self, config_path: str):
75
79
self .change_to_adopted = rawconfig .get (
76
80
"UI" , "change_to_adopted" , fallback = self .change_to_adopted
77
81
)
82
+ self .share_ssh_session = rawconfig .get (
83
+ "UI" , "share_ssh_session" , fallback = self .share_ssh_session
84
+ )
78
85
79
86
80
87
def read_config () -> StackyConfig :
@@ -123,7 +130,21 @@ def __init__(self, fmt, *args, **kwargs):
123
130
super ().__init__ (fmt .format (* args , ** kwargs ))
124
131
125
132
133
+ def stop_muxed_ssh (remote : str = "origin" ):
134
+ if CONFIG .share_ssh_session :
135
+ hostish = get_remote_type (remote )
136
+ if hostish is not None :
137
+ cmd = gen_ssh_mux_cmd ()
138
+ cmd .append ("-O" )
139
+ cmd .append ("exit" )
140
+ cmd .append (hostish )
141
+ subprocess .Popen (cmd , stderr = subprocess .DEVNULL )
142
+
143
+
126
144
def die (* args , ** kwargs ):
145
+ # We are taking a wild guess at what is the remote ...
146
+ # TODO (mpatou) fix the assumption about the remote
147
+ stop_muxed_ssh ()
127
148
raise ExitException (* args , ** kwargs )
128
149
129
150
@@ -726,6 +747,7 @@ def create_gh_pr(b, prefix):
726
747
727
748
728
749
def do_push (forest , * , force = False , pr = False , remote_name = "origin" ):
750
+ start_muxed_ssh (remote_name )
729
751
if pr :
730
752
load_pr_info_for_forest (forest )
731
753
print_forest (forest )
@@ -833,9 +855,16 @@ def do_push(forest, *, force=False, pr=False, remote_name="origin"):
833
855
elif pr_action == PR_CREATE :
834
856
create_gh_pr (b , prefix )
835
857
858
+ stop_muxed_ssh (remote_name )
859
+
836
860
837
861
def cmd_stack_push (stack , args ):
838
- do_push (get_current_stack_as_forest (stack ), force = args .force , pr = args .pr )
862
+ do_push (
863
+ get_current_stack_as_forest (stack ),
864
+ force = args .force ,
865
+ pr = args .pr ,
866
+ remote_name = args .remote_name ,
867
+ )
839
868
840
869
841
870
def do_sync (forest ):
@@ -980,7 +1009,12 @@ def cmd_upstack_info(stack, args):
980
1009
981
1010
982
1011
def cmd_upstack_push (stack , args ):
983
- do_push (get_current_upstack_as_forest (stack ), force = args .force , pr = args .pr )
1012
+ do_push (
1013
+ get_current_upstack_as_forest (stack ),
1014
+ force = args .force ,
1015
+ pr = args .pr ,
1016
+ remote_name = args .remote_name ,
1017
+ )
984
1018
985
1019
986
1020
def cmd_upstack_sync (stack , args ):
@@ -1024,7 +1058,12 @@ def cmd_downstack_info(stack, args):
1024
1058
1025
1059
1026
1060
def cmd_downstack_push (stack , args ):
1027
- do_push (get_current_downstack_as_forest (stack ), force = args .force , pr = args .pr )
1061
+ do_push (
1062
+ get_current_downstack_as_forest (stack ),
1063
+ force = args .force ,
1064
+ pr = args .pr ,
1065
+ remote_name = args .remote_name ,
1066
+ )
1028
1067
1029
1068
1030
1069
def cmd_downstack_sync (stack , args ):
@@ -1038,34 +1077,57 @@ def get_bottom_level_branches_as_forest(stack):
1038
1077
]
1039
1078
1040
1079
1041
- def cmd_update (stack , args ):
1042
- remote = "origin"
1043
- info ("Fetching from {}" , remote )
1044
- run (["git" , "fetch" , remote ])
1045
-
1046
- # TODO(tudor): We should rebase instead of silently dropping
1047
- # everything you have on local master. Oh well.
1048
- global CURRENT_BRANCH
1049
- for b in stack .bottoms :
1050
- run (
1051
- [
1052
- "git" ,
1053
- "update-ref" ,
1054
- "refs/heads/{}" .format (b .name ),
1055
- "refs/remotes/{}/{}" .format (remote , b .remote_branch ),
1056
- ]
1080
+ def get_remote_type (remote : str = "origin" ) -> Optional [str ]:
1081
+ out = run (["git" , "remote" , "-v" ])
1082
+ for l in out .split ("\n " ):
1083
+ match = re .match (
1084
+ r"^{}\s+(?:ssh://)?([^/]*):(?!//).*\s+\(push\)$" .format (remote ), l
1057
1085
)
1058
- if b .name == CURRENT_BRANCH :
1059
- run (["git" , "reset" , "--hard" , "HEAD" ])
1086
+ if match :
1087
+ sshish_host = match .group (1 )
1088
+ return sshish_host
1089
+
1090
+
1091
+ def gen_ssh_mux_cmd () -> List [str ]:
1092
+ args = [
1093
+ "ssh" ,
1094
+ "-o" ,
1095
+ "ControlMaster=auto" ,
1096
+ "-o" ,
1097
+ f"ControlPersist={ MAX_SSH_MUX_LIFETIME } " ,
1098
+ "-o" ,
1099
+ "ControlPath=~/.ssh/stacky-%C" ,
1100
+ ]
1060
1101
1061
- # We treat origin as the source of truth for bottom branches (master), and
1062
- # the local repo as the source of truth for everything else. So we can only
1063
- # track PR closure for branches that are direct descendants of master.
1102
+ return args
1064
1103
1065
- info ("Checking if any PRs have been merged and can be deleted" )
1066
- forest = get_bottom_level_branches_as_forest (stack )
1067
- load_pr_info_for_forest (forest )
1068
1104
1105
+ def start_muxed_ssh (remote : str = "origin" ):
1106
+ if not CONFIG .share_ssh_session :
1107
+ return
1108
+ hostish = get_remote_type (remote )
1109
+ if hostish is not None :
1110
+ info ("Creating a muxed ssh connection" )
1111
+ cmd = gen_ssh_mux_cmd ()
1112
+ os .environ ["GIT_SSH_COMMAND" ] = " " .join (cmd )
1113
+ cmd .append ("-MNf" )
1114
+ cmd .append (hostish )
1115
+ # We don't want to use the run() wrapper because
1116
+ # we don't want to wait for the process to finish
1117
+
1118
+ p = subprocess .Popen (cmd , stderr = subprocess .PIPE )
1119
+ # Wait a little bit for the connection to establish
1120
+ # before carrying on
1121
+ while p .poll () is None :
1122
+ time .sleep (1 )
1123
+ if p .returncode != 0 :
1124
+ error = p .stderr .read ()
1125
+ die (
1126
+ f"Failed to start ssh muxed connection, error was: { error .decode ('utf-8' ).strip ()} "
1127
+ )
1128
+
1129
+
1130
+ def get_branches_to_delete (forest ):
1069
1131
deletes = []
1070
1132
for b in depth_first (forest ):
1071
1133
if not b .parent or b .open_pr_info :
@@ -1087,10 +1149,11 @@ def cmd_update(stack, args):
1087
1149
b .parent .name ,
1088
1150
)
1089
1151
break
1152
+ return deletes
1090
1153
1091
- if deletes and not args .force :
1092
- confirm ()
1093
1154
1155
+ def delete_branches (stack , deletes ):
1156
+ global CURRENT_BRANCH
1094
1157
# Make sure we're not trying to delete the current branch
1095
1158
for b in deletes :
1096
1159
for c in b .children :
@@ -1106,6 +1169,43 @@ def cmd_update(stack, args):
1106
1169
run (["git" , "branch" , "-D" , b .name ])
1107
1170
1108
1171
1172
+ def cmd_update (stack , args ):
1173
+ remote = "origin"
1174
+ start_muxed_ssh (remote )
1175
+ info ("Fetching from {}" , remote )
1176
+ run (["git" , "fetch" , remote ])
1177
+
1178
+ # TODO(tudor): We should rebase instead of silently dropping
1179
+ # everything you have on local master. Oh well.
1180
+ global CURRENT_BRANCH
1181
+ for b in stack .bottoms :
1182
+ run (
1183
+ [
1184
+ "git" ,
1185
+ "update-ref" ,
1186
+ "refs/heads/{}" .format (b .name ),
1187
+ "refs/remotes/{}/{}" .format (remote , b .remote_branch ),
1188
+ ]
1189
+ )
1190
+ if b .name == CURRENT_BRANCH :
1191
+ run (["git" , "reset" , "--hard" , "HEAD" ])
1192
+
1193
+ # We treat origin as the source of truth for bottom branches (master), and
1194
+ # the local repo as the source of truth for everything else. So we can only
1195
+ # track PR closure for branches that are direct descendants of master.
1196
+
1197
+ info ("Checking if any PRs have been merged and can be deleted" )
1198
+ forest = get_bottom_level_branches_as_forest (stack )
1199
+ load_pr_info_for_forest (forest )
1200
+
1201
+ deletes = get_branches_to_delete (forest )
1202
+ if deletes and not args .force :
1203
+ confirm ()
1204
+
1205
+ delete_branches (stack , deletes )
1206
+ stop_muxed_ssh (remote )
1207
+
1208
+
1109
1209
def cmd_import (stack , args ):
1110
1210
# Importing has to happen based on PR info, rather than local branch
1111
1211
# relationships, as that's the only place Graphite populates.
0 commit comments