diff --git a/Sources/Auth/AuthClient.swift b/Sources/Auth/AuthClient.swift index 2c2f0105..771c55f3 100644 --- a/Sources/Auth/AuthClient.swift +++ b/Sources/Auth/AuthClient.swift @@ -595,7 +595,7 @@ public final class AuthClient: Sendable { let params = extractParams(from: url) if isPKCEFlow(url: url) { - guard let code = params.first(where: { $0.name == "code" })?.value else { + guard let code = params["code"] else { throw AuthError.pkce(.codeVerifierNotFound) } @@ -603,24 +603,22 @@ public final class AuthClient: Sendable { return session } - if let errorDescription = params.first(where: { $0.name == "error_description" })?.value { + if let errorDescription = params["error_description"] { throw AuthError.api(.init(errorDescription: errorDescription)) } guard - let accessToken = params.first(where: { $0.name == "access_token" })?.value, - let expiresIn = params.first(where: { $0.name == "expires_in" }).map(\.value) - .flatMap(TimeInterval.init), - let refreshToken = params.first(where: { $0.name == "refresh_token" })?.value, - let tokenType = params.first(where: { $0.name == "token_type" })?.value + let accessToken = params["access_token"], + let expiresIn = params["expires_in"].flatMap(TimeInterval.init), + let refreshToken = params["refresh_token"], + let tokenType = params["token_type"] else { throw URLError(.badURL) } - let expiresAt = params.first(where: { $0.name == "expires_at" }).map(\.value) - .flatMap(TimeInterval.init) - let providerToken = params.first(where: { $0.name == "provider_token" })?.value - let providerRefreshToken = params.first(where: { $0.name == "provider_refresh_token" })?.value + let expiresAt = params["expires_at"].flatMap(TimeInterval.init) + let providerToken = params["provider_token"] + let providerRefreshToken = params["provider_refresh_token"] let user = try await api.execute( .init( @@ -644,7 +642,7 @@ public final class AuthClient: Sendable { try await sessionManager.update(session) eventEmitter.emit(.signedIn, session: session) - if let type = params.first(where: { $0.name == "type" })?.value, type == "recovery" { + if let type = params["type"], type == "recovery" { eventEmitter.emit(.passwordRecovery, session: session) } @@ -1060,15 +1058,13 @@ public final class AuthClient: Sendable { private func isImplicitGrantFlow(url: URL) -> Bool { let fragments = extractParams(from: url) - return fragments.contains { - $0.name == "access_token" || $0.name == "error_description" - } + return fragments["access_token"] != nil || fragments["error_description"] != nil } private func isPKCEFlow(url: URL) -> Bool { let fragments = extractParams(from: url) let currentCodeVerifier = codeVerifierStorage.get() - return fragments.contains(where: { $0.name == "code" }) && currentCodeVerifier != nil + return fragments["code"] != nil && currentCodeVerifier != nil } private func getURLForProvider( diff --git a/Sources/Auth/Internal/Helpers.swift b/Sources/Auth/Internal/Helpers.swift index 33e0df1d..991bad8f 100644 --- a/Sources/Auth/Internal/Helpers.swift +++ b/Sources/Auth/Internal/Helpers.swift @@ -1,29 +1,30 @@ import Foundation -struct Params: Hashable { - var name: String - var value: String -} - -func extractParams(from url: URL) -> [Params] { +/// Extracts parameters encoded in the URL both in the query and fragment. +func extractParams(from url: URL) -> [String: String] { guard let components = URLComponents(url: url, resolvingAgainstBaseURL: false) else { - return [] + return [:] } + var result: [String: String] = [:] + if let fragment = components.fragment { - return extractParams(from: fragment) + let items = extractParams(from: fragment) + for item in items { + result[item.name] = item.value + } } - if let queryItems = components.queryItems { - return queryItems.map { - Params(name: $0.name, value: $0.value ?? "") + if let items = components.queryItems { + for item in items { + result[item.name] = item.value } } - return [] + return result } -func extractParams(from fragment: String) -> [Params] { +private func extractParams(from fragment: String) -> [URLQueryItem] { let components = fragment .split(separator: "&") @@ -33,7 +34,7 @@ func extractParams(from fragment: String) -> [Params] { components .compactMap { $0.count == 2 - ? Params(name: String($0[0]), value: String($0[1])) + ? URLQueryItem(name: String($0[0]), value: String($0[1])) : nil } } diff --git a/Tests/AuthTests/ExtractParamsTests.swift b/Tests/AuthTests/ExtractParamsTests.swift index 0585ddbc..4ced833e 100644 --- a/Tests/AuthTests/ExtractParamsTests.swift +++ b/Tests/AuthTests/ExtractParamsTests.swift @@ -13,13 +13,26 @@ final class ExtractParamsTests: XCTestCase { let code = UUID().uuidString let url = URL(string: "io.supabase.flutterquickstart://login-callback/?code=\(code)")! let params = extractParams(from: url) - XCTAssertEqual(params, [Params(name: "code", value: code)]) + XCTAssertEqual(params, ["code": code]) } func testExtractParamsInFragment() { let code = UUID().uuidString let url = URL(string: "io.supabase.flutterquickstart://login-callback/#code=\(code)")! let params = extractParams(from: url) - XCTAssertEqual(params, [Params(name: "code", value: code)]) + XCTAssertEqual(params, ["code": code]) + } + + func testExtractParamsInBothFragmentAndQuery() { + let code = UUID().uuidString + let url = URL(string: "io.supabase.flutterquickstart://login-callback/?code=\(code)#message=abc")! + let params = extractParams(from: url) + XCTAssertEqual(params, ["code": code, "message": "abc"]) + } + + func testExtractParamsQueryTakesPrecedence() { + let url = URL(string: "io.supabase.flutterquickstart://login-callback/?code=123#code=abc")! + let params = extractParams(from: url) + XCTAssertEqual(params, ["code": "123"]) } }