]> git.proxmox.com Git - mirror_ubuntu-bionic-kernel.git/blobdiff - net/rxrpc/af_rxrpc.c
rxrpc: Replace conn->trans->{local,peer} with conn->params.{local,peer}
[mirror_ubuntu-bionic-kernel.git] / net / rxrpc / af_rxrpc.c
index e45e94ca030f3b1e3a6881d3178b021b4ce670ad..48b45a0280c09bfd718ceae65310a905111ab609 100644 (file)
@@ -9,6 +9,8 @@
  * 2 of the License, or (at your option) any later version.
  */
 
+#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
+
 #include <linux/module.h>
 #include <linux/kernel.h>
 #include <linux/net.h>
@@ -31,8 +33,6 @@ unsigned int rxrpc_debug; // = RXRPC_DEBUG_KPROTO;
 module_param_named(debug, rxrpc_debug, uint, S_IWUSR | S_IRUGO);
 MODULE_PARM_DESC(debug, "RxRPC debugging mask");
 
-static int sysctl_rxrpc_max_qlen __read_mostly = 10;
-
 static struct proto rxrpc_proto;
 static const struct proto_ops rxrpc_rpc_ops;
 
@@ -97,11 +97,13 @@ static int rxrpc_validate_address(struct rxrpc_sock *rx,
            srx->transport_len > len)
                return -EINVAL;
 
