From 0090e472a387c8c8c139a53ee76634aa756769cf Mon Sep 17 00:00:00 2001 From: samyarpotlapalli Date: Thu, 22 May 2025 09:58:55 -0400 Subject: [PATCH 1/8] model versioning files --- src/sasctl/_services/model_repository.py | 134 ++++++++++- tests/unit/test_model_repository.py | 293 +++++++++++++++++++++++ 2 files changed, 421 insertions(+), 6 deletions(-) diff --git a/src/sasctl/_services/model_repository.py b/src/sasctl/_services/model_repository.py index dfbbb95d..36bbb26f 100644 --- a/src/sasctl/_services/model_repository.py +++ b/src/sasctl/_services/model_repository.py @@ -8,8 +8,10 @@ import datetime from warnings import warn +import requests +from requests.exceptions import HTTPError -from ..core import HTTPError, current_session, delete, get, sasctl_command +from ..core import current_session, delete, get, sasctl_command from .service import Service FUNCTIONS = { @@ -612,14 +614,134 @@ def list_model_versions(cls, model): Returns ------- - list + RestObj """ - model = cls.get_model(model) - if cls.get_model_link(model, "modelVersions") is None: - raise ValueError("Unable to retrieve versions for model '%s'" % model) - return cls.request_link(model, "modelVersions") + link = cls.get_model_link(model, "modelHistory") + if link is None: + raise ValueError( + "Cannot find link for version history for model '%s'" % model + ) + + modelHistory = cls.request_link( + link, + "modelHistory", + headers={"Accept": "application/vnd.sas.models.model.version"}, + ) + if modelHistory is None: + return {} + + return modelHistory + + @classmethod + def get_model_version(cls, model, version_id): + + model_history = cls.list_model_versions(model) + model_history_items = model_history.get("items") + + for i, item in enumerate(model_history_items): + if item.get("id") == version_id: + return cls.request_link( + item, + "self", + headers={"Accept": "application/vnd.sas.models.model.version"}, + ) + + raise ValueError("The version id specified could not be found.") + + @classmethod + def get_model_with_versions(cls, model): + if cls.is_uuid(model): + model_id = model + elif isinstance(model, dict) and "id" in model: + model_id = model["id"] + else: + model = cls.get_model(model) + if not model: + raise HTTPError( + "This model may not exist in a project or the model may not exist at all." + ) + model_id = model["id"] + + versions_uri = f"/modelRepository/models/{model_id}/versions" + version_history = cls.get( + versions_uri, + headers={"Accept": "application/vnd.sas.models.model.version"}, + ) + if version_history is None: + return {} + return version_history + + @classmethod + def get_model_or_version(cls, model, version_id): + + if cls.is_uuid(model): + model_id = model + elif isinstance(model, dict) and "id" in model: + model_id = model["id"] + else: + model = cls.get_model(model) + if not model: + raise HTTPError( + "This model may not exist in a project or the model may not exist at all." + ) + model_id = model["id"] + + if model_id == version_id: + return cls.get_model(model) + + version_history = cls.get_model_with_versions(model) + model_versions = version_history.get("modelVersions") + for i, item in enumerate(model_versions): + if item.get("id") == version_id: + return cls.request_link( + item, + "self", + headers={"Accept": "application/vnd.sas.models.model.version"}, + ) + + raise ValueError("The version id specified could not be found.") + + @classmethod + def get_model_version_contents(cls, model, version_id): + model_version = cls.get_model_version(model, version_id) + version_contents = cls.request_link( + model_version, + "contents", + headers={"Accept": "application/vnd.sas.models.model.content"}, + ) + + if version_contents is None: + return {} + return version_contents + + @classmethod + def get_model_version_content_metadata(cls, model, version_id, content_id): + model_version_contents = cls.get_model_version_contents(model, version_id) + + model_version_contents_items = model_version_contents.get("items") + for i, item in enumerate(model_version_contents_items): + if item.get("id") == content_id: + return cls.request_link( + item, + "self", + headers={"Accept": "application/vnd.sas.models.model.content"}, + ) + + raise ValueError("The content id specified could not be found.") + + @classmethod + def get_model_version_content(cls, model, version_id, content_id): + + metadata = cls.get_model_version_content_metadata(model, version_id, content_id) + version_content_file = cls.request_link( + metadata, "content", headers={"Accept": "text/plain"} + ) + + if version_content_file is None: + raise HTTPError("Something went wrong while accessing the metadata file.") + return version_content_file @classmethod def copy_analytic_store(cls, model): diff --git a/tests/unit/test_model_repository.py b/tests/unit/test_model_repository.py index 9232896b..b3a2aa3e 100644 --- a/tests/unit/test_model_repository.py +++ b/tests/unit/test_model_repository.py @@ -13,6 +13,9 @@ from sasctl import current_session from sasctl.services import model_repository as mr +from sasctl.core import RestObj, VersionInfo, request +from requests import HTTPError + def test_create_model(): MODEL_NAME = "Test Model" @@ -230,3 +233,293 @@ def test_add_model_content(): assert post.call_args[1]["files"] == { "files": ("test.pkl", binary_data, "application/image") } + + +def test_list_model_versions(): + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model_link" + ) as get_model_link: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.request_link" + ) as request_link: + + get_model_link.return_value = None + with pytest.raises(ValueError): + mr.list_model_versions( + model="12345", + ) + + # Successfully returns an empty list of model versions + get_model_link_mock = { + "method": "GET", + "rel": "modelHistory", + "href": "/modelRepository/models/12345/history", + "uri": "/modelRepository/models/12345/history", + "type": "application/vnd.sas.collection", + "responseItemType": "application/vnd.sas.models.model.version", + } + get_model_link.return_value = get_model_link_mock + request_link.return_value = None + response = mr.list_model_versions(model="12345") + assert response == {} + + request_link.return_value = RestObj + response = mr.list_model_versions(model="12345") + assert response == RestObj + + +def test_get_model_version(): + with mock.patch( + "sasctl._services.model_repository.ModelRepository.list_model_versions" + ) as list_model_versions: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.request_link" + ) as request_link: + + list_model_versions_mock = { + "count": 2, + "items": [ + { + "id": "123", + "links": [ + { + "method": "GET", + "rel": "self", + "href": "/modelRepository/models/abc/history/123", + "uri": "/modelRepository/models/abc/history/123", + "type": "demo", + } + ], + }, + {"id": "345", "links": []}, + ], + } + + # originally wrapped as a list, so this test works assuming dict/restobj returning is allowed + + list_model_versions.return_value = list_model_versions_mock + + with pytest.raises(ValueError): + mr.get_model_version(model="000", version_id="000") + + response = mr.get_model_version(model="000", version_id="123") + request_link.assert_called_once_with( + list_model_versions_mock.get("items")[0], + "self", + headers={"Accept": "application/vnd.sas.models.model.version"}, + ) + + +def test_get_model_with_versions(): + with mock.patch( + "sasctl._services.model_repository.ModelRepository.is_uuid" + ) as is_uuid: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model" + ) as get_model: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get" + ) as get: + + is_uuid.return_value = True + response = mr.get_model_with_versions(model="12345") + assert response + + is_uuid.return_value = False + get_model.return_value = None + response = mr.get_model_with_versions(model={"id": "12345"}) + assert response + + is_uuid.return_value = False + get_model.return_value = None + with pytest.raises(HTTPError): + mr.get_model_with_versions(model=RestObj) + + is_uuid.return_value = False + get_model.return_value = RestObj + get.return_value = RestObj + response = mr.get_model_with_versions(model=RestObj) + assert response != {} + + get.return_value = None + response = mr.get_model_with_versions(model=RestObj) + assert response == {} + + assert get.call_count == 4 + get.assert_any_call( + "/modelRepository/models/sasctl.core.RestObj['id']/versions", + headers={"Accept": "application/vnd.sas.models.model.version"}, + ) + + get.assert_any_call( + "/modelRepository/models/12345/versions", + headers={"Accept": "application/vnd.sas.models.model.version"}, + ) + + # add test case for returning empty dictionary + + +def test_get_model_or_version(): + with mock.patch( + "sasctl._services.model_repository.ModelRepository.is_uuid" + ) as is_uuid: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model" + ) as get_model: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model_with_versions" + ) as get_model_with_versions: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.request_link" + ) as request_link: + + is_uuid.return_value = True + get_model.return_value = RestObj + response = mr.get_model_or_version( + model="12345", version_id="12345" + ) + assert response + + is_uuid.return_value = False + get_model.return_value = RestObj + response = mr.get_model_or_version( + model={"id": "12345"}, version_id="12345" + ) + assert response + + is_uuid.return_value = False + get_model.return_value = None + with pytest.raises(HTTPError): + mr.get_model_or_version(model=RestObj, version_id="12345") + + is_uuid.return_value = False + get_model_mock = {"id": "12345"} + get_model.side_effect = [ + get_model_mock, + RestObj, + ] + response = mr.get_model_or_version( + model=RestObj, version_id="12345" + ) + assert response + + assert get_model.call_count == 5 + + get_model_with_versions_mock = { + "count": 2, + "items": [{}], + "modelVersions": [ + { + "id": "123", + "links": [ + { + "method": "GET", + "rel": "self", + "href": "/modelRepository/models/abc/history/123", + "uri": "/modelRepository/models/abc/history/123", + "type": "demo", + } + ], + }, + {"id": "345", "links": []}, + ], + } + + is_uuid.return_value = True + get_model_with_versions.return_value = get_model_with_versions_mock + with pytest.raises(ValueError): + mr.get_model_or_version(model="012", version_id="000") + + response = mr.get_model_or_version(model="012", version_id="123") + assert response + request_link.assert_called_once_with( + get_model_with_versions_mock.get("modelVersions")[0], + "self", + headers={"Accept": "application/vnd.sas.models.model.version"}, + ) + + +def test_get_model_version_contents(): + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model_version" + ) as get_model_version: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.request_link" + ) as request_link: + + get_model_version.return_value = RestObj + request_link.return_value = None + response = mr.get_model_version_contents(model="12345", version_id="345") + assert response == {} + + request_link.return_value = RestObj + response = mr.get_model_version_contents(model="12345", version_id="345") + assert response == RestObj + + +def test_get_model_version_content_metadata(): + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model_version_contents" + ) as get_model_version_contents: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.request_link" + ) as request_link: + + model_version_contents_mock = { + "count": 2, + "items": [ + { + "id": "345", + "links": [ + { + "method": "GET", + "rel": "self", + "href": "/modelRepository/models/abc/history/123/contents/345", + "uri": "/modelRepository/models/abc/history/123/contents/345", + "type": "demo", + } + ], + }, + {"id": "567", "links": []}, + ], + } + + get_model_version_contents.return_value = model_version_contents_mock + with pytest.raises(ValueError): + response = mr.get_model_version_content_metadata( + model="abc", version_id="123", content_id="000" + ) + + response = mr.get_model_version_content_metadata( + model="abc", version_id="123", content_id="345" + ) + assert response + request_link.assert_called_once_with( + model_version_contents_mock.get("items")[0], + "self", + headers={"Accept": "application/vnd.sas.models.model.content"}, + ) + + +def test_get_model_version_content(): + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model_version_content_metadata" + ) as get_model_version_content_metadata: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.request_link" + ) as request_link: + + get_model_version_content_metadata.return_value = RestObj + request_link.return_value = None + with pytest.raises(HTTPError): + mr.get_model_version_content( + model="abc", version_id="123", content_id="345" + ) + + request_link.return_value = RestObj + response = mr.get_model_version_content( + model="abc", version_id="123", content_id="345" + ) + assert response + + assert request_link.call_count == 2 From 742f730bf71aa8c6b1a354416597649dc14ba401 Mon Sep 17 00:00:00 2001 From: samyarpotlapalli Date: Thu, 22 May 2025 17:18:24 -0400 Subject: [PATCH 2/8] edited list_model_versions --- src/sasctl/_services/model_repository.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/sasctl/_services/model_repository.py b/src/sasctl/_services/model_repository.py index 36bbb26f..b58fef40 100644 --- a/src/sasctl/_services/model_repository.py +++ b/src/sasctl/_services/model_repository.py @@ -10,6 +10,7 @@ from warnings import warn import requests from requests.exceptions import HTTPError +import traceback from ..core import current_session, delete, get, sasctl_command from .service import Service @@ -624,24 +625,31 @@ def list_model_versions(cls, model): "Cannot find link for version history for model '%s'" % model ) - modelHistory = cls.request_link( + + modelHistory = cls.request_link( link, "modelHistory", - headers={"Accept": "application/vnd.sas.models.model.version"}, + headers={"Accept": "application/vnd.sas.collection+json"}, ) - if modelHistory is None: - return {} return modelHistory @classmethod - def get_model_version(cls, model, version_id): + def get_model_version(cls, model, version_id): #check if this now handles a return 1 case model_history = cls.list_model_versions(model) - model_history_items = model_history.get("items") - for i, item in enumerate(model_history_items): - if item.get("id") == version_id: + for item in model_history: + if isinstance(item, str): + if item == 'id' and dict(model_history)[item] == version_id: + return cls.request_link( + model_history, + "self", + headers={"Accept": "application/vnd.sas.models.model.version"}, + ) + continue + + if item["id"] == version_id: return cls.request_link( item, "self", From 8af5b8fe52d3ec4badf5c4e937c7ba89de01e14a Mon Sep 17 00:00:00 2001 From: samyarpotlapalli Date: Fri, 30 May 2025 09:37:47 -0400 Subject: [PATCH 3/8] passing tests and example notebook --- src/sasctl/_services/model_repository.py | 175 ++++++++--- tests/unit/test_model_repository.py | 352 +++++++++++++---------- 2 files changed, 329 insertions(+), 198 deletions(-) diff --git a/src/sasctl/_services/model_repository.py b/src/sasctl/_services/model_repository.py index b58fef40..663fce0e 100644 --- a/src/sasctl/_services/model_repository.py +++ b/src/sasctl/_services/model_repository.py @@ -10,11 +10,15 @@ from warnings import warn import requests from requests.exceptions import HTTPError -import traceback +import urllib -from ..core import current_session, delete, get, sasctl_command +# import traceback +# import sys + +from ..core import current_session, delete, get, sasctl_command, RestObj from .service import Service + FUNCTIONS = { "Analytical", "Classification", @@ -615,7 +619,7 @@ def list_model_versions(cls, model): Returns ------- - RestObj + list """ @@ -625,41 +629,60 @@ def list_model_versions(cls, model): "Cannot find link for version history for model '%s'" % model ) - - modelHistory = cls.request_link( + modelHistory = cls.request_link( link, "modelHistory", headers={"Accept": "application/vnd.sas.collection+json"}, ) + if isinstance(modelHistory, RestObj): + return [modelHistory] return modelHistory @classmethod - def get_model_version(cls, model, version_id): #check if this now handles a return 1 case + def get_model_version(cls, model, version_id): + """Get a specific version of a model. + + Parameters + ---------- + model : str or dict + The name, id, or dictionary representation of a model. + version_id: str + The id of a model version. + + Returns + ------- + RestObj + + """ model_history = cls.list_model_versions(model) for item in model_history: - if isinstance(item, str): - if item == 'id' and dict(model_history)[item] == version_id: - return cls.request_link( - model_history, - "self", - headers={"Accept": "application/vnd.sas.models.model.version"}, - ) - continue - if item["id"] == version_id: return cls.request_link( item, "self", - headers={"Accept": "application/vnd.sas.models.model.version"}, + headers={"Accept": "application/vnd.sas.models.model.version+json"}, ) raise ValueError("The version id specified could not be found.") @classmethod def get_model_with_versions(cls, model): + """Get the current model with its version history. + + Parameters + ---------- + model : str or dict + The name, id, or dictionary representation of a model. + + Returns + ------- + list + + """ + if cls.is_uuid(model): model_id = model elif isinstance(model, dict) and "id" in model: @@ -672,75 +695,130 @@ def get_model_with_versions(cls, model): ) model_id = model["id"] - versions_uri = f"/modelRepository/models/{model_id}/versions" - version_history = cls.get( - versions_uri, - headers={"Accept": "application/vnd.sas.models.model.version"}, - ) - if version_history is None: - return {} + versions_uri = f"/models/{model_id}/versions" + try: + version_history = cls.request( + "GET", + versions_uri, + headers={"Accept": "application/vnd.sas.collection+json"}, + ) + except urllib.error.HTTPError as e: + raise HTTPError( + f"Request failed: Model id may be referencing a non-existing model." + ) from None + + if isinstance(version_history, RestObj): + return [version_history] + return version_history @classmethod def get_model_or_version(cls, model, version_id): + """Get a specific version of a model but if model id and version id are the same, the current model is returned. - if cls.is_uuid(model): - model_id = model - elif isinstance(model, dict) and "id" in model: - model_id = model["id"] - else: - model = cls.get_model(model) - if not model: - raise HTTPError( - "This model may not exist in a project or the model may not exist at all." - ) - model_id = model["id"] + Parameters + ---------- + model : str or dict + The name, id, or dictionary representation of a model. + version_id: str + The id of a model version. + + Returns + ------- + RestObj - if model_id == version_id: - return cls.get_model(model) + """ version_history = cls.get_model_with_versions(model) - model_versions = version_history.get("modelVersions") - for i, item in enumerate(model_versions): - if item.get("id") == version_id: + + for item in version_history: + if item["id"] == version_id: return cls.request_link( item, "self", - headers={"Accept": "application/vnd.sas.models.model.version"}, + headers={ + "Accept": "application/vnd.sas.models.model.version+json, application/vnd.sas.models.model+json" + }, ) raise ValueError("The version id specified could not be found.") @classmethod def get_model_version_contents(cls, model, version_id): + """Get the contents of a model version. + + Parameters + ---------- + model : str or dict + The name, id, or dictionary representation of a model. + version_id: str + The id of a model version. + + Returns + ------- + list + + """ model_version = cls.get_model_version(model, version_id) version_contents = cls.request_link( model_version, "contents", - headers={"Accept": "application/vnd.sas.models.model.content"}, + headers={"Accept": "application/vnd.sas.collection+json"}, ) - if version_contents is None: - return {} + if isinstance(version_contents, RestObj): + return [version_contents] + return version_contents @classmethod def get_model_version_content_metadata(cls, model, version_id, content_id): + """Get the content metadata header information for a model version. + + Parameters + ---------- + model : str or dict + The name, id, or dictionary representation of a model. + version_id: str + The id of a model version. + content_id: str + The id of the content file. + + Returns + ------- + RestObj + + """ model_version_contents = cls.get_model_version_contents(model, version_id) - model_version_contents_items = model_version_contents.get("items") - for i, item in enumerate(model_version_contents_items): - if item.get("id") == content_id: + for item in model_version_contents: + if item["id"] == content_id: return cls.request_link( item, "self", - headers={"Accept": "application/vnd.sas.models.model.content"}, + headers={"Accept": "application/vnd.sas.models.model.content+json"}, ) raise ValueError("The content id specified could not be found.") @classmethod def get_model_version_content(cls, model, version_id, content_id): + """Get the specific content inside the content file for a model version. + + Parameters + ---------- + model : str or dict + The name, id, or dictionary representation of a model. + version_id: str + The id of a model version. + content_id: str + The id of the specific content file. + + Returns + ------- + list + + """ metadata = cls.get_model_version_content_metadata(model, version_id, content_id) version_content_file = cls.request_link( @@ -749,6 +827,9 @@ def get_model_version_content(cls, model, version_id, content_id): if version_content_file is None: raise HTTPError("Something went wrong while accessing the metadata file.") + + if isinstance(version_content_file, RestObj): + return [version_content_file] return version_content_file @classmethod diff --git a/tests/unit/test_model_repository.py b/tests/unit/test_model_repository.py index b3a2aa3e..5e8ac154 100644 --- a/tests/unit/test_model_repository.py +++ b/tests/unit/test_model_repository.py @@ -15,6 +15,7 @@ from sasctl.core import RestObj, VersionInfo, request from requests import HTTPError +import urllib.error def test_create_model(): @@ -235,6 +236,53 @@ def test_add_model_content(): } +def test_create_model_version(): + model_mock = {"id": 12345} + new_model_mock = {"id": 34567} + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model", + side_effect=[ + model_mock, + model_mock, + new_model_mock, + model_mock, + new_model_mock, + ], + ) as get_model: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model_link" + ) as get_model_link: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.request_link" + ) as request_link: + get_model_link_mock = { + "method": "GET", + "rel": "modelHistory", + "href": "/modelRepository/models/12345/history", + "uri": "/modelRepository/models/12345/history", + "type": "application/vnd.sas.collection", + "responseItemType": "application/vnd.sas.models.model.version", + } + + get_model_link.return_value = None + with pytest.raises(ValueError): + mr.create_model_version(model=model_mock, minor=False) + + get_model_link.return_value = get_model_link_mock + response = mr.create_model_version(model=model_mock, minor=False) + + request_link.assert_called_with( + model_mock, "addModelVersion", json={"option": "major"} + ) + assert response == new_model_mock + + response = mr.create_model_version(model=model_mock, minor=True) + request_link.assert_called_with( + model_mock, "addModelVersion", json={"option": "minor"} + ) + assert response == new_model_mock + + def test_list_model_versions(): with mock.patch( "sasctl._services.model_repository.ModelRepository.get_model_link" @@ -249,7 +297,6 @@ def test_list_model_versions(): model="12345", ) - # Successfully returns an empty list of model versions get_model_link_mock = { "method": "GET", "rel": "modelHistory", @@ -258,14 +305,22 @@ def test_list_model_versions(): "type": "application/vnd.sas.collection", "responseItemType": "application/vnd.sas.models.model.version", } + get_model_link.return_value = get_model_link_mock - request_link.return_value = None + + response = mr.list_model_versions(model="12345") + assert response + + request_link.return_value = RestObj({"id": "12345"}) response = mr.list_model_versions(model="12345") - assert response == {} + assert isinstance(response, list) - request_link.return_value = RestObj + request_link.return_value = [ + RestObj({"id": "12345"}), + RestObj({"id": "3456"}), + ] response = mr.list_model_versions(model="12345") - assert response == RestObj + assert isinstance(response, list) def test_get_model_version(): @@ -276,26 +331,21 @@ def test_get_model_version(): "sasctl._services.model_repository.ModelRepository.request_link" ) as request_link: - list_model_versions_mock = { - "count": 2, - "items": [ - { - "id": "123", - "links": [ - { - "method": "GET", - "rel": "self", - "href": "/modelRepository/models/abc/history/123", - "uri": "/modelRepository/models/abc/history/123", - "type": "demo", - } - ], - }, - {"id": "345", "links": []}, - ], - } - - # originally wrapped as a list, so this test works assuming dict/restobj returning is allowed + list_model_versions_mock = [ + { + "id": "123", + "links": [ + { + "method": "GET", + "rel": "self", + "href": "/modelRepository/models/abc/history/123", + "uri": "/modelRepository/models/abc/history/123", + "type": "demo", + } + ], + }, + {"id": "345", "links": []}, + ] list_model_versions.return_value = list_model_versions_mock @@ -304,9 +354,9 @@ def test_get_model_version(): response = mr.get_model_version(model="000", version_id="123") request_link.assert_called_once_with( - list_model_versions_mock.get("items")[0], + list_model_versions_mock[0], "self", - headers={"Accept": "application/vnd.sas.models.model.version"}, + headers={"Accept": "application/vnd.sas.models.model.version+json"}, ) @@ -318,8 +368,8 @@ def test_get_model_with_versions(): "sasctl._services.model_repository.ModelRepository.get_model" ) as get_model: with mock.patch( - "sasctl._services.model_repository.ModelRepository.get" - ) as get: + "sasctl._services.model_repository.ModelRepository.request" + ) as request: is_uuid.return_value = True response = mr.get_model_with_versions(model="12345") @@ -337,106 +387,81 @@ def test_get_model_with_versions(): is_uuid.return_value = False get_model.return_value = RestObj - get.return_value = RestObj + request.side_effect = urllib.error.HTTPError( + url="http://demo.sas.com", + code=404, + msg="Not Found", + hdrs=None, + fp=None, + ) + with pytest.raises(HTTPError): + mr.get_model_with_versions(model=RestObj) + + request.side_effect = None + request.return_value = RestObj({"id": "12345"}) response = mr.get_model_with_versions(model=RestObj) - assert response != {} + assert isinstance(response, list) - get.return_value = None + request.return_value = [ + RestObj({"id": "12345"}), + RestObj({"id": "3456"}), + ] response = mr.get_model_with_versions(model=RestObj) - assert response == {} + assert isinstance(response, list) - assert get.call_count == 4 - get.assert_any_call( - "/modelRepository/models/sasctl.core.RestObj['id']/versions", - headers={"Accept": "application/vnd.sas.models.model.version"}, + request.assert_any_call( + "GET", + "/models/sasctl.core.RestObj['id']/versions", + headers={"Accept": "application/vnd.sas.collection+json"}, ) - get.assert_any_call( - "/modelRepository/models/12345/versions", - headers={"Accept": "application/vnd.sas.models.model.version"}, + request.assert_any_call( + "GET", + "/models/12345/versions", + headers={"Accept": "application/vnd.sas.collection+json"}, ) - # add test case for returning empty dictionary - def test_get_model_or_version(): with mock.patch( - "sasctl._services.model_repository.ModelRepository.is_uuid" - ) as is_uuid: + "sasctl._services.model_repository.ModelRepository.get_model_with_versions" + ) as get_model_with_versions: with mock.patch( - "sasctl._services.model_repository.ModelRepository.get_model" - ) as get_model: - with mock.patch( - "sasctl._services.model_repository.ModelRepository.get_model_with_versions" - ) as get_model_with_versions: - with mock.patch( - "sasctl._services.model_repository.ModelRepository.request_link" - ) as request_link: - - is_uuid.return_value = True - get_model.return_value = RestObj - response = mr.get_model_or_version( - model="12345", version_id="12345" - ) - assert response - - is_uuid.return_value = False - get_model.return_value = RestObj - response = mr.get_model_or_version( - model={"id": "12345"}, version_id="12345" - ) - assert response - - is_uuid.return_value = False - get_model.return_value = None - with pytest.raises(HTTPError): - mr.get_model_or_version(model=RestObj, version_id="12345") - - is_uuid.return_value = False - get_model_mock = {"id": "12345"} - get_model.side_effect = [ - get_model_mock, - RestObj, - ] - response = mr.get_model_or_version( - model=RestObj, version_id="12345" - ) - assert response - - assert get_model.call_count == 5 - - get_model_with_versions_mock = { - "count": 2, - "items": [{}], - "modelVersions": [ - { - "id": "123", - "links": [ - { - "method": "GET", - "rel": "self", - "href": "/modelRepository/models/abc/history/123", - "uri": "/modelRepository/models/abc/history/123", - "type": "demo", - } - ], - }, - {"id": "345", "links": []}, - ], - } - - is_uuid.return_value = True - get_model_with_versions.return_value = get_model_with_versions_mock - with pytest.raises(ValueError): - mr.get_model_or_version(model="012", version_id="000") - - response = mr.get_model_or_version(model="012", version_id="123") - assert response - request_link.assert_called_once_with( - get_model_with_versions_mock.get("modelVersions")[0], - "self", - headers={"Accept": "application/vnd.sas.models.model.version"}, - ) + "sasctl._services.model_repository.ModelRepository.request_link" + ) as request_link: + + get_model_with_versions_mock = [ + { + "id": "123", + "links": [ + { + "method": "GET", + "rel": "self", + "href": "/modelRepository/models/abc/history/123", + "uri": "/modelRepository/models/abc/history/123", + "type": "demo", + } + ], + }, + {"id": "345", "links": []}, + ] + + get_model_with_versions.return_value = [] + with pytest.raises(ValueError): + mr.get_model_or_version(model="000", version_id="000") + + get_model_with_versions.return_value = get_model_with_versions_mock + with pytest.raises(ValueError): + mr.get_model_or_version(model="000", version_id="000") + + response = mr.get_model_or_version(model="000", version_id="123") + request_link.assert_called_once_with( + get_model_with_versions_mock[0], + "self", + headers={ + "Accept": "application/vnd.sas.models.model.version+json, application/vnd.sas.models.model+json" + }, + ) def test_get_model_version_contents(): @@ -447,14 +472,23 @@ def test_get_model_version_contents(): "sasctl._services.model_repository.ModelRepository.request_link" ) as request_link: - get_model_version.return_value = RestObj - request_link.return_value = None - response = mr.get_model_version_contents(model="12345", version_id="345") - assert response == {} - - request_link.return_value = RestObj - response = mr.get_model_version_contents(model="12345", version_id="345") - assert response == RestObj + get_model_version.return_value = {"id": "000"} + request_link.return_value = RestObj({"id": "12345"}) + response = mr.get_model_version_contents(model="12345", version_id="3456") + assert isinstance(response, list) + + request_link.return_value = [ + RestObj({"id": "12345"}), + RestObj({"id": "3456"}), + ] + response = mr.get_model_version_contents(model="12345", version_id="3456") + assert isinstance(response, list) + + request_link.assert_any_call( + {"id": "000"}, + "contents", + headers={"Accept": "application/vnd.sas.collection+json"}, + ) def test_get_model_version_content_metadata(): @@ -465,28 +499,31 @@ def test_get_model_version_content_metadata(): "sasctl._services.model_repository.ModelRepository.request_link" ) as request_link: - model_version_contents_mock = { - "count": 2, - "items": [ - { - "id": "345", - "links": [ - { - "method": "GET", - "rel": "self", - "href": "/modelRepository/models/abc/history/123/contents/345", - "uri": "/modelRepository/models/abc/history/123/contents/345", - "type": "demo", - } - ], - }, - {"id": "567", "links": []}, - ], - } + get_model_with_metadata_mock = [ + { + "id": "123", + "links": [ + { + "method": "GET", + "rel": "self", + "href": "/modelRepository/models/abc/history/123", + "uri": "/modelRepository/models/abc/history/123", + "type": "demo", + } + ], + }, + {"id": "345", "links": []}, + ] + + get_model_version_contents.return_value = [] + with pytest.raises(ValueError): + mr.get_model_version_content_metadata( + model="000", version_id="123", content_id="000" + ) - get_model_version_contents.return_value = model_version_contents_mock + get_model_version_contents.return_value = get_model_with_metadata_mock with pytest.raises(ValueError): - response = mr.get_model_version_content_metadata( + mr.get_model_version_content_metadata( model="abc", version_id="123", content_id="000" ) @@ -495,9 +532,9 @@ def test_get_model_version_content_metadata(): ) assert response request_link.assert_called_once_with( - model_version_contents_mock.get("items")[0], + get_model_with_metadata_mock[1], "self", - headers={"Accept": "application/vnd.sas.models.model.content"}, + headers={"Accept": "application/vnd.sas.models.model.content+json"}, ) @@ -509,17 +546,30 @@ def test_get_model_version_content(): "sasctl._services.model_repository.ModelRepository.request_link" ) as request_link: - get_model_version_content_metadata.return_value = RestObj + get_model_version_content_metadata.return_value = {"id": 000} request_link.return_value = None with pytest.raises(HTTPError): mr.get_model_version_content( model="abc", version_id="123", content_id="345" ) - request_link.return_value = RestObj + request_link.return_value = RestObj({"id": "12345"}) response = mr.get_model_version_content( model="abc", version_id="123", content_id="345" ) - assert response + assert isinstance(response, list) + + request_link.return_value = [ + RestObj({"id": "12345"}), + RestObj({"id": "3456"}), + ] + response = mr.get_model_version_content( + model="abc", version_id="123", content_id="345" + ) + assert isinstance(response, list) - assert request_link.call_count == 2 + request_link.assert_any_call( + {"id": 000}, + "content", + headers={"Accept": "text/plain"}, + ) From 0d5fd79b2ac36e836bc152795bcda53aebada6d1 Mon Sep 17 00:00:00 2001 From: djm21 Date: Fri, 30 May 2025 10:33:10 -0700 Subject: [PATCH 4/8] update list_model_versions to accomodate viya 3.5 --- src/sasctl/_services/model_repository.py | 33 ++++++++++++++---------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/src/sasctl/_services/model_repository.py b/src/sasctl/_services/model_repository.py index 663fce0e..74fb6446 100644 --- a/src/sasctl/_services/model_repository.py +++ b/src/sasctl/_services/model_repository.py @@ -623,21 +623,28 @@ def list_model_versions(cls, model): """ - link = cls.get_model_link(model, "modelHistory") - if link is None: - raise ValueError( - "Cannot find link for version history for model '%s'" % model - ) + if current_session().version_info() < 4: + model = cls.get_model(model) + if cls.get_model_link(model, "modelVersions") is None: + raise ValueError("Unable to retrieve versions for model '%s'" % model) - modelHistory = cls.request_link( - link, - "modelHistory", - headers={"Accept": "application/vnd.sas.collection+json"}, - ) + return cls.request_link(model, "modelVersions") + else: + link = cls.get_model_link(model, "modelHistory") + if link is None: + raise ValueError( + "Cannot find link for version history for model '%s'" % model + ) + + modelHistory = cls.request_link( + link, + "modelHistory", + headers={"Accept": "application/vnd.sas.collection+json"}, + ) - if isinstance(modelHistory, RestObj): - return [modelHistory] - return modelHistory + if isinstance(modelHistory, RestObj): + return [modelHistory] + return modelHistory @classmethod def get_model_version(cls, model, version_id): From f61eb04082e6fda8ccfdb0c03ba920f2fd33e185 Mon Sep 17 00:00:00 2001 From: djm21 Date: Fri, 30 May 2025 10:45:10 -0700 Subject: [PATCH 5/8] update model_repository unit tests to accomodate version checking --- tests/unit/test_model_repository.py | 55 +++++++++++++++-------------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/tests/unit/test_model_repository.py b/tests/unit/test_model_repository.py index 5e8ac154..9b199b9d 100644 --- a/tests/unit/test_model_repository.py +++ b/tests/unit/test_model_repository.py @@ -290,37 +290,40 @@ def test_list_model_versions(): with mock.patch( "sasctl._services.model_repository.ModelRepository.request_link" ) as request_link: + with mock.patch( + "sasctl.core.Session.version_info" + ) as version: + version.return_value = VersionInfo(4) + get_model_link.return_value = None + with pytest.raises(ValueError): + mr.list_model_versions( + model="12345", + ) - get_model_link.return_value = None - with pytest.raises(ValueError): - mr.list_model_versions( - model="12345", - ) - - get_model_link_mock = { - "method": "GET", - "rel": "modelHistory", - "href": "/modelRepository/models/12345/history", - "uri": "/modelRepository/models/12345/history", - "type": "application/vnd.sas.collection", - "responseItemType": "application/vnd.sas.models.model.version", - } + get_model_link_mock = { + "method": "GET", + "rel": "modelHistory", + "href": "/modelRepository/models/12345/history", + "uri": "/modelRepository/models/12345/history", + "type": "application/vnd.sas.collection", + "responseItemType": "application/vnd.sas.models.model.version", + } - get_model_link.return_value = get_model_link_mock + get_model_link.return_value = get_model_link_mock - response = mr.list_model_versions(model="12345") - assert response + response = mr.list_model_versions(model="12345") + assert response - request_link.return_value = RestObj({"id": "12345"}) - response = mr.list_model_versions(model="12345") - assert isinstance(response, list) + request_link.return_value = RestObj({"id": "12345"}) + response = mr.list_model_versions(model="12345") + assert isinstance(response, list) - request_link.return_value = [ - RestObj({"id": "12345"}), - RestObj({"id": "3456"}), - ] - response = mr.list_model_versions(model="12345") - assert isinstance(response, list) + request_link.return_value = [ + RestObj({"id": "12345"}), + RestObj({"id": "3456"}), + ] + response = mr.list_model_versions(model="12345") + assert isinstance(response, list) def test_get_model_version(): From 97e021024100fa80708c8faec0d9e149ee3c7c63 Mon Sep 17 00:00:00 2001 From: djm21 Date: Fri, 30 May 2025 10:49:14 -0700 Subject: [PATCH 6/8] black reformatting --- tests/unit/test_model_repository.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/unit/test_model_repository.py b/tests/unit/test_model_repository.py index 9b199b9d..9bab3f9f 100644 --- a/tests/unit/test_model_repository.py +++ b/tests/unit/test_model_repository.py @@ -290,9 +290,7 @@ def test_list_model_versions(): with mock.patch( "sasctl._services.model_repository.ModelRepository.request_link" ) as request_link: - with mock.patch( - "sasctl.core.Session.version_info" - ) as version: + with mock.patch("sasctl.core.Session.version_info") as version: version.return_value = VersionInfo(4) get_model_link.return_value = None with pytest.raises(ValueError): From bcecc0e5667aa321937dbe042ad2fe76a32389f0 Mon Sep 17 00:00:00 2001 From: djm21 Date: Fri, 30 May 2025 12:58:05 -0700 Subject: [PATCH 7/8] update model_repository test to fix python 3.8 tests --- tests/unit/test_model_repository.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_model_repository.py b/tests/unit/test_model_repository.py index 9bab3f9f..c64c8d58 100644 --- a/tests/unit/test_model_repository.py +++ b/tests/unit/test_model_repository.py @@ -387,7 +387,7 @@ def test_get_model_with_versions(): mr.get_model_with_versions(model=RestObj) is_uuid.return_value = False - get_model.return_value = RestObj + get_model.return_value = RestObj({"id": "12345"}) request.side_effect = urllib.error.HTTPError( url="http://demo.sas.com", code=404, From 013443f0baacaddd8cd67dc0f91b29def190405f Mon Sep 17 00:00:00 2001 From: djm21 Date: Fri, 30 May 2025 13:06:46 -0700 Subject: [PATCH 8/8] further testing updates --- tests/unit/test_model_repository.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_model_repository.py b/tests/unit/test_model_repository.py index c64c8d58..bf4f9284 100644 --- a/tests/unit/test_model_repository.py +++ b/tests/unit/test_model_repository.py @@ -387,7 +387,7 @@ def test_get_model_with_versions(): mr.get_model_with_versions(model=RestObj) is_uuid.return_value = False - get_model.return_value = RestObj({"id": "12345"}) + get_model.return_value = RestObj({"id": "123456"}) request.side_effect = urllib.error.HTTPError( url="http://demo.sas.com", code=404, @@ -412,7 +412,7 @@ def test_get_model_with_versions(): request.assert_any_call( "GET", - "/models/sasctl.core.RestObj['id']/versions", + "/models/123456/versions", headers={"Accept": "application/vnd.sas.collection+json"}, )