Commit fd418be2 authored by Apoorva Srinivas Appadoo's avatar Apoorva Srinivas Appadoo
Browse files

feat: add ssm endpoint


Squashed commit of the following:

commit ee1ef6598b3fcb6e3a28d1fa55e3c28a5eb8d2e2
Author: Apoorva Srinivas Appadoo <apoorva-srinivas.appadoo@etu.univ-cotedazur.fr>
Date:   Mon Oct 6 08:02:30 2025 +0200

    fix: dockerfile

    Signed-off-by: default avatarApoorva Srinivas Appadoo <apoorva-srinivas.appadoo@etu.univ-cotedazur.fr>

commit 271e4caf6c16b70b04035b59c8e7b7337961d027
Author: Apoorva Srinivas Appadoo <apoorva-srinivas.appadoo@etu.univ-cotedazur.fr>
Date:   Sun Oct 5 13:24:54 2025 +0200

    refactor: change SSM port forward endpoint to use SsmPortForwardResponse and change response type

    Signed-off-by: default avatarApoorva Srinivas Appadoo <apoorva-srinivas.appadoo@etu.univ-cotedazur.fr>

commit 14ffe5392dc1cf8920c9c8e70f559e665381dc7c
Author: Apoorva Srinivas Appadoo <apoorva-srinivas.appadoo@etu.univ-cotedazur.fr>
Date:   Sat Oct 4 18:31:38 2025 +0200

    refactor: remove redundant configure_boto function parameter and update usage

    Signed-off-by: default avatarApoorva Srinivas Appadoo <apoorva-srinivas.appadoo@etu.univ-cotedazur.fr>

commit 617743d952c891d282c5c6ebfa08985622c0a5c5
Author: Apoorva Srinivas Appadoo <apoorva-srinivas.appadoo@etu.univ-cotedazur.fr>
Date:   Sat Oct 4 18:21:23 2025 +0200

    docs: fix readme.md

    Signed-off-by: default avatarApoorva Srinivas Appadoo <apoorva-srinivas.appadoo@etu.univ-cotedazur.fr>

commit 46bb7002297c481eccc43d7a37cf7dffcbb645f7
Author: Apoorva Srinivas Appadoo <apoorva-srinivas.appadoo@etu.univ-cotedazur.fr>
Date:   Sat Oct 4 17:55:48 2025 +0200

    feat: add ssm proxy

    Signed-off-by: default avatarApoorva Srinivas Appadoo <apoorva-srinivas.appadoo@etu.univ-cotedazur.fr>

commit d612a93a01630832e8c801052e5ae35635a36b37
Author: Apoorva Srinivas Appadoo <apoorva-srinivas.appadoo@etu.univ-cotedazur.fr>
Date:   Sat Oct 4 17:55:04 2025 +0200

    refactor: move utils to utils.py

    Signed-off-by: default avatarApoorva Srinivas Appadoo <apoorva-srinivas.appadoo@etu.univ-cotedazur.fr>
parent 026513d2
Loading
Loading
Loading
Loading
+20 −2
Original line number Diff line number Diff line
@@ -6,8 +6,26 @@ WORKDIR /code
COPY pyproject.toml poetry.lock README.md /code/
COPY ./aws_auth_provider /code/aws_auth_provider

# Use ash and enable pipefail so commands using | fail the RUN step on errors (fix DL4006)
SHELL ["/bin/ash", "-o", "pipefail", "-c"]

# Install dependencies and Session Manager plugin
RUN apk upgrade --no-cache \
    && pip install --no-cache-dir .
    && apk add --no-cache \
        curl=8.14.1-r2 \
        gcompat=1.1.0-r4 \
        rpm=4.19.1.1-r2 \
        cpio=2.15-r0 \
    # Install pip packages
    && pip install --no-cache-dir . \
    # Install Session Manager plugin (extract binary directly from RPM)
    && curl "https://s3.amazonaws.com/session-manager-downloads/plugin/latest/linux_64bit/session-manager-plugin.rpm" -o "session-manager-plugin.rpm" \
    && rpm2cpio session-manager-plugin.rpm | cpio -idmv \
    && mv usr/local/sessionmanagerplugin/bin/session-manager-plugin /usr/local/bin/ \
    && chmod +x /usr/local/bin/session-manager-plugin \
    && rm -rf usr session-manager-plugin.rpm \
    # Verify installation
    && session-manager-plugin --version

