Use `getClientAddress` instead of `getClientIP`. (#12599)
getClientIP was deprecated in Twisted 18.4.0, which also added getClientAddress. The Synapse minimum version for Twisted is currently 18.9.0, so all supported versions have the new API.
This commit is contained in:
parent
116a4c8340
commit
7fbf42499d
|
@ -0,0 +1 @@
|
||||||
|
Use `getClientAddress` instead of the deprecated `getClientIP`.
|
|
@ -187,7 +187,7 @@ class Auth:
|
||||||
Once get_user_by_req has set up the opentracing span, this does the actual work.
|
Once get_user_by_req has set up the opentracing span, this does the actual work.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
ip_addr = request.getClientIP()
|
ip_addr = request.getClientAddress().host
|
||||||
user_agent = get_request_user_agent(request)
|
user_agent = get_request_user_agent(request)
|
||||||
|
|
||||||
access_token = self.get_access_token_from_request(request)
|
access_token = self.get_access_token_from_request(request)
|
||||||
|
@ -356,7 +356,7 @@ class Auth:
|
||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
if app_service.ip_range_whitelist:
|
if app_service.ip_range_whitelist:
|
||||||
ip_address = IPAddress(request.getClientIP())
|
ip_address = IPAddress(request.getClientAddress().host)
|
||||||
if ip_address not in app_service.ip_range_whitelist:
|
if ip_address not in app_service.ip_range_whitelist:
|
||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
|
|
|
@ -551,7 +551,7 @@ class AuthHandler:
|
||||||
await self.store.set_ui_auth_clientdict(sid, clientdict)
|
await self.store.set_ui_auth_clientdict(sid, clientdict)
|
||||||
|
|
||||||
user_agent = get_request_user_agent(request)
|
user_agent = get_request_user_agent(request)
|
||||||
clientip = request.getClientIP()
|
clientip = request.getClientAddress().host
|
||||||
|
|
||||||
await self.store.add_user_agent_ip_to_ui_auth_session(
|
await self.store.add_user_agent_ip_to_ui_auth_session(
|
||||||
session.session_id, user_agent, clientip
|
session.session_id, user_agent, clientip
|
||||||
|
|
|
@ -92,7 +92,7 @@ class IdentityHandler:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
await self._3pid_validation_ratelimiter_ip.ratelimit(
|
await self._3pid_validation_ratelimiter_ip.ratelimit(
|
||||||
None, (medium, request.getClientIP())
|
None, (medium, request.getClientAddress().host)
|
||||||
)
|
)
|
||||||
await self._3pid_validation_ratelimiter_address.ratelimit(
|
await self._3pid_validation_ratelimiter_address.ratelimit(
|
||||||
None, (medium, address)
|
None, (medium, address)
|
||||||
|
|
|
@ -468,7 +468,7 @@ class SsoHandler:
|
||||||
auth_provider_id,
|
auth_provider_id,
|
||||||
remote_user_id,
|
remote_user_id,
|
||||||
get_request_user_agent(request),
|
get_request_user_agent(request),
|
||||||
request.getClientIP(),
|
request.getClientAddress().host,
|
||||||
)
|
)
|
||||||
new_user = True
|
new_user = True
|
||||||
elif self._sso_update_profile_information:
|
elif self._sso_update_profile_information:
|
||||||
|
@ -928,7 +928,7 @@ class SsoHandler:
|
||||||
session.auth_provider_id,
|
session.auth_provider_id,
|
||||||
session.remote_user_id,
|
session.remote_user_id,
|
||||||
get_request_user_agent(request),
|
get_request_user_agent(request),
|
||||||
request.getClientIP(),
|
request.getClientAddress().host,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
@ -238,7 +238,7 @@ class SynapseRequest(Request):
|
||||||
request_id,
|
request_id,
|
||||||
request=ContextRequest(
|
request=ContextRequest(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
ip_address=self.getClientIP(),
|
ip_address=self.getClientAddress().host,
|
||||||
site_tag=self.synapse_site.site_tag,
|
site_tag=self.synapse_site.site_tag,
|
||||||
# The requester is going to be unknown at this point.
|
# The requester is going to be unknown at this point.
|
||||||
requester=None,
|
requester=None,
|
||||||
|
@ -381,7 +381,7 @@ class SynapseRequest(Request):
|
||||||
|
|
||||||
self.synapse_site.access_logger.debug(
|
self.synapse_site.access_logger.debug(
|
||||||
"%s - %s - Received request: %s %s",
|
"%s - %s - Received request: %s %s",
|
||||||
self.getClientIP(),
|
self.getClientAddress().host,
|
||||||
self.synapse_site.site_tag,
|
self.synapse_site.site_tag,
|
||||||
self.get_method(),
|
self.get_method(),
|
||||||
self.get_redacted_uri(),
|
self.get_redacted_uri(),
|
||||||
|
@ -429,7 +429,7 @@ class SynapseRequest(Request):
|
||||||
"%s - %s - {%s}"
|
"%s - %s - {%s}"
|
||||||
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
|
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
|
||||||
' %sB %s "%s %s %s" "%s" [%d dbevts]',
|
' %sB %s "%s %s %s" "%s" [%d dbevts]',
|
||||||
self.getClientIP(),
|
self.getClientAddress().host,
|
||||||
self.synapse_site.site_tag,
|
self.synapse_site.site_tag,
|
||||||
requester,
|
requester,
|
||||||
processing_time,
|
processing_time,
|
||||||
|
|
|
@ -884,7 +884,7 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
|
||||||
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
|
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
|
||||||
tags.HTTP_METHOD: request.get_method(),
|
tags.HTTP_METHOD: request.get_method(),
|
||||||
tags.HTTP_URL: request.get_redacted_uri(),
|
tags.HTTP_URL: request.get_redacted_uri(),
|
||||||
tags.PEER_HOST_IPV6: request.getClientIP(),
|
tags.PEER_HOST_IPV6: request.getClientAddress().host,
|
||||||
}
|
}
|
||||||
|
|
||||||
request_name = request.request_metrics.name
|
request_name = request.request_metrics.name
|
||||||
|
|
|
@ -112,7 +112,7 @@ class AuthRestServlet(RestServlet):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.auth_handler.add_oob_auth(
|
await self.auth_handler.add_oob_auth(
|
||||||
LoginType.RECAPTCHA, authdict, request.getClientIP()
|
LoginType.RECAPTCHA, authdict, request.getClientAddress().host
|
||||||
)
|
)
|
||||||
except LoginError as e:
|
except LoginError as e:
|
||||||
# Authentication failed, let user try again
|
# Authentication failed, let user try again
|
||||||
|
@ -132,7 +132,7 @@ class AuthRestServlet(RestServlet):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.auth_handler.add_oob_auth(
|
await self.auth_handler.add_oob_auth(
|
||||||
LoginType.TERMS, authdict, request.getClientIP()
|
LoginType.TERMS, authdict, request.getClientAddress().host
|
||||||
)
|
)
|
||||||
except LoginError as e:
|
except LoginError as e:
|
||||||
# Authentication failed, let user try again
|
# Authentication failed, let user try again
|
||||||
|
@ -161,7 +161,9 @@ class AuthRestServlet(RestServlet):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.auth_handler.add_oob_auth(
|
await self.auth_handler.add_oob_auth(
|
||||||
LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP()
|
LoginType.REGISTRATION_TOKEN,
|
||||||
|
authdict,
|
||||||
|
request.getClientAddress().host,
|
||||||
)
|
)
|
||||||
except LoginError as e:
|
except LoginError as e:
|
||||||
html = self.registration_token_template.render(
|
html = self.registration_token_template.render(
|
||||||
|
|
|
@ -176,7 +176,7 @@ class LoginRestServlet(RestServlet):
|
||||||
|
|
||||||
if appservice.is_rate_limited():
|
if appservice.is_rate_limited():
|
||||||
await self._address_ratelimiter.ratelimit(
|
await self._address_ratelimiter.ratelimit(
|
||||||
None, request.getClientIP()
|
None, request.getClientAddress().host
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await self._do_appservice_login(
|
result = await self._do_appservice_login(
|
||||||
|
@ -188,19 +188,25 @@ class LoginRestServlet(RestServlet):
|
||||||
self.jwt_enabled
|
self.jwt_enabled
|
||||||
and login_submission["type"] == LoginRestServlet.JWT_TYPE
|
and login_submission["type"] == LoginRestServlet.JWT_TYPE
|
||||||
):
|
):
|
||||||
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
await self._address_ratelimiter.ratelimit(
|
||||||
|
None, request.getClientAddress().host
|
||||||
|
)
|
||||||
result = await self._do_jwt_login(
|
result = await self._do_jwt_login(
|
||||||
login_submission,
|
login_submission,
|
||||||
should_issue_refresh_token=should_issue_refresh_token,
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
)
|
)
|
||||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||||
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
await self._address_ratelimiter.ratelimit(
|
||||||
|
None, request.getClientAddress().host
|
||||||
|
)
|
||||||
result = await self._do_token_login(
|
result = await self._do_token_login(
|
||||||
login_submission,
|
login_submission,
|
||||||
should_issue_refresh_token=should_issue_refresh_token,
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
await self._address_ratelimiter.ratelimit(
|
||||||
|
None, request.getClientAddress().host
|
||||||
|
)
|
||||||
result = await self._do_other_login(
|
result = await self._do_other_login(
|
||||||
login_submission,
|
login_submission,
|
||||||
should_issue_refresh_token=should_issue_refresh_token,
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
|
|
|
@ -352,7 +352,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
|
||||||
if self.inhibit_user_in_use_error:
|
if self.inhibit_user_in_use_error:
|
||||||
return 200, {"available": True}
|
return 200, {"available": True}
|
||||||
|
|
||||||
ip = request.getClientIP()
|
ip = request.getClientAddress().host
|
||||||
with self.ratelimiter.ratelimit(ip) as wait_deferred:
|
with self.ratelimiter.ratelimit(ip) as wait_deferred:
|
||||||
await wait_deferred
|
await wait_deferred
|
||||||
|
|
||||||
|
@ -394,7 +394,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
|
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
|
||||||
await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
|
await self.ratelimiter.ratelimit(None, (request.getClientAddress().host,))
|
||||||
|
|
||||||
if not self.hs.config.registration.enable_registration:
|
if not self.hs.config.registration.enable_registration:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
|
@ -441,7 +441,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
client_addr = request.getClientIP()
|
client_addr = request.getClientAddress().host
|
||||||
|
|
||||||
await self.ratelimiter.ratelimit(None, client_addr, update=False)
|
await self.ratelimiter.ratelimit(None, client_addr, update=False)
|
||||||
|
|
||||||
|
|
|
@ -105,7 +105,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = simple_async_mock(None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||||
|
@ -124,7 +124,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = simple_async_mock(None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "192.168.10.10"
|
request.getClientAddress.return_value.host = "192.168.10.10"
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||||
|
@ -143,7 +143,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = simple_async_mock(None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "131.111.8.42"
|
request.getClientAddress.return_value.host = "131.111.8.42"
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
f = self.get_failure(
|
f = self.get_failure(
|
||||||
|
@ -190,7 +190,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = simple_async_mock(None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.args[b"user_id"] = [masquerading_user_id]
|
request.args[b"user_id"] = [masquerading_user_id]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
|
@ -209,7 +209,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = simple_async_mock(None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.args[b"user_id"] = [masquerading_user_id]
|
request.args[b"user_id"] = [masquerading_user_id]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
|
@ -236,7 +236,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.store.get_device = simple_async_mock({"hidden": False})
|
self.store.get_device = simple_async_mock({"hidden": False})
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.args[b"user_id"] = [masquerading_user_id]
|
request.args[b"user_id"] = [masquerading_user_id]
|
||||||
request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
|
request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
|
||||||
|
@ -268,7 +268,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.store.get_device = simple_async_mock(None)
|
self.store.get_device = simple_async_mock(None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.args[b"user_id"] = [masquerading_user_id]
|
request.args[b"user_id"] = [masquerading_user_id]
|
||||||
request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
|
request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
|
||||||
|
@ -288,7 +288,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.store.insert_client_ip = simple_async_mock(None)
|
self.store.insert_client_ip = simple_async_mock(None)
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
self.get_success(self.auth.get_user_by_req(request))
|
self.get_success(self.auth.get_user_by_req(request))
|
||||||
|
@ -305,7 +305,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.store.insert_client_ip = simple_async_mock(None)
|
self.store.insert_client_ip = simple_async_mock(None)
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
self.get_success(self.auth.get_user_by_req(request))
|
self.get_success(self.auth.get_user_by_req(request))
|
||||||
|
|
|
@ -204,7 +204,7 @@ def _mock_request():
|
||||||
mock = Mock(
|
mock = Mock(
|
||||||
spec=[
|
spec=[
|
||||||
"finish",
|
"finish",
|
||||||
"getClientIP",
|
"getClientAddress",
|
||||||
"getHeader",
|
"getHeader",
|
||||||
"setHeader",
|
"setHeader",
|
||||||
"setResponseCode",
|
"setResponseCode",
|
||||||
|
|
|
@ -1300,7 +1300,7 @@ def _build_callback_request(
|
||||||
"getCookie",
|
"getCookie",
|
||||||
"cookies",
|
"cookies",
|
||||||
"requestHeaders",
|
"requestHeaders",
|
||||||
"getClientIP",
|
"getClientAddress",
|
||||||
"getHeader",
|
"getHeader",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -1310,5 +1310,5 @@ def _build_callback_request(
|
||||||
request.args = {}
|
request.args = {}
|
||||||
request.args[b"code"] = [code.encode("utf-8")]
|
request.args[b"code"] = [code.encode("utf-8")]
|
||||||
request.args[b"state"] = [state.encode("utf-8")]
|
request.args[b"state"] = [state.encode("utf-8")]
|
||||||
request.getClientIP.return_value = ip_address
|
request.getClientAddress.return_value.host = ip_address
|
||||||
return request
|
return request
|
||||||
|
|
|
@ -352,7 +352,7 @@ def _mock_request():
|
||||||
mock = Mock(
|
mock = Mock(
|
||||||
spec=[
|
spec=[
|
||||||
"finish",
|
"finish",
|
||||||
"getClientIP",
|
"getClientAddress",
|
||||||
"getHeader",
|
"getHeader",
|
||||||
"setHeader",
|
"setHeader",
|
||||||
"setResponseCode",
|
"setResponseCode",
|
||||||
|
|
|
@ -154,10 +154,12 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(port, 8765)
|
self.assertEqual(port, 8765)
|
||||||
|
|
||||||
# Set up client side protocol
|
# Set up client side protocol
|
||||||
client_protocol = client_factory.buildProtocol(None)
|
client_address = IPv4Address("TCP", "127.0.0.1", 1234)
|
||||||
|
client_protocol = client_factory.buildProtocol(("127.0.0.1", 1234))
|
||||||
|
|
||||||
# Set up the server side protocol
|
# Set up the server side protocol
|
||||||
channel = self.site.buildProtocol(None)
|
server_address = IPv4Address("TCP", host, port)
|
||||||
|
channel = self.site.buildProtocol((host, port))
|
||||||
|
|
||||||
# hook into the channel's request factory so that we can keep a record
|
# hook into the channel's request factory so that we can keep a record
|
||||||
# of the requests
|
# of the requests
|
||||||
|
@ -173,12 +175,12 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Connect client to server and vice versa.
|
# Connect client to server and vice versa.
|
||||||
client_to_server_transport = FakeTransport(
|
client_to_server_transport = FakeTransport(
|
||||||
channel, self.reactor, client_protocol
|
channel, self.reactor, client_protocol, server_address, client_address
|
||||||
)
|
)
|
||||||
client_protocol.makeConnection(client_to_server_transport)
|
client_protocol.makeConnection(client_to_server_transport)
|
||||||
|
|
||||||
server_to_client_transport = FakeTransport(
|
server_to_client_transport = FakeTransport(
|
||||||
client_protocol, self.reactor, channel
|
client_protocol, self.reactor, channel, client_address, server_address
|
||||||
)
|
)
|
||||||
channel.makeConnection(server_to_client_transport)
|
channel.makeConnection(server_to_client_transport)
|
||||||
|
|
||||||
|
@ -406,19 +408,21 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(port, repl_port)
|
self.assertEqual(port, repl_port)
|
||||||
|
|
||||||
# Set up client side protocol
|
# Set up client side protocol
|
||||||
client_protocol = client_factory.buildProtocol(None)
|
client_address = IPv4Address("TCP", "127.0.0.1", 1234)
|
||||||
|
client_protocol = client_factory.buildProtocol(("127.0.0.1", 1234))
|
||||||
|
|
||||||
# Set up the server side protocol
|
# Set up the server side protocol
|
||||||
channel = self._hs_to_site[hs].buildProtocol(None)
|
server_address = IPv4Address("TCP", host, port)
|
||||||
|
channel = self._hs_to_site[hs].buildProtocol((host, port))
|
||||||
|
|
||||||
# Connect client to server and vice versa.
|
# Connect client to server and vice versa.
|
||||||
client_to_server_transport = FakeTransport(
|
client_to_server_transport = FakeTransport(
|
||||||
channel, self.reactor, client_protocol
|
channel, self.reactor, client_protocol, server_address, client_address
|
||||||
)
|
)
|
||||||
client_protocol.makeConnection(client_to_server_transport)
|
client_protocol.makeConnection(client_to_server_transport)
|
||||||
|
|
||||||
server_to_client_transport = FakeTransport(
|
server_to_client_transport = FakeTransport(
|
||||||
client_protocol, self.reactor, channel
|
client_protocol, self.reactor, channel, client_address, server_address
|
||||||
)
|
)
|
||||||
channel.makeConnection(server_to_client_transport)
|
channel.makeConnection(server_to_client_transport)
|
||||||
|
|
||||||
|
|
|
@ -181,7 +181,7 @@ class FakeChannel:
|
||||||
self.resource_usage = _self.logcontext.get_resource_usage()
|
self.resource_usage = _self.logcontext.get_resource_usage()
|
||||||
|
|
||||||
def getPeer(self):
|
def getPeer(self):
|
||||||
# We give an address so that getClientIP returns a non null entry,
|
# We give an address so that getClientAddress/getClientIP returns a non null entry,
|
||||||
# causing us to record the MAU
|
# causing us to record the MAU
|
||||||
return address.IPv4Address("TCP", self._ip, 3423)
|
return address.IPv4Address("TCP", self._ip, 3423)
|
||||||
|
|
||||||
|
@ -562,7 +562,10 @@ class FakeTransport:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_peer_address: Optional[IAddress] = attr.ib(default=None)
|
_peer_address: Optional[IAddress] = attr.ib(default=None)
|
||||||
"""The value to be returend by getPeer"""
|
"""The value to be returned by getPeer"""
|
||||||
|
|
||||||
|
_host_address: Optional[IAddress] = attr.ib(default=None)
|
||||||
|
"""The value to be returned by getHost"""
|
||||||
|
|
||||||
disconnecting = False
|
disconnecting = False
|
||||||
disconnected = False
|
disconnected = False
|
||||||
|
@ -571,11 +574,11 @@ class FakeTransport:
|
||||||
producer = attr.ib(default=None)
|
producer = attr.ib(default=None)
|
||||||
autoflush = attr.ib(default=True)
|
autoflush = attr.ib(default=True)
|
||||||
|
|
||||||
def getPeer(self):
|
def getPeer(self) -> Optional[IAddress]:
|
||||||
return self._peer_address
|
return self._peer_address
|
||||||
|
|
||||||
def getHost(self):
|
def getHost(self) -> Optional[IAddress]:
|
||||||
return None
|
return self._host_address
|
||||||
|
|
||||||
def loseConnection(self, reason=None):
|
def loseConnection(self, reason=None):
|
||||||
if not self.disconnecting:
|
if not self.disconnecting:
|
||||||
|
|
Loading…
Reference in New Issue