Skip to content

Commit 7e33baa

Browse files
committed
Add persistence to fetched / selected models
This has been a long time coming.
1 parent f1f14f4 commit 7e33baa

File tree

7 files changed

+252
-68
lines changed

7 files changed

+252
-68
lines changed

lua/gptmodels/store.lua

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,29 @@ end
9696
---@field get_job fun(self: Store): Job | nil
9797
---@field clear_job fun(self: Store)
9898
---@field get_models fun(self: Store, provider: Provider): string[]
99-
---@field set_models fun(self: Store, provider: Provider, models: string[])
99+
---@field set_models fun(self: Store, provider: Provider, models: string[], skip_persistence?: boolean)
100100
---@field get_model fun(self: Store): { provider: string, model: string }
101101
---@field set_model fun(self: Store, provider: Provider, model: string)
102102
---@field cycle_model_forward fun(self: Store)
103103
---@field cycle_model_backward fun(self: Store)
104104
---@field llm_model_strings fun(self: Store): string[]
105105
---@field correct_potentially_missing_current_model fun(self: Store)
106+
---@field save_persisted_state fun(self: Store)
107+
---@field load_persisted_state fun(self: Store)
108+
109+
-- Get the path to the persistence file
110+
---@return string
111+
local function get_persistence_file_path()
112+
local data_path = vim.fn.stdpath("data")
113+
return data_path .. "/gptmodels/state.json"
114+
end
115+
116+
-- Get the directory for persistence files
117+
---@return string
118+
local function get_persistence_dir()
119+
local data_path = vim.fn.stdpath("data")
120+
return data_path .. "/gptmodels"
121+
end
106122

