2020-07-16 03:55:51 -06:00
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
import abc
|
|
|
|
import threading
|
2020-08-25 10:32:30 -06:00
|
|
|
from typing import Callable, List, Optional
|
2020-07-16 03:55:51 -06:00
|
|
|
|
|
|
|
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
|
|
|
from synapse.storage.types import Cursor
|
|
|
|
|
|
|
|
|
|
|
|
class SequenceGenerator(metaclass=abc.ABCMeta):
|
|
|
|
"""A class which generates a unique sequence of integers"""
|
|
|
|
|
|
|
|
@abc.abstractmethod
|
|
|
|
def get_next_id_txn(self, txn: Cursor) -> int:
|
|
|
|
"""Gets the next ID in the sequence"""
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
|
|
class PostgresSequenceGenerator(SequenceGenerator):
|
|
|
|
"""An implementation of SequenceGenerator which uses a postgres sequence"""
|
|
|
|
|
|
|
|
def __init__(self, sequence_name: str):
|
|
|
|
self._sequence_name = sequence_name
|
|
|
|
|
|
|
|
def get_next_id_txn(self, txn: Cursor) -> int:
|
|
|
|
txn.execute("SELECT nextval(?)", (self._sequence_name,))
|
|
|
|
return txn.fetchone()[0]
|
|
|
|
|
2020-08-25 10:32:30 -06:00
|
|
|
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
|
|
|
|
txn.execute(
|
|
|
|
"SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n)
|
|
|
|
)
|
|
|
|
return [i for (i,) in txn]
|
|
|
|
|
2020-07-16 03:55:51 -06:00
|
|
|
|
|
|
|
GetFirstCallbackType = Callable[[Cursor], int]
|
|
|
|
|
|
|
|
|
|
|
|
class LocalSequenceGenerator(SequenceGenerator):
|
|
|
|
"""An implementation of SequenceGenerator which uses local locking
|
|
|
|
|
|
|
|
This only works reliably if there are no other worker processes generating IDs at
|
|
|
|
the same time.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, get_first_callback: GetFirstCallbackType):
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
get_first_callback: a callback which is called on the first call to
|
|
|
|
get_next_id_txn; should return the curreent maximum id
|
|
|
|
"""
|
|
|
|
# the callback. this is cleared after it is called, so that it can be GCed.
|
|
|
|
self._callback = get_first_callback # type: Optional[GetFirstCallbackType]
|
|
|
|
|
|
|
|
# The current max value, or None if we haven't looked in the DB yet.
|
|
|
|
self._current_max_id = None # type: Optional[int]
|
|
|
|
self._lock = threading.Lock()
|
|
|
|
|
|
|
|
def get_next_id_txn(self, txn: Cursor) -> int:
|
|
|
|
# We do application locking here since if we're using sqlite then
|
|
|
|
# we are a single process synapse.
|
|
|
|
with self._lock:
|
|
|
|
if self._current_max_id is None:
|
|
|
|
assert self._callback is not None
|
|
|
|
self._current_max_id = self._callback(txn)
|
|
|
|
self._callback = None
|
|
|
|
|
|
|
|
self._current_max_id += 1
|
|
|
|
return self._current_max_id
|
|
|
|
|
|
|
|
|
|
|
|
def build_sequence_generator(
|
|
|
|
database_engine: BaseDatabaseEngine,
|
|
|
|
get_first_callback: GetFirstCallbackType,
|
|
|
|
sequence_name: str,
|
|
|
|
) -> SequenceGenerator:
|
|
|
|
"""Get the best impl of SequenceGenerator available
|
|
|
|
|
|
|
|
This uses PostgresSequenceGenerator on postgres, and a locally-locked impl on
|
|
|
|
sqlite.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
database_engine: the database engine we are connected to
|
|
|
|
get_first_callback: a callback which gets the next sequence ID. Used if
|
|
|
|
we're on sqlite.
|
|
|
|
sequence_name: the name of a postgres sequence to use.
|
|
|
|
"""
|
|
|
|
if isinstance(database_engine, PostgresEngine):
|
|
|
|
return PostgresSequenceGenerator(sequence_name)
|
|
|
|
else:
|
|
|
|
return LocalSequenceGenerator(get_first_callback)
|