]> git.proxmox.com Git - mirror_ubuntu-jammy-kernel.git/blobdiff - drivers/infiniband/core/cm.c
IB/cm: Split cm_alloc_msg()
[mirror_ubuntu-jammy-kernel.git] / drivers / infiniband / core / cm.c
index 0ead0d223154011532402d16eb7a8e7b443e1cf6..94613275edccb4e6965d067a95a85cb9829bde5f 100644 (file)
@@ -305,8 +305,7 @@ static inline void cm_deref_id(struct cm_id_private *cm_id_priv)
                complete(&cm_id_priv->comp);
 }
 
-static int cm_alloc_msg(struct cm_id_private *cm_id_priv,
-                       struct ib_mad_send_buf **msg)
+static struct ib_mad_send_buf *cm_alloc_msg(struct cm_id_private *cm_id_priv)
 {
        struct ib_mad_agent *mad_agent;
        struct ib_mad_send_buf *m;
@@ -359,12 +358,42 @@ static int cm_alloc_msg(struct cm_id_private *cm_id_priv,
        m->retries = cm_id_priv->max_cm_retries;
 
        refcount_inc(&cm_id_priv->refcount);
+       spin_unlock_irqrestore(&cm.state_lock, flags2);
        m->context[0] = cm_id_priv;
-       *msg = m;
+       return m;
 
 out:
        spin_unlock_irqrestore(&cm.state_lock, flags2);
-       return ret;
+       return ERR_PTR(ret);
+}
+
+static struct ib_mad_send_buf *
+cm_alloc_priv_msg(struct cm_id_private *cm_id_priv)
+{
+       struct ib_mad_send_buf *msg;
+
+       lockdep_assert_held(&cm_id_priv->lock);
+
+       msg = cm_alloc_msg(cm_id_priv);
+       if (IS_ERR(msg))
+               return msg;
+       cm_id_priv->msg = msg;
+       return msg;
+}
+
+static void cm_free_priv_msg(struct ib_mad_send_buf *msg)
+{
+       struct cm_id_private *cm_id_priv = msg->context[0];
+
+       lockdep_assert_held(&cm_id_priv->lock);
+
+       if (!WARN_ON(cm_id_priv->msg != msg))
+               cm_id_priv->msg = NULL;
+
+       if (msg->ah)
+               rdma_destroy_ah(msg->ah, 0);
+       cm_deref_id(cm_id_priv);
+       ib_free_send_mad(msg);
 }
 
 static struct ib_mad_send_buf *cm_alloc_response_msg_no_ah(struct cm_port *port,
@@ -413,7 +442,7 @@ static int cm_alloc_response_msg(struct cm_port *port,
 
        ret = cm_create_response_msg_ah(port, mad_recv_wc, m);
        if (ret) {
-               cm_free_msg(m);
+               ib_free_send_mad(m);
                return ret;
        }
 
@@ -421,6 +450,13 @@ static int cm_alloc_response_msg(struct cm_port *port,
        return 0;
 }
 
+static void cm_free_response_msg(struct ib_mad_send_buf *msg)
+{
+       if (msg->ah)
+               rdma_destroy_ah(msg->ah, 0);
+       ib_free_send_mad(msg);
+}
+
 static void *cm_copy_private_data(const void *private_data, u8 private_data_len)
 {
        void *data;
@@ -1501,6 +1537,7 @@ int ib_send_cm_req(struct ib_cm_id *cm_id,
                   struct ib_cm_req_param *param)
 {
        struct cm_id_private *cm_id_priv;
+       struct ib_mad_send_buf *msg;
        struct cm_req_msg *req_msg;
        unsigned long flags;
        int ret;
@@ -1552,33 +1589,36 @@ int ib_send_cm_req(struct ib_cm_id *cm_id,
        cm_id_priv->pkey = param->primary_path->pkey;
        cm_id_priv->qp_type = param->qp_type;
 
-       ret = cm_alloc_msg(cm_id_priv, &cm_id_priv->msg);
-       if (ret)
-               goto out;
+       spin_lock_irqsave(&cm_id_priv->lock, flags);
+       msg = cm_alloc_priv_msg(cm_id_priv);
+       if (IS_ERR(msg)) {
+               ret = PTR_ERR(msg);
+               goto out_unlock;
+       }
 
-       req_msg = (struct cm_req_msg *) cm_id_priv->msg->mad;
+       req_msg = (struct cm_req_msg *)msg->mad;
        cm_format_req(req_msg, cm_id_priv, param);
        cm_id_priv->tid = req_msg->hdr.tid;
-       cm_id_priv->msg->timeout_ms = cm_id_priv->timeout_ms;
-       cm_id_priv->msg->context[1] = (void *) (unsigned long) IB_CM_REQ_SENT;
+       msg->timeout_ms = cm_id_priv->timeout_ms;
+       msg->context[1] = (void *)(unsigned long)IB_CM_REQ_SENT;
 
        cm_id_priv->local_qpn = cpu_to_be32(IBA_GET(CM_REQ_LOCAL_QPN, req_msg));
        cm_id_priv->rq_psn = cpu_to_be32(IBA_GET(CM_REQ_STARTING_PSN, req_msg));
 
        trace_icm_send_req(&cm_id_priv->id);
-       spin_lock_irqsave(&cm_id_priv->lock, flags);
-       ret = ib_post_send_mad(cm_id_priv->msg, NULL);
-       if (ret) {
-               spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-               goto error2;
-       }
+       ret = ib_post_send_mad(msg, NULL);
+       if (ret)
+               goto out_free;
        BUG_ON(cm_id->state != IB_CM_IDLE);
        cm_id->state = IB_CM_REQ_SENT;
        spin_unlock_irqrestore(&cm_id_priv->lock, flags);
        return 0;
-
-error2:        cm_free_msg(cm_id_priv->msg);
-out:   return ret;
+out_free:
+       cm_free_priv_msg(msg);
+out_unlock:
+       spin_unlock_irqrestore(&cm_id_priv->lock, flags);
+out:
+       return ret;
 }
 EXPORT_SYMBOL(ib_send_cm_req);
 
@@ -1618,7 +1658,7 @@ static int cm_issue_rej(struct cm_port *port,
                IBA_GET(CM_REJ_REMOTE_COMM_ID, rcv_msg));
        ret = ib_post_send_mad(msg, NULL);
        if (ret)
-               cm_free_msg(msg);
+               cm_free_response_msg(msg);
 
        return ret;
 }
@@ -1974,7 +2014,7 @@ static void cm_dup_req_handler(struct cm_work *work,
        return;
 
 unlock:        spin_unlock_irq(&cm_id_priv->lock);
-free:  cm_free_msg(msg);
+free:  cm_free_response_msg(msg);
 }
 
 static struct cm_id_private *cm_match_req(struct cm_work *work,
@@ -2283,9 +2323,11 @@ int ib_send_cm_rep(struct ib_cm_id *cm_id,
                goto out;
        }
 
-       ret = cm_alloc_msg(cm_id_priv, &msg);
-       if (ret)
+       msg = cm_alloc_priv_msg(cm_id_priv);
+       if (IS_ERR(msg)) {
+               ret = PTR_ERR(msg);
                goto out;
+       }
 
        rep_msg = (struct cm_rep_msg *) msg->mad;
        cm_format_rep(rep_msg, cm_id_priv, param);
@@ -2294,14 +2336,10 @@ int ib_send_cm_rep(struct ib_cm_id *cm_id,
 
        trace_icm_send_rep(cm_id);
        ret = ib_post_send_mad(msg, NULL);
-       if (ret) {
-               spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-               cm_free_msg(msg);
-               return ret;
-       }
+       if (ret)
+               goto out_free;
 
        cm_id->state = IB_CM_REP_SENT;
-       cm_id_priv->msg = msg;
        cm_id_priv->initiator_depth = param->initiator_depth;
        cm_id_priv->responder_resources = param->responder_resources;
        cm_id_priv->rq_psn = cpu_to_be32(IBA_GET(CM_REP_STARTING_PSN, rep_msg));
@@ -2309,8 +2347,13 @@ int ib_send_cm_rep(struct ib_cm_id *cm_id,
                  "IBTA declares QPN to be 24 bits, but it is 0x%X\n",
                  param->qp_num);
        cm_id_priv->local_qpn = cpu_to_be32(param->qp_num & 0xFFFFFF);
+       spin_unlock_irqrestore(&cm_id_priv->lock, flags);
+       return 0;
 
-out:   spin_unlock_irqrestore(&cm_id_priv->lock, flags);
+out_free:
+       cm_free_priv_msg(msg);
+out:
+       spin_unlock_irqrestore(&cm_id_priv->lock, flags);
        return ret;
 }
 EXPORT_SYMBOL(ib_send_cm_rep);
@@ -2357,9 +2400,11 @@ int ib_send_cm_rtu(struct ib_cm_id *cm_id,
                goto error;
        }
 
-       ret = cm_alloc_msg(cm_id_priv, &msg);
-       if (ret)
+       msg = cm_alloc_msg(cm_id_priv);
+       if (IS_ERR(msg)) {
+               ret = PTR_ERR(msg);
                goto error;
+       }
 
        cm_format_rtu((struct cm_rtu_msg *) msg->mad, cm_id_priv,
                      private_data, private_data_len);
@@ -2453,7 +2498,7 @@ static void cm_dup_rep_handler(struct cm_work *work)
        goto deref;
 
 unlock:        spin_unlock_irq(&cm_id_priv->lock);
-free:  cm_free_msg(msg);
+free:  cm_free_response_msg(msg);
 deref: cm_deref_id(cm_id_priv);
 }
 
@@ -2657,10 +2702,10 @@ static int cm_send_dreq_locked(struct cm_id_private *cm_id_priv,
            cm_id_priv->id.lap_state == IB_CM_MRA_LAP_RCVD)
                ib_cancel_mad(cm_id_priv->av.port->mad_agent, cm_id_priv->msg);
 
-       ret = cm_alloc_msg(cm_id_priv, &msg);
-       if (ret) {
+       msg = cm_alloc_priv_msg(cm_id_priv);
+       if (IS_ERR(msg)) {
                cm_enter_timewait(cm_id_priv);
-               return ret;
+               return PTR_ERR(msg);
        }
 
        cm_format_dreq((struct cm_dreq_msg *) msg->mad, cm_id_priv,
@@ -2672,12 +2717,11 @@ static int cm_send_dreq_locked(struct cm_id_private *cm_id_priv,
        ret = ib_post_send_mad(msg, NULL);
        if (ret) {
                cm_enter_timewait(cm_id_priv);
-               cm_free_msg(msg);
+               cm_free_priv_msg(msg);
                return ret;
        }
 
        cm_id_priv->id.state = IB_CM_DREQ_SENT;
-       cm_id_priv->msg = msg;
        return 0;
 }
 
@@ -2732,9 +2776,9 @@ static int cm_send_drep_locked(struct cm_id_private *cm_id_priv,
        cm_set_private_data(cm_id_priv, private_data, private_data_len);
        cm_enter_timewait(cm_id_priv);
 
-       ret = cm_alloc_msg(cm_id_priv, &msg);
-       if (ret)
-               return ret;
+       msg = cm_alloc_msg(cm_id_priv);
+       if (IS_ERR(msg))
+               return PTR_ERR(msg);
 
        cm_format_drep((struct cm_drep_msg *) msg->mad, cm_id_priv,
                       private_data, private_data_len);
@@ -2794,7 +2838,7 @@ static int cm_issue_drep(struct cm_port *port,
                IBA_GET(CM_DREQ_REMOTE_COMM_ID, dreq_msg));
        ret = ib_post_send_mad(msg, NULL);
        if (ret)
-               cm_free_msg(msg);
+               cm_free_response_msg(msg);
 
        return ret;
 }
@@ -2853,7 +2897,7 @@ static int cm_dreq_handler(struct cm_work *work)
 
                if (cm_create_response_msg_ah(work->port, work->mad_recv_wc, msg) ||
                    ib_post_send_mad(msg, NULL))
-                       cm_free_msg(msg);
+                       cm_free_response_msg(msg);
                goto deref;
        case IB_CM_DREQ_RCVD:
                atomic_long_inc(&work->port->counter_group[CM_RECV_DUPLICATES].
@@ -2927,9 +2971,9 @@ static int cm_send_rej_locked(struct cm_id_private *cm_id_priv,
        case IB_CM_REP_RCVD:
        case IB_CM_MRA_REP_SENT:
                cm_reset_to_idle(cm_id_priv);
-               ret = cm_alloc_msg(cm_id_priv, &msg);
-               if (ret)
-                       return ret;
+               msg = cm_alloc_msg(cm_id_priv);
+               if (IS_ERR(msg))
+                       return PTR_ERR(msg);
                cm_format_rej((struct cm_rej_msg *)msg->mad, cm_id_priv, reason,
                              ari, ari_length, private_data, private_data_len,
                              state);
@@ -2937,9 +2981,9 @@ static int cm_send_rej_locked(struct cm_id_private *cm_id_priv,
        case IB_CM_REP_SENT:
        case IB_CM_MRA_REP_RCVD:
                cm_enter_timewait(cm_id_priv);
-               ret = cm_alloc_msg(cm_id_priv, &msg);
-               if (ret)
-                       return ret;
+               msg = cm_alloc_msg(cm_id_priv);
+               if (IS_ERR(msg))
+                       return PTR_ERR(msg);
                cm_format_rej((struct cm_rej_msg *)msg->mad, cm_id_priv, reason,
                              ari, ari_length, private_data, private_data_len,
                              state);
@@ -3117,13 +3161,15 @@ int ib_send_cm_mra(struct ib_cm_id *cm_id,
        default:
                trace_icm_send_mra_unknown_err(&cm_id_priv->id);
                ret = -EINVAL;
-               goto error1;
+               goto error_unlock;
        }
 
        if (!(service_timeout & IB_CM_MRA_FLAG_DELAY)) {
-               ret = cm_alloc_msg(cm_id_priv, &msg);
-               if (ret)
-                       goto error1;
+               msg = cm_alloc_msg(cm_id_priv);
+               if (IS_ERR(msg)) {
+                       ret = PTR_ERR(msg);
+                       goto error_unlock;
+               }
 
                cm_format_mra((struct cm_mra_msg *) msg->mad, cm_id_priv,
                              msg_response, service_timeout,
@@ -3131,7 +3177,7 @@ int ib_send_cm_mra(struct ib_cm_id *cm_id,
                trace_icm_send_mra(cm_id);
                ret = ib_post_send_mad(msg, NULL);
                if (ret)
-                       goto error2;
+                       goto error_free_msg;
        }
 
        cm_id->state = cm_state;
@@ -3141,13 +3187,11 @@ int ib_send_cm_mra(struct ib_cm_id *cm_id,
        spin_unlock_irqrestore(&cm_id_priv->lock, flags);
        return 0;
 
-error1:        spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-       kfree(data);
-       return ret;
-
-error2:        spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-       kfree(data);
+error_free_msg:
        cm_free_msg(msg);
+error_unlock:
+       spin_unlock_irqrestore(&cm_id_priv->lock, flags);
+       kfree(data);
        return ret;
 }
 EXPORT_SYMBOL(ib_send_cm_mra);
@@ -3343,7 +3387,7 @@ static int cm_lap_handler(struct cm_work *work)
 
                if (cm_create_response_msg_ah(work->port, work->mad_recv_wc, msg) ||
                    ib_post_send_mad(msg, NULL))
-                       cm_free_msg(msg);
+                       cm_free_response_msg(msg);
                goto deref;
        case IB_CM_LAP_RCVD:
                atomic_long_inc(&work->port->counter_group[CM_RECV_DUPLICATES].
@@ -3483,38 +3527,41 @@ int ib_send_cm_sidr_req(struct ib_cm_id *cm_id,
                                 &cm_id_priv->av,
                                 cm_id_priv);
        if (ret)
-               goto out;
+               return ret;
 
        cm_id->service_id = param->service_id;
        cm_id->service_mask = ~cpu_to_be64(0);
        cm_id_priv->timeout_ms = param->timeout_ms;
        cm_id_priv->max_cm_retries = param->max_cm_retries;
-       ret = cm_alloc_msg(cm_id_priv, &msg);
-       if (ret)
-               goto out;
-
-       cm_format_sidr_req((struct cm_sidr_req_msg *) msg->mad, cm_id_priv,
-                          param);
-       msg->timeout_ms = cm_id_priv->timeout_ms;
-       msg->context[1] = (void *) (unsigned long) IB_CM_SIDR_REQ_SENT;
 
        spin_lock_irqsave(&cm_id_priv->lock, flags);
-       if (cm_id->state == IB_CM_IDLE) {
-               trace_icm_send_sidr_req(&cm_id_priv->id);
-               ret = ib_post_send_mad(msg, NULL);
-       } else {
+       if (cm_id->state != IB_CM_IDLE) {
                ret = -EINVAL;
+               goto out_unlock;
        }
 
-       if (ret) {
-               spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-               cm_free_msg(msg);
-               goto out;
+       msg = cm_alloc_priv_msg(cm_id_priv);
+       if (IS_ERR(msg)) {
+               ret = PTR_ERR(msg);
+               goto out_unlock;
        }
+
+       cm_format_sidr_req((struct cm_sidr_req_msg *)msg->mad, cm_id_priv,
+                          param);
+       msg->timeout_ms = cm_id_priv->timeout_ms;
+       msg->context[1] = (void *)(unsigned long)IB_CM_SIDR_REQ_SENT;
+
+       trace_icm_send_sidr_req(&cm_id_priv->id);
+       ret = ib_post_send_mad(msg, NULL);
+       if (ret)
+               goto out_free;
        cm_id->state = IB_CM_SIDR_REQ_SENT;
-       cm_id_priv->msg = msg;
        spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-out:
+       return 0;
+out_free:
+       cm_free_priv_msg(msg);
+out_unlock:
+       spin_unlock_irqrestore(&cm_id_priv->lock, flags);
        return ret;
 }
 EXPORT_SYMBOL(ib_send_cm_sidr_req);
@@ -3661,9 +3708,9 @@ static int cm_send_sidr_rep_locked(struct cm_id_private *cm_id_priv,
        if (cm_id_priv->id.state != IB_CM_SIDR_REQ_RCVD)
                return -EINVAL;
 
-       ret = cm_alloc_msg(cm_id_priv, &msg);
-       if (ret)
-               return ret;
+       msg = cm_alloc_msg(cm_id_priv);
+       if (IS_ERR(msg))
+               return PTR_ERR(msg);
 
        cm_format_sidr_rep((struct cm_sidr_rep_msg *) msg->mad, cm_id_priv,
                           param);