EXPOSE ${PORT}
# hadolint ignore=DL3025
+22 −0
Original line number Diff line number Diff line
@@ -118,6 +118,13 @@ This API retrieves the [CodeArtifact repository endpoint](https://docs.aws.amazo
| `region`   | AWS region to use                                             | no _(can be retrieved from env)_ |
| `env_ctx`  | the [environment context to consider](#the-notion-of-env_ctx) | no _(can be guessed)_ |

### `GET /ssm/port-forward`

This API starts an [AWS Systems Manager (SSM) port forwarding session](https://docs.aws.amazon.com/systems-manager/latest/userguide/session-manager-working-with-sessions-start.html#sessions-start-port-forwarding) to tunnel traffic from gitlab-ci to or through an EC2 instance

**Returns:** Plain text URL to access the forwarded port (e.g., `http://aws-auth-provider:14240`)

#### Query Parameters
### `GET /kubeconfig`

This API generates a complete kubeconfig file for an AWS EKS cluster with a valid authentication token.
@@ -190,7 +197,22 @@ metadata:
- Ensure your AWS credentials have `eks:DescribeCluster` permission for the cluster and permissions to access the cluster


| Name          | Description                                                   | Required              |
|---------------|---------------------------------------------------------------|-----------------------|
| `instance_id` | EC2 instance ID to connect to                                 | yes                   |
| `remote_port` | Port on the remote instance                                   | yes                   |
| `remote_host` | Remote host for advanced port forwarding scenarios           | no                    |
| `local_port`  | Local port to bind (auto-allocated if not specified)          | no                    |
| `protocol`    | URL protocol (`http` or `https`, default: `http`)            | no                    |
| `region`      | AWS region to use                                             | no _(can be retrieved from env)_ |
| `env_ctx`     | the [environment context to consider](#the-notion-of-env_ctx) | no _(can be guessed)_ |
| `role_arn`    | AWS IAM role ARN to assume                                    | no _(can be retrieved from env)_ |


**Environment Variables:**
- `API_HOST`: Host for generated URLs (default: `aws-auth-provider`)
- `PORT_RANGE_START`: Minimum port for auto-allocation (default: `10000`)
- `PORT_RANGE_END`: Maximum port for auto-allocation (default: `20000`)

## Use in GitLab CI

+34 −89
Original line number Diff line number Diff line
@@ -7,7 +7,7 @@ import tempfile
from datetime import datetime, timedelta
from functools import cache
from http import HTTPStatus
from typing import Literal, Optional
from typing import Literal

import boto3
import yaml
@@ -18,96 +18,16 @@ from fastapi.exceptions import RequestValidationError
from fastapi.responses import PlainTextResponse
from starlette.exceptions import HTTPException as StarletteHTTPException

logger = logging.getLogger("aws-auth-provider")
app = FastAPI()


# manages the different AWS authentication methods
def configure_boto(env_ctx: str = None, region: str = None, role_arn: str = None):
    # auto-determine env type
    if not env_ctx:
        env_ctx = guess_env_ctx()

    # set region
    if region is None:
        region = (
            getenv_cleared(f"AWS_{env_ctx}_REGION")
            or getenv_cleared("AWS_REGION")
            or getenv_cleared("AWS_DEFAULT_REGION")
        )
    if not region:
        logger.error("AWS region not found")
        raise HTTPException(status_code=400, detail="AWS region not found")
    os.environ["AWS_DEFAULT_REGION"] = region

    # determine auth method
    jwt_token = os.environ.get("AWS_JWT")
    if role_arn is None:
        role_arn = getenv_cleared(f"AWS_{env_ctx}_OIDC_ROLE_ARN") or getenv_cleared(
            "AWS_OIDC_ROLE_ARN"
        )
    if jwt_token and role_arn:
        # Assume Role with Web Identity Provider
        # see: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html#assume-role-with-web-identity-provider
        logger.info("Auth method: STS Assume Role with Web Identity Provider")
        with tempfile.NamedTemporaryFile(
            mode="w", encoding="utf-8", delete=False
        ) as token_file:
            token_file.write(jwt_token)
            token_file.close()
        os.environ["AWS_ROLE_ARN"] = role_arn
        os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = token_file.name
        os.environ["AWS_ROLE_SESSION_NAME"] = (
            f"GitLabRunner-{os.getenv('CI_PROJECT_ID')}-{os.getenv('CI_PIPELINE_ID')}"
        )
        return
from . import ssm
from .utils import configure_boto, getenv_cleared

    access_key_id = getenv_cleared(f"AWS_{env_ctx}_ACCESS_KEY_ID") or getenv_cleared(
        "AWS_DEFAULT_ACCESS_KEY_ID"
logger = logging.getLogger("aws-auth-provider")
setting_level = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(
    level=setting_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
    secret_access_key = getenv_cleared(
        f"AWS_{env_ctx}_SECRET_ACCESS_KEY"
    ) or getenv_cleared("AWS_DEFAULT_SECRET_ACCESS_KEY")
    if access_key_id and secret_access_key:
        # see: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html#environment-variables
        logger.info("Auth method: basic (access key ID & secret access key)")
        os.environ["AWS_ACCESS_KEY_ID"] = access_key_id
        os.environ["AWS_SECRET_ACCESS_KEY"] = secret_access_key
        return

    logger.error("Auth method not found: no credentials available")
    raise HTTPException(status_code=401, detail="AWS credentials not found")


def guess_env_ctx() -> str:
    # guess from GitLab CI predefined vars
    ref_name = os.getenv("CI_COMMIT_REF_NAME", "-")
    prod_ref = os.getenv("PROD_REF", "/^(master|main)$/").strip("/")
    if re.match(prod_ref, ref_name):
        # could be staging or prod
        if os.getenv("CI_JOB_STAGE", "-") in [
            "publish",
            "infra-prod",
            "production",
            ".post",
        ]:
            return "PROD"
        return "STAGING"

    integ_ref = os.getenv("INTEG_REF", "/^develop$/").strip("/")
    if re.match(integ_ref, ref_name):
        return "INTEG"

    return "REVIEW"


# Workaround the GitLab bug with forced exposed variables:
# variables:
#   SOMEVAR: "$SOMEVAR"
# os.getenv("SOMEVAR") may have value '$SOMEVAR' if the variable is not defined as a project variable
def getenv_cleared(name: str) -> Optional[str]:
    value = os.getenv(name)
    return None if value == f"${name}" else value
logger.setLevel(setting_level)
app = FastAPI()


@app.get("/health", response_class=PlainTextResponse)
@@ -275,6 +195,31 @@ def get_codeartifact_repository_endpoint(
    return response["repositoryEndpoint"]


@app.get("/ssm/port-forward", response_class=PlainTextResponse)
def start_ssm_port_forward_endpoint(
    env_ctx: str = Query(default=None, alias="env_ctx"),
    region: str = Query(default=None, alias="region"),
    role_arn: str = Query(default=None, alias="role_arn"),
    instance_id: str = Query(alias="instance_id"),
    remote_port: int = Query(alias="remote_port"),
    remote_host: str = Query(default=None, alias="remote_host"),
    local_port: int = Query(default=None, alias="local_port"),
    protocol: Literal["http", "https"] = Query(default="http", alias="protocol"),
) -> PlainTextResponse:
    """Start an SSM port forwarding session."""
    result = ssm.start_ssm_port_forward(
        env_ctx=env_ctx,
        region=region,
        role_arn=role_arn,
        instance_id=instance_id,
        remote_port=remote_port,
        remote_host=remote_host,
        local_port=local_port,
        protocol=protocol,
    )
    return PlainTextResponse(content=result.url, status_code=result.status_code)


def get_eks_token(cluster_name, ttl_minutes=15) -> str:
    """
    Generate an EKS authentication token that matches the format of 'aws eks get-token'.
+182 −0
Original line number Diff line number Diff line
import logging
import socket
import threading
from typing import Optional

logger = logging.getLogger("aws-auth-provider.proxy")


class TcpProxy:
    """
    A simple TCP proxy that forwards traffic from one port to another.
    Allows binding to 0.0.0.0 while the session-manager-plugin binds to 127.0.0.1.
    """

    def __init__(
        self, listen_host: str, listen_port: int, target_host: str, target_port: int
    ):
        """
        Initialize the TCP proxy.

        Args:
            listen_host: Host to listen on (typically 0.0.0.0)
            listen_port: Port to listen on
            target_host: Target host to forward to (typically 127.0.0.1)
            target_port: Target port to forward to
        """
        self.listen_host = listen_host
        self.listen_port = listen_port
        self.target_host = target_host
        self.target_port = target_port
        self.server_socket: Optional[socket.socket] = None
        self.running = False
        self.server_thread: Optional[threading.Thread] = None

    def _forward_data(self, source: socket.socket, destination: socket.socket):
        """
        Forward data from source socket to destination socket.

        Args:
            source: Source socket to read from
            destination: Destination socket to write to
        """
        try:
            while self.running:
                data = source.recv(4096)
                if not data:
                    break
                destination.sendall(data)
        except Exception as e:
            logger.debug(f"Connection closed: {e}")
        finally:
            try:
                source.close()
            except Exception:
                pass
            try:
                destination.close()
            except Exception:
                pass

    def _handle_client(self, client_socket: socket.socket, client_address):
        """
        Handle a client connection by creating a connection to the target
        and forwarding data bidirectionally.

        Args:
            client_socket: Client socket
            client_address: Client address
        """
        target_socket = None
        try:
            logger.debug(f"New connection from {client_address}")

            # Connect to target
            target_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            target_socket.connect((self.target_host, self.target_port))
            logger.debug(f"Connected to target {self.target_host}:{self.target_port}")

            # Create two threads for bidirectional forwarding
            client_to_target = threading.Thread(
                target=self._forward_data,
                args=(client_socket, target_socket),
                daemon=True,
            )
            target_to_client = threading.Thread(
                target=self._forward_data,
                args=(target_socket, client_socket),
                daemon=True,
            )

            client_to_target.start()
            target_to_client.start()

            # Wait for both threads to complete
            client_to_target.join()
            target_to_client.join()

        except Exception as e:
            logger.error(f"Error handling client: {e}")
        finally:
            if target_socket:
                try:
                    target_socket.close()
                except Exception:
                    pass
            try:
                client_socket.close()
            except Exception:
                pass

    def start(self):
        """Start the proxy server."""
        if self.running:
            logger.warning("Proxy is already running")
            return

        try:
            self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            self.server_socket.bind((self.listen_host, self.listen_port))
            self.server_socket.listen(5)
            self.running = True

            logger.info(
                f"Proxy listening on {self.listen_host}:{self.listen_port}, forwarding to {self.target_host}:{self.target_port}"
            )

            def accept_connections():
                while self.running:
                    try:
                        self.server_socket.settimeout(1.0)
                        try:
                            client_socket, client_address = self.server_socket.accept()
                        except socket.timeout:
                            continue

                        # Handle each client in a separate thread
                        client_thread = threading.Thread(
                            target=self._handle_client,
                            args=(client_socket, client_address),
                            daemon=True,
                        )
                        client_thread.start()
                    except Exception as e:
                        if self.running:
                            logger.error(f"Error accepting connection: {e}")

            self.server_thread = threading.Thread(
                target=accept_connections, daemon=True
            )
            self.server_thread.start()

        except Exception as e:
            logger.error(f"Failed to start proxy: {e}")
            self.running = False
            if self.server_socket:
                try:
                    self.server_socket.close()
                except Exception:
                    pass
            raise

    def stop(self):
        """Stop the proxy server."""
        if not self.running:
            return

        logger.info(f"Stopping proxy on {self.listen_host}:{self.listen_port}")
        self.running = False

        if self.server_socket:
            try:
                self.server_socket.close()
            except Exception:
                pass

        if self.server_thread:
            self.server_thread.join(timeout=2)

    def is_running(self) -> bool:
        """Check if the proxy is currently running."""
        return self.running
+368 −0

File added.

Preview size limit exceeded, changes collapsed.

Loading