]> git.proxmox.com Git - mirror_qemu.git/blobdiff - util/oslib-win32.c
win32: replace closesocket() with close() wrapper
[mirror_qemu.git] / util / oslib-win32.c
index 528c9ee156d99aed44e7efa5f4678371f43ee348..29a667ae3d9232b3e42f1ef026587b00b43d4540 100644 (file)
@@ -180,7 +180,7 @@ static int socket_error(void)
 void qemu_socket_set_block(int fd)
 {
     unsigned long opt = 0;
-    WSAEventSelect(fd, NULL, 0);
+    qemu_socket_unselect(fd, NULL);
     ioctlsocket(fd, FIONBIO, &opt);
 }
 
@@ -283,12 +283,45 @@ char *qemu_get_pid_name(pid_t pid)
 }
 
 
+bool qemu_socket_select(int sockfd, WSAEVENT hEventObject,
+                        long lNetworkEvents, Error **errp)
+{
+    SOCKET s = _get_osfhandle(sockfd);
+
+    if (errp == NULL) {
+        errp = &error_warn;
+    }
+
+    if (s == INVALID_SOCKET) {
+        error_setg(errp, "invalid socket fd=%d", sockfd);
+        return false;
+    }
+
+    if (WSAEventSelect(s, hEventObject, lNetworkEvents) != 0) {
+        error_setg_win32(errp, WSAGetLastError(), "failed to WSAEventSelect()");
+        return false;
+    }
+
+    return true;
+}
+
+bool qemu_socket_unselect(int sockfd, Error **errp)
+{
+    return qemu_socket_select(sockfd, NULL, 0, errp);
+}
+
 #undef connect
 int qemu_connect_wrap(int sockfd, const struct sockaddr *addr,
                       socklen_t addrlen)
 {
     int ret;
-    ret = connect(sockfd, addr, addrlen);
+    SOCKET s = _get_osfhandle(sockfd);
+
+    if (s == INVALID_SOCKET) {
+        return -1;
+    }
+
+    ret = connect(s, addr, addrlen);
     if (ret < 0) {
         if (WSAGetLastError() == WSAEWOULDBLOCK) {
             errno = EINPROGRESS;
@@ -304,7 +337,13 @@ int qemu_connect_wrap(int sockfd, const struct sockaddr *addr,
 int qemu_listen_wrap(int sockfd, int backlog)
 {
     int ret;
-    ret = listen(sockfd, backlog);
+    SOCKET s = _get_osfhandle(sockfd);
+
+    if (s == INVALID_SOCKET) {
+        return -1;
+    }
+
+    ret = listen(s, backlog);
     if (ret < 0) {
         errno = socket_error();
     }
@@ -317,7 +356,13 @@ int qemu_bind_wrap(int sockfd, const struct sockaddr *addr,
                    socklen_t addrlen)
 {
     int ret;
-    ret = bind(sockfd, addr, addrlen);
+    SOCKET s = _get_osfhandle(sockfd);
+
+    if (s == INVALID_SOCKET) {
+        return -1;
+    }
+
+    ret = bind(s, addr, addrlen);
     if (ret < 0) {
         errno = socket_error();
     }
@@ -325,15 +370,82 @@ int qemu_bind_wrap(int sockfd, const struct sockaddr *addr,
 }
 
 
+#undef close
+int qemu_close_wrap(int fd)
+{
+    int ret;
+    DWORD flags = 0;
+    SOCKET s = INVALID_SOCKET;
+
+    if (fd_is_socket(fd)) {
+        s = _get_osfhandle(fd);
+
+        /*
+         * If we were to just call _close on the descriptor, it would close the
+         * HANDLE, but it wouldn't free any of the resources associated to the
+         * SOCKET, and we can't call _close after calling closesocket, because
+         * closesocket has already closed the HANDLE, and _close would attempt to
+         * close the HANDLE again, resulting in a double free. We can however
+         * protect the HANDLE from actually being closed long enough to close the
+         * file descriptor, then close the socket itself.
+         */
+        if (!GetHandleInformation((HANDLE)s, &flags)) {
+            errno = EACCES;
+            return -1;
+        }
+
+        if (!SetHandleInformation((HANDLE)s, HANDLE_FLAG_PROTECT_FROM_CLOSE, HANDLE_FLAG_PROTECT_FROM_CLOSE)) {
+            errno = EACCES;
+            return -1;
+        }
+    }
+
+    ret = close(fd);
+
+    if (s != INVALID_SOCKET && !SetHandleInformation((HANDLE)s, flags, flags)) {
+        errno = EACCES;
+        return -1;
+    }
+
+    /*
+     * close() returns EBADF since we PROTECT_FROM_CLOSE the underlying handle,
+     * but the FD is actually freed
+     */
+    if (ret < 0 && (s == INVALID_SOCKET || errno != EBADF)) {
+        return ret;
+    }
+
+    if (s != INVALID_SOCKET) {
+        ret = closesocket(s);
+        if (ret < 0) {
+            errno = socket_error();
+        }
+    }
+
+    return ret;
+}
+
+
 #undef socket
 int qemu_socket_wrap(int domain, int type, int protocol)
 {
-    int ret;
-    ret = socket(domain, type, protocol);
-    if (ret < 0) {
+    SOCKET s;
+    int fd;
+
+    s = socket(domain, type, protocol);
+    if (s == -1) {
         errno = socket_error();
+        return -1;
     }
-    return ret;
+
+    fd = _open_osfhandle(s, _O_BINARY);
+    if (fd < 0) {
+        closesocket(s);
+        /* _open_osfhandle may not set errno, and closesocket() may override it */
+        errno = ENOMEM;
+    }
+
+    return fd;
 }
 
 
@@ -341,12 +453,27 @@ int qemu_socket_wrap(int domain, int type, int protocol)
 int qemu_accept_wrap(int sockfd, struct sockaddr *addr,
                      socklen_t *addrlen)
 {
-    int ret;
-    ret = accept(sockfd, addr, addrlen);
-    if (ret < 0) {
+    int fd;
+    SOCKET s = _get_osfhandle(sockfd);
+
+    if (s == INVALID_SOCKET) {
+        return -1;
+    }
+
+    s = accept(s, addr, addrlen);
+    if (s == -1) {
         errno = socket_error();
+        return -1;
     }
-    return ret;
+
+    fd = _open_osfhandle(s, _O_BINARY);
+    if (fd < 0) {
+        closesocket(s);
+        /* _open_osfhandle may not set errno, and closesocket() may override it */
+        errno = ENOMEM;
+    }
+
+    return fd;
 }
 
 
@@ -354,7 +481,13 @@ int qemu_accept_wrap(int sockfd, struct sockaddr *addr,
 int qemu_shutdown_wrap(int sockfd, int how)
 {
     int ret;
-    ret = shutdown(sockfd, how);
+    SOCKET s = _get_osfhandle(sockfd);
+
+    if (s == INVALID_SOCKET) {
+        return -1;
+    }
+
+    ret = shutdown(s, how);
     if (ret < 0) {
         errno = socket_error();
     }
@@ -366,19 +499,13 @@ int qemu_shutdown_wrap(int sockfd, int how)
 int qemu_ioctlsocket_wrap(int fd, int req, void *val)
 {
     int ret;
-    ret = ioctlsocket(fd, req, val);
-    if (ret < 0) {
-        errno = socket_error();
-    }
-    return ret;
-}
+    SOCKET s = _get_osfhandle(fd);
 
+    if (s == INVALID_SOCKET) {
+        return -1;
+    }
 
-#undef closesocket
-int qemu_closesocket_wrap(int fd)
-{
-    int ret;
-    ret = closesocket(fd);
+    ret = ioctlsocket(s, req, val);
     if (ret < 0) {
         errno = socket_error();
     }
@@ -391,7 +518,13 @@ int qemu_getsockopt_wrap(int sockfd, int level, int optname,
                          void *optval, socklen_t *optlen)
 {
     int ret;
-    ret = getsockopt(sockfd, level, optname, optval, optlen);
+    SOCKET s = _get_osfhandle(sockfd);
+
+    if (s == INVALID_SOCKET) {
+        return -1;
+    }
+
+    ret = getsockopt(s, level, optname, optval, optlen);
     if (ret < 0) {
         errno = socket_error();
     }
@@ -404,7 +537,13 @@ int qemu_setsockopt_wrap(int sockfd, int level, int optname,
                          const void *optval, socklen_t optlen)
 {
     int ret;
-    ret = setsockopt(sockfd, level, optname, optval, optlen);
+    SOCKET s = _get_osfhandle(sockfd);
+
+    if (s == INVALID_SOCKET) {
+        return -1;
+    }
+
+    ret = setsockopt(s, level, optname, optval, optlen);
     if (ret < 0) {
         errno = socket_error();
     }
@@ -417,7 +556,13 @@ int qemu_getpeername_wrap(int sockfd, struct sockaddr *addr,
                           socklen_t *addrlen)
 {
     int ret;
-    ret = getpeername(sockfd, addr, addrlen);
+    SOCKET s = _get_osfhandle(sockfd);
+
+    if (s == INVALID_SOCKET) {
+        return -1;
+    }
+
+    ret = getpeername(s, addr, addrlen);
     if (ret < 0) {
         errno = socket_error();
     }
@@ -430,7 +575,13 @@ int qemu_getsockname_wrap(int sockfd, struct sockaddr *addr,
                           socklen_t *addrlen)
 {
     int ret;
-    ret = getsockname(sockfd, addr, addrlen);
+    SOCKET s = _get_osfhandle(sockfd);
+
+    if (s == INVALID_SOCKET) {
+        return -1;
+    }
+
+    ret = getsockname(s, addr, addrlen);
     if (ret < 0) {
         errno = socket_error();
     }
@@ -442,7 +593,13 @@ int qemu_getsockname_wrap(int sockfd, struct sockaddr *addr,
 ssize_t qemu_send_wrap(int sockfd, const void *buf, size_t len, int flags)
 {
     int ret;
-    ret = send(sockfd, buf, len, flags);
+    SOCKET s = _get_osfhandle(sockfd);
+
+    if (s == INVALID_SOCKET) {
+        return -1;
+    }
+
+    ret = send(s, buf, len, flags);
     if (ret < 0) {
         errno = socket_error();
     }
@@ -455,7 +612,13 @@ ssize_t qemu_sendto_wrap(int sockfd, const void *buf, size_t len, int flags,
                          const struct sockaddr *addr, socklen_t addrlen)
 {
     int ret;
-    ret = sendto(sockfd, buf, len, flags, addr, addrlen);
+    SOCKET s = _get_osfhandle(sockfd);
+
+    if (s == INVALID_SOCKET) {
+        return -1;
+    }
+
+    ret = sendto(s, buf, len, flags, addr, addrlen);
     if (ret < 0) {
         errno = socket_error();
     }
@@ -467,7 +630,13 @@ ssize_t qemu_sendto_wrap(int sockfd, const void *buf, size_t len, int flags,
 ssize_t qemu_recv_wrap(int sockfd, void *buf, size_t len, int flags)
 {
     int ret;
-    ret = recv(sockfd, buf, len, flags);
+    SOCKET s = _get_osfhandle(sockfd);
+
+    if (s == INVALID_SOCKET) {
+        return -1;
+    }
+
+    ret = recv(s, buf, len, flags);
     if (ret < 0) {
         errno = socket_error();
     }
@@ -480,7 +649,13 @@ ssize_t qemu_recvfrom_wrap(int sockfd, void *buf, size_t len, int flags,
                            struct sockaddr *addr, socklen_t *addrlen)
 {
     int ret;
-    ret = recvfrom(sockfd, buf, len, flags, addr, addrlen);
+    SOCKET s = _get_osfhandle(sockfd);
+
+    if (s == INVALID_SOCKET) {
+        return -1;
+    }
+
+    ret = recvfrom(s, buf, len, flags, addr, addrlen);
     if (ret < 0) {
         errno = socket_error();
     }