diff --git a/cwltool/load_tool.py b/cwltool/load_tool.py index 8ab6e029a..2f3d88763 100644 --- a/cwltool/load_tool.py +++ b/cwltool/load_tool.py @@ -7,7 +7,7 @@ import re import urlparse -from schema_salad.ref_resolver import Loader, Fetcher, DefaultFetcher +from schema_salad.ref_resolver import Loader, Fetcher import schema_salad.validate as validate from schema_salad.validate import ValidationException import schema_salad.schema as schema @@ -26,7 +26,7 @@ def fetch_document(argsworkflow, # type: Union[Text, dict[Text, Any]] resolver=None, # type: Callable[[Loader, Union[Text, dict[Text, Any]]], Text] - fetcher_constructor=DefaultFetcher # type: Callable[[Dict[unicode, unicode], requests.sessions.Session], Fetcher] + fetcher_constructor=None # type: Callable[[Dict[unicode, unicode], requests.sessions.Session], Fetcher] ): # type: (...) -> Tuple[Loader, Dict[Text, Any], Text] """Retrieve a CWL document.""" @@ -111,11 +111,19 @@ def validate_document(document_loader, # type: Loader enable_dev=False, # type: bool strict=True, # type: bool preprocess_only=False, # type: bool - fetcher_constructor=DefaultFetcher # type: Callable[[Dict[unicode, unicode], requests.sessions.Session], Fetcher] + fetcher_constructor=None # type: Callable[[Dict[unicode, unicode], requests.sessions.Session], Fetcher] ): # type: (...) -> Tuple[Loader, Names, Union[Dict[Text, Any], List[Dict[Text, Any]]], Dict[Text, Any], Text] """Validate a CWL document.""" + if isinstance(workflowobj, list): + workflowobj = { + "$graph": workflowobj + } + + if not isinstance(workflowobj, dict): + raise ValueError("workflowjobj must be a dict") + jobobj = None if "cwl:tool" in workflowobj: jobobj, _ = document_loader.resolve_all(workflowobj, uri) @@ -123,11 +131,6 @@ def validate_document(document_loader, # type: Loader del cast(dict, jobobj)["https://w3id.org/cwl/cwl#tool"] workflowobj = fetch_document(uri, fetcher_constructor=fetcher_constructor)[1] - if isinstance(workflowobj, list): - workflowobj = { - "$graph": workflowobj - } - fileuri = urlparse.urldefrag(uri)[0] if "cwlVersion" in workflowobj: @@ -235,7 +238,7 @@ def load_tool(argsworkflow, # type: Union[Text, Dict[Text, Any]] enable_dev=False, # type: bool strict=True, # type: bool resolver=None, # type: Callable[[Loader, Union[Text, dict[Text, Any]]], Text] - fetcher_constructor=DefaultFetcher # type: Callable[[Dict[unicode, unicode], requests.sessions.Session], Fetcher] + fetcher_constructor=None # type: Callable[[Dict[unicode, unicode], requests.sessions.Session], Fetcher] ): # type: (...) -> Process diff --git a/cwltool/main.py b/cwltool/main.py index 9e3618377..64068168a 100755 --- a/cwltool/main.py +++ b/cwltool/main.py @@ -18,7 +18,7 @@ from typing import (Union, Any, AnyStr, cast, Callable, Dict, Sequence, Text, Tuple, Type, IO) -from schema_salad.ref_resolver import Loader, Fetcher, DefaultFetcher +from schema_salad.ref_resolver import Loader, Fetcher import schema_salad.validate as validate import schema_salad.jsonld_context import schema_salad.makedoc @@ -516,7 +516,7 @@ def printdeps(obj, document_loader, stdout, relative_deps, uri, basedir=None): "location": uri} # type: Dict[Text, Any] def loadref(b, u): - return document_loader.fetch(urlparse.urljoin(b, u)) + return document_loader.fetch(document_loader.fetcher.urljoin(b, u)) sf = scandeps( basedir if basedir else uri, obj, set(("$import", "run")), @@ -565,7 +565,7 @@ def main(argsl=None, # type: List[str] versionfunc=versionstring, # type: Callable[[], Text] job_order_object=None, # type: Union[Tuple[Dict[Text, Any], Text], int] make_fs_access=StdFsAccess, # type: Callable[[Text], StdFsAccess] - fetcher_constructor=DefaultFetcher, # type: Callable[[Dict[unicode, unicode], requests.sessions.Session], Fetcher] + fetcher_constructor=None, # type: Callable[[Dict[unicode, unicode], requests.sessions.Session], Fetcher] resolver=tool_resolver ): # type: (...) -> int diff --git a/cwltool/process.py b/cwltool/process.py index c092db0ad..b4e9001dd 100644 --- a/cwltool/process.py +++ b/cwltool/process.py @@ -7,7 +7,6 @@ import tempfile import glob import urlparse -import pprint from collections import Iterable import errno import shutil @@ -640,8 +639,8 @@ def mergedirs(listing): r.extend(ents.itervalues()) return r -def scandeps(base, doc, reffields, urlfields, loadref): - # type: (Text, Any, Set[Text], Set[Text], Callable[[Text, Text], Any]) -> List[Dict[Text, Text]] +def scandeps(base, doc, reffields, urlfields, loadref, urljoin=urlparse.urljoin): + # type: (Text, Any, Set[Text], Set[Text], Callable[[Text, Text], Any], Callable[[Text, Text], Text]) -> List[Dict[Text, Text]] r = [] # type: List[Dict[Text, Text]] deps = None # type: Dict[Text, Any] if isinstance(doc, dict): @@ -660,7 +659,7 @@ def scandeps(base, doc, reffields, urlfields, loadref): if u and not u.startswith("_:"): deps = { "class": doc["class"], - "location": urlparse.urljoin(base, u) + "location": urljoin(base, u) } if doc["class"] == "Directory" and "listing" in doc: deps["listing"] = doc["listing"] @@ -670,23 +669,23 @@ def scandeps(base, doc, reffields, urlfields, loadref): r.append(deps) else: if doc["class"] == "Directory" and "listing" in doc: - r.extend(scandeps(base, doc["listing"], reffields, urlfields, loadref)) + r.extend(scandeps(base, doc["listing"], reffields, urlfields, loadref, urljoin=urljoin)) elif doc["class"] == "File" and "secondaryFiles" in doc: - r.extend(scandeps(base, doc["secondaryFiles"], reffields, urlfields, loadref)) + r.extend(scandeps(base, doc["secondaryFiles"], reffields, urlfields, loadref, urljoin=urljoin)) for k, v in doc.iteritems(): if k in reffields: for u in aslist(v): if isinstance(u, dict): - r.extend(scandeps(base, u, reffields, urlfields, loadref)) + r.extend(scandeps(base, u, reffields, urlfields, loadref, urljoin=urljoin)) else: sub = loadref(base, u) - subid = urlparse.urljoin(base, u) + subid = urljoin(base, u) deps = { "class": "File", "location": subid } - sf = scandeps(subid, sub, reffields, urlfields, loadref) + sf = scandeps(subid, sub, reffields, urlfields, loadref, urljoin=urljoin) if sf: deps["secondaryFiles"] = sf deps = nestdir(base, deps) @@ -695,19 +694,20 @@ def scandeps(base, doc, reffields, urlfields, loadref): for u in aslist(v): deps = { "class": "File", - "location": urlparse.urljoin(base, u) + "location": urljoin(base, u) } deps = nestdir(base, deps) r.append(deps) elif k not in ("listing", "secondaryFiles"): - r.extend(scandeps(base, v, reffields, urlfields, loadref)) + r.extend(scandeps(base, v, reffields, urlfields, loadref, urljoin=urljoin)) elif isinstance(doc, list): for d in doc: - r.extend(scandeps(base, d, reffields, urlfields, loadref)) + r.extend(scandeps(base, d, reffields, urlfields, loadref, urljoin=urljoin)) if r: normalizeFilesDirs(r) r = mergedirs(r) + return r def compute_checksums(fs_access, fileobj): diff --git a/cwltool/workflow.py b/cwltool/workflow.py index dfa6c4c89..f2a07d04a 100644 --- a/cwltool/workflow.py +++ b/cwltool/workflow.py @@ -445,7 +445,8 @@ def __init__(self, toolpath_object, pos, **kwargs): self.embedded_tool = load_tool( toolpath_object["run"], kwargs.get("makeTool"), kwargs, enable_dev=kwargs.get("enable_dev"), - strict=kwargs.get("strict")) + strict=kwargs.get("strict"), + fetcher_constructor=kwargs.get("fetcher_constructor")) except validate.ValidationException as v: raise WorkflowException( u"Tool definition %s failed validation:\n%s" % diff --git a/setup.py b/setup.py index 80420336c..3721e874c 100755 --- a/setup.py +++ b/setup.py @@ -49,7 +49,7 @@ 'rdflib-jsonld == 0.3.0', 'html5lib >=0.90, <= 0.9999999', 'shellescape', - 'schema-salad >= 1.20.20161122192122, < 2', + 'schema-salad >= 1.21.20161206181442, < 2', 'typing >= 3.5.2', 'cwltest >= 1.0.20160907111242'], test_suite='tests',