return move_fd(fd);
}
-int lxc_abstract_unix_send_fds_iov(int fd, int *sendfds, int num_sendfds,
+int lxc_abstract_unix_send_fds_iov(int fd, const int *sendfds, int num_sendfds,
struct iovec *iov, size_t iovlen)
{
__do_free char *cmsgbuf = NULL;
int ret;
- struct msghdr msg;
+ struct msghdr msg = {};
struct cmsghdr *cmsg = NULL;
size_t cmsgbufsize = CMSG_SPACE(num_sendfds * sizeof(int));
- memset(&msg, 0, sizeof(msg));
-
cmsgbuf = malloc(cmsgbufsize);
- if (!cmsgbuf) {
- errno = ENOMEM;
- return -1;
- }
+ if (!cmsgbuf)
+ return ret_errno(-ENOMEM);
msg.msg_control = cmsgbuf;
msg.msg_controllen = cmsgbufsize;
return ret;
}
-int lxc_abstract_unix_send_fds(int fd, int *sendfds, int num_sendfds,
+int lxc_abstract_unix_send_fds(int fd, const int *sendfds, int num_sendfds,
void *data, size_t size)
{
- char buf[1] = {0};
+ char buf[1] = {};
struct iovec iov = {
- .iov_base = data ? data : buf,
- .iov_len = data ? size : sizeof(buf),
+ .iov_base = data ? data : buf,
+ .iov_len = data ? size : sizeof(buf),
};
return lxc_abstract_unix_send_fds_iov(fd, sendfds, num_sendfds, &iov, 1);
}
return lxc_abstract_unix_send_fds(fd, sendfds, num_sendfds, data, size);
}
-static int lxc_abstract_unix_recv_fds_iov(int fd, int *recvfds, int num_recvfds,
- struct iovec *iov, size_t iovlen)
+static ssize_t lxc_abstract_unix_recv_fds_iov(int fd,
+ struct unix_fds *ret_fds,
+ struct iovec *ret_iov,
+ size_t size_ret_iov)
{
__do_free char *cmsgbuf = NULL;
- int ret;
- struct msghdr msg;
+ ssize_t ret;
+ struct msghdr msg = {};
+ struct cmsghdr *cmsg = NULL;
size_t cmsgbufsize = CMSG_SPACE(sizeof(struct ucred)) +
- CMSG_SPACE(num_recvfds * sizeof(int));
+ CMSG_SPACE(ret_fds->fd_count_max * sizeof(int));
- memset(&msg, 0, sizeof(msg));
-
- cmsgbuf = malloc(cmsgbufsize);
+ cmsgbuf = zalloc(cmsgbufsize);
if (!cmsgbuf)
return ret_errno(ENOMEM);
- msg.msg_control = cmsgbuf;
- msg.msg_controllen = cmsgbufsize;
+ msg.msg_control = cmsgbuf;
+ msg.msg_controllen = cmsgbufsize;
- msg.msg_iov = iov;
- msg.msg_iovlen = iovlen;
+ msg.msg_iov = ret_iov;
+ msg.msg_iovlen = size_ret_iov;
- do {
- ret = recvmsg(fd, &msg, MSG_CMSG_CLOEXEC);
- } while (ret < 0 && errno == EINTR);
- if (ret < 0 || ret == 0)
- return ret;
+again:
+ ret = recvmsg(fd, &msg, MSG_CMSG_CLOEXEC);
+ if (ret < 0) {
+ if (errno == EINTR)
+ goto again;
- /*
- * If SO_PASSCRED is set we will always get a ucred message.
- */
- for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
- if (cmsg->cmsg_type != SCM_RIGHTS)
- continue;
-
- memset(recvfds, -1, num_recvfds * sizeof(int));
- if (cmsg &&
- cmsg->cmsg_len == CMSG_LEN(num_recvfds * sizeof(int)) &&
- cmsg->cmsg_level == SOL_SOCKET)
- memcpy(recvfds, CMSG_DATA(cmsg), num_recvfds * sizeof(int));
- break;
+ return syserrno(-errno, "Failed to receive response");
+ }
+ if (ret == 0)
+ return 0;
+
+ /* If SO_PASSCRED is set we will always get a ucred message. */
+ for (cmsg = CMSG_FIRSTHDR(&msg); cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
+ if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
+ __u32 idx;
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wcast-align"
+ int *fds_raw = (int *)CMSG_DATA(cmsg);
+#pragma GCC diagnostic pop
+ __u32 num_raw = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
+
+ /*
+ * We received an insane amount of file descriptors
+ * which exceeds the kernel limit we know about so
+ * close them and return an error.
+ */
+ if (num_raw > KERNEL_SCM_MAX_FD) {
+ for (idx = 0; idx < num_raw; idx++)
+ close(fds_raw[idx]);
+
+ return syserrno_set(-EFBIG, "Received excessive number of file descriptors");
+ }
+
+ if (ret_fds->fd_count_max > num_raw) {
+ /*
+ * Make sure any excess entries in the fd array
+ * are set to -EBADF so our cleanup functions
+ * can safely be called.
+ */
+ for (idx = num_raw; idx < ret_fds->fd_count_max; idx++)
+ ret_fds->fd[idx] = -EBADF;
+
+ WARN("Received fewer file descriptors than we expected %u != %u", ret_fds->fd_count_max, num_raw);
+ } else if (ret_fds->fd_count_max < num_raw) {
+ /* Make sure we close any excess fds we received. */
+ for (idx = ret_fds->fd_count_max; idx < num_raw; idx++)
+ close(fds_raw[idx]);
+
+ WARN("Received more file descriptors than we expected %u != %u", ret_fds->fd_count_max, num_raw);
+
+ /* Cap the number of received file descriptors. */
+ num_raw = ret_fds->fd_count_max;
+ }
+
+ memcpy(ret_fds->fd, CMSG_DATA(cmsg), num_raw * sizeof(int));
+ ret_fds->fd_count_ret = num_raw;
+ break;
+ }
}
return ret;
}
-int lxc_abstract_unix_recv_fds(int fd, int *recvfds, int num_recvfds,
- void *data, size_t size)
+ssize_t lxc_abstract_unix_recv_fds(int fd, struct unix_fds *ret_fds,
+ void *ret_data, size_t size_ret_data)
{
- char buf[1] = {0};
+ char buf[1] = {};
+ struct iovec iov = {
+ .iov_base = ret_data ? ret_data : buf,
+ .iov_len = ret_data ? size_ret_data : sizeof(buf),
+ };
+ ssize_t ret;
+
+ ret = lxc_abstract_unix_recv_fds_iov(fd, ret_fds, &iov, 1);
+ if (ret < 0)
+ return ret;
+
+ return ret;
+}
+
+ssize_t lxc_abstract_unix_recv_one_fd(int fd, int *ret_fd, void *ret_data,
+ size_t size_ret_data)
+{
+ call_cleaner(put_unix_fds) struct unix_fds *fds = NULL;
+ char buf[1] = {};
struct iovec iov = {
- .iov_base = data ? data : buf,
- .iov_len = data ? size : sizeof(buf),
+ .iov_base = ret_data ? ret_data : buf,
+ .iov_len = ret_data ? size_ret_data : sizeof(buf),
+ };
+ ssize_t ret;
+
+ fds = &(struct unix_fds){
+ .fd_count_max = 1,
};
- return lxc_abstract_unix_recv_fds_iov(fd, recvfds, num_recvfds, &iov, 1);
+
+ ret = lxc_abstract_unix_recv_fds_iov(fd, fds, &iov, 1);
+ if (ret < 0)
+ return ret;
+
+ if (ret == 0)
+ return ret_errno(ENODATA);
+
+ if (fds->fd_count_ret != fds->fd_count_max)
+ *ret_fd = -EBADF;
+ else
+ *ret_fd = move_fd(fds->fd[0]);
+
+ return ret;
+}
+
+ssize_t lxc_abstract_unix_recv_two_fds(int fd, int *ret_fd)
+{
+ call_cleaner(put_unix_fds) struct unix_fds *fds = NULL;
+ char buf[1] = {};
+ struct iovec iov = {
+ .iov_base = buf,
+ .iov_len = sizeof(buf),
+ };
+ ssize_t ret;
+
+ fds = &(struct unix_fds){
+ .fd_count_max = 2,
+ };
+
+ ret = lxc_abstract_unix_recv_fds_iov(fd, fds, &iov, 1);
+ if (ret < 0)
+ return ret;
+
+ if (ret == 0)
+ return ret_errno(ENODATA);
+
+ if (fds->fd_count_ret != fds->fd_count_max) {
+ ret_fd[0] = -EBADF;
+ ret_fd[1] = -EBADF;
+ } else {
+ ret_fd[0] = move_fd(fds->fd[0]);
+ ret_fd[1] = move_fd(fds->fd[1]);
+ }
+
+ return 0;
}
int lxc_abstract_unix_send_credential(int fd, void *data, size_t size)
#include <stdio.h>
#include <sys/socket.h>
+#include <stddef.h>
#include <sys/un.h>
#include "compiler.h"
+#include "macro.h"
+#include "memory_utils.h"
+
+/*
+ * Technically 253 is the kernel limit but we want to the struct to be a
+ * multiple of 8.
+ */
+#define KERNEL_SCM_MAX_FD 252
+
+struct unix_fds {
+ __u32 fd_count_max;
+ __u32 fd_count_ret;
+ __s32 fd[KERNEL_SCM_MAX_FD];
+} __attribute__((aligned(8)));
/* does not enforce \0-termination */
__hidden extern int lxc_abstract_unix_open(const char *path, int type, int flags);
/* does not enforce \0-termination */
__hidden extern int lxc_abstract_unix_connect(const char *path);
-__hidden extern int lxc_abstract_unix_send_fds(int fd, int *sendfds, int num_sendfds, void *data,
- size_t size) __access_r(2, 3) __access_r(4, 5);
+__hidden extern int lxc_abstract_unix_send_fds(int fd, const int *sendfds,
+ int num_sendfds, void *data,
+ size_t size) __access_r(2, 3)
+ __access_r(4, 5);
+
+__hidden extern int lxc_abstract_unix_send_fds_iov(int fd, const int *sendfds,
+ int num_sendfds,
+ struct iovec *iov,
+ size_t iovlen)
+ __access_r(2, 3);
+
+__hidden extern ssize_t lxc_abstract_unix_recv_fds(int fd,
+ struct unix_fds *ret_fds,
+ void *ret_data,
+ size_t size_ret_data)
+ __access_r(3, 4);
-__hidden extern int lxc_abstract_unix_send_fds_iov(int fd, int *sendfds, int num_sendfds,
- struct iovec *iov, size_t iovlen) __access_r(2, 3);
+__hidden extern ssize_t lxc_abstract_unix_recv_one_fd(int fd, int *ret_fd,
+ void *ret_data,
+ size_t size_ret_data)
+ __access_r(3, 4);
-__hidden extern int lxc_abstract_unix_recv_fds(int fd, int *recvfds, int num_recvfds, void *data,
- size_t size) __access_r(2, 3) __access_r(4, 5);
+__hidden extern ssize_t lxc_abstract_unix_recv_two_fds(int fd, int *ret_fd);
__hidden extern int lxc_unix_send_fds(int fd, int *sendfds, int num_sendfds, void *data, size_t size);
__hidden extern int lxc_unix_connect_type(struct sockaddr_un *addr, int type);
__hidden extern int lxc_socket_set_timeout(int fd, int rcv_timeout, int snd_timeout);
+static inline void put_unix_fds(struct unix_fds *fds)
+{
+ if (!IS_ERR_OR_NULL(fds)) {
+ for (size_t idx = 0; idx < fds->fd_count_ret; idx++)
+ close_prot_errno_disarm(fds->fd[idx]);
+ }
+}
+define_cleanup_function(struct unix_fds *, put_unix_fds);
+
#endif /* __LXC_AF_UNIX_H */
*/
static int lxc_cmd_rsp_recv(int sock, struct lxc_cmd_rr *cmd)
{
- __do_close int fd_rsp = -EBADF;
+ call_cleaner(put_unix_fds) struct unix_fds *fds = NULL;
int ret;
struct lxc_cmd_rsp *rsp = &cmd->rsp;
- ret = lxc_abstract_unix_recv_fds(sock, &fd_rsp, 1, rsp, sizeof(*rsp));
+ fds = &(struct unix_fds){
+ .fd_count_max = 1,
+ };
+
+ ret = lxc_abstract_unix_recv_fds(sock, fds, rsp, sizeof(*rsp));
if (ret < 0)
return log_warn_errno(-1,
errno, "Failed to receive response for command \"%s\"",
ENOMEM, "Failed to receive response for command \"%s\"",
lxc_cmd_str(cmd->req.cmd));
- rspdata->ptxfd = move_fd(fd_rsp);
+ rspdata->ptxfd = move_fd(fds->fd[0]);
rspdata->ttynum = PTR_TO_INT(rsp->data);
rsp->data = rspdata;
}
if (cmd->req.cmd == LXC_CMD_GET_CGROUP2_FD ||
- cmd->req.cmd == LXC_CMD_GET_LIMITING_CGROUP2_FD)
- {
- int cgroup2_fd = move_fd(fd_rsp);
+ cmd->req.cmd == LXC_CMD_GET_LIMITING_CGROUP2_FD) {
+ int cgroup2_fd = move_fd(fds->fd[0]);
rsp->data = INT_TO_PTR(cgroup2_fd);
}
if (cmd->req.cmd == LXC_CMD_GET_INIT_PIDFD) {
- int init_pidfd = move_fd(fd_rsp);
+ int init_pidfd = move_fd(fds->fd[0]);
rsp->data = INT_TO_PTR(init_pidfd);
}
if (cmd->req.cmd == LXC_CMD_GET_DEVPTS_FD) {
- int devpts_fd = move_fd(fd_rsp);
+ int devpts_fd = move_fd(fds->fd[0]);
rsp->data = INT_TO_PTR(devpts_fd);
}
if (cmd->req.cmd == LXC_CMD_GET_SECCOMP_NOTIFY_FD) {
- int seccomp_notify_fd = move_fd(fd_rsp);
+ int seccomp_notify_fd = move_fd(fds->fd[0]);
rsp->data = INT_TO_PTR(seccomp_notify_fd);
}
int ret;
__do_close int recv_fd = -EBADF;
- ret = lxc_abstract_unix_recv_fds(fd, &recv_fd, 1, NULL, 0);
+ ret = lxc_abstract_unix_recv_one_fd(fd, &recv_fd, NULL, 0);
if (ret <= 0) {
rsp.ret = -errno;
goto out;