Source code for cbf_sdp.transmitters.spead2_transmitters

# -*- coding: utf-8 -*-
"""
Implementation for the SPEAD2 network transport

This module contains the logic to take ICD Payloads and transmit them using
the SPEAD protocol.
"""
import asyncio
import logging
import math
import time
from contextlib import AbstractAsyncContextManager

import numpy as np
import spead2.send.asyncio
from overrides import overrides
from realtime.receive.core.icd import Items, Payload

IS_SPEAD3 = spead2.__version__.split(".")[0] == "3"

logger = logging.getLogger(__name__)


async def create_stream(
    thread_pool, target_host, port, config, buffer_size, transport_proto
):
    assert transport_proto in ("udp", "tcp")
    kwargs = {
        "thread_pool": thread_pool,
        "config": config,
        "buffer_size": buffer_size,
    }
    if IS_SPEAD3:
        kwargs["endpoints"] = ((target_host, port),)
    else:
        kwargs["hostname"] = target_host
        kwargs["port"] = port
        kwargs["loop"] = asyncio.get_running_loop()
    if transport_proto == "tcp":
        return await spead2.send.asyncio.TcpStream.connect(**kwargs)
    return spead2.send.asyncio.UdpStream(**kwargs)


[docs]def parse_endpoints(endpoints_spec): """ Parse endpoint specifications. Each endpoint is a colon-separated host and port pair, and multiple endpoints are separated by commas. A port can be a single number or a range specified as "start-end", both inclusive. """ endpoints = [] for endpoint in endpoints_spec.split(","): host, port = endpoint.split(":") if "-" in port: start, end = map(int, port.split("-")) if start > end: raise ValueError(f"invalid port range: {start} > {end}") for port in range(start, end + 1): endpoints.append((host, port)) else: endpoints.append((host, int(port))) return endpoints
[docs]class Spead2SenderPayload(Payload): """SPEAD2 payload following the CSP-SDP interface document""" def __init__(self, num_baselines=None, num_channels=None): super(Spead2SenderPayload, self).__init__() self._item_group = spead2.send.ItemGroup( flavour=spead2.Flavour(4, 64, 48, 0) ) self._add_items(num_baselines, num_channels) self.baseline_count = num_baselines self.channel_count = num_channels self.correlated_data_fraction = np.ones([num_channels, num_baselines]) def _add_items(self, num_baselines, num_channels): """ Adds all the items to the payload as defined by the ICD :param num_baselines: number of baselines in the HEAP - used for sizing :param num_channels: number of channels in the HEAP - used for sizing """ vis_shape = (num_channels, num_baselines) for item in Items: item_desc = item.value shape = tuple() if item == Items.CORRELATOR_OUTPUT_DATA: shape = vis_shape self._item_group.add_item( id=item_desc.id, name=item_desc.name, description="", shape=shape, format=None, dtype=item_desc.dtype, ) vis = np.zeros( shape=vis_shape, dtype=Items.CORRELATOR_OUTPUT_DATA.value.dtype, ) self._item_group[Items.CORRELATOR_OUTPUT_DATA.value.id].value = vis def get_heap(self): def set_item(item, value): self._item_group[item.value.id].value = value set_item(Items.BASELINE_COUNT, self.baseline_count) set_item(Items.CHANNEL_COUNT, self.channel_count) set_item(Items.CHANNEL_ID, self.channel_id) set_item(Items.HARDWARE_ID, self.hardware_id) set_item(Items.PHASE_BIN_ID, self.phase_bin_id) set_item(Items.PHASE_BIN_COUNT, self.phase_bin_count) set_item(Items.POLARISATION_ID, self.polarisation_id) set_item(Items.SCAN_ID, self.scan_id) set_item(Items.TIMESTAMP_COUNT, self.timestamp_count) set_item(Items.TIMESTAMP_FRACTION, self.timestamp_fraction) corr_out_data = self._item_group[ Items.CORRELATOR_OUTPUT_DATA.value.id ].value if len(self.time_centroid_indices): corr_out_data["TCI"] = self.time_centroid_indices if len(self.correlated_data_fraction): corr_out_data["FD"] = self.correlated_data_fraction if len(self.visibilities): corr_out_data["VIS"] = self.visibilities return self._item_group.get_heap(descriptors="none", data="all") def get_start_heap(self): start_heap = self._item_group.get_start() self._item_group.add_to_heap( start_heap, descriptors="all", data="none" ) return start_heap def get_end_heap(self): return self._item_group.get_end()
[docs]class transmitter(AbstractAsyncContextManager): """ SPEAD2 transmitter This class uses the spead2 library to transmit visibilities over multiple spead2 streams. Each visiblity set given to this class' `send` method is broken down by channel range (depending on the configuration parameters), and each channel range is sent through a different stream. """ def __init__(self, config): self.config = config max_packet_size = int(config.get("max_packet_size", 1472)) logger.info( "Creating StreamConfig with max_packet_size=%d", max_packet_size ) self.stream_config = spead2.send.StreamConfig( max_packet_size=max_packet_size, rate=int(config.getfloat("rate", 1024 * 1024 * 1024)), burst_size=10, max_heaps=int(config.get("max_heaps", 1)), ) self.channels_per_stream = int(config.get("channels_per_stream", 0)) self.sender_threads = int(config.get("sender_threads", 1)) self.num_streams = 0 # set on first call to send() self.bytes_sent = 0 self.heaps_sent = 0 self.streams = [] self._start_heaps_sent = False self._delay_start_of_stream_heaps = int( config.get("delay_start_of_stream_heaps", 0) )
[docs] async def prepare(self, num_baselines, num_channels): """Create the sending SPEAD streams""" start_time = time.time() if self.channels_per_stream == 0: self.num_streams = 1 self.channels_per_stream = num_channels else: self.num_streams = math.ceil( num_channels / self.channels_per_stream ) # Each stream uses a separate ItemGroup because Heaps created out of # ItemGroups can point to memory held by the ItemGroup; and since we # want different heaps sent through each fo the streams we then need # independent ItemGroups self.payloads = [ Spead2SenderPayload(num_baselines, self.channels_per_stream) for _ in range(self.num_streams) ] # Create the streams; they still share a single I/O threadpool thread_pool = spead2.ThreadPool(threads=self.sender_threads) config = self.config if "endpoints" in config: def endpoints(): endpoints = parse_endpoints(config["endpoints"]) if len(endpoints) < self.num_streams: raise ValueError("missing endpoints for number of streams") yield from endpoints[: self.num_streams] else: def endpoints(): target_host = config.get("target_host", "127.0.0.1") target_port = int(config.get("target_port_start", 41000)) for i in range(self.num_streams): yield (target_host, target_port + i) buffer_size = int( config.get( "buffer_size", spead2.send.asyncio.UdpStream.DEFAULT_BUFFER_SIZE, ) ) transport_proto = config.get("transport_protocol", "udp").lower() if transport_proto not in ("udp", "tcp"): raise ValueError("transport_protocol should be udp or tcp") for i, endpoint in enumerate(endpoints()): host, port = endpoint logger.debug("Sending stream %d to %s:%d", i, host, port) stream = await create_stream( thread_pool, host, port, self.stream_config, buffer_size, transport_proto, ) self.streams.append(stream) logger.info( "Created %d %s spead2 streams to send data for " "%d channels in %.3f [ms]", self.num_streams, transport_proto.upper(), num_channels, (time.time() - start_time) * 1000, )
def _should_send_start_of_stream_heaps(self): if self._delay_start_of_stream_heaps < 0 or self._start_heaps_sent: return False elif self._delay_start_of_stream_heaps == 0: return True data_heaps_sent_per_stream = self.heaps_sent // len(self.streams) return data_heaps_sent_per_stream >= self._delay_start_of_stream_heaps async def _maybe_send_start_of_stream_heap(self): if self._should_send_start_of_stream_heaps(): await self._send_heaps( [payload.get_start_heap() for payload in self.payloads] ) self._start_heaps_sent = True async def _send_heaps(self, heaps): assert len(heaps) == len(self.streams) send_operations = [] for heap, stream in zip(heaps, self.streams): send_operations.append(stream.async_send_heap(heap)) results = await asyncio.gather(*send_operations) self.bytes_sent += sum(results) self.heaps_sent += len(heaps)
[docs] async def send( self, scan_id: int, ts: int, ts_fraction: int, vis: np.ndarray ): """ Send a visibility set through all SPEAD2 streams :param int: the scan id :param ts: the integer part of the visibilities' timestamp :param ts_fraction: the fractional part of the visibilities' timestamp :param vis: the visibilities """ await self._maybe_send_start_of_stream_heap() logger.debug("Sending heaps to %d spead2 streams", len(self.streams)) heaps = [] assert len(self.payloads) == len(self.streams) for i, payload in enumerate(self.payloads): # When sending to multiple streams (e.g., 96) we could spend a long # time in this loop without yielding control back, so let's do that if i % 20 == 0: await asyncio.sleep(0) ( first_chan, last_chan, ) = self.channels_per_stream * i, self.channels_per_stream * ( i + 1 ) payload.scan_id = int(scan_id) payload.timestamp_count = ts payload.timestamp_fraction = ts_fraction payload.visibilities = vis[first_chan:last_chan] payload.channel_id = first_chan payload.channel_count = self.channels_per_stream heaps.append(payload.get_heap()) await self._send_heaps(heaps)
[docs] async def close(self): """Sends the end-of-stream message""" await self._maybe_send_start_of_stream_heap() await self._send_heaps( [payload.get_end_heap() for payload in self.payloads] )
@overrides async def __aenter__(self): return self @overrides async def __aexit__(self, ext_type, exc, tb): await self.close()