107123
---@return StrPane
108124
local function build_strpane()
@@ -140,6 +156,7 @@ local Store = {
140156
set_model = function(self, provider, model)
141157
self._llm_provider = provider
142158
self._llm_model = model
159+
self:save_persisted_state()
143160
end,
144161

145162
-- get all models for a provider
@@ -148,8 +165,11 @@ local Store = {
148165
end,
149166

150167
-- set all models for a provider, overwriting previous values
151-
set_models = function(self, provider, models)
168+
set_models = function(self, provider, models, skip_persistence)
152169
self._llm_models[provider] = models
170+
if not skip_persistence then
171+
self:save_persisted_state()
172+
end
153173
end,
154174

155175
cycle_model_forward = function(self)
@@ -240,9 +260,10 @@ local Store = {
240260
clear = function(self)
241261
self.code:clear()
242262
self.chat:clear()
243-
self:set_models("ollama", {})
244-
self:set_models("openai", {})
245-
-- TODO Need to clear default model as well?
263+
self._llm_models.ollama = {}
264+
self._llm_models.openai = {}
265+
self._llm_provider = ""
266+
self._llm_model = ""
246267
end,
247268

248269
code = {
@@ -311,6 +332,78 @@ local Store = {
311332
clear_job = function(self)
312333
self._job = nil
313334
end,
335+
336+
-- Save current state to disk
337+
save_persisted_state = function(self)
338+
local persistence_dir = get_persistence_dir()
339+
local persistence_file = get_persistence_file_path()
340+
341+
-- Create directory if it doesn't exist
342+
vim.fn.mkdir(persistence_dir, "p")
343+
344+
-- Prepare data to save
345+
local data = {
346+
current_provider = self._llm_provider,
347+
current_model = self._llm_model,
348+
models = {
349+
openai = self._llm_models.openai,
350+
ollama = self._llm_models.ollama,
351+
},
352+
}
353+
354+
-- Write to file
355+
local encoded = vim.json.encode(data)
356+
local file = io.open(persistence_file, "w")
357+
if file then
358+
file:write(encoded)
359+
file:close()
360+
end
361+
end,
362+
363+
-- Load persisted state from disk
364+
load_persisted_state = function(self)
365+
local persistence_file = get_persistence_file_path()
366+
367+
-- Check if file exists
368+
if vim.fn.filereadable(persistence_file) == 0 then
369+
return
370+
end
371+
372+
-- Read and parse file
373+
local file = io.open(persistence_file, "r")
374+
if not file then
375+
return
376+
end
377+
378+
local content = file:read("*a")
379+
file:close()
380+
381+
-- Handle empty file
382+
if not content or content == "" then
383+
return
384+
end
385+
386+
-- Parse JSON safely
387+
local ok, data = pcall(vim.json.decode, content)
388+
if not ok or not data then
389+
return
390+
end
391+
392+
-- Restore state
393+
if data.current_provider and data.current_model then
394+
self._llm_provider = data.current_provider
395+
self._llm_model = data.current_model
396+
end
397+
398+
if data.models then
399+
if data.models.openai then
400+
self._llm_models.openai = data.models.openai
401+
end
402+
if data.models.ollama then
403+
self._llm_models.ollama = data.models.ollama
404+
end
405+
end
406+
end,
314407
}
315408

316409
return Store

lua/gptmodels/windows/chat.lua

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ function M.build_and_mount(selection)
126126
Store.chat.chat.popup = chat
127127
Store.chat.input.popup = input
128128

129+
-- Load any persisted state first (model selection and cached models)
130+
Store:load_persisted_state()
131+
com.set_window_title(chat, WINDOW_TITLE_PREFIX .. com.model_display_name())
132+
129133
-- Fetch all models so user can work with what they have on their system
130134
com.trigger_models_etl(function()
131135
local has_buf_and_win = chat.bufnr and chat.winid
@@ -141,7 +145,7 @@ function M.build_and_mount(selection)
141145

142146
-- all providers, but especially openai, can have the etl finish after a window has been closed,
143147
-- if it opens then closes real fast
144-
com.set_window_title(chat, "Chat w/ " .. com.model_display_name())
148+
com.set_window_title(chat, WINDOW_TITLE_PREFIX .. com.model_display_name())
145149
end)
146150

147151
-- Input window is text with no syntax

lua/gptmodels/windows/code.lua

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,11 @@ function M.build_and_mount(selection)
171171
Store.code.left.popup = left
172172
Store.code.input.popup = input
173173

174-
-- Fetch all models so user can work with what they have on their system
174+
-- Load any persisted state first (model selection and cached models)
175+
Store:load_persisted_state()
176+
com.set_window_title(right, com.model_display_name())
177+
178+
-- Fetch all models so user can work with what they have access to
175179
com.trigger_models_etl(function()
176180
local has_buf_and_win = right.bufnr and right.winid
177181
if not has_buf_and_win then

lua/gptmodels/windows/common.lua

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -214,27 +214,29 @@ end
214214
-- Triggers the fetching / saving of available models from the ollama and openai servers
215215
---@param on_complete fun(): nil
216216
M.trigger_models_etl = function(on_complete)
217+
local completed_providers = {}
218+
local total_providers = 2
219+
217220
---@param err string | nil
218221
---@param models string[] | nil
219222
---@param provider Provider
220223
local function handle_models_fetch(err, models, provider)
221224
-- If there's an error fetching, assume we have no models
222-
-- TODO We still need to inform the user somehow that their ollama models
223-
-- fetching didn't work. Just not if we earlier detected a missing ollama
224-
-- executable. Store.detected_missing_ollama_exe = true?
225-
-- BUT DO WE? Maybe the models not appearing is sufficient feedback!
226-
-- I think passing the err back is a good idea, because that can include
227-
-- provider information
228225
if err or not models or #models == 0 then
229-
Store:set_models(provider, {})
230-
Store:correct_potentially_missing_current_model()
231-
return on_complete()
226+
Store:set_models(provider, {}, true) -- skip persistence during fetching
227+
else
228+
Store:set_models(provider, models, true) -- skip persistence during fetching
232229
end
233230

234-
Store:set_models(provider, models)
235-
-- TODO Test that this gets called
236-
Store:correct_potentially_missing_current_model()
237-
on_complete()
231+
-- Mark this provider as completed
232+
completed_providers[provider] = true
233+
234+
-- Save state after ALL providers have completed
235+
if vim.tbl_count(completed_providers) == total_providers then
236+
-- Save the final state with all fetched models
237+
Store:save_persisted_state()
238+
on_complete()
239+
end
238240
end
239241

240242
-- Fetch models from ollama server

tests/gptmodels/persistence_spec.lua

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
---@diagnostic disable: undefined-global
2+
3+
local assert = require("luassert")
4+
local stub = require('luassert.stub')
5+
local Store = require("gptmodels.store")
6+
7+
describe("Store | persistence", function()
8+
local mock_data_path = "/tmp/nvim-test-data"
9+
local vim_fn_stdpath_stub
10+
11+
before_each(function()
12+
Store:clear()
13+
-- Mock vim.fn.stdpath to return our test data path
14+
vim_fn_stdpath_stub = stub(vim.fn, 'stdpath')
15+
vim_fn_stdpath_stub.returns(mock_data_path)
16+
end)
17+
18+
after_each(function()
19+
vim_fn_stdpath_stub:revert()
20+
-- Clean up test files
21+
vim.fn.delete(mock_data_path .. "/gptmodels", "rf")
22+
end)
23+
24+
describe("load_persisted_state", function()
25+
it("loads previously saved model selection", function()
26+
-- Setup: Save a model selection
27+
Store:set_model("openai", "gpt-4o")
28+
Store:save_persisted_state()
29+
30+
-- Clear the store and load persisted state
31+
Store:clear()
32+
Store:load_persisted_state()
33+
34+
local model_info = Store:get_model()
35+
assert.equal("openai", model_info.provider)
36+
assert.equal("gpt-4o", model_info.model)
37+
end)
38+
39+
it("loads previously saved fetched models", function()
40+
-- Setup: Save fetched models
41+
Store:set_models("openai", {"gpt-4o", "gpt-4o-mini"})
42+
Store:set_models("ollama", {"llama3.1:latest", "deepseek-v2:latest"})
43+
Store:save_persisted_state()
44+
45+
-- Clear the store and load persisted state
46+
Store:clear()
47+
Store:load_persisted_state()
48+
49+
assert.same({"gpt-4o", "gpt-4o-mini"}, Store:get_models("openai"))
50+
assert.same({"llama3.1:latest", "deepseek-v2:latest"}, Store:get_models("ollama"))
51+
end)
52+
53+
it("handles missing persistence file gracefully", function()
54+
-- Ensure no persistence file exists
55+
vim.fn.delete(mock_data_path .. "/gptmodels", "rf")
56+
57+
Store:load_persisted_state()
58+
59+
-- Should not crash and should have empty state
60+
local model_info = Store:get_model()
61+
assert.equal("", model_info.provider)
62+
assert.equal("", model_info.model)
63+
assert.same({}, Store:get_models("openai"))
64+
assert.same({}, Store:get_models("ollama"))
65+
end)
66+
67+
it("handles corrupted persistence file gracefully", function()
68+
-- Create corrupted file
69+
vim.fn.mkdir(mock_data_path .. "/gptmodels", "p")
70+
local file = io.open(mock_data_path .. "/gptmodels/state.json", "w")
71+
file:write("invalid json {")
72+
file:close()
73+
74+
Store:load_persisted_state()
75+
76+
-- Should not crash and should have empty state
77+
local model_info = Store:get_model()
78+
assert.equal("", model_info.provider)
79+
assert.equal("", model_info.model)
80+
end)
81+
end)
82+
83+
describe("save_persisted_state", function()
84+
it("creates data directory if it doesn't exist", function()
85+
-- Ensure directory doesn't exist
86+
vim.fn.delete(mock_data_path, "rf")
87+
88+
Store:set_model("openai", "gpt-4o")
89+
Store:save_persisted_state()
90+
91+
-- Verify directory and file were created
92+
assert.equal(1, vim.fn.isdirectory(mock_data_path .. "/gptmodels"))
93+
assert.equal(1, vim.fn.filereadable(mock_data_path .. "/gptmodels/state.json"))
94+
end)
95+
end)
96+
97+
describe("auto persistence", function()
98+
it("automatically saves state when model is changed", function()
99+
Store:set_model("openai", "gpt-4o")
100+
101+
-- Should automatically save
102+
local file = io.open(mock_data_path .. "/gptmodels/state.json", "r")
103+
assert.is_not_nil(file)
104+
local content = file:read("*a")
105+
file:close()
106+
107+
local data = vim.json.decode(content)
108+
assert.equal("openai", data.current_provider)
109+
assert.equal("gpt-4o", data.current_model)
110+
end)
111+
112+
it("automatically saves state when models are updated", function()
113+
Store:set_models("ollama", {"llama3.1:latest", "deepseek-v2:latest"})
114+
115+
-- Should automatically save
116+
local file = io.open(mock_data_path .. "/gptmodels/state.json", "r")
117+
assert.is_not_nil(file)
118+
local content = file:read("*a")
119+
file:close()
120+
121+
local data = vim.json.decode(content)
122+
assert.same({"llama3.1:latest", "deepseek-v2:latest"}, data.models.ollama)
123+
end)
124+
end)
125+
end)

tests/gptmodels/spec_helpers.lua

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ M.hook_reset_state = function()
4343
-- stubbing cmd.exec prevents the llm call from happening
4444
stub(cmd, "exec")
4545

46+
-- Mock persistence functions to prevent loading real saved data during tests
47+
stub(Store, "load_persisted_state")
48+
stub(Store, "save_persisted_state")
49+
4650
Store:clear()
4751
snapshot = assert:snapshot()
4852
end)

0 commit comments

Comments
 (0)