Skip to content

Commit e043a6c

Browse files
committed
✅ Initial tests
Unit style tests with Swift Testing
1 parent be1268e commit e043a6c

14 files changed

+379
-166
lines changed

DraftPatch/DraftPatchApp.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ struct DraftPatchApp: App {
3535
}
3636

3737
let ctx = ModelContext(self.modelContainer)
38-
_viewModel = StateObject(wrappedValue: DraftPatchViewModel(context: ctx))
38+
let repository = SwiftDataChatThreadRepository(context: ctx)
39+
_viewModel = StateObject(wrappedValue: DraftPatchViewModel(repository: repository))
3940

4041
// Request accessibility permissions for drafting
4142
DraftingService.shared.checkAccessibilityPermission()

DraftPatch/DraftPatchViewModel.swift

Lines changed: 67 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ import SwiftUI
1010

1111
@MainActor
1212
class DraftPatchViewModel: ObservableObject {
13-
private var context: ModelContext
13+
private var repository: ChatThreadRepository
14+
private var llmManager: LLMManager
1415

1516
@Published var chatThreads: [ChatThread] = []
1617
@Published var selectedThread: ChatThread? {
@@ -38,8 +39,9 @@ class DraftPatchViewModel: ObservableObject {
3839
@Published var settings: Settings? = nil
3940
@Published var errorMessage: String? = nil
4041

41-
init(context: ModelContext) {
42-
self.context = context
42+
init(repository: ChatThreadRepository, llmManager: LLMManager = .shared) {
43+
self.repository = repository
44+
self.llmManager = llmManager
4345

4446
loadSettings()
4547
loadThreads()
@@ -49,6 +51,46 @@ class DraftPatchViewModel: ObservableObject {
4951
}
5052
}
5153

54+
func loadThreads() {
55+
do {
56+
chatThreads = try repository.fetchThreads()
57+
selectedThread = chatThreads.first
58+
} catch {
59+
print("Error loading threads: \(error)")
60+
chatThreads = []
61+
selectedThread = nil
62+
}
63+
}
64+
65+
func loadSettings() {
66+
do {
67+
settings = try repository.fetchSettings()
68+
69+
if settings != nil, let ollamaEndpontURL = settings?.ollamaConfig?.endpointURL {
70+
OllamaService.shared.endpointURL = ollamaEndpontURL
71+
}
72+
} catch {
73+
print("Error loading settings: \(error)")
74+
}
75+
}
76+
77+
func deleteThread(_ thread: ChatThread) {
78+
do {
79+
try repository.deleteThread(thread)
80+
try repository.save()
81+
82+
if let index = chatThreads.firstIndex(where: { $0.id == thread.id }) {
83+
chatThreads.remove(at: index)
84+
}
85+
86+
if selectedThread == thread {
87+
selectedThread = chatThreads.first
88+
}
89+
} catch {
90+
print("Error deleting thread: \(error)")
91+
}
92+
}
93+
5294
func toggleDrafting() {
5395
isDraftingEnabled.toggle()
5496

@@ -122,35 +164,6 @@ class DraftPatchViewModel: ObservableObject {
122164
}
123165
}
124166

125-
private func loadThreads() {
126-
let descriptor = FetchDescriptor<ChatThread>(
127-
sortBy: [SortDescriptor(\.updatedAt, order: .reverse)]
128-
)
129-
130-
do {
131-
chatThreads = try context.fetch(descriptor)
132-
selectedThread = chatThreads.first
133-
} catch {
134-
print("Error loading threads: \(error)")
135-
chatThreads = []
136-
selectedThread = nil
137-
}
138-
}
139-
140-
private func loadSettings() {
141-
let descriptor = FetchDescriptor<Settings>()
142-
143-
do {
144-
settings = try context.fetch(descriptor).first
145-
146-
if settings != nil, let ollamaEndpontURL = settings?.ollamaConfig?.endpointURL {
147-
OllamaService.shared.endpointURL = ollamaEndpontURL
148-
}
149-
} catch {
150-
print("Error loading settings: \(error)")
151-
}
152-
}
153-
154167
func toggleDraftWithLastApp() {
155168
if let lastAppDraftedWith = settings?.lastAppDraftedWith {
156169
if isDraftingEnabled {
@@ -217,10 +230,11 @@ class DraftPatchViewModel: ObservableObject {
217230
return
218231
}
219232

233+
// If we're working with a draft thread, persist it
220234
if let draftThread, draftThread == thread {
221-
context.insert(thread)
222235
do {
223-
try context.save()
236+
try repository.insertThread(thread)
237+
try repository.save()
224238
chatThreads.insert(thread, at: 0)
225239
} catch {
226240
print("Error saving new thread: \(error)")
@@ -238,7 +252,11 @@ class DraftPatchViewModel: ObservableObject {
238252

239253
if let tokenStream = getTokenStream(for: thread, with: messagesPayload) {
240254
thinking = true
241-
saveContext()
255+
do {
256+
try repository.save()
257+
} catch {
258+
print("Error saving context: \(error)")
259+
}
242260

243261
let assistantMsg = ChatMessage(text: "", role: .assistant, streaming: true)
244262
thread.messages.append(assistantMsg)
@@ -257,13 +275,21 @@ class DraftPatchViewModel: ObservableObject {
257275
}
258276

259277
assistantMsg.streaming = false
260-
saveContext()
278+
do {
279+
try repository.save()
280+
} catch {
281+
print("Error saving context: \(error)")
282+
}
261283

262284
if thread.title == "New Conversation" {
263285
do {
264286
let title = try await generateTitle(for: messageText, using: thread.model)
265287
thread.title = title
266-
saveContext()
288+
do {
289+
try repository.save()
290+
} catch {
291+
print("Error saving context: \(error)")
292+
}
267293
} catch {
268294
print("Error generating thread title: \(error)")
269295
}
@@ -277,72 +303,17 @@ class DraftPatchViewModel: ObservableObject {
277303
}
278304
}
279305

280-
private func saveContext() {
281-
do {
282-
try context.save()
283-
objectWillChange.send()
284-
} catch {
285-
print("Error saving context: \(error)")
286-
}
287-
}
288-
289-
func deleteThread(_ thread: ChatThread) {
290-
context.delete(thread)
291-
292-
do {
293-
try context.save()
294-
295-
if let index = chatThreads.firstIndex(where: { $0.id == thread.id }) {
296-
chatThreads.remove(at: index)
297-
}
298-
299-
if selectedThread == thread {
300-
selectedThread = chatThreads.first
301-
}
302-
} catch {
303-
print("Error deleting thread: \(error)")
304-
}
305-
}
306-
307306
/// Determines the correct service and returns a token stream.
308307
private func getTokenStream(for thread: ChatThread, with messages: [ChatMessagePayload])
309308
-> AsyncThrowingStream<String, Error>?
310309
{
311-
switch thread.model.provider {
312-
case .ollama:
313-
return OllamaService.shared.streamChat(
314-
messages: messages,
315-
modelName: thread.model.name
316-
)
317-
case .openai:
318-
return OpenAIService.shared.streamChat(
319-
messages: messages,
320-
modelName: thread.model.name
321-
)
322-
case .gemini:
323-
return GeminiService.shared.streamChat(
324-
messages: messages,
325-
modelName: thread.model.name
326-
)
327-
case .anthropic:
328-
return ClaudeService.shared.streamChat(
329-
messages: messages,
330-
modelName: thread.model.name
331-
)
332-
}
310+
return llmManager.getService(for: thread.model.provider)
311+
.streamChat(messages: messages, modelName: thread.model.name)
333312
}
334313

335314
/// Calls the appropriate service to generate a title based on the provider.
336315
private func generateTitle(for text: String, using model: ChatModel) async throws -> String {
337-
switch model.provider {
338-
case .ollama:
339-
return try await OllamaService.shared.generateTitle(for: text, modelName: model.name)
340-
case .openai:
341-
return try await OpenAIService.shared.generateTitle(for: text, modelName: model.name)
342-
case .gemini:
343-
return try await GeminiService.shared.generateTitle(for: text, modelName: model.name)
344-
case .anthropic:
345-
return try await ClaudeService.shared.generateTitle(for: text, modelName: model.name)
346-
}
316+
return try await llmManager.getService(for: model.provider)
317+
.generateTitle(for: text, modelName: model.name)
347318
}
348319
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//
2+
// ChatThreadRepository.swift
3+
// DraftPatch
4+
//
5+
// Created by Robert DeLuca on 3/11/25.
6+
//
7+
8+
protocol ChatThreadRepository {
9+
func fetchThreads() throws -> [ChatThread]
10+
func fetchSettings() throws -> Settings?
11+
func insertThread(_ thread: ChatThread) throws
12+
func save() throws
13+
func deleteThread(_ thread: ChatThread) throws
14+
}

DraftPatch/Services/ClaudeService.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import Foundation
99

1010
struct ClaudeService: LLMService {
11-
@MainActor static let shared = ClaudeService()
11+
static let shared = ClaudeService()
1212

1313
let endpointURL = URL(string: "https://api.anthropic.com/v1")!
1414
let apiKey: String? = KeychainHelper.shared.load(for: "anthropic_api_key")

DraftPatch/Services/GeminiService.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import Foundation
99

1010
struct GeminiService: LLMService {
11-
@MainActor static let shared = GeminiService()
11+
static let shared = GeminiService()
1212

1313
let endpointURL = URL(string: "https://generativelanguage.googleapis.com/v1beta/models")!
1414
let apiKey: String? = KeychainHelper.shared.load(for: "gemini_api_key") ?? ""

DraftPatch/Services/LLMManager.swift

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//
2+
// LLMManager.swift
3+
// DraftPatch
4+
//
5+
// Created by Robert DeLuca on 3/11/25.
6+
//
7+
8+
class LLMManager {
9+
static let shared = LLMManager()
10+
11+
func getService(for provider: ChatModel.LLMProvider) -> LLMService {
12+
switch provider {
13+
case .ollama:
14+
return OllamaService.shared
15+
case .openai:
16+
return OpenAIService.shared
17+
case .gemini:
18+
return GeminiService.shared
19+
case .anthropic:
20+
return ClaudeService.shared
21+
}
22+
}
23+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
//
2+
// MockLLMService.swift
3+
// DraftPatch
4+
//
5+
// Created by Robert DeLuca on 3/11/25.
6+
//
7+
8+
import Foundation
9+
10+
class MockLLMService: LLMService {
11+
static let shared = MockLLMService()
12+
13+
var endpointURL: URL
14+
var apiKey: String?
15+
16+
init(endpointURL: URL = URL(string: "http://example.com")!, apiKey: String? = nil) {
17+
self.endpointURL = endpointURL
18+
self.apiKey = apiKey
19+
}
20+
21+
func fetchAvailableModels() async throws -> [String] {
22+
return ["MockModel1", "MockModel2", "MockModel3"]
23+
}
24+
25+
func streamChat(
26+
messages: [ChatMessagePayload],
27+
modelName: String
28+
) -> AsyncThrowingStream<String, Error> {
29+
AsyncThrowingStream { continuation in
30+
let tokens = ["Hello ", "world!"]
31+
for token in tokens {
32+
continuation.yield(token)
33+
}
34+
35+
continuation.finish()
36+
}
37+
}
38+
39+
func singleChatCompletion(
40+
message: String,
41+
modelName: String,
42+
systemPrompt: String? = nil
43+
) async throws -> String {
44+
return "Mock single completion for message: \(message)"
45+
}
46+
47+
func generateTitle(
48+
for message: String,
49+
modelName: String
50+
) async throws -> String {
51+
// Return a mock title.
52+
return "Mock Title"
53+
}
54+
}

DraftPatch/Services/OllamaService.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import Foundation
99

1010
final class OllamaService: LLMService {
11-
@MainActor static let shared = OllamaService()
11+
static let shared = OllamaService()
1212

1313
var endpointURL = URL(string: "http://localhost:11434")!
1414
var apiKey: String? = nil

DraftPatch/Services/OpenAIService.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import Foundation
99

1010
struct OpenAIService: LLMService {
11-
@MainActor static let shared = OpenAIService()
11+
static let shared = OpenAIService()
1212

1313
let endpointURL: URL = URL(string: "https://api.openai.com/v1")!
1414
let apiKey: String? = KeychainHelper.shared.load(for: "openai_api_key") ?? ""

0 commit comments

Comments
 (0)