Commit 2cd6995b authored by TheTechRobo's avatar TheTechRobo
Browse files

Currently broken: Get stream coercing working

parent 8a9701f2
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -5,3 +5,4 @@ nohup.out
static/ign
config.py
config.yml
tmptest.py
 No newline at end of file
+79 −7
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ import copy
import dataclasses
import time
import typing_extensions as typing
import inspect
import re

import asyncio
@@ -19,8 +20,6 @@ with open('config.yml', 'r') as file:
    config_yml = yaml.safe_load(file)
    methods = config_yml["methods"]

# (this name is fine)

@dataclasses.dataclass
class Service(JSONDataclass):
    """
@@ -114,6 +113,72 @@ class YouTubeService(Service): # pylint: disable=abstract-method
class InvalidVideoIdError(ValueError):
    pass

class TargetAPIVersionTooLowError(ValueError):
    """
    Raised when `coerce_to_api_version` is called with an unsupported API version.
    """
    pass

class TargetAPIVersionTooHighError(ValueError):
    """
    Raised when `coerce_to_api_version` is called with an unsupported API version.
    """
    pass

# The current API version.
API_VERSION = 4

# Thin wrapper around a generator that allows us to define a `coerce_to_api_version` method,
# just like with the non-streaming response type.
class YouTubeStreamResponse:
    """
    A streamed response, as an iterable. It is not recommended to use the iterable directly;
    instead, write code around a specific API version and use `coerce_to_api_version` to ensure
    that you are using it.
    
    Note that when streaming, the minimum API version is 4.
    """
    api_version: int = API_VERSION

    # Initialises the iterator.
    # `gen` should be the generator
    def __init__(self, gen):
        self.gen = gen

    # Iterator Code
    def __aiter__(self):
        return self

    async def __anext__(self):
        return anext(self.gen)

    async def coerce_to_api_version(self, targetVersion):
        """
        Wraps the iterator, converting all messages to the target API version.
        The minimum version for streamed responses is 4.

        Note: If this function raises an exception other than TargetAPIVersionTooHighError or
        TargetAPIVersionTooLowError, the generator may be unusable and you should restart the process.
        """
        if targetVersion > self.api_version:
            raise TargetAPIVersionTooHighError(targetVersion)
        # The function calls a direct current to target rather than current to current-1
        # because otherwise we can't be 100% sure that we can downgrade to the correct API version
        # and we can't "un-get" the item from the generator.
        arrOfNamesFunction = getattr(self, f"_convert_narr_v{self.api_version}_to_v{targetVersion}", None)
        serviceObjectFunction = getattr(self, f"_convert_service_v{self.api_version}_to_v{targetVersion}", None)
        verdictObjectFunction = getattr(self, f"_convert_verdict_v{self.api_version}_to_v{targetVersion}", None)
        if not arrOfNamesFunction or not serviceObjectFunction or not verdictObjectFunction:
            raise TargetAPIVersionTooLowError(targetVersion)
        arrayOfNames = await anext(self.gen)
        yield arrOfNamesFunction(arrayOfNames)
        async for item in self.gen:
            yield serviceObjectFunction(item)
            if item is None:
                break
        yield verdictObjectFunction(await anext(self.gen))


@dataclasses.dataclass
class YouTubeResponse(JSONDataclass):
    """
@@ -130,7 +195,7 @@ class YouTubeResponse(JSONDataclass):
    status: str
    keys: list[YouTubeService]
    verdict: dict
    api_version: int = 4
    api_version: int = API_VERSION

    def coerce_to_api_version(selfNEW, targetVersion): # pylint: disable=no-self-argument
        """
@@ -139,15 +204,17 @@ class YouTubeResponse(JSONDataclass):

        Arguments:
            targetVersion (int): The target API version. Must be lower than self.api_version

        Raises either TargetAPIVersionTooHighError or TargetAPIVersionTooLowError if the target is unsupported.
        """
        self = copy.deepcopy(selfNEW)
        currentApiVersion = self.api_version
        if currentApiVersion < targetVersion:
            raise ValueError("cannot upgrade api version")
            raise TargetAPIVersionTooHighError("cannot upgrade api version")
        while self.api_version != targetVersion:
            fname = f"_convert_v{self.api_version}_to_v{self.api_version-1}"
            if not hasattr(self, fname):
                raise ValueError("cannot downgrade any further")
                raise TargetAPIVersionTooLowError("cannot downgrade any further")
            self = getattr(self, fname)()
        assert self.api_version == targetVersion
        return self
@@ -203,7 +270,7 @@ class YouTubeResponse(JSONDataclass):
        return verdict

    @classmethod
    async def generateStream(cls, id: str, includeRaw=False):
    async def _generateStream(cls, id: str, includeRaw=False):
        """
        Runs all the Services but as a generator.
        First item is a list of all the service names.
@@ -236,9 +303,14 @@ class YouTubeResponse(JSONDataclass):
        any_archived['human_friendly'] = verdict
        yield any_archived

    @classmethod
    async def generateStream(cls, id: str, includeRaw=False):
        gen = cls._generateStream(id, includeRaw=includeRaw)
        return YouTubeStreamResponse(gen)

    @classmethod
    async def generate(cls, id: str, includeRaw=False):
        generator = cls.generateStream(id, includeRaw)
        generator = await cls.generateStream(id, includeRaw)
        # ignore the list of names as that is redundant in this case
        await anext(generator)
        results = []