-       if (srx->transport.family != rx->proto)
+       if (srx->transport.family != rx->family)
                return -EAFNOSUPPORT;
 
        switch (srx->transport.family) {
        case AF_INET:
+               if (srx->transport_len < sizeof(struct sockaddr_in))
+                       return -EINVAL;
                _debug("INET: %x @ %pI4",
                       ntohs(srx->transport.sin.sin_port),
                       &srx->transport.sin.sin_addr);
@@ -137,33 +139,33 @@ static int rxrpc_bind(struct socket *sock, struct sockaddr *saddr, int len)
 
        lock_sock(&rx->sk);
 
-       if (rx->sk.sk_state != RXRPC_UNCONNECTED) {
+       if (rx->sk.sk_state != RXRPC_UNBOUND) {
                ret = -EINVAL;
                goto error_unlock;
        }
 
        memcpy(&rx->srx, srx, sizeof(rx->srx));
 
-       /* Find or create a local transport endpoint to use */
        local = rxrpc_lookup_local(&rx->srx);
        if (IS_ERR(local)) {
                ret = PTR_ERR(local);
                goto error_unlock;
        }
 
-       rx->local = local;
-       if (srx->srx_service) {
+       if (rx->srx.srx_service) {
                write_lock_bh(&local->services_lock);
                list_for_each_entry(prx, &local->services, listen_link) {
-                       if (prx->srx.srx_service == srx->srx_service)
+                       if (prx->srx.srx_service == rx->srx.srx_service)
                                goto service_in_use;
                }
 
+               rx->local = local;
                list_add_tail(&rx->listen_link, &local->services);
                write_unlock_bh(&local->services_lock);
 
                rx->sk.sk_state = RXRPC_SERVER_BOUND;
        } else {
+               rx->local = local;
                rx->sk.sk_state = RXRPC_CLIENT_BOUND;
        }
 
@@ -172,8 +174,9 @@ static int rxrpc_bind(struct socket *sock, struct sockaddr *saddr, int len)
        return 0;
 
 service_in_use:
-       ret = -EADDRINUSE;
        write_unlock_bh(&local->services_lock);
+       rxrpc_put_local(local);
+       ret = -EADDRINUSE;
 error_unlock:
        release_sock(&rx->sk);
 error:
@@ -188,6 +191,7 @@ static int rxrpc_listen(struct socket *sock, int backlog)
 {
        struct sock *sk = sock->sk;
        struct rxrpc_sock *rx = rxrpc_sk(sk);
+       unsigned int max;
        int ret;
 
        _enter("%p,%d", rx, backlog);
@@ -195,20 +199,24 @@ static int rxrpc_listen(struct socket *sock, int backlog)
        lock_sock(&rx->sk);
 
        switch (rx->sk.sk_state) {
-       case RXRPC_UNCONNECTED:
+       case RXRPC_UNBOUND:
                ret = -EADDRNOTAVAIL;
                break;
-       case RXRPC_CLIENT_BOUND:
-       case RXRPC_CLIENT_CONNECTED:
-       default:
-               ret = -EBUSY;
-               break;
        case RXRPC_SERVER_BOUND:
                ASSERT(rx->local != NULL);
+               max = READ_ONCE(rxrpc_max_backlog);
+               ret = -EINVAL;
+               if (backlog == INT_MAX)
+                       backlog = max;
+               else if (backlog < 0 || backlog > max)
+                       break;
                sk->sk_max_ack_backlog = backlog;
                rx->sk.sk_state = RXRPC_SERVER_LISTENING;
                ret = 0;
                break;
+       default:
+               ret = -EBUSY;
+               break;
        }
 
        release_sock(&rx->sk);
@@ -219,34 +227,30 @@ static int rxrpc_listen(struct socket *sock, int backlog)
 /*
  * find a transport by address
  */
-static struct rxrpc_transport *rxrpc_name_to_transport(struct socket *sock,
-                                                      struct sockaddr *addr,
-                                                      int addr_len, int flags,
-                                                      gfp_t gfp)
+struct rxrpc_transport *
+rxrpc_name_to_transport(struct rxrpc_conn_parameters *cp,
+                       struct sockaddr *addr,
+                       int addr_len,
+                       gfp_t gfp)
 {
        struct sockaddr_rxrpc *srx = (struct sockaddr_rxrpc *) addr;
        struct rxrpc_transport *trans;
-       struct rxrpc_sock *rx = rxrpc_sk(sock->sk);
-       struct rxrpc_peer *peer;
-
-       _enter("%p,%p,%d,%d", rx, addr, addr_len, flags);
 
-       ASSERT(rx->local != NULL);
-       ASSERT(rx->sk.sk_state > RXRPC_UNCONNECTED);
+       _enter("%p,%d", addr, addr_len);
 
-       if (rx->srx.transport_type != srx->transport_type)
+       if (cp->local->srx.transport_type != srx->transport_type)
                return ERR_PTR(-ESOCKTNOSUPPORT);
-       if (rx->srx.transport.family != srx->transport.family)
+       if (cp->local->srx.transport.family != srx->transport.family)
                return ERR_PTR(-EAFNOSUPPORT);
 
        /* find a remote transport endpoint from the local one */
-       peer = rxrpc_get_peer(srx, gfp);
-       if (IS_ERR(peer))
-               return ERR_CAST(peer);
+       cp->peer = rxrpc_lookup_peer(cp->local, srx, gfp);
+       if (!cp->peer)
+               return ERR_PTR(-ENOMEM);
 
        /* find a transport */
-       trans = rxrpc_get_transport(rx->local, peer, gfp);
-       rxrpc_put_peer(peer);
+       trans = rxrpc_get_transport(cp->local, cp->peer, gfp);
+       rxrpc_put_peer(cp->peer);
        _leave(" = %p", trans);
        return trans;
 }
@@ -254,7 +258,7 @@ static struct rxrpc_transport *rxrpc_name_to_transport(struct socket *sock,
 /**
  * rxrpc_kernel_begin_call - Allow a kernel service to begin a call
  * @sock: The socket on which to make the call
- * @srx: The address of the peer to contact (defaults to socket setting)
+ * @srx: The address of the peer to contact
  * @key: The security context to use (defaults to socket setting)
  * @user_call_ID: The ID to use
  *
@@ -271,6 +275,7 @@ struct rxrpc_call *rxrpc_kernel_begin_call(struct socket *sock,
                                           unsigned long user_call_ID,
                                           gfp_t gfp)
 {
+       struct rxrpc_conn_parameters cp;
        struct rxrpc_conn_bundle *bundle;
        struct rxrpc_transport *trans;
        struct rxrpc_call *call;
@@ -280,38 +285,34 @@ struct rxrpc_call *rxrpc_kernel_begin_call(struct socket *sock,
 
        lock_sock(&rx->sk);
 
-       if (srx) {
-               trans = rxrpc_name_to_transport(sock, (struct sockaddr *) srx,
-                                               sizeof(*srx), 0, gfp);
-               if (IS_ERR(trans)) {
-                       call = ERR_CAST(trans);
-                       trans = NULL;
-                       goto out_notrans;
-               }
-       } else {
-               trans = rx->trans;
-               if (!trans) {
-                       call = ERR_PTR(-ENOTCONN);
-                       goto out_notrans;
-               }
-               atomic_inc(&trans->usage);
-       }
-
-       if (!srx)
-               srx = &rx->srx;
        if (!key)
                key = rx->key;
        if (key && !key->payload.data[0])
                key = NULL; /* a no-security key */
 
+       memset(&cp, 0, sizeof(cp));
+       cp.local                = rx->local;
+       cp.key                  = key;
+       cp.security_level       = 0;
+       cp.exclusive            = false;
+       cp.service_id           = srx->srx_service;
+
+       trans = rxrpc_name_to_transport(&cp, (struct sockaddr *)srx,
+                                       sizeof(*srx), gfp);
+       if (IS_ERR(trans)) {
+               call = ERR_CAST(trans);
+               trans = NULL;
+               goto out_notrans;
+       }
+       cp.peer = trans->peer;
+
        bundle = rxrpc_get_bundle(rx, trans, key, srx->srx_service, gfp);
        if (IS_ERR(bundle)) {
                call = ERR_CAST(bundle);
                goto out;
        }
 
-       call = rxrpc_get_client_call(rx, trans, bundle, user_call_ID, true,
-                                    gfp);
+       call = rxrpc_new_client_call(rx, &cp, trans, bundle, user_call_ID, gfp);
        rxrpc_put_bundle(trans, bundle);
 out:
        rxrpc_put_transport(trans);
@@ -367,11 +368,8 @@ EXPORT_SYMBOL(rxrpc_kernel_intercept_rx_messages);
 static int rxrpc_connect(struct socket *sock, struct sockaddr *addr,
                         int addr_len, int flags)
 {
-       struct sockaddr_rxrpc *srx = (struct sockaddr_rxrpc *) addr;
-       struct sock *sk = sock->sk;
-       struct rxrpc_transport *trans;
-       struct rxrpc_local *local;
-       struct rxrpc_sock *rx = rxrpc_sk(sk);
+       struct sockaddr_rxrpc *srx = (struct sockaddr_rxrpc *)addr;
+       struct rxrpc_sock *rx = rxrpc_sk(sock->sk);
        int ret;
 
        _enter("%p,%p,%d,%d", rx, addr, addr_len, flags);
@@ -384,45 +382,28 @@ static int rxrpc_connect(struct socket *sock, struct sockaddr *addr,
 
        lock_sock(&rx->sk);
 
+       ret = -EISCONN;
+       if (test_bit(RXRPC_SOCK_CONNECTED, &rx->flags))
+               goto error;
+
        switch (rx->sk.sk_state) {
-       case RXRPC_UNCONNECTED:
-               /* find a local transport endpoint if we don't have one already */
-               ASSERTCMP(rx->local, ==, NULL);
-               rx->srx.srx_family = AF_RXRPC;
-               rx->srx.srx_service = 0;
-               rx->srx.transport_type = srx->transport_type;
-               rx->srx.transport_len = sizeof(sa_family_t);
-               rx->srx.transport.family = srx->transport.family;
-               local = rxrpc_lookup_local(&rx->srx);
-               if (IS_ERR(local)) {
-                       release_sock(&rx->sk);
-                       return PTR_ERR(local);
-               }
-               rx->local = local;
-               rx->sk.sk_state = RXRPC_CLIENT_BOUND;
+       case RXRPC_UNBOUND:
+               rx->sk.sk_state = RXRPC_CLIENT_UNBOUND;
+       case RXRPC_CLIENT_UNBOUND:
        case RXRPC_CLIENT_BOUND:
                break;
-       case RXRPC_CLIENT_CONNECTED:
-               release_sock(&rx->sk);
-               return -EISCONN;
        default:
-               release_sock(&rx->sk);
-               return -EBUSY; /* server sockets can't connect as well */
-       }
-
-       trans = rxrpc_name_to_transport(sock, addr, addr_len, flags,
-                                       GFP_KERNEL);
-       if (IS_ERR(trans)) {
-               release_sock(&rx->sk);
-               _leave(" = %ld", PTR_ERR(trans));
-               return PTR_ERR(trans);
+               ret = -EBUSY;
+               goto error;
        }
 
-       rx->trans = trans;
-       rx->sk.sk_state = RXRPC_CLIENT_CONNECTED;
+       rx->connect_srx = *srx;
+       set_bit(RXRPC_SOCK_CONNECTED, &rx->flags);
+       ret = 0;
 
+error:
        release_sock(&rx->sk);
-       return 0;
+       return ret;
 }
 
 /*
@@ -436,7 +417,7 @@ static int rxrpc_connect(struct socket *sock, struct sockaddr *addr,
  */
 static int rxrpc_sendmsg(struct socket *sock, struct msghdr *m, size_t len)
 {
-       struct rxrpc_transport *trans;
+       struct rxrpc_local *local;
        struct rxrpc_sock *rx = rxrpc_sk(sock->sk);
        int ret;
 
@@ -453,48 +434,38 @@ static int rxrpc_sendmsg(struct socket *sock, struct msghdr *m, size_t len)
                }
        }
 
-       trans = NULL;
        lock_sock(&rx->sk);
 
-       if (m->msg_name) {
-               ret = -EISCONN;
-               trans = rxrpc_name_to_transport(sock, m->msg_name,
-                                               m->msg_namelen, 0, GFP_KERNEL);
-               if (IS_ERR(trans)) {
-                       ret = PTR_ERR(trans);
-                       trans = NULL;
-                       goto out;
-               }
-       } else {
-               trans = rx->trans;
-               if (trans)
-                       atomic_inc(&trans->usage);
-       }
-
        switch (rx->sk.sk_state) {
-       case RXRPC_SERVER_LISTENING:
-               if (!m->msg_name) {
-                       ret = rxrpc_server_sendmsg(rx, m, len);
-                       break;
+       case RXRPC_UNBOUND:
+               local = rxrpc_lookup_local(&rx->srx);
+               if (IS_ERR(local)) {
+                       ret = PTR_ERR(local);
+                       goto error_unlock;
                }
-       case RXRPC_SERVER_BOUND:
+
+               rx->local = local;
+               rx->sk.sk_state = RXRPC_CLIENT_UNBOUND;
+               /* Fall through */
+
+       case RXRPC_CLIENT_UNBOUND:
        case RXRPC_CLIENT_BOUND:
-               if (!m->msg_name) {
-                       ret = -ENOTCONN;
-                       break;
+               if (!m->msg_name &&
+                   test_bit(RXRPC_SOCK_CONNECTED, &rx->flags)) {
+                       m->msg_name = &rx->connect_srx;
+                       m->msg_namelen = sizeof(rx->connect_srx);
                }
-       case RXRPC_CLIENT_CONNECTED:
-               ret = rxrpc_client_sendmsg(rx, trans, m, len);
+       case RXRPC_SERVER_BOUND:
+       case RXRPC_SERVER_LISTENING:
+               ret = rxrpc_do_sendmsg(rx, m, len);
                break;
        default:
-               ret = -ENOTCONN;
+               ret = -EINVAL;
                break;
        }
 
-out:
+error_unlock:
        release_sock(&rx->sk);
-       if (trans)
-               rxrpc_put_transport(trans);
        _leave(" = %d", ret);
        return ret;
 }
@@ -521,7 +492,7 @@ static int rxrpc_setsockopt(struct socket *sock, int level, int optname,
                        if (optlen != 0)
                                goto error;
                        ret = -EISCONN;
-                       if (rx->sk.sk_state != RXRPC_UNCONNECTED)
+                       if (rx->sk.sk_state != RXRPC_UNBOUND)
                                goto error;
                        set_bit(RXRPC_SOCK_EXCLUSIVE_CONN, &rx->flags);
                        goto success;
@@ -531,7 +502,7 @@ static int rxrpc_setsockopt(struct socket *sock, int level, int optname,
                        if (rx->key)
                                goto error;
                        ret = -EISCONN;
-                       if (rx->sk.sk_state != RXRPC_UNCONNECTED)
+                       if (rx->sk.sk_state != RXRPC_UNBOUND)
                                goto error;
                        ret = rxrpc_request_key(rx, optval, optlen);
                        goto error;
@@ -541,7 +512,7 @@ static int rxrpc_setsockopt(struct socket *sock, int level, int optname,
                        if (rx->key)
                                goto error;
                        ret = -EISCONN;
-                       if (rx->sk.sk_state != RXRPC_UNCONNECTED)
+                       if (rx->sk.sk_state != RXRPC_UNBOUND)
                                goto error;
                        ret = rxrpc_server_keyring(rx, optval, optlen);
                        goto error;
@@ -551,7 +522,7 @@ static int rxrpc_setsockopt(struct socket *sock, int level, int optname,
                        if (optlen != sizeof(unsigned int))
                                goto error;
                        ret = -EISCONN;
-                       if (rx->sk.sk_state != RXRPC_UNCONNECTED)
+                       if (rx->sk.sk_state != RXRPC_UNBOUND)
                                goto error;
                        ret = get_user(min_sec_level,
                                       (unsigned int __user *) optval);
@@ -630,13 +601,13 @@ static int rxrpc_create(struct net *net, struct socket *sock, int protocol,
                return -ENOMEM;
 
        sock_init_data(sock, sk);
-       sk->sk_state            = RXRPC_UNCONNECTED;
+       sk->sk_state            = RXRPC_UNBOUND;
        sk->sk_write_space      = rxrpc_write_space;
-       sk->sk_max_ack_backlog  = sysctl_rxrpc_max_qlen;
+       sk->sk_max_ack_backlog  = 0;
        sk->sk_destruct         = rxrpc_sock_destructor;
 
        rx = rxrpc_sk(sk);
-       rx->proto = protocol;
+       rx->family = protocol;
        rx->calls = RB_ROOT;
 
        INIT_LIST_HEAD(&rx->listen_link);
@@ -703,14 +674,6 @@ static int rxrpc_release_sock(struct sock *sk)
                rx->conn = NULL;
        }
 
-       if (rx->bundle) {
-               rxrpc_put_bundle(rx->trans, rx->bundle);
-               rx->bundle = NULL;
-       }
-       if (rx->trans) {
-               rxrpc_put_transport(rx->trans);
-               rx->trans = NULL;
-       }
        if (rx->local) {
                rxrpc_put_local(rx->local);
                rx->local = NULL;
@@ -796,49 +759,49 @@ static int __init af_rxrpc_init(void)
                "rxrpc_call_jar", sizeof(struct rxrpc_call), 0,
                SLAB_HWCACHE_ALIGN, NULL);
        if (!rxrpc_call_jar) {
-               printk(KERN_NOTICE "RxRPC: Failed to allocate call jar\n");
+               pr_notice("Failed to allocate call jar\n");
                goto error_call_jar;
        }
 
        rxrpc_workqueue = alloc_workqueue("krxrpcd", 0, 1);
        if (!rxrpc_workqueue) {
-               printk(KERN_NOTICE "RxRPC: Failed to allocate work queue\n");
+               pr_notice("Failed to allocate work queue\n");
                goto error_work_queue;
        }
 
        ret = rxrpc_init_security();
        if (ret < 0) {
-               printk(KERN_CRIT "RxRPC: Cannot initialise security\n");
+               pr_crit("Cannot initialise security\n");
                goto error_security;
        }
 
        ret = proto_register(&rxrpc_proto, 1);
        if (ret < 0) {
-               printk(KERN_CRIT "RxRPC: Cannot register protocol\n");
+               pr_crit("Cannot register protocol\n");
                goto error_proto;
        }
 
        ret = sock_register(&rxrpc_family_ops);
        if (ret < 0) {
-               printk(KERN_CRIT "RxRPC: Cannot register socket family\n");
+               pr_crit("Cannot register socket family\n");
                goto error_sock;
        }
 
        ret = register_key_type(&key_type_rxrpc);
        if (ret < 0) {
-               printk(KERN_CRIT "RxRPC: Cannot register client key type\n");
+               pr_crit("Cannot register client key type\n");
                goto error_key_type;
        }
 
        ret = register_key_type(&key_type_rxrpc_s);
        if (ret < 0) {
-               printk(KERN_CRIT "RxRPC: Cannot register server key type\n");
+               pr_crit("Cannot register server key type\n");
                goto error_key_type_s;
        }
 
        ret = rxrpc_sysctl_init();
        if (ret < 0) {
-               printk(KERN_CRIT "RxRPC: Cannot register sysctls\n");
+               pr_crit("Cannot register sysctls\n");
                goto error_sysctls;
        }
 
@@ -881,13 +844,27 @@ static void __exit af_rxrpc_exit(void)
        rxrpc_destroy_all_calls();
        rxrpc_destroy_all_connections();
        rxrpc_destroy_all_transports();
-       rxrpc_destroy_all_peers();
-       rxrpc_destroy_all_locals();
 
        ASSERTCMP(atomic_read(&rxrpc_n_skbs), ==, 0);
 
+       /* We need to flush the scheduled work twice because the local endpoint
+        * records involve a work item in their destruction as they can only be
+        * destroyed from process context.  However, a connection may have a
+        * work item outstanding - and this will pin the local endpoint record
+        * until the connection goes away.
+        *
+        * Peers don't pin locals and calls pin sockets - which prevents the
+        * module from being unloaded - so we should only need two flushes.
+        */
        _debug("flush scheduled work");
        flush_workqueue(rxrpc_workqueue);
+       _debug("flush scheduled work 2");
+       flush_workqueue(rxrpc_workqueue);
+       _debug("synchronise RCU");
+       rcu_barrier();
+       _debug("destroy locals");
+       rxrpc_destroy_all_locals();
+
        remove_proc_entry("rxrpc_conns", init_net.proc_net);
        remove_proc_entry("rxrpc_calls", init_net.proc_net);
        destroy_workqueue(rxrpc_workqueue);