Skip to content

Commit 36aea5a

Browse files
committed
Fix streaming pull.
1 parent 34c683b commit 36aea5a

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

bake/async/ollama.rb

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,7 @@ def models
1717

1818
# Pulls the specified models from the Ollama API. If no models are specified, but there is a default model, it will pull that one.
1919
# @parameter models [Array(String)] The names of the models to pull.
20-
def pull(models)
21-
if models.empty?
22-
models = [Async::Ollama::Client.default_model]
23-
end
24-
20+
def pull(models: [Async::Ollama::MODEL])
2521
Async::Ollama::Client.open do |client|
2622
models.each do |model|
2723
client.pull(model) do |response|

lib/async/ollama/generate.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
module Async
1010
module Ollama
1111
# Represents a generated response from the Ollama API.
12-
class Generate < Async::REST::Representation[Wrapper]
12+
class Generate < Async::REST::Representation[GenerateWrapper]
1313
# @returns [String | nil] The generated response, or nil if not present.
1414
def response
1515
self.value[:response]

lib/async/ollama/wrapper.rb

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,24 @@ def parser_for(response)
139139
content_type = response.headers["content-type"]
140140
media_type = content_type.split(";").first
141141

142+
case media_type
143+
when APPLICATION_JSON
144+
return Async::REST::Wrapper::JSON::Parser
145+
when APPLICATION_JSON_STREAM
146+
return StreamingParser
147+
end
148+
end
149+
end
150+
151+
# Wraps generate-specific HTTP responses for the Ollama API, selecting the appropriate parser.
152+
class GenerateWrapper < Wrapper
153+
# Selects the appropriate parser for the generate HTTP response.
154+
# @parameter response [Protocol::HTTP::Response] The HTTP response object.
155+
# @returns [Class] The parser class to use.
156+
def parser_for(response)
157+
content_type = response.headers["content-type"]
158+
media_type = content_type.split(";").first
159+
142160
case media_type
143161
when APPLICATION_JSON
144162
return Async::REST::Wrapper::JSON::Parser

0 commit comments

Comments
 (0)