263 lines
8.9 KiB
Python
263 lines
8.9 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
import threading
|
|
import time
|
|
from typing import List
|
|
|
|
import redis
|
|
import requests
|
|
from flask import jsonify, request, Flask
|
|
from flask_caching import Cache
|
|
from urllib3.exceptions import InsecureRequestWarning
|
|
|
|
from checker.units import filesize
|
|
|
|
requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning)
|
|
|
|
MAX_POINTS_PER_IP = 1900
|
|
|
|
OPNSENSE_URL = os.environ.get('OPN_URL')
|
|
OPNSENSE_KEY = os.environ.get('OPN_KEY')
|
|
OPNSENSE_SECRET = os.environ.get('OPN_SECRET')
|
|
|
|
if not OPNSENSE_URL or not OPNSENSE_KEY or not OPNSENSE_URL:
|
|
raise Exception('Missing environment variables')
|
|
OPNSENSE_URL = OPNSENSE_URL.strip('/')
|
|
|
|
|
|
class TrafficEntry:
|
|
def __init__(self, interface: str, address: str, rate_bits_in: int, rate_bits_out: int, rate_bits: int, cumulative_bytes_in: int, cumulative_bytes_out: int, cumulative_bytes: int, connections: dict, timestamp: float):
|
|
self.interface = interface
|
|
self.address = address
|
|
self.rate_bits_in = rate_bits_in
|
|
self.rate_bits_out = rate_bits_out
|
|
self.rate_bits = rate_bits
|
|
self.cumulative_bytes_in = cumulative_bytes_in
|
|
self.cumulative_bytes_out = cumulative_bytes_out
|
|
self.cumulative_bytes = cumulative_bytes
|
|
self.connections = connections
|
|
self.timestamp = timestamp
|
|
|
|
def to_json(self):
|
|
return self.__dict__
|
|
|
|
@classmethod
|
|
def from_json(cls, json_str):
|
|
data = json.loads(json_str)
|
|
return cls(**data)
|
|
|
|
|
|
class OpnsenseTraffic:
|
|
def __init__(self):
|
|
self.redis = redis.Redis(host='localhost', port=6379, db=1)
|
|
|
|
def flush(self):
|
|
self.redis.flushdb()
|
|
|
|
def add_entry(self, item: TrafficEntry):
|
|
# TODO: kick out the oldest item
|
|
|
|
key = f"{item.interface}:{item.address}"
|
|
if self.redis.llen(key) >= MAX_POINTS_PER_IP:
|
|
self.redis.lpop(key)
|
|
|
|
self.redis.rpush(key, json.dumps(item.to_json()))
|
|
|
|
def get_address(self, input_address: str) -> List[TrafficEntry]:
|
|
keys = self.redis.keys(f"*:{input_address}")
|
|
data = []
|
|
for key in keys:
|
|
entries = self.redis.lrange(key, 0, -1)
|
|
data.extend([TrafficEntry.from_json(entry.decode()) for entry in entries])
|
|
return data
|
|
|
|
def get_entries(self, input_address: str):
|
|
keys = self.redis.keys()
|
|
data = {}
|
|
for key in keys:
|
|
try:
|
|
interface, address = key.decode().split(":")
|
|
except ValueError:
|
|
# Can get things like "opt2:::"
|
|
continue
|
|
if address != input_address:
|
|
continue
|
|
entries = self.redis.lrange(key, 0, -1)
|
|
if interface not in data:
|
|
data[interface] = {}
|
|
data[interface][address] = [TrafficEntry.from_json(entry.decode()).to_json() for entry in entries]
|
|
return data
|
|
|
|
def get_traffic(self, address: str, minus_seconds: int = 0, human: bool = False):
|
|
max_rate_in = 0
|
|
max_rate_out = 0
|
|
bytes_in = 0
|
|
bytes_out = 0
|
|
connections = 0
|
|
|
|
if minus_seconds == 0:
|
|
minus_sec_diff = 0
|
|
else:
|
|
minus_sec_diff = int(time.time()) - minus_seconds
|
|
|
|
address_traffic = self.get_address(address)
|
|
for entry in address_traffic:
|
|
if entry.timestamp >= minus_sec_diff:
|
|
max_rate_in = max(max_rate_in, entry.rate_bits_in)
|
|
max_rate_out = max(max_rate_out, entry.rate_bits_out)
|
|
bytes_in += entry.cumulative_bytes_in
|
|
bytes_out += entry.cumulative_bytes_out
|
|
connections += len(entry.connections)
|
|
|
|
if human:
|
|
return filesize(max_rate_in), filesize(max_rate_out), filesize(bytes_in), filesize(bytes_out), connections
|
|
else:
|
|
return max_rate_in, max_rate_out, bytes_in, bytes_out, connections
|
|
|
|
|
|
def get_interfaces():
|
|
r = redis.Redis(host='localhost', port=6379, db=2)
|
|
try:
|
|
return json.loads(r.get('interfaces'))
|
|
except Exception as e:
|
|
return []
|
|
|
|
|
|
def get_interface_names():
|
|
r = redis.Redis(host='localhost', port=6379, db=2)
|
|
# Map interface names to their internal names
|
|
while True:
|
|
interfaces_mapping_response = requests.get(f'{OPNSENSE_URL}/api/diagnostics/traffic/interface',
|
|
headers={'Accept': 'application/json'}, auth=(OPNSENSE_KEY, OPNSENSE_SECRET),
|
|
verify=False)
|
|
interfaces_mapping_response.raise_for_status()
|
|
|
|
interfaces = list(interfaces_mapping_response.json()['interfaces'].keys())
|
|
if 'interface' in interfaces:
|
|
interfaces.remove('interface')
|
|
|
|
# Store the interfaces in Redis
|
|
r.set('interfaces', json.dumps(interfaces))
|
|
|
|
time.sleep(60)
|
|
|
|
|
|
def background_thread():
|
|
traffic_data = OpnsenseTraffic()
|
|
traffic_data.flush()
|
|
while True:
|
|
start_time = time.time()
|
|
interface_req = ','.join(get_interfaces())
|
|
response = requests.get(f'{OPNSENSE_URL}/api/diagnostics/traffic/top/{interface_req}',
|
|
headers={'Accept': 'application/json'}, auth=(OPNSENSE_KEY, OPNSENSE_SECRET), verify=False)
|
|
response.raise_for_status()
|
|
timestamp = time.time()
|
|
|
|
for interface, data in response.json().items():
|
|
for item in data.get('records'):
|
|
traffic_data.add_entry(
|
|
TrafficEntry(
|
|
address=item['address'],
|
|
interface=interface,
|
|
rate_bits=item['rate_bits'],
|
|
rate_bits_in=item['rate_bits_in'],
|
|
rate_bits_out=item['rate_bits_out'],
|
|
cumulative_bytes=item['cumulative_bytes'],
|
|
cumulative_bytes_in=item['cumulative_bytes_in'],
|
|
cumulative_bytes_out=item['cumulative_bytes_out'],
|
|
connections=item['details'],
|
|
timestamp=timestamp
|
|
)
|
|
)
|
|
end_time = time.time()
|
|
api_request_time = end_time - start_time
|
|
adjusted_sleep_duration = max(1 - api_request_time, 0)
|
|
time.sleep(adjusted_sleep_duration)
|
|
|
|
|
|
flask_traffic_data = OpnsenseTraffic()
|
|
app = Flask(__name__)
|
|
cache = Cache(app, config={
|
|
"CACHE_TYPE": "RedisCache",
|
|
"CACHE_REDIS_HOST": "127.0.0.1",
|
|
"port": 6379
|
|
})
|
|
|
|
|
|
@app.route('/traffic/<address>', methods=['GET'])
|
|
@app.route('/traffic/<address>/<interface>', methods=['GET'])
|
|
@cache.cached(timeout=10, query_string=True)
|
|
def get_traffic(address, interface=None):
|
|
minus_seconds = request.args.get('seconds', default=0, type=int)
|
|
human = request.args.get('human', default=False, type=bool)
|
|
max_rate_in, max_rate_out, bytes_in, bytes_out, connections = flask_traffic_data.get_traffic(address, minus_seconds, human)
|
|
entries = flask_traffic_data.get_entries(address)
|
|
num_entries = 0
|
|
|
|
if interface:
|
|
if interface not in entries.keys():
|
|
return 'Interface not found for address', 404
|
|
entries = entries[interface]
|
|
num_entries = len(entries)
|
|
return jsonify({
|
|
interface: {
|
|
'max_rate_in': max_rate_in, 'max_rate_out': max_rate_out, 'bytes_in': bytes_in, 'bytes_out': bytes_out, 'connections': connections
|
|
},
|
|
'entries': num_entries
|
|
})
|
|
else:
|
|
for interface in entries:
|
|
num_entries += len(entries[interface][list(entries[interface].keys())[0]])
|
|
return jsonify({
|
|
'data': {
|
|
'max_rate_in': max_rate_in, 'max_rate_out': max_rate_out, 'bytes_in': bytes_in, 'bytes_out': bytes_out, 'connections': connections
|
|
},
|
|
'entries': num_entries
|
|
})
|
|
|
|
|
|
@app.route('/data/<ip>', methods=['GET'])
|
|
@cache.cached(timeout=10)
|
|
def get_entries(ip):
|
|
entries = flask_traffic_data.get_entries(ip)
|
|
# Remove the IP address from the dict
|
|
new_entries = {}
|
|
for key, value in entries.items():
|
|
new_entries[key] = []
|
|
for sub_key, sub_value in value.items():
|
|
new_entries[key].extend(sub_value)
|
|
return jsonify(new_entries)
|
|
|
|
|
|
@app.route('/interfaces', methods=['GET'])
|
|
@cache.cached(timeout=10)
|
|
def flask_get_interfaces():
|
|
return jsonify(get_interfaces())
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--daemon', action='store_true', help='Run the background daemon.')
|
|
args = parser.parse_args()
|
|
|
|
if args.daemon:
|
|
t1 = threading.Thread(target=get_interface_names, daemon=True)
|
|
t1.start()
|
|
|
|
print('Fetching interface list... ', end='')
|
|
while not len(get_interfaces()):
|
|
time.sleep(2)
|
|
print('Done!')
|
|
|
|
t2 = threading.Thread(target=background_thread, daemon=True)
|
|
t2.start()
|
|
|
|
print('Threads started!')
|
|
|
|
while True:
|
|
time.sleep(10000)
|
|
|
|
else:
|
|
app.run(host='0.0.0.0', debug=True)
|