|
5 | 5 | import re |
6 | 6 | import sys |
7 | 7 | from functools import lru_cache |
8 | | -from typing import Final, List, Match, Pattern |
| 8 | +from typing import Final, List, Match, Pattern, Tuple |
9 | 9 |
|
10 | 10 | from black._width_table import WIDTH_TABLE |
11 | 11 | from blib2to3.pytree import Leaf |
@@ -169,8 +169,7 @@ def _cached_compile(pattern: str) -> Pattern[str]: |
169 | 169 | def normalize_string_quotes(s: str) -> str: |
170 | 170 | """Prefer double quotes but only if it doesn't cause more escaping. |
171 | 171 |
|
172 | | - Adds or removes backslashes as appropriate. Doesn't parse and fix |
173 | | - strings nested in f-strings. |
| 172 | + Adds or removes backslashes as appropriate. |
174 | 173 | """ |
175 | 174 | value = s.lstrip(STRING_PREFIX_CHARS) |
176 | 175 | if value[:3] == '"""': |
@@ -211,6 +210,7 @@ def normalize_string_quotes(s: str) -> str: |
211 | 210 | s = f"{prefix}{orig_quote}{body}{orig_quote}" |
212 | 211 | new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body) |
213 | 212 | new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body) |
| 213 | + |
214 | 214 | if "f" in prefix.casefold(): |
215 | 215 | matches = re.findall( |
216 | 216 | r""" |
@@ -240,6 +240,71 @@ def normalize_string_quotes(s: str) -> str: |
240 | 240 | return f"{prefix}{new_quote}{new_body}{new_quote}" |
241 | 241 |
|
242 | 242 |
|
| 243 | +def normalize_fstring_quotes( |
| 244 | + quote: str, |
| 245 | + middles: List[Leaf], |
| 246 | + is_raw_fstring: bool, |
| 247 | +) -> Tuple[List[Leaf], str]: |
| 248 | + """Prefer double quotes but only if it doesn't cause more escaping. |
| 249 | +
|
| 250 | + Adds or removes backslashes as appropriate. |
| 251 | + """ |
| 252 | + if quote == '"""': |
| 253 | + return middles, quote |
| 254 | + |
| 255 | + elif quote == "'''": |
| 256 | + new_quote = '"""' |
| 257 | + elif quote == '"': |
| 258 | + new_quote = "'" |
| 259 | + else: |
| 260 | + new_quote = '"' |
| 261 | + |
| 262 | + unescaped_new_quote = _cached_compile(rf"(([^\\]|^)(\\\\)*){new_quote}") |
| 263 | + escaped_new_quote = _cached_compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}") |
| 264 | + escaped_orig_quote = _cached_compile(rf"([^\\]|^)\\((?:\\\\)*){quote}") |
| 265 | + if is_raw_fstring: |
| 266 | + for middle in middles: |
| 267 | + if unescaped_new_quote.search(middle.value): |
| 268 | + # There's at least one unescaped new_quote in this raw string |
| 269 | + # so converting is impossible |
| 270 | + return middles, quote |
| 271 | + |
| 272 | + # Do not introduce or remove backslashes in raw strings, just use double quote |
| 273 | + return middles, '"' |
| 274 | + |
| 275 | + new_segments = [] |
| 276 | + for middle in middles: |
| 277 | + segment = middle.value |
| 278 | + # remove unnecessary escapes |
| 279 | + new_segment = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", segment) |
| 280 | + if segment != new_segment: |
| 281 | + # Consider the string without unnecessary escapes as the original |
| 282 | + middle.value = new_segment |
| 283 | + |
| 284 | + new_segment = sub_twice(escaped_orig_quote, rf"\1\2{quote}", new_segment) |
| 285 | + new_segment = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_segment) |
| 286 | + new_segments.append(new_segment) |
| 287 | + |
| 288 | + if new_quote == '"""' and new_segments[-1].endswith('"'): |
| 289 | + # edge case: |
| 290 | + new_segments[-1] = new_segments[-1][:-1] + '\\"' |
| 291 | + |
| 292 | + for middle, new_segment in zip(middles, new_segments): |
| 293 | + orig_escape_count = middle.value.count("\\") |
| 294 | + new_escape_count = new_segment.count("\\") |
| 295 | + |
| 296 | + if new_escape_count > orig_escape_count: |
| 297 | + return middles, quote # Do not introduce more escaping |
| 298 | + |
| 299 | + if new_escape_count == orig_escape_count and quote == '"': |
| 300 | + return middles, quote # Prefer double quotes |
| 301 | + |
| 302 | + for middle, new_segment in zip(middles, new_segments): |
| 303 | + middle.value = new_segment |
| 304 | + |
| 305 | + return middles, new_quote |
| 306 | + |
| 307 | + |
243 | 308 | def normalize_unicode_escape_sequences(leaf: Leaf) -> None: |
244 | 309 | """Replace hex codes in Unicode escape sequences with lowercase representation.""" |
245 | 310 | text = leaf.value |
|
0 commit comments