import json
import re
import threading
import time
import uuid
from collections import defaultdict
from typing import Set
from urllib.parse import urlencode
from requests import HTTPError
from notion.block.collection.basic import CollectionBlock
from notion.logger import logger
from notion.record import Record
from notion.settings import MESSAGE_STORE_URL
[docs]class Monitor:
"""
Monitor class for automatic data polling of records.
"""
thread = None
[docs] def __init__(self, client, root_url: str = MESSAGE_STORE_URL):
"""
Create Monitor object.
Arguments
---------
client : NotionClient
Client to use.
root_url : str, optional
Root URL for polling message stats.
Defaults to valid notion message store URL.
"""
self.sid = None
self.client = client
self.root_url = root_url
self.session_id = str(uuid.uuid4())
self._subscriptions = set()
self.initialize()
@staticmethod
def _encode_numbered_json_thing(data: list) -> bytes:
results = ""
for obj in data:
msg = str(len(obj)) + json.dumps(obj, separators=(",", ":"))
msg = f"{len(msg)}:{msg}"
results += msg
return results.encode()
def _decode_numbered_json_thing(self, thing: bytes) -> list:
thing = thing.decode().strip()
for ping in re.findall(r'\d+:\d+"primus::ping::\d+"', thing):
logger.debug(f"Received ping: {ping}")
self.post_data(ping.replace("::ping::", "::pong::"))
results = []
for blob in re.findall(r"\d+:\d+({.+})(?=\d|$)", thing):
results.append(json.loads(blob))
if thing and not results and "::ping::" not in thing:
logger.debug(f"Could not parse monitoring response: {thing}")
return results
def _refresh_updated_records(self, events: list):
records_to_refresh = defaultdict(list)
versions_pattern = re.compile(r"versions/([^:]+):(.+)")
collection_pattern = re.compile(r"collection/(.+)")
events = filter(lambda e: isinstance(e, dict), events)
events = filter(lambda e: e.get("type", "") == "notification", events)
for event in events:
logger.debug(f"Received the following event from notion: {event}")
key = event.get("key")
# TODO: rewrite below if cases to something simpler
if key.startswith("versions/"):
match = versions_pattern.match(key)
if not match:
continue
record_id, record_table = match.groups()
name = f"{record_table}/{record_id}"
new = event["value"]
old = self.client._store.get_current_version(
table=record_table, record_id=record_id
)
if new > old:
logger.debug(
(
f"Record {name} has changed; refreshing to update"
f"from version {old} to version {new}"
)
)
records_to_refresh[record_table].append(record_id)
else:
logger.debug(
(
f"Record {name} already at version {old}"
f"not trying to update to version {new}"
)
)
if key.startswith("collection/"):
match = collection_pattern.match(key)
if not match:
continue
collection_id = match.groups()[0]
self.client.refresh_collection_rows(collection_id)
row_ids = self.client._store.get_collection_rows(collection_id)
logger.debug(
(
f"Something inside collection '{collection_id}' "
f"has changed. Refreshing all {row_ids} rows inside it"
)
)
records_to_refresh["block"] += row_ids
self.client.refresh_records(**records_to_refresh)
[docs] def url(self, **kwargs) -> str:
kwargs["b64"] = 1
kwargs["transport"] = kwargs.get("transport", "polling")
kwargs["sessionId"] = kwargs.get("sessionId", self.session_id)
return f"{self.root_url}?{urlencode(kwargs)}"
[docs] def initialize(self):
"""
Initialize the monitoring session.
"""
logger.debug("Initializing new monitoring session.")
content = self.client.session.get(self.url(EIO=3)).content
self.sid = self._decode_numbered_json_thing(content)[0]["sid"]
logger.debug(f"New monitoring session ID is: {self.sid}")
# resubscribe to any existing subscriptions if we're reconnecting
old_subscriptions = self._subscriptions
self._subscriptions = set()
self.subscribe(old_subscriptions)
[docs] def subscribe(self, records: Set[Record]):
"""
Subscribe to changes of passed records.
Arguments
---------
records : set of Record
Set of `Record` objects to subscribe to.
"""
if isinstance(records, list):
records = set(records)
# TODO: how to describe that you can also pass
# record explicitly or should we block it?
if not isinstance(records, set):
records = {records}
sub_data = []
for record in records.difference(self._subscriptions):
key = f"{record.id}:{record._table}"
logger.debug(f"Subscribing new record: {key}")
# save it in case we're disconnected
self._subscriptions.add(record)
# TODO: hide that dict generation in Record class
sub_data.append(
{
"type": "/api/v1/registerSubscription",
"requestId": str(uuid.uuid4()),
"key": f"versions/{key}",
"version": record.get("version", -1),
}
)
# if it's a collection, subscribe to changes to its children too
if isinstance(record, CollectionBlock):
sub_data.append(
{
"type": "/api/v1/registerSubscription",
"requestId": str(uuid.uuid4()),
"key": f"collection/{record.id}",
"version": -1,
}
)
self.post_data(self._encode_numbered_json_thing(sub_data))
[docs] def post_data(self, data: bytes):
"""
Send monitoring requests to Notion.
Arguments
---------
data : bytes
Form encoded request data.
"""
if not data:
return
logger.debug(f"Posting monitoring data: {data}")
self.client.session.post(self.url(sid=self.sid), data=data)
[docs] def poll(self, retries: int = 10):
"""
Poll for changes.
Arguments
---------
retries : int, optional
Number of times to retry request if it fails.
Should be bigger than 5.
Defaults to 10.
Raises
------
HTTPError
When GET request fails for `retries` times.
"""
logger.debug("Starting new long-poll request")
url = self.url(EIO=3, sid=self.sid)
response = None
while retries:
try:
retries -= 1
response = self.client.session.get(url)
response.raise_for_status()
except HTTPError as e:
try:
message = f"{response.content} / {e}"
except AttributeError:
message = str(e)
logger.warn(
"Problem with submitting poll request: "
f"{message} (will retry {retries} more times)"
)
time.sleep(0.1)
if retries <= 0:
raise
if retries <= 5:
logger.error(
"Persistent error submitting poll request: "
f"{message} (will retry {retries} more times)"
)
if retries == 3:
# if we're close to giving up, try to restart the session
self.initialize()
self._refresh_updated_records(
self._decode_numbered_json_thing(response.content)
)
[docs] def poll_async(self):
if self.thread:
# Already polling async; no need to have two threads
return
logger.debug("Starting new thread for async polling")
self.thread = threading.Thread(target=self.poll_forever, daemon=True)
self.thread.start()
[docs] def poll_forever(self):
"""
Call `poll()` in never-ending loop with small time intervals in-between.
This function is blocking, it never returns!
"""
while True:
try:
self.poll()
except Exception as e:
logger.error("Encountered error during polling!")
logger.error(e, exc_info=True)
time.sleep(1)