# -*- coding: utf-8 -*-
"""
Implementation for the SPEAD2 network transport
This module contains the logic to take ICD Payloads and transmit them using
the SPEAD protocol.
"""
import asyncio
import logging
import math
import time
from contextlib import AbstractAsyncContextManager
import numpy as np
import spead2.send.asyncio
from overrides import overrides
from realtime.receive.core.icd import Items, Payload
IS_SPEAD3 = spead2.__version__.split(".")[0] == "3"
logger = logging.getLogger(__name__)
async def create_stream(
thread_pool, target_host, port, config, buffer_size, transport_proto
):
assert transport_proto in ("udp", "tcp")
kwargs = {
"thread_pool": thread_pool,
"config": config,
"buffer_size": buffer_size,
}
if IS_SPEAD3:
kwargs["endpoints"] = ((target_host, port),)
else:
kwargs["hostname"] = target_host
kwargs["port"] = port
kwargs["loop"] = asyncio.get_running_loop()
if transport_proto == "tcp":
return await spead2.send.asyncio.TcpStream.connect(**kwargs)
return spead2.send.asyncio.UdpStream(**kwargs)
[docs]def parse_endpoints(endpoints_spec):
"""
Parse endpoint specifications.
Each endpoint is a colon-separated host and port pair, and multiple
endpoints are separated by commas. A port can be a single number or a range
specified as "start-end", both inclusive.
"""
endpoints = []
for endpoint in endpoints_spec.split(","):
host, port = endpoint.split(":")
if "-" in port:
start, end = map(int, port.split("-"))
if start > end:
raise ValueError(f"invalid port range: {start} > {end}")
for port in range(start, end + 1):
endpoints.append((host, port))
else:
endpoints.append((host, int(port)))
return endpoints
[docs]class Spead2SenderPayload(Payload):
"""SPEAD2 payload following the CSP-SDP interface document"""
def __init__(self, num_baselines=None, num_channels=None):
super(Spead2SenderPayload, self).__init__()
self._item_group = spead2.send.ItemGroup(
flavour=spead2.Flavour(4, 64, 48, 0)
)
self._add_items(num_baselines, num_channels)
self.baseline_count = num_baselines
self.channel_count = num_channels
self.correlated_data_fraction = np.ones([num_channels, num_baselines])
def _add_items(self, num_baselines, num_channels):
"""
Adds all the items to the payload as defined by the ICD
:param num_baselines: number of baselines in the
HEAP - used for sizing
:param num_channels: number of channels in the
HEAP - used for sizing
"""
vis_shape = (num_channels, num_baselines)
for item in Items:
item_desc = item.value
shape = tuple()
if item == Items.CORRELATOR_OUTPUT_DATA:
shape = vis_shape
self._item_group.add_item(
id=item_desc.id,
name=item_desc.name,
description="",
shape=shape,
format=None,
dtype=item_desc.dtype,
)
vis = np.zeros(
shape=vis_shape,
dtype=Items.CORRELATOR_OUTPUT_DATA.value.dtype,
)
self._item_group[Items.CORRELATOR_OUTPUT_DATA.value.id].value = vis
def get_heap(self):
def set_item(item, value):
self._item_group[item.value.id].value = value
set_item(Items.BASELINE_COUNT, self.baseline_count)
set_item(Items.CHANNEL_COUNT, self.channel_count)
set_item(Items.CHANNEL_ID, self.channel_id)
set_item(Items.HARDWARE_ID, self.hardware_id)
set_item(Items.PHASE_BIN_ID, self.phase_bin_id)
set_item(Items.PHASE_BIN_COUNT, self.phase_bin_count)
set_item(Items.POLARISATION_ID, self.polarisation_id)
set_item(Items.SCAN_ID, self.scan_id)
set_item(Items.TIMESTAMP_COUNT, self.timestamp_count)
set_item(Items.TIMESTAMP_FRACTION, self.timestamp_fraction)
corr_out_data = self._item_group[
Items.CORRELATOR_OUTPUT_DATA.value.id
].value
if len(self.time_centroid_indices):
corr_out_data["TCI"] = self.time_centroid_indices
if len(self.correlated_data_fraction):
corr_out_data["FD"] = self.correlated_data_fraction
if len(self.visibilities):
corr_out_data["VIS"] = self.visibilities
return self._item_group.get_heap(descriptors="none", data="all")
def get_start_heap(self):
start_heap = self._item_group.get_start()
self._item_group.add_to_heap(
start_heap, descriptors="all", data="none"
)
return start_heap
def get_end_heap(self):
return self._item_group.get_end()
[docs]class transmitter(AbstractAsyncContextManager):
"""
SPEAD2 transmitter
This class uses the spead2 library to transmit visibilities over multiple
spead2 streams. Each visiblity set given to this class' `send` method is
broken down by channel range (depending on the configuration parameters),
and each channel range is sent through a different stream.
"""
def __init__(self, config):
self.config = config
max_packet_size = int(config.get("max_packet_size", 1472))
logger.info(
"Creating StreamConfig with max_packet_size=%d", max_packet_size
)
self.stream_config = spead2.send.StreamConfig(
max_packet_size=max_packet_size,
rate=int(config.getfloat("rate", 1024 * 1024 * 1024)),
burst_size=10,
max_heaps=int(config.get("max_heaps", 1)),
)
self.channels_per_stream = int(config.get("channels_per_stream", 0))
self.sender_threads = int(config.get("sender_threads", 1))
self.num_streams = 0 # set on first call to send()
self.bytes_sent = 0
self.heaps_sent = 0
self.streams = []
self._start_heaps_sent = False
self._delay_start_of_stream_heaps = int(
config.get("delay_start_of_stream_heaps", 0)
)
[docs] async def prepare(self, num_baselines, num_channels):
"""Create the sending SPEAD streams"""
start_time = time.time()
if self.channels_per_stream == 0:
self.num_streams = 1
self.channels_per_stream = num_channels
else:
self.num_streams = math.ceil(
num_channels / self.channels_per_stream
)
# Each stream uses a separate ItemGroup because Heaps created out of
# ItemGroups can point to memory held by the ItemGroup; and since we
# want different heaps sent through each fo the streams we then need
# independent ItemGroups
self.payloads = [
Spead2SenderPayload(num_baselines, self.channels_per_stream)
for _ in range(self.num_streams)
]
# Create the streams; they still share a single I/O threadpool
thread_pool = spead2.ThreadPool(threads=self.sender_threads)
config = self.config
if "endpoints" in config:
def endpoints():
endpoints = parse_endpoints(config["endpoints"])
if len(endpoints) < self.num_streams:
raise ValueError("missing endpoints for number of streams")
yield from endpoints[: self.num_streams]
else:
def endpoints():
target_host = config.get("target_host", "127.0.0.1")
target_port = int(config.get("target_port_start", 41000))
for i in range(self.num_streams):
yield (target_host, target_port + i)
buffer_size = int(
config.get(
"buffer_size",
spead2.send.asyncio.UdpStream.DEFAULT_BUFFER_SIZE,
)
)
transport_proto = config.get("transport_protocol", "udp").lower()
if transport_proto not in ("udp", "tcp"):
raise ValueError("transport_protocol should be udp or tcp")
for i, endpoint in enumerate(endpoints()):
host, port = endpoint
logger.debug("Sending stream %d to %s:%d", i, host, port)
stream = await create_stream(
thread_pool,
host,
port,
self.stream_config,
buffer_size,
transport_proto,
)
self.streams.append(stream)
logger.info(
"Created %d %s spead2 streams to send data for "
"%d channels in %.3f [ms]",
self.num_streams,
transport_proto.upper(),
num_channels,
(time.time() - start_time) * 1000,
)
def _should_send_start_of_stream_heaps(self):
if self._delay_start_of_stream_heaps < 0 or self._start_heaps_sent:
return False
elif self._delay_start_of_stream_heaps == 0:
return True
data_heaps_sent_per_stream = self.heaps_sent // len(self.streams)
return data_heaps_sent_per_stream >= self._delay_start_of_stream_heaps
async def _maybe_send_start_of_stream_heap(self):
if self._should_send_start_of_stream_heaps():
await self._send_heaps(
[payload.get_start_heap() for payload in self.payloads]
)
self._start_heaps_sent = True
async def _send_heaps(self, heaps):
assert len(heaps) == len(self.streams)
send_operations = []
for heap, stream in zip(heaps, self.streams):
send_operations.append(stream.async_send_heap(heap))
results = await asyncio.gather(*send_operations)
self.bytes_sent += sum(results)
self.heaps_sent += len(heaps)
[docs] async def send(
self, scan_id: int, ts: int, ts_fraction: int, vis: np.ndarray
):
"""
Send a visibility set through all SPEAD2 streams
:param int: the scan id
:param ts: the integer part of the visibilities' timestamp
:param ts_fraction: the fractional part of the visibilities' timestamp
:param vis: the visibilities
"""
await self._maybe_send_start_of_stream_heap()
logger.debug("Sending heaps to %d spead2 streams", len(self.streams))
heaps = []
assert len(self.payloads) == len(self.streams)
for i, payload in enumerate(self.payloads):
# When sending to multiple streams (e.g., 96) we could spend a long
# time in this loop without yielding control back, so let's do that
if i % 20 == 0:
await asyncio.sleep(0)
(
first_chan,
last_chan,
) = self.channels_per_stream * i, self.channels_per_stream * (
i + 1
)
payload.scan_id = int(scan_id)
payload.timestamp_count = ts
payload.timestamp_fraction = ts_fraction
payload.visibilities = vis[first_chan:last_chan]
payload.channel_id = first_chan
payload.channel_count = self.channels_per_stream
heaps.append(payload.get_heap())
await self._send_heaps(heaps)
[docs] async def close(self):
"""Sends the end-of-stream message"""
await self._maybe_send_start_of_stream_heap()
await self._send_heaps(
[payload.get_end_heap() for payload in self.payloads]
)
@overrides
async def __aenter__(self):
return self
@overrides
async def __aexit__(self, ext_type, exc, tb):
await self.close()