]> git.proxmox.com Git - mirror_frr.git/blobdiff - lib/mgmt_msg.c
Merge pull request #13649 from donaldsharp/unlock_the_node_or_else
[mirror_frr.git] / lib / mgmt_msg.c
index 2fab03bc5472b39bc7a8a77f4f4cb86194ca1b32..0d9802a2b364b29f08cdcbdda6fc85b024bdf0d3 100644 (file)
@@ -7,10 +7,11 @@
  * Copyright (c) 2023, LabN Consulting, L.L.C.
  */
 #include <zebra.h>
+#include "debug.h"
 #include "network.h"
 #include "sockopt.h"
 #include "stream.h"
-#include "thread.h"
+#include "frrevent.h"
 #include "mgmt_msg.h"
 
 
@@ -22,7 +23,9 @@
        } while (0)
 
 #define MGMT_MSG_ERR(ms, fmt, ...)                                             \
-       zlog_err("%s: %s: " fmt, ms->idtag, __func__, ##__VA_ARGS__)
+       zlog_err("%s: %s: " fmt, (ms)->idtag, __func__, ##__VA_ARGS__)
+
+DEFINE_MTYPE(LIB, MSG_CONN, "msg connection state");
 
 /**
  * Read data from a socket into streams containing 1 or more full msgs headed by
@@ -81,7 +84,7 @@ enum mgmt_msg_rsched mgmt_msg_read(struct mgmt_msg_state *ms, int fd,
        left = stream_get_endp(ms->ins);
        while (left > (long)sizeof(struct mgmt_msg_hdr)) {
                mhdr = (struct mgmt_msg_hdr *)(STREAM_DATA(ms->ins) + total);
-               if (mhdr->marker != MGMT_MSG_MARKER) {
+               if (!MGMT_MSG_IS_MARKER(mhdr->marker)) {
                        MGMT_MSG_DBG(dbgtag, "recv corrupt buffer, disconnect");
                        return MSR_DISCONNECT;
                }
@@ -127,8 +130,8 @@ enum mgmt_msg_rsched mgmt_msg_read(struct mgmt_msg_state *ms, int fd,
  *     true if more to process (so reschedule) else false
  */
 bool mgmt_msg_procbufs(struct mgmt_msg_state *ms,
-                      void (*handle_msg)(void *user, uint8_t *msg,
-                                         size_t msglen),
+                      void (*handle_msg)(uint8_t version, uint8_t *msg,
+                                         size_t msglen, void *user),
                       void *user, bool debug)
 {
        const char *dbgtag = debug ? ms->idtag : NULL;
@@ -153,11 +156,13 @@ bool mgmt_msg_procbufs(struct mgmt_msg_state *ms,
                     left -= mhdr->len, data += mhdr->len) {
                        mhdr = (struct mgmt_msg_hdr *)data;
 
-                       assert(mhdr->marker == MGMT_MSG_MARKER);
+                       assert(MGMT_MSG_IS_MARKER(mhdr->marker));
                        assert(left >= mhdr->len);
 
-                       handle_msg(user, (uint8_t *)(mhdr + 1),
-                                  mhdr->len - sizeof(struct mgmt_msg_hdr));
+                       handle_msg(MGMT_MSG_MARKER_VERSION(mhdr->marker),
+                                  (uint8_t *)(mhdr + 1),
+                                  mhdr->len - sizeof(struct mgmt_msg_hdr),
+                                  user);
                        ms->nrxm++;
                        nproc++;
                }
@@ -251,7 +256,7 @@ enum mgmt_msg_wsched mgmt_msg_write(struct mgmt_msg_state *ms, int fd,
                        dbgtag,
                        "reached %zu buffer writes, pausing with %zu streams left",
                        ms->max_write_buf, ms->outq.count);
-               return MSW_SCHED_WRITES_OFF;
+               return MSW_SCHED_STREAM;
        }
        MGMT_MSG_DBG(dbgtag, "flushed all streams from output q");
        return MSW_SCHED_NONE;
@@ -264,15 +269,19 @@ enum mgmt_msg_wsched mgmt_msg_write(struct mgmt_msg_state *ms, int fd,
  *
  * Args:
  *     ms: mgmt_msg_state for this process.
- *     fd: socket/file to read data from.
+ *     version: version of this message, will be given to receiving side.
+ *     msg: the message to be sent.
+ *     len: the length of the message.
+ *     packf: a function to pack the message.
  *     debug: true to enable debug logging.
  *
  * Returns:
  *      0 on success, otherwise -1 on failure. The only failure mode is if a
  *      the message exceeds the maximum message size configured on init.
  */
-int mgmt_msg_send_msg(struct mgmt_msg_state *ms, void *msg, size_t len,
-                     mgmt_msg_packf packf, bool debug)
+int mgmt_msg_send_msg(struct mgmt_msg_state *ms, uint8_t version, void *msg,
+                     size_t len, size_t (*packf)(void *msg, void *buf),
+                     bool debug)
 {
        const char *dbgtag = debug ? ms->idtag : NULL;
        struct mgmt_msg_hdr *mhdr;
@@ -308,12 +317,17 @@ int mgmt_msg_send_msg(struct mgmt_msg_state *ms, void *msg, size_t len,
 
        /* We have a stream with space, pack the message into it. */
        mhdr = (struct mgmt_msg_hdr *)(STREAM_DATA(s) + s->endp);
-       mhdr->marker = MGMT_MSG_MARKER;
+       mhdr->marker = MGMT_MSG_MARKER(version);
        mhdr->len = mlen;
        stream_forward_endp(s, sizeof(*mhdr));
        endp = stream_get_endp(s);
        dstbuf = STREAM_DATA(s) + endp;
-       n = packf(msg, dstbuf);
+       if (packf)
+               n = packf(msg, dstbuf);
+       else {
+               memcpy(dstbuf, msg, len);
+               n = len;
+       }
        stream_set_endp(s, endp + n);
        ms->ntxm++;
 
@@ -392,6 +406,7 @@ size_t mgmt_msg_reset_writes(struct mgmt_msg_state *ms)
        return nproc;
 }
 
+
 void mgmt_msg_init(struct mgmt_msg_state *ms, size_t max_read_buf,
                   size_t max_write_buf, size_t max_msg_sz, const char *idtag)
 {
@@ -412,3 +427,494 @@ void mgmt_msg_destroy(struct mgmt_msg_state *ms)
                stream_free(ms->ins);
        free(ms->idtag);
 }
+
+/*
+ * Connections
+ */
+
+#define MSG_CONN_DEFAULT_CONN_RETRY_MSEC 250
+#define MSG_CONN_SEND_BUF_SIZE (1u << 16)
+#define MSG_CONN_RECV_BUF_SIZE (1u << 16)
+
+static void msg_client_sched_connect(struct msg_client *client,
+                                    unsigned long msec);
+
+static void msg_conn_sched_proc_msgs(struct msg_conn *conn);
+static void msg_conn_sched_read(struct msg_conn *conn);
+static void msg_conn_sched_write(struct msg_conn *conn);
+
+static void msg_conn_write(struct event *thread)
+{
+       struct msg_conn *conn = EVENT_ARG(thread);
+       enum mgmt_msg_wsched rv;
+
+       rv = mgmt_msg_write(&conn->mstate, conn->fd, conn->debug);
+       if (rv == MSW_SCHED_STREAM)
+               msg_conn_sched_write(conn);
+       else if (rv == MSW_DISCONNECT)
+               msg_conn_disconnect(conn, conn->is_client);
+       else
+               assert(rv == MSW_SCHED_NONE);
+}
+
+static void msg_conn_read(struct event *thread)
+{
+       struct msg_conn *conn = EVENT_ARG(thread);
+       enum mgmt_msg_rsched rv;
+
+       rv = mgmt_msg_read(&conn->mstate, conn->fd, conn->debug);
+       if (rv == MSR_DISCONNECT) {
+               msg_conn_disconnect(conn, conn->is_client);
+               return;
+       }
+       if (rv == MSR_SCHED_BOTH)
+               msg_conn_sched_proc_msgs(conn);
+       msg_conn_sched_read(conn);
+}
+
+/* collapse this into mgmt_msg_procbufs */
+static void msg_conn_proc_msgs(struct event *thread)
+{
+       struct msg_conn *conn = EVENT_ARG(thread);
+
+       if (mgmt_msg_procbufs(&conn->mstate,
+                             (void (*)(uint8_t, uint8_t *, size_t,
+                                       void *))conn->handle_msg,
+                             conn, conn->debug))
+               /* there's more, schedule handling more */
+               msg_conn_sched_proc_msgs(conn);
+}
+
+static void msg_conn_sched_read(struct msg_conn *conn)
+{
+       event_add_read(conn->loop, msg_conn_read, conn, conn->fd,
+                      &conn->read_ev);
+}
+
+static void msg_conn_sched_write(struct msg_conn *conn)
+{
+       event_add_write(conn->loop, msg_conn_write, conn, conn->fd,
+                       &conn->write_ev);
+}
+
+static void msg_conn_sched_proc_msgs(struct msg_conn *conn)
+{
+       event_add_event(conn->loop, msg_conn_proc_msgs, conn, 0,
+                       &conn->proc_msg_ev);
+}
+
+
+void msg_conn_disconnect(struct msg_conn *conn, bool reconnect)
+{
+
+       /* disconnect short-circuit if present */
+       if (conn->remote_conn) {
+               conn->remote_conn->remote_conn = NULL;
+               conn->remote_conn = NULL;
+       }
+
+       if (conn->fd != -1) {
+               close(conn->fd);
+               conn->fd = -1;
+
+               /* Notify client through registered callback (if any) */
+               if (conn->notify_disconnect)
+                       (void)(*conn->notify_disconnect)(conn);
+       }
+
+       if (reconnect) {
+               assert(conn->is_client);
+               msg_client_sched_connect(
+                       container_of(conn, struct msg_client, conn),
+                       MSG_CONN_DEFAULT_CONN_RETRY_MSEC);
+       }
+}
+
+int msg_conn_send_msg(struct msg_conn *conn, uint8_t version, void *msg,
+                     size_t mlen, size_t (*packf)(void *, void *),
+                     bool short_circuit_ok)
+{
+       const char *dbgtag = conn->debug ? conn->mstate.idtag : NULL;
+
+       if (conn->fd == -1) {
+               MGMT_MSG_ERR(&conn->mstate,
+                            "can't send message on closed connection");
+               return -1;
+       }
+
+       /* immediately handle the message if short-circuit is present */
+       if (conn->remote_conn && short_circuit_ok) {
+               uint8_t *buf = msg;
+               size_t n = mlen;
+
+               if (packf) {
+                       buf = XMALLOC(MTYPE_TMP, mlen);
+                       n = packf(msg, buf);
+               }
+
+               MGMT_MSG_DBG(dbgtag, "SC send: depth %u msg: %p",
+                            ++conn->short_circuit_depth, msg);
+
+               conn->remote_conn->handle_msg(version, buf, n,
+                                             conn->remote_conn);
+
+               MGMT_MSG_DBG(dbgtag, "SC return from depth: %u msg: %p",
+                            conn->short_circuit_depth--, msg);
+
+               if (packf)
+                       XFREE(MTYPE_TMP, buf);
+               return 0;
+       }
+
+       int rv = mgmt_msg_send_msg(&conn->mstate, version, msg, mlen, packf,
+                                  conn->debug);
+
+       msg_conn_sched_write(conn);
+
+       return rv;
+}
+
+void msg_conn_cleanup(struct msg_conn *conn)
+{
+       struct mgmt_msg_state *ms = &conn->mstate;
+
+       /* disconnect short-circuit if present */
+       if (conn->remote_conn) {
+               conn->remote_conn->remote_conn = NULL;
+               conn->remote_conn = NULL;
+       }
+
+       if (conn->fd != -1) {
+               close(conn->fd);
+               conn->fd = -1;
+       }
+
+       EVENT_OFF(conn->read_ev);
+       EVENT_OFF(conn->write_ev);
+       EVENT_OFF(conn->proc_msg_ev);
+
+       mgmt_msg_destroy(ms);
+}
+
+/*
+ * Client Connections
+ */
+
+DECLARE_LIST(msg_server_list, struct msg_server, link);
+
+static struct msg_server_list_head msg_servers;
+
+static void msg_client_connect(struct msg_client *conn);
+
+static void msg_client_connect_timer(struct event *thread)
+{
+       msg_client_connect(EVENT_ARG(thread));
+}
+
+static void msg_client_sched_connect(struct msg_client *client,
+                                    unsigned long msec)
+{
+       struct msg_conn *conn = &client->conn;
+       const char *dbgtag = conn->debug ? conn->mstate.idtag : NULL;
+
+       MGMT_MSG_DBG(dbgtag, "connection retry in %lu msec", msec);
+       if (msec)
+               event_add_timer_msec(conn->loop, msg_client_connect_timer,
+                                    client, msec, &client->conn_retry_tmr);
+       else
+               event_add_event(conn->loop, msg_client_connect_timer, client, 0,
+                               &client->conn_retry_tmr);
+}
+
+static bool msg_client_connect_short_circuit(struct msg_client *client)
+{
+       struct msg_conn *server_conn;
+       struct msg_server *server;
+       const char *dbgtag =
+               client->conn.debug ? client->conn.mstate.idtag : NULL;
+       union sockunion su = {0};
+       int sockets[2];
+
+       frr_each (msg_server_list, &msg_servers, server)
+               if (!strcmp(server->sopath, client->sopath))
+                       break;
+       if (!server) {
+               MGMT_MSG_DBG(dbgtag,
+                            "no short-circuit connection available for %s",
+                            client->sopath);
+
+               return false;
+       }
+
+       if (socketpair(AF_UNIX, SOCK_STREAM, 0, sockets)) {
+               MGMT_MSG_ERR(
+                       &client->conn.mstate,
+                       "socketpair failed trying to short-circuit connection on %s: %s",
+                       client->sopath, safe_strerror(errno));
+               return false;
+       }
+
+       /* client side */
+       client->conn.fd = sockets[0];
+       set_nonblocking(sockets[0]);
+       setsockopt_so_sendbuf(sockets[0], client->conn.mstate.max_write_buf);
+       setsockopt_so_recvbuf(sockets[0], client->conn.mstate.max_read_buf);
+       client->conn.is_short_circuit = true;
+
+       /* server side */
+       memset(&su, 0, sizeof(union sockunion));
+       server_conn = server->create(sockets[1], &su);
+       server_conn->is_short_circuit = true;
+
+       client->conn.remote_conn = server_conn;
+       server_conn->remote_conn = &client->conn;
+
+       MGMT_MSG_DBG(
+               dbgtag,
+               "short-circuit connection on %s server %s:%d to client %s:%d",
+               client->sopath, server_conn->mstate.idtag, server_conn->fd,
+               client->conn.mstate.idtag, client->conn.fd);
+
+       MGMT_MSG_DBG(
+               server_conn->debug ? server_conn->mstate.idtag : NULL,
+               "short-circuit connection on %s client %s:%d to server %s:%d",
+               client->sopath, client->conn.mstate.idtag, client->conn.fd,
+               server_conn->mstate.idtag, server_conn->fd);
+
+       return true;
+}
+
+
+/* Connect and start reading from the socket */
+static void msg_client_connect(struct msg_client *client)
+{
+       struct msg_conn *conn = &client->conn;
+       const char *dbgtag = conn->debug ? conn->mstate.idtag : NULL;
+
+       if (!client->short_circuit_ok ||
+           !msg_client_connect_short_circuit(client))
+               conn->fd =
+                       mgmt_msg_connect(client->sopath, MSG_CONN_SEND_BUF_SIZE,
+                                        MSG_CONN_RECV_BUF_SIZE, dbgtag);
+
+       if (conn->fd == -1)
+               /* retry the connection */
+               msg_client_sched_connect(client,
+                                        MSG_CONN_DEFAULT_CONN_RETRY_MSEC);
+       else if (client->notify_connect && client->notify_connect(client))
+               /* client connect notify failed */
+               msg_conn_disconnect(conn, true);
+       else
+               /* start reading */
+               msg_conn_sched_read(conn);
+}
+
+void msg_client_init(struct msg_client *client, struct event_loop *tm,
+                    const char *sopath,
+                    int (*notify_connect)(struct msg_client *client),
+                    int (*notify_disconnect)(struct msg_conn *client),
+                    void (*handle_msg)(uint8_t version, uint8_t *data,
+                                       size_t len, struct msg_conn *client),
+                    size_t max_read_buf, size_t max_write_buf,
+                    size_t max_msg_sz, bool short_circuit_ok,
+                    const char *idtag, bool debug)
+{
+       struct msg_conn *conn = &client->conn;
+       memset(client, 0, sizeof(*client));
+
+       conn->loop = tm;
+       conn->fd = -1;
+       conn->handle_msg = handle_msg;
+       conn->notify_disconnect = notify_disconnect;
+       conn->is_client = true;
+       conn->debug = debug;
+       client->short_circuit_ok = short_circuit_ok;
+       client->sopath = strdup(sopath);
+       client->notify_connect = notify_connect;
+
+       mgmt_msg_init(&conn->mstate, max_read_buf, max_write_buf, max_msg_sz,
+                     idtag);
+
+       /* XXX maybe just have client kick this off */
+       /* Start trying to connect to server */
+       msg_client_sched_connect(client, 0);
+}
+
+void msg_client_cleanup(struct msg_client *client)
+{
+       assert(client->conn.is_client);
+
+       EVENT_OFF(client->conn_retry_tmr);
+       free(client->sopath);
+
+       msg_conn_cleanup(&client->conn);
+}
+
+
+/*
+ * Server-side connections
+ */
+
+static void msg_server_accept(struct event *event)
+{
+       struct msg_server *server = EVENT_ARG(event);
+       int fd;
+       union sockunion su;
+
+       if (server->fd < 0)
+               return;
+
+       /* We continue hearing server listen socket. */
+       event_add_read(server->loop, msg_server_accept, server, server->fd,
+                      &server->listen_ev);
+
+       memset(&su, 0, sizeof(union sockunion));
+
+       /* We can handle IPv4 or IPv6 socket. */
+       fd = sockunion_accept(server->fd, &su);
+       if (fd < 0) {
+               zlog_err("Failed to accept %s client connection: %s",
+                        server->idtag, safe_strerror(errno));
+               return;
+       }
+       set_nonblocking(fd);
+       set_cloexec(fd);
+
+       DEBUGD(server->debug, "Accepted new %s connection", server->idtag);
+
+       server->create(fd, &su);
+}
+
+int msg_server_init(struct msg_server *server, const char *sopath,
+                   struct event_loop *loop,
+                   struct msg_conn *(*create)(int fd, union sockunion *su),
+                   const char *idtag, struct debug *debug)
+{
+       int ret;
+       int sock;
+       struct sockaddr_un addr;
+       mode_t old_mask;
+
+       memset(server, 0, sizeof(*server));
+       server->fd = -1;
+
+       sock = socket(AF_UNIX, SOCK_STREAM, PF_UNSPEC);
+       if (sock < 0) {
+               zlog_err("Failed to create %s server socket: %s", server->idtag,
+                        safe_strerror(errno));
+               goto fail;
+       }
+
+       addr.sun_family = AF_UNIX,
+       strlcpy(addr.sun_path, sopath, sizeof(addr.sun_path));
+       unlink(addr.sun_path);
+       old_mask = umask(0077);
+       ret = bind(sock, (struct sockaddr *)&addr, sizeof(addr));
+       if (ret < 0) {
+               zlog_err("Failed to bind %s server socket to '%s': %s",
+                        server->idtag, addr.sun_path, safe_strerror(errno));
+               umask(old_mask);
+               goto fail;
+       }
+       umask(old_mask);
+
+       ret = listen(sock, MGMTD_MAX_CONN);
+       if (ret < 0) {
+               zlog_err("Failed to listen on %s server socket: %s",
+                        server->idtag, safe_strerror(errno));
+               goto fail;
+       }
+
+       server->fd = sock;
+       server->loop = loop;
+       server->sopath = strdup(sopath);
+       server->idtag = strdup(idtag);
+       server->create = create;
+       server->debug = debug;
+
+       msg_server_list_add_head(&msg_servers, server);
+
+       event_add_read(server->loop, msg_server_accept, server, server->fd,
+                      &server->listen_ev);
+
+
+       DEBUGD(debug, "Started %s server, listening on %s", idtag, sopath);
+       return 0;
+
+fail:
+       if (sock >= 0)
+               close(sock);
+       server->fd = -1;
+       return -1;
+}
+
+void msg_server_cleanup(struct msg_server *server)
+{
+       DEBUGD(server->debug, "Closing %s server", server->idtag);
+
+       if (server->listen_ev)
+               EVENT_OFF(server->listen_ev);
+
+       msg_server_list_del(&msg_servers, server);
+
+       if (server->fd >= 0)
+               close(server->fd);
+       free((char *)server->sopath);
+       free((char *)server->idtag);
+
+       memset(server, 0, sizeof(*server));
+       server->fd = -1;
+}
+
+/*
+ * Initialize and start reading from the accepted socket
+ *
+ *     notify_connect - only called for disconnect i.e., connected == false
+ */
+void msg_conn_accept_init(struct msg_conn *conn, struct event_loop *tm, int fd,
+                         int (*notify_disconnect)(struct msg_conn *conn),
+                         void (*handle_msg)(uint8_t version, uint8_t *data,
+                                            size_t len, struct msg_conn *conn),
+                         size_t max_read, size_t max_write, size_t max_size,
+                         const char *idtag)
+{
+       conn->loop = tm;
+       conn->fd = fd;
+       conn->notify_disconnect = notify_disconnect;
+       conn->handle_msg = handle_msg;
+       conn->is_client = false;
+
+       mgmt_msg_init(&conn->mstate, max_read, max_write, max_size, idtag);
+
+       /* start reading */
+       msg_conn_sched_read(conn);
+
+       /* Make socket non-blocking.  */
+       set_nonblocking(conn->fd);
+       setsockopt_so_sendbuf(conn->fd, MSG_CONN_SEND_BUF_SIZE);
+       setsockopt_so_recvbuf(conn->fd, MSG_CONN_RECV_BUF_SIZE);
+}
+
+struct msg_conn *
+msg_server_conn_create(struct event_loop *tm, int fd,
+                      int (*notify_disconnect)(struct msg_conn *conn),
+                      void (*handle_msg)(uint8_t version, uint8_t *data,
+                                         size_t len, struct msg_conn *conn),
+                      size_t max_read, size_t max_write, size_t max_size,
+                      void *user, const char *idtag)
+{
+       struct msg_conn *conn = XMALLOC(MTYPE_MSG_CONN, sizeof(*conn));
+       memset(conn, 0, sizeof(*conn));
+       msg_conn_accept_init(conn, tm, fd, notify_disconnect, handle_msg,
+                            max_read, max_write, max_size, idtag);
+       conn->user = user;
+       return conn;
+}
+
+void msg_server_conn_delete(struct msg_conn *conn)
+{
+       if (!conn)
+               return;
+       msg_conn_cleanup(conn);
+       XFREE(MTYPE_MSG_CONN, conn);
+}