Add type annotations to `synapse.metrics` (#10847)

This commit is contained in:
Sean Quah 2021-11-17 19:07:02 +00:00 committed by GitHub
parent d993c3bb1e
commit 84fac0f814
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 174 additions and 86 deletions

1
changelog.d/10847.misc Normal file
View File

@ -0,0 +1 @@
Add type annotations to `synapse.metrics`.

View File

@ -160,6 +160,9 @@ disallow_untyped_defs = True
[mypy-synapse.handlers.*]
disallow_untyped_defs = True
[mypy-synapse.metrics.*]
disallow_untyped_defs = True
[mypy-synapse.push.*]
disallow_untyped_defs = True

View File

@ -402,7 +402,7 @@ async def start(hs: "HomeServer") -> None:
if hasattr(signal, "SIGHUP"):
@wrap_as_background_process("sighup")
def handle_sighup(*args: Any, **kwargs: Any) -> None:
async def handle_sighup(*args: Any, **kwargs: Any) -> None:
# Tell systemd our state, if we're using it. This will silently fail if
# we're not using systemd.
sdnotify(b"RELOADING=1")

View File

@ -40,6 +40,8 @@ from typing import TYPE_CHECKING, Optional, Tuple
from signedjson.sign import sign_json
from twisted.internet.defer import Deferred
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import JsonDict, get_domain_from_id
@ -166,7 +168,7 @@ class GroupAttestionRenewer:
return {}
def _start_renew_attestations(self) -> None:
def _start_renew_attestations(self) -> "Deferred[None]":
return run_as_background_process("renew_attestations", self._renew_attestations)
async def _renew_attestations(self) -> None:

View File

@ -90,7 +90,7 @@ class FollowerTypingHandler:
self.wheel_timer = WheelTimer(bucket_size=5000)
@wrap_as_background_process("typing._handle_timeouts")
def _handle_timeouts(self) -> None:
async def _handle_timeouts(self) -> None:
logger.debug("Checking for typing timeouts")
now = self.clock.time_msec()

View File

@ -20,10 +20,25 @@ import os
import platform
import threading
import time
from typing import Callable, Dict, Iterable, Mapping, Optional, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Generic,
Iterable,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
)
import attr
from prometheus_client import Counter, Gauge, Histogram
from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram, Metric
from prometheus_client.core import (
REGISTRY,
CounterMetricFamily,
@ -32,6 +47,7 @@ from prometheus_client.core import (
)
from twisted.internet import reactor
from twisted.internet.base import ReactorBase
from twisted.python.threadpool import ThreadPool
import synapse
@ -54,7 +70,7 @@ HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
class RegistryProxy:
@staticmethod
def collect():
def collect() -> Iterable[Metric]:
for metric in REGISTRY.collect():
if not metric.name.startswith("__"):
yield metric
@ -74,7 +90,7 @@ class LaterGauge:
]
)
def collect(self):
def collect(self) -> Iterable[Metric]:
g = GaugeMetricFamily(self.name, self.desc, labels=self.labels)
@ -93,10 +109,10 @@ class LaterGauge:
yield g
def __attrs_post_init__(self):
def __attrs_post_init__(self) -> None:
self._register()
def _register(self):
def _register(self) -> None:
if self.name in all_gauges.keys():
logger.warning("%s already registered, reregistering" % (self.name,))
REGISTRY.unregister(all_gauges.pop(self.name))
@ -105,7 +121,12 @@ class LaterGauge:
all_gauges[self.name] = self
class InFlightGauge:
# `MetricsEntry` only makes sense when it is a `Protocol`,
# but `Protocol` can't be used as a `TypeVar` bound.
MetricsEntry = TypeVar("MetricsEntry")
class InFlightGauge(Generic[MetricsEntry]):
"""Tracks number of things (e.g. requests, Measure blocks, etc) in flight
at any given time.
@ -115,14 +136,19 @@ class InFlightGauge:
callbacks.
Args:
name (str)
desc (str)
labels (list[str])
sub_metrics (list[str]): A list of sub metrics that the callbacks
will update.
name
desc
labels
sub_metrics: A list of sub metrics that the callbacks will update.
"""
def __init__(self, name, desc, labels, sub_metrics):
def __init__(
self,
name: str,
desc: str,
labels: Sequence[str],
sub_metrics: Sequence[str],
):
self.name = name
self.desc = desc
self.labels = labels
@ -130,19 +156,25 @@ class InFlightGauge:
# Create a class which have the sub_metrics values as attributes, which
# default to 0 on initialization. Used to pass to registered callbacks.
self._metrics_class = attr.make_class(
self._metrics_class: Type[MetricsEntry] = attr.make_class(
"_MetricsEntry", attrs={x: attr.ib(0) for x in sub_metrics}, slots=True
)
# Counts number of in flight blocks for a given set of label values
self._registrations: Dict = {}
self._registrations: Dict[
Tuple[str, ...], Set[Callable[[MetricsEntry], None]]
] = {}
# Protects access to _registrations
self._lock = threading.Lock()
self._register_with_collector()
def register(self, key, callback):
def register(
self,
key: Tuple[str, ...],
callback: Callable[[MetricsEntry], None],
) -> None:
"""Registers that we've entered a new block with labels `key`.
`callback` gets called each time the metrics are collected. The same
@ -158,13 +190,17 @@ class InFlightGauge:
with self._lock:
self._registrations.setdefault(key, set()).add(callback)
def unregister(self, key, callback):
def unregister(
self,
key: Tuple[str, ...],
callback: Callable[[MetricsEntry], None],
) -> None:
"""Registers that we've exited a block with labels `key`."""
with self._lock:
self._registrations.setdefault(key, set()).discard(callback)
def collect(self):
def collect(self) -> Iterable[Metric]:
"""Called by prometheus client when it reads metrics.
Note: may be called by a separate thread.
@ -200,7 +236,7 @@ class InFlightGauge:
gauge.add_metric(key, getattr(metrics, name))
yield gauge
def _register_with_collector(self):
def _register_with_collector(self) -> None:
if self.name in all_gauges.keys():
logger.warning("%s already registered, reregistering" % (self.name,))
REGISTRY.unregister(all_gauges.pop(self.name))
@ -230,7 +266,7 @@ class GaugeBucketCollector:
name: str,
documentation: str,
buckets: Iterable[float],
registry=REGISTRY,
registry: CollectorRegistry = REGISTRY,
):
"""
Args:
@ -257,12 +293,12 @@ class GaugeBucketCollector:
registry.register(self)
def collect(self):
def collect(self) -> Iterable[Metric]:
# Don't report metrics unless we've already collected some data
if self._metric is not None:
yield self._metric
def update_data(self, values: Iterable[float]):
def update_data(self, values: Iterable[float]) -> None:
"""Update the data to be reported by the metric
The existing data is cleared, and each measurement in the input is assigned
@ -304,7 +340,7 @@ class GaugeBucketCollector:
class CPUMetrics:
def __init__(self):
def __init__(self) -> None:
ticks_per_sec = 100
try:
# Try and get the system config
@ -314,7 +350,7 @@ class CPUMetrics:
self.ticks_per_sec = ticks_per_sec
def collect(self):
def collect(self) -> Iterable[Metric]:
if not HAVE_PROC_SELF_STAT:
return
@ -364,7 +400,7 @@ gc_time = Histogram(
class GCCounts:
def collect(self):
def collect(self) -> Iterable[Metric]:
cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"])
for n, m in enumerate(gc.get_count()):
cm.add_metric([str(n)], m)
@ -382,7 +418,7 @@ if not running_on_pypy:
class PyPyGCStats:
def collect(self):
def collect(self) -> Iterable[Metric]:
# @stats is a pretty-printer object with __str__() returning a nice table,
# plus some fields that contain data from that table.
@ -565,7 +601,7 @@ def register_threadpool(name: str, threadpool: ThreadPool) -> None:
class ReactorLastSeenMetric:
def collect(self):
def collect(self) -> Iterable[Metric]:
cm = GaugeMetricFamily(
"python_twisted_reactor_last_seen",
"Seconds since the Twisted reactor was last seen",
@ -584,9 +620,12 @@ MIN_TIME_BETWEEN_GCS = (1.0, 10.0, 30.0)
_last_gc = [0.0, 0.0, 0.0]
def runUntilCurrentTimer(reactor, func):
F = TypeVar("F", bound=Callable[..., Any])
def runUntilCurrentTimer(reactor: ReactorBase, func: F) -> F:
@functools.wraps(func)
def f(*args, **kwargs):
def f(*args: Any, **kwargs: Any) -> Any:
now = reactor.seconds()
num_pending = 0
@ -649,7 +688,7 @@ def runUntilCurrentTimer(reactor, func):
return ret
return f
return cast(F, f)
try:
@ -677,5 +716,5 @@ __all__ = [
"start_http_server",
"LaterGauge",
"InFlightGauge",
"BucketCollector",
"GaugeBucketCollector",
]

View File

@ -25,27 +25,25 @@ import math
import threading
from http.server import BaseHTTPRequestHandler, HTTPServer
from socketserver import ThreadingMixIn
from typing import Dict, List
from typing import Any, Dict, List, Type, Union
from urllib.parse import parse_qs, urlparse
from prometheus_client import REGISTRY
from prometheus_client import REGISTRY, CollectorRegistry
from prometheus_client.core import Sample
from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse.util import caches
CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8"
INF = float("inf")
MINUS_INF = float("-inf")
def floatToGoString(d):
def floatToGoString(d: Union[int, float]) -> str:
d = float(d)
if d == INF:
if d == math.inf:
return "+Inf"
elif d == MINUS_INF:
elif d == -math.inf:
return "-Inf"
elif math.isnan(d):
return "NaN"
@ -60,7 +58,7 @@ def floatToGoString(d):
return s
def sample_line(line, name):
def sample_line(line: Sample, name: str) -> str:
if line.labels:
labelstr = "{{{0}}}".format(
",".join(
@ -82,7 +80,7 @@ def sample_line(line, name):
return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp)
def generate_latest(registry, emit_help=False):
def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> bytes:
# Trigger the cache metrics to be rescraped, which updates the common
# metrics but do not produce metrics themselves
@ -187,7 +185,7 @@ class MetricsHandler(BaseHTTPRequestHandler):
registry = REGISTRY
def do_GET(self):
def do_GET(self) -> None:
registry = self.registry
params = parse_qs(urlparse(self.path).query)
@ -207,11 +205,11 @@ class MetricsHandler(BaseHTTPRequestHandler):
self.end_headers()
self.wfile.write(output)
def log_message(self, format, *args):
def log_message(self, format: str, *args: Any) -> None:
"""Log nothing."""
@classmethod
def factory(cls, registry):
def factory(cls, registry: CollectorRegistry) -> Type:
"""Returns a dynamic MetricsHandler class tied
to the passed registry.
"""
@ -236,7 +234,9 @@ class _ThreadingSimpleServer(ThreadingMixIn, HTTPServer):
daemon_threads = True
def start_http_server(port, addr="", registry=REGISTRY):
def start_http_server(
port: int, addr: str = "", registry: CollectorRegistry = REGISTRY
) -> None:
"""Starts an HTTP server for prometheus metrics as a daemon thread"""
CustomMetricsHandler = MetricsHandler.factory(registry)
httpd = _ThreadingSimpleServer((addr, port), CustomMetricsHandler)
@ -252,10 +252,10 @@ class MetricsResource(Resource):
isLeaf = True
def __init__(self, registry=REGISTRY):
def __init__(self, registry: CollectorRegistry = REGISTRY):
self.registry = registry
def render_GET(self, request):
def render_GET(self, request: Request) -> bytes:
request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii"))
response = generate_latest(self.registry)
request.setHeader(b"Content-Length", str(len(response)))

View File

@ -15,19 +15,37 @@
import logging
import threading
from functools import wraps
from typing import TYPE_CHECKING, Dict, Optional, Set, Union
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Iterable,
Optional,
Set,
Type,
TypeVar,
Union,
cast,
)
from prometheus_client import Metric
from prometheus_client.core import REGISTRY, Counter, Gauge
from twisted.internet import defer
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.logging.context import (
ContextResourceUsage,
LoggingContext,
PreserveLoggingContext,
)
from synapse.logging.opentracing import (
SynapseTags,
noop_context_manager,
start_active_span,
)
from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
import resource
@ -116,7 +134,7 @@ class _Collector:
before they are returned.
"""
def collect(self):
def collect(self) -> Iterable[Metric]:
global _background_processes_active_since_last_scrape
# We swap out the _background_processes set with an empty one so that
@ -144,12 +162,12 @@ REGISTRY.register(_Collector())
class _BackgroundProcess:
def __init__(self, desc, ctx):
def __init__(self, desc: str, ctx: LoggingContext):
self.desc = desc
self._context = ctx
self._reported_stats = None
self._reported_stats: Optional[ContextResourceUsage] = None
def update_metrics(self):
def update_metrics(self) -> None:
"""Updates the metrics with values from this process."""
new_stats = self._context.get_resource_usage()
if self._reported_stats is None:
@ -169,7 +187,16 @@ class _BackgroundProcess:
)
def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwargs):
R = TypeVar("R")
def run_as_background_process(
desc: str,
func: Callable[..., Awaitable[Optional[R]]],
*args: Any,
bg_start_span: bool = True,
**kwargs: Any,
) -> "defer.Deferred[Optional[R]]":
"""Run the given function in its own logcontext, with resource metrics
This should be used to wrap processes which are fired off to run in the
@ -189,11 +216,13 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
args: positional args for func
kwargs: keyword args for func
Returns: Deferred which returns the result of func, but note that it does not
follow the synapse logcontext rules.
Returns:
Deferred which returns the result of func, or `None` if func raises.
Note that the returned Deferred does not follow the synapse logcontext
rules.
"""
async def run():
async def run() -> Optional[R]:
with _bg_metrics_lock:
count = _background_process_counts.get(desc, 0)
_background_process_counts[desc] = count + 1
@ -210,12 +239,13 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
else:
ctx = noop_context_manager()
with ctx:
return await maybe_awaitable(func(*args, **kwargs))
return await func(*args, **kwargs)
except Exception:
logger.exception(
"Background process '%s' threw an exception",
desc,
)
return None
finally:
_background_process_in_flight_count.labels(desc).dec()
@ -225,19 +255,24 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
return defer.ensureDeferred(run())
def wrap_as_background_process(desc):
F = TypeVar("F", bound=Callable[..., Awaitable[Optional[Any]]])
def wrap_as_background_process(desc: str) -> Callable[[F], F]:
"""Decorator that wraps a function that gets called as a background
process.
Equivalent of calling the function with `run_as_background_process`
Equivalent to calling the function with `run_as_background_process`
"""
def wrap_as_background_process_inner(func):
def wrap_as_background_process_inner(func: F) -> F:
@wraps(func)
def wrap_as_background_process_inner_2(*args, **kwargs):
def wrap_as_background_process_inner_2(
*args: Any, **kwargs: Any
) -> "defer.Deferred[Optional[R]]":
return run_as_background_process(desc, func, *args, **kwargs)
return wrap_as_background_process_inner_2
return cast(F, wrap_as_background_process_inner_2)
return wrap_as_background_process_inner
@ -265,7 +300,7 @@ class BackgroundProcessLoggingContext(LoggingContext):
super().__init__("%s-%s" % (name, instance_id))
self._proc = _BackgroundProcess(name, self)
def start(self, rusage: "Optional[resource.struct_rusage]"):
def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
"""Log context has started running (again)."""
super().start(rusage)
@ -276,7 +311,12 @@ class BackgroundProcessLoggingContext(LoggingContext):
with _bg_metrics_lock:
_background_processes_active_since_last_scrape.add(self._proc)
def __exit__(self, type, value, traceback) -> None:
def __exit__(
self,
type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""Log context has finished."""
super().__exit__(type, value, traceback)

View File

@ -16,14 +16,16 @@ import ctypes
import logging
import os
import re
from typing import Optional
from typing import Iterable, Optional
from prometheus_client import Metric
from synapse.metrics import REGISTRY, GaugeMetricFamily
logger = logging.getLogger(__name__)
def _setup_jemalloc_stats():
def _setup_jemalloc_stats() -> None:
"""Checks to see if jemalloc is loaded, and hooks up a collector to record
statistics exposed by jemalloc.
"""
@ -135,7 +137,7 @@ def _setup_jemalloc_stats():
class JemallocCollector:
"""Metrics for internal jemalloc stats."""
def collect(self):
def collect(self) -> Iterable[Metric]:
_jemalloc_refresh_stats()
g = GaugeMetricFamily(
@ -185,7 +187,7 @@ def _setup_jemalloc_stats():
logger.debug("Added jemalloc stats")
def setup_jemalloc_stats():
def setup_jemalloc_stats() -> None:
"""Try to setup jemalloc stats, if jemalloc is loaded."""
try:

View File

@ -188,7 +188,7 @@ class LoggingDatabaseConnection:
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
_CallbackListEntry = Tuple[Callable[..., None], Iterable[Any], Dict[str, Any]]
_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]]
R = TypeVar("R")
@ -235,7 +235,7 @@ class LoggingTransaction:
self.after_callbacks = after_callbacks
self.exception_callbacks = exception_callbacks
def call_after(self, callback: Callable[..., None], *args: Any, **kwargs: Any):
def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any):
"""Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the
correct thread.
@ -247,7 +247,7 @@ class LoggingTransaction:
self.after_callbacks.append((callback, args, kwargs))
def call_on_exception(
self, callback: Callable[..., None], *args: Any, **kwargs: Any
self, callback: Callable[..., object], *args: Any, **kwargs: Any
):
# if self.exception_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that

View File

@ -159,7 +159,7 @@ class ExpiringCache(Generic[KT, VT]):
self[key] = value
return value
def _prune_cache(self) -> None:
async def _prune_cache(self) -> None:
if not self._expiry_ms:
# zero expiry time means don't expire. This should never get called
# since we have this check in start too.

View File

@ -56,14 +56,6 @@ block_db_sched_duration = Counter(
"synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"]
)
# Tracks the number of blocks currently active
in_flight = InFlightGauge(
"synapse_util_metrics_block_in_flight",
"",
labels=["block_name"],
sub_metrics=["real_time_max", "real_time_sum"],
)
# This is dynamically created in InFlightGauge.__init__.
class _InFlightMetric(Protocol):
@ -71,6 +63,15 @@ class _InFlightMetric(Protocol):
real_time_sum: float
# Tracks the number of blocks currently active
in_flight: InFlightGauge[_InFlightMetric] = InFlightGauge(
"synapse_util_metrics_block_in_flight",
"",
labels=["block_name"],
sub_metrics=["real_time_max", "real_time_sum"],
)
T = TypeVar("T", bound=Callable[..., Any])