Commit 10ea9225 authored by Stein Welberg's avatar Stein Welberg Committed by Clement Bois
Browse files

chore: extract interaction with Dependency Track API to separate class

parent 6392b3a1
Loading
Loading
Loading
Loading
+42 −47
Original line number Diff line number Diff line
@@ -243,6 +243,29 @@ class Version:
            return 0, alpha


class ApiClient:
    def __init__(self, base_api_url: str, api_key: str, verify_ssl: bool):
        self.base_api_url = base_api_url
        self.session = requests.Session()
        self.session.headers.update({
            "X-API-Key": api_key,
            "accept": MIME_APPLICATION_JSON,
        })
        self.session.verify = verify_ssl

    def get(self, path, **kwargs):
        url = f"{self.base_api_url}{path}"
        return self.session.get(url, **kwargs)

    def post(self, path, **kwargs):
        url = f"{self.base_api_url}{path}"
        return self.session.post(url, **kwargs)

    def put(self, path, **kwargs):
        url = f"{self.base_api_url}{path}"
        return self.session.put(url, **kwargs)


class Scanner:
    def __init__(
        self,
@@ -263,14 +286,12 @@ class Scanner:
        merged_vex_file = None,
        **_: None,
    ):
        self.base_api_url = base_api_url
        self.api_key = api_key
        self.api_client = ApiClient(base_api_url, api_key, verify_ssl)
        self.project_path = project_path
        self.path_separator = path_separator
        self._purl_max_len = purl_max_len
        self.merge = merge
        self.merge_output = merge_output
        self.verify_ssl = verify_ssl
        self.show_findings = show_findings
        self.risk_score_threshold = risk_score_threshold
        self.tags = list(filter(None, map(str.strip, tags.split(",")))) if tags else []
@@ -286,11 +307,7 @@ class Scanner:
    def dt_version(self) -> Version:
        """Determines the DT server version."""
        return Version(
            requests.get(
                f"{self.base_api_url}/version",
                headers={"accept": MIME_APPLICATION_JSON},
                verify=self.verify_ssl,
            ).json()["version"]
            self.api_client.get("/version").json()["version"]
        )

    @property
@@ -328,11 +345,7 @@ class Scanner:
    def get_permissions(self) -> list[DtPermission]:
        return [
            permission["name"]
            for permission in requests.get(
                f"{self.base_api_url}/v1/team/self",
                headers={"X-API-Key": self.api_key, "accept": MIME_APPLICATION_JSON},
                verify=self.verify_ssl,
            ).json()["permissions"]
            for permission in self.api_client.get("/v1/team/self").json()["permissions"]
        ]

    def has_permission(self, perm: DtPermission) -> bool:
@@ -353,11 +366,9 @@ class Scanner:
            return project_def.uuid

        # project is defined by name/version...
        resp = requests.get(
            f"{self.base_api_url}/v1/project",
            headers={"X-API-Key": self.api_key, "accept": MIME_APPLICATION_JSON},
        resp = self.api_client.get(
            "/v1/project",
            params={"name": project_def.name},
            verify=self.verify_ssl,
        )
        resp.raise_for_status()
        # find project with matching name/version
@@ -389,11 +400,9 @@ class Scanner:
                f"- {AnsiColors.YELLOW}{project_path}{AnsiColors.RESET} found sibling (version: {name_match.get('version')}): {name_match['uuid']}..."
            )
            # now create a clone of the project
            resp = requests.put(
                f"{self.base_api_url}/v1/project/clone",
            resp = self.api_client.put(
                "/v1/project/clone",
                headers={
                    "X-API-Key": self.api_key,
                    "accept": MIME_APPLICATION_JSON,
                    "content-type": MIME_APPLICATION_JSON,
                },
                json={
@@ -406,19 +415,16 @@ class Scanner:
                    "includeAuditHistory": True,
                    "includeACL": True,
                },
                verify=self.verify_ssl,
            )
            try:
                resp.raise_for_status()
                # TODO: clone doesn't return UUID :(
                resp = requests.get(
                    f"{self.base_api_url}/v1/project/lookup",
                resp = self.api_client.get(
                    "/v1/project/lookup",
                    headers={
                        "X-API-Key": self.api_key,
                        "accept": MIME_APPLICATION_JSON,
                    },
                    params={"name": project_def.name, "version": project_def.version},
                    verify=self.verify_ssl,
                )
                resp.raise_for_status()
                # retrieve UUID from response and return
@@ -474,15 +480,12 @@ class Scanner:
        print(
            f"- {AnsiColors.YELLOW}{project_path}{AnsiColors.RESET} not found: create with params {AnsiColors.HGRAY}{json.dumps(data)}{AnsiColors.RESET}..."
        )
        resp = requests.put(
            f"{self.base_api_url}/v1/project",
        resp = self.api_client.put(
            "/v1/project",
            headers={
                "X-API-Key": self.api_key,
                "accept": MIME_APPLICATION_JSON,
                "content-type": MIME_APPLICATION_JSON,
            },
            json=data,
            verify=self.verify_ssl,
        )
        try:
            resp.raise_for_status()
@@ -556,12 +559,10 @@ class Scanner:
        print(
            f"- publish params: {AnsiColors.HGRAY}{json.dumps(params)}{AnsiColors.RESET}..."
        )
        resp = requests.post(
            f"{self.base_api_url}/v1/bom",
            headers={"X-API-Key": self.api_key, "accept": MIME_APPLICATION_JSON},
        resp = self.api_client.post(
            "/v1/bom",
            files={"bom": sbom_json},
            data=params,
            verify=self.verify_ssl,
        )
        try:
            resp.raise_for_status()
@@ -619,12 +620,10 @@ class Scanner:

        with open(vex_file_path, "r") as vex_file:
            params = project_def.params
            resp = requests.post(
                f"{self.base_api_url}/v1/vex",
                headers={"X-API-Key": self.api_key, "accept": MIME_APPLICATION_JSON},
            resp = self.api_client.post(
                "/v1/vex",
                files={"vex": vex_file},
                data=params,
                verify=self.verify_ssl,
            )
            try:
                resp.raise_for_status()
@@ -648,11 +647,9 @@ class Scanner:
            params["name"] = project_def.name
            if project_def.version:
                params["version"] = project_def.version
            resp = requests.get(
                f"{self.base_api_url}/v1/project/lookup",
                headers={"X-API-Key": self.api_key, "accept": MIME_APPLICATION_JSON},
            resp = self.api_client.get(
                "/v1/project/lookup",
                params=params,
                verify=self.verify_ssl,
            )
            project_id = resp.json().get("uuid")

@@ -693,10 +690,8 @@ class Scanner:
    def wait_for_event_processing(self, event_id: str):
        for n in range(8):  # ~5 minutes
            sleep(2**n)
            resp = requests.get(
                f"{self.base_api_url}/v1/{self.event_token_path}/{event_id}",
                headers={"X-API-Key": self.api_key, "accept": MIME_APPLICATION_JSON},
                verify=self.verify_ssl,
            resp = self.api_client.get(
                f"/v1/{self.event_token_path}/{event_id}",
            )
            if not resp.json().get("processing", False):
                break