Skip to content

SOURCE CODE tgintegration.collector DOCS

"""
Standalone `collector` utilities.
"""
import asyncio
import logging
from contextlib import asynccontextmanager
from datetime import datetime
from datetime import timedelta
from typing import AsyncContextManager
from typing import TYPE_CHECKING

from pyrogram.errors import InternalServerError
from pyrogram.filters import Filter
from pyrogram.handlers import MessageHandler

from tgintegration.expectation import Expectation
from tgintegration.handler_utils import add_handler_transient
from tgintegration.timeout_settings import TimeoutSettings

if TYPE_CHECKING:
    from tgintegration.botcontroller import BotController

from tgintegration.containers.responses import InvalidResponseError, Response
from tgintegration.update_recorder import MessageRecorder

logger = logging.getLogger(__name__)


@asynccontextmanager
async def collect(
    controller: "BotController",
    filters: Filter = None,
    expectation: Expectation = None,
    timeouts: TimeoutSettings = None,
) -> AsyncContextManager[Response]:
    expectation = expectation or Expectation()
    timeouts = timeouts or TimeoutSettings()

    recorder = MessageRecorder()
    handler = MessageHandler(recorder.record_message, filters=filters)

    assert controller.client.is_connected

    async with add_handler_transient(controller.client, handler):
        response = Response(controller, recorder)

        logger.debug("Collector set up. Executing user-defined interaction...")
        yield response  # Start user-defined interaction
        logger.debug("interaction complete.")

        num_received = 0
        # last_received_timestamp = (
        #     None  # TODO: work with the message's timestamp instead of utcnow()
        # )
        timeout_end = datetime.utcnow() + timedelta(seconds=timeouts.max_wait)

        try:
            seconds_remaining = (timeout_end - datetime.utcnow()).total_seconds()

            while True:
                if seconds_remaining > 0:
                    # Wait until we receive any message or time out
                    logger.debug(f"Waiting for message #{num_received + 1}")
                    await asyncio.wait_for(
                        recorder.wait_until(
                            lambda msgs: expectation.is_sufficient(msgs)
                            or len(msgs) > num_received
                        ),
                        timeout=seconds_remaining,
                    )

                num_received = len(recorder.messages)  # TODO: this is ugly

                if timeouts.wait_consecutive:
                    # Always wait for at least `wait_consecutive` seconds for another message
                    try:
                        logger.debug(
                            f"Checking for consecutive message to #{num_received}..."
                        )
                        await asyncio.wait_for(
                            recorder.wait_until(lambda msgs: len(msgs) > num_received),
                            # The consecutive end may go over the max wait timeout,
                            # which is a design decision.
                            timeout=timeouts.wait_consecutive,
                        )
                        logger.debug("received 1.")
                    except TimeoutError:
                        logger.debug("none received.")

                num_received = len(recorder.messages)  # TODO: this is ugly

                is_sufficient = expectation.is_sufficient(recorder.messages)
                if is_sufficient:
                    expectation.verify(recorder.messages, timeouts)
                    return

                seconds_remaining = (timeout_end - datetime.utcnow()).total_seconds()

                assert seconds_remaining is not None

                if seconds_remaining <= 0:
                    expectation.verify(recorder.messages, timeouts)
                    return

        except InternalServerError as e:
            logger.warning(e)
            await asyncio.sleep(60)  # Internal Telegram error
        except asyncio.exceptions.TimeoutError as te:
            if timeouts.raise_on_timeout:
                raise InvalidResponseError() from te
            else:
                # TODO: better warning message
                logger.warning("Peer did not reply.")
        finally:
            recorder.stop()