diff --git a/adsbus/incoming.c b/adsbus/incoming.c index 845d656..b6ef250 100644 --- a/adsbus/incoming.c +++ b/adsbus/incoming.c @@ -8,6 +8,7 @@ #include #include +#include "buf.h" #include "list.h" #include "peer.h" #include "resolve.h" @@ -26,6 +27,7 @@ struct incoming { const char *error; uint32_t attempt; incoming_connection_handler handler; + incoming_get_hello hello; void *passthrough; uint32_t *count; struct list_head incoming_list; @@ -67,6 +69,20 @@ static void incoming_handler(struct peer *peer) { socket_connected_init(fd); + { + struct buf buf = BUF_INIT, *buf_ptr = &buf; + if (incoming->hello) { + incoming->hello(&buf_ptr, incoming->passthrough); + } + if (buf_ptr->length) { + if (write(fd, buf_at(buf_ptr, 0), buf_ptr->length) != (ssize_t) buf_ptr->length) { + fprintf(stderr, "I %s: Error writing greeting\n", incoming->id); + assert(!close(fd)); + return; + } + } + } + incoming->handler(fd, incoming->passthrough, NULL); } @@ -146,7 +162,7 @@ void incoming_cleanup() { } } -void incoming_new(char *node, char *service, incoming_connection_handler handler, void *passthrough, uint32_t *count) { +void incoming_new(char *node, char *service, incoming_connection_handler handler, incoming_get_hello hello, void *passthrough, uint32_t *count) { (*count)++; struct incoming *incoming = malloc(sizeof(*incoming)); @@ -156,6 +172,7 @@ void incoming_new(char *node, char *service, incoming_connection_handler handler incoming->service = strdup(service); incoming->attempt = 0; incoming->handler = handler; + incoming->hello = hello; incoming->passthrough = passthrough; incoming->count = count; diff --git a/adsbus/incoming.h b/adsbus/incoming.h index 72f8e14..a6978ca 100644 --- a/adsbus/incoming.h +++ b/adsbus/incoming.h @@ -2,8 +2,10 @@ #include +struct buf; struct peer; void incoming_cleanup(void); typedef void (*incoming_connection_handler)(int fd, void *, struct peer *); -void incoming_new(char *, char *, incoming_connection_handler, void *, uint32_t *); +typedef void (*incoming_get_hello)(struct buf **, void *); +void incoming_new(char *, char *, incoming_connection_handler, incoming_get_hello, void *, uint32_t *); diff --git a/adsbus/opts.c b/adsbus/opts.c index 8e2c183..06f9df3 100644 --- a/adsbus/opts.c +++ b/adsbus/opts.c @@ -20,13 +20,13 @@ static char *opts_split(char **arg, char delim) { return ret; } -static void opts_add_listen(char *host_port, incoming_connection_handler handler, void *passthrough, uint32_t *count) { +static void opts_add_listen(char *host_port, incoming_connection_handler handler, incoming_get_hello hello, void *passthrough, uint32_t *count) { char *host = opts_split(&host_port, '/'); if (host) { - incoming_new(host, host_port, handler, passthrough, count); + incoming_new(host, host_port, handler, hello, passthrough, count); free(host); } else { - incoming_new(NULL, host_port, handler, passthrough, count); + incoming_new(NULL, host_port, handler, hello, passthrough, count); } } @@ -51,7 +51,7 @@ bool opts_add_connect_receive(char *arg) { return false; } - outgoing_new(host, arg, receive_new, NULL, &peer_count_in); + outgoing_new(host, arg, receive_new, NULL, NULL, &peer_count_in); free(host); return true; } @@ -67,13 +67,13 @@ bool opts_add_connect_send(char *arg) { return false; } - outgoing_new(host, arg, send_new_wrapper, serializer, &peer_count_out); + outgoing_new(host, arg, send_new_wrapper, send_hello, serializer, &peer_count_out); free(host); return true; } bool opts_add_listen_receive(char *arg) { - opts_add_listen(arg, receive_new, NULL, &peer_count_in); + opts_add_listen(arg, receive_new, NULL, NULL, &peer_count_in); return true; } @@ -83,7 +83,7 @@ bool opts_add_listen_send(char *arg) { return false; } - opts_add_listen(arg, send_new_wrapper, serializer, &peer_count_out); + opts_add_listen(arg, send_new_wrapper, send_hello, serializer, &peer_count_out); return true; } diff --git a/adsbus/outgoing.c b/adsbus/outgoing.c index 9bd4e06..e5736ec 100644 --- a/adsbus/outgoing.c +++ b/adsbus/outgoing.c @@ -8,6 +8,7 @@ #include #include +#include "buf.h" #include "list.h" #include "peer.h" #include "resolve.h" @@ -27,6 +28,7 @@ struct outgoing { const char *error; uint32_t attempt; outgoing_connection_handler handler; + outgoing_get_hello hello; void *passthrough; uint32_t *count; struct list_head outgoing_list; @@ -60,9 +62,14 @@ static void outgoing_connect_next(struct outgoing *outgoing) { outgoing->peer.fd = socket(outgoing->addr->ai_family, outgoing->addr->ai_socktype | SOCK_NONBLOCK | SOCK_CLOEXEC, outgoing->addr->ai_protocol); assert(outgoing->peer.fd >= 0); - char buf[1]; - int result = (int) sendto(outgoing->peer.fd, buf, 0, MSG_FASTOPEN, outgoing->addr->ai_addr, outgoing->addr->ai_addrlen); - outgoing_connect_result(outgoing, result == 0 ? result : errno); + { + struct buf buf = BUF_INIT, *buf_ptr = &buf; + if (outgoing->hello) { + outgoing->hello(&buf_ptr, outgoing->passthrough); + } + int result = (int) sendto(outgoing->peer.fd, buf_at(buf_ptr, 0), buf_ptr->length, MSG_FASTOPEN, outgoing->addr->ai_addr, outgoing->addr->ai_addrlen); + outgoing_connect_result(outgoing, result == 0 ? result : errno); + } } static void outgoing_connect_handler(struct peer *peer) { @@ -155,7 +162,7 @@ void outgoing_cleanup() { } } -void outgoing_new(char *node, char *service, outgoing_connection_handler handler, void *passthrough, uint32_t *count) { +void outgoing_new(char *node, char *service, outgoing_connection_handler handler, outgoing_get_hello hello, void *passthrough, uint32_t *count) { (*count)++; struct outgoing *outgoing = malloc(sizeof(*outgoing)); @@ -164,6 +171,7 @@ void outgoing_new(char *node, char *service, outgoing_connection_handler handler outgoing->service = strdup(service); outgoing->attempt = 0; outgoing->handler = handler; + outgoing->hello = hello; outgoing->passthrough = passthrough; outgoing->count = count; diff --git a/adsbus/outgoing.h b/adsbus/outgoing.h index 190dfb3..a669b94 100644 --- a/adsbus/outgoing.h +++ b/adsbus/outgoing.h @@ -1,7 +1,9 @@ #pragma once +struct buf; struct peer; void outgoing_cleanup(void); typedef void (*outgoing_connection_handler)(int fd, void *, struct peer *); -void outgoing_new(char *, char *, outgoing_connection_handler, void *, uint32_t *); +typedef void (*outgoing_get_hello)(struct buf **, void *); +void outgoing_new(char *, char *, outgoing_connection_handler, outgoing_get_hello, void *, uint32_t *); diff --git a/adsbus/send.c b/adsbus/send.c index de05ca2..3f87f87 100644 --- a/adsbus/send.c +++ b/adsbus/send.c @@ -78,18 +78,6 @@ static void send_del_wrapper(struct peer *peer) { send_del((struct send *) peer); } -static bool send_hello(int fd, struct serializer *serializer) { - struct buf buf = BUF_INIT; - serializer->serialize(NULL, &buf); - if (buf.length == 0) { - return true; - } - if (write(fd, buf_at(&buf, 0), buf.length) != (ssize_t) buf.length) { - return false; - } - return true; -} - void send_init() { assert(signal(SIGPIPE, SIG_IGN) != SIG_ERR); for (size_t i = 0; i < NUM_SERIALIZERS; i++) { @@ -134,18 +122,18 @@ void send_new(int fd, struct serializer *serializer, struct peer *on_close) { peer_epoll_add((struct peer *) send, 0); fprintf(stderr, "S %s (%s): New send connection\n", send->id, serializer->name); - - if (!send_hello(fd, serializer)) { - fprintf(stderr, "S %s: Failed to write hello\n", send->id); - send_del(send); - return; - } } void send_new_wrapper(int fd, void *passthrough, struct peer *on_close) { send_new(fd, (struct serializer *) passthrough, on_close); } +void send_hello(struct buf **buf_pp, void *passthrough) { + struct serializer *serializer = (struct serializer *) passthrough; + // TODO: change API to avoid special-case NULL packet*, and to allow static greetings. + serializer->serialize(NULL, *buf_pp); +} + void send_write(struct packet *packet) { packet_sanity_check(packet); for (size_t i = 0; i < NUM_SERIALIZERS; i++) { diff --git a/adsbus/send.h b/adsbus/send.h index cf49ae3..fdc3d67 100644 --- a/adsbus/send.h +++ b/adsbus/send.h @@ -1,5 +1,6 @@ #pragma once +struct buf; struct packet; struct peer; @@ -8,5 +9,6 @@ void send_cleanup(void); struct serializer *send_get_serializer(char *); void send_new(int, struct serializer *, struct peer *); void send_new_wrapper(int, void *, struct peer *); +void send_hello(struct buf **, void *); void send_write(struct packet *); void send_print_usage(void);