Updated buffer tests for windowed channel

This commit is contained in:
Mark Qvist 2023-05-18 23:32:29 +02:00
parent e184861822
commit 4c272aa536
1 changed files with 37 additions and 28 deletions

View File

@ -4,6 +4,7 @@ import subprocess
import shlex import shlex
import threading import threading
import time import time
import random
from unittest import skipIf from unittest import skipIf
import RNS import RNS
import os import os
@ -23,6 +24,8 @@ fixed_keys = [
("08bb35f92b06a0832991165a0d9b4fd91af7b7765ce4572aa6222070b11b767092b61b0fd18b3a59cae6deb9db6d4bfb1c7fcfe076cfd66eea7ddd5f877543b9", "d13712efc45ef87674fb5ac26c37c912"), ("08bb35f92b06a0832991165a0d9b4fd91af7b7765ce4572aa6222070b11b767092b61b0fd18b3a59cae6deb9db6d4bfb1c7fcfe076cfd66eea7ddd5f877543b9", "d13712efc45ef87674fb5ac26c37c912"),
] ]
BUFFER_TEST_TARGET = 32000
def targets_job(caller): def targets_job(caller):
cmd = "python -c \"from tests.link import targets; targets()\"" cmd = "python -c \"from tests.link import targets; targets()\""
print("Opening subprocess for "+str(cmd)+"...", RNS.LOG_VERBOSE) print("Opening subprocess for "+str(cmd)+"...", RNS.LOG_VERBOSE)
@ -455,7 +458,7 @@ class TestLink(unittest.TestCase):
# @skipIf(os.getenv('SKIP_NORMAL_TESTS') != None and os.getenv('RUN_SLOW_TESTS') == None, "Skipping") # @skipIf(os.getenv('SKIP_NORMAL_TESTS') != None and os.getenv('RUN_SLOW_TESTS') == None, "Skipping")
def test_12_buffer_round_trip_big(self, local_bitrate = None): def test_12_buffer_round_trip_big(self, local_bitrate = None):
global c_rns global c_rns, buffer_read_target
init_rns(self) init_rns(self)
print("") print("")
print("Buffer round trip test") print("Buffer round trip test")
@ -490,9 +493,9 @@ class TestLink(unittest.TestCase):
buffer = None buffer = None
received = [] received = []
def handle_data(ready_bytes: int): def handle_data(ready_bytes: int):
# TODO: Remove global received_bytes
RNS.log("Handling data")
data = buffer.read(ready_bytes) data = buffer.read(ready_bytes)
received.append(data) received.append(data)
@ -509,10 +512,11 @@ class TestLink(unittest.TestCase):
if local_interface.bitrate < 1000: if local_interface.bitrate < 1000:
target_bytes = 3000 target_bytes = 3000
else: else:
target_bytes = 16000 target_bytes = BUFFER_TEST_TARGET
random.seed(154889)
message = os.urandom(target_bytes) message = random.randbytes(target_bytes)
buffer_read_target = len(message)
# the return message will have an appendage string " back at you" # the return message will have an appendage string " back at you"
# for every StreamDataMessage that arrives. To verify, we need # for every StreamDataMessage that arrives. To verify, we need
@ -527,35 +531,24 @@ class TestLink(unittest.TestCase):
# since the segments will be received at max length for a # since the segments will be received at max length for a
# StreamDataMessage, the appended text will end up in a # StreamDataMessage, the appended text will end up in a
# separate packet. # separate packet.
expected_chunk_count = ceil(len(message)/StreamDataMessage.MAX_DATA_LEN * 2)-1 print("Sending " + str(len(message)) + " bytes, receiving " + str(len(expected_rx_message)) + " bytes, ")
print("Sending " + str(len(message)) + " bytes, receiving " + str(len(expected_rx_message)) + " bytes, " +
"expecting " + str(expected_chunk_count) + " chunks of " + str(StreamDataMessage.MAX_DATA_LEN) + " bytes")
buffer.write(message) buffer.write(message)
buffer.flush() buffer.flush()
# delay a reasonable time for the send and receive timeout = time.time() + 4
# a chunk each way plus a little more for a proof each way while not time.time() > timeout:
# while time.time() < expected_ready_time and len(received) < expected_chunk_count: time.sleep(1)
# time.sleep(0.1) print(f"Received {len(received)} chunks so far")
# # sleep for at least one more chunk round trip in case there time.sleep(1)
# # are more chunks than expected
# if time.time() < expected_ready_time:
# time.sleep(max(c_rns.MTU * 2 / local_interface.bitrate * 8, 1))
timeout = time.time() + 10
while len(received) < expected_chunk_count and not time.time() > timeout:
time.sleep(2)
print(f"Received {len(received)} out of {expected_chunk_count} chunks so far")
time.sleep(2)
print(f"Received {len(received)} out of {expected_chunk_count} chunks")
data = bytearray() data = bytearray()
for rx in received: for rx in received:
data.extend(rx) data.extend(rx)
rx_message = data rx_message = data
print(f"Received {len(received)} chunks, totalling {len(rx_message)} bytes")
self.assertEqual(len(expected_rx_message), len(rx_message)) self.assertEqual(len(expected_rx_message), len(rx_message))
for i in range(0, len(expected_rx_message)): for i in range(0, len(expected_rx_message)):
self.assertEqual(expected_rx_message[i], rx_message[i]) self.assertEqual(expected_rx_message[i], rx_message[i])
@ -598,7 +591,7 @@ class TestLink(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
unittest.main(verbosity=1) unittest.main(verbosity=1)
buffer_read_len = 0
def targets(yp=False): def targets(yp=False):
if yp: if yp:
import yappi import yappi
@ -645,10 +638,26 @@ def targets(yp=False):
buffer = None buffer = None
response_data = []
def handle_buffer(ready_bytes: int): def handle_buffer(ready_bytes: int):
global buffer_read_len, BUFFER_TEST_TARGET
data = buffer.read(ready_bytes) data = buffer.read(ready_bytes)
buffer_read_len += len(data)
response_data.append(data)
if data == "Hi there".encode("utf-8"):
RNS.log("Sending response")
for data in response_data:
buffer.write(data + " back at you".encode("utf-8")) buffer.write(data + " back at you".encode("utf-8"))
buffer.flush() buffer.flush()
buffer_read_len = 0
if buffer_read_len == BUFFER_TEST_TARGET:
RNS.log("Sending response")
for data in response_data:
buffer.write(data + " back at you".encode("utf-8"))
buffer.flush()
buffer_read_len = 0
buffer = RNS.Buffer.create_bidirectional_buffer(0, 0, channel, handle_buffer) buffer = RNS.Buffer.create_bidirectional_buffer(0, 0, channel, handle_buffer)