]> git.proxmox.com Git - mirror_ubuntu-jammy-kernel.git/blobdiff - drivers/infiniband/core/cm.c
IB/cm: Split cm_alloc_msg()
[mirror_ubuntu-jammy-kernel.git] / drivers / infiniband / core / cm.c
index 3d194bb608405de1f8f925395311d13c46b9ba17..94613275edccb4e6965d067a95a85cb9829bde5f 100644 (file)
@@ -202,7 +202,7 @@ static struct attribute *cm_counter_default_attrs[] = {
 struct cm_port {
        struct cm_device *cm_dev;
        struct ib_mad_agent *mad_agent;
-       u8 port_num;
+       u32 port_num;
        struct list_head cm_priv_prim_list;
        struct list_head cm_priv_altr_list;
        struct cm_counter_group counter_group[CM_COUNTER_GROUPS];
@@ -255,7 +255,8 @@ struct cm_id_private {
        struct completion comp;
        refcount_t refcount;
        /* Number of clients sharing this ib_cm_id. Only valid for listeners.
-        * Protected by the cm.lock spinlock. */
+        * Protected by the cm.lock spinlock.
+        */
        int listen_sharecount;
        struct rcu_head rcu;
 
@@ -304,8 +305,7 @@ static inline void cm_deref_id(struct cm_id_private *cm_id_priv)
                complete(&cm_id_priv->comp);
 }
 
-static int cm_alloc_msg(struct cm_id_private *cm_id_priv,
-                       struct ib_mad_send_buf **msg)
+static struct ib_mad_send_buf *cm_alloc_msg(struct cm_id_private *cm_id_priv)
 {
        struct ib_mad_agent *mad_agent;
        struct ib_mad_send_buf *m;
@@ -358,12 +358,42 @@ static int cm_alloc_msg(struct cm_id_private *cm_id_priv,
        m->retries = cm_id_priv->max_cm_retries;
 
        refcount_inc(&cm_id_priv->refcount);
+       spin_unlock_irqrestore(&cm.state_lock, flags2);
        m->context[0] = cm_id_priv;
-       *msg = m;
+       return m;
 
 out:
        spin_unlock_irqrestore(&cm.state_lock, flags2);
-       return ret;
+       return ERR_PTR(ret);
+}
+
+static struct ib_mad_send_buf *
+cm_alloc_priv_msg(struct cm_id_private *cm_id_priv)
+{
+       struct ib_mad_send_buf *msg;
+
+       lockdep_assert_held(&cm_id_priv->lock);
+
+       msg = cm_alloc_msg(cm_id_priv);
+       if (IS_ERR(msg))
+               return msg;
+       cm_id_priv->msg = msg;
+       return msg;
+}
+
+static void cm_free_priv_msg(struct ib_mad_send_buf *msg)
+{
+       struct cm_id_private *cm_id_priv = msg->context[0];
+
+       lockdep_assert_held(&cm_id_priv->lock);
+
+       if (!WARN_ON(cm_id_priv->msg != msg))
+               cm_id_priv->msg = NULL;
+
+       if (msg->ah)
+               rdma_destroy_ah(msg->ah, 0);
+       cm_deref_id(cm_id_priv);
+       ib_free_send_mad(msg);
 }
 
 static struct ib_mad_send_buf *cm_alloc_response_msg_no_ah(struct cm_port *port,
@@ -412,7 +442,7 @@ static int cm_alloc_response_msg(struct cm_port *port,
 
        ret = cm_create_response_msg_ah(port, mad_recv_wc, m);
        if (ret) {
-               cm_free_msg(m);
+               ib_free_send_mad(m);
                return ret;
        }
 
@@ -420,8 +450,14 @@ static int cm_alloc_response_msg(struct cm_port *port,
        return 0;
 }
 
-static void * cm_copy_private_data(const void *private_data,
-                                  u8 private_data_len)
+static void cm_free_response_msg(struct ib_mad_send_buf *msg)
+{
+       if (msg->ah)
+               rdma_destroy_ah(msg->ah, 0);
+       ib_free_send_mad(msg);
+}
+
+static void *cm_copy_private_data(const void *private_data, u8 private_data_len)
 {
        void *data;
 
@@ -680,8 +716,8 @@ static struct cm_id_private *cm_insert_listen(struct cm_id_private *cm_id_priv,
        return cm_id_priv;
 }
 
-static struct cm_id_private * cm_find_listen(struct ib_device *device,
-                                            __be64 service_id)
+static struct cm_id_private *cm_find_listen(struct ib_device *device,
+                                           __be64 service_id)
 {
        struct rb_node *node = cm.listen_service_table.rb_node;
        struct cm_id_private *cm_id_priv;
@@ -708,8 +744,8 @@ static struct cm_id_private * cm_find_listen(struct ib_device *device,
        return NULL;
 }
 
-static struct cm_timewait_info * cm_insert_remote_id(struct cm_timewait_info
-                                                    *timewait_info)
+static struct cm_timewait_info *
+cm_insert_remote_id(struct cm_timewait_info *timewait_info)
 {
        struct rb_node **link = &cm.remote_id_table.rb_node;
        struct rb_node *parent = NULL;
@@ -767,8 +803,8 @@ static struct cm_id_private *cm_find_remote_id(__be64 remote_ca_guid,
        return res;
 }
 
-static struct cm_timewait_info * cm_insert_remote_qpn(struct cm_timewait_info
-                                                     *timewait_info)
+static struct cm_timewait_info *
+cm_insert_remote_qpn(struct cm_timewait_info *timewait_info)
 {
        struct rb_node **link = &cm.remote_qp_table.rb_node;
        struct rb_node *parent = NULL;
@@ -797,8 +833,8 @@ static struct cm_timewait_info * cm_insert_remote_qpn(struct cm_timewait_info
        return NULL;
 }
 
-static struct cm_id_private * cm_insert_remote_sidr(struct cm_id_private
-                                                   *cm_id_priv)
+static struct cm_id_private *
+cm_insert_remote_sidr(struct cm_id_private *cm_id_priv)
 {
        struct rb_node **link = &cm.remote_sidr_table.rb_node;
        struct rb_node *parent = NULL;
@@ -897,7 +933,7 @@ struct ib_cm_id *ib_create_cm_id(struct ib_device *device,
 }
 EXPORT_SYMBOL(ib_create_cm_id);
 
-static struct cm_work * cm_dequeue_work(struct cm_id_private *cm_id_priv)
+static struct cm_work *cm_dequeue_work(struct cm_id_private *cm_id_priv)
 {
        struct cm_work *work;
 
@@ -986,7 +1022,7 @@ static void cm_remove_remote(struct cm_id_private *cm_id_priv)
        }
 }
 
-static struct cm_timewait_info * cm_create_timewait_info(__be32 local_id)
+static struct cm_timewait_info *cm_create_timewait_info(__be32 local_id)
 {
        struct cm_timewait_info *timewait_info;
 
@@ -1501,6 +1537,7 @@ int ib_send_cm_req(struct ib_cm_id *cm_id,
                   struct ib_cm_req_param *param)
 {
        struct cm_id_private *cm_id_priv;
+       struct ib_mad_send_buf *msg;
        struct cm_req_msg *req_msg;
        unsigned long flags;
        int ret;
@@ -1552,33 +1589,36 @@ int ib_send_cm_req(struct ib_cm_id *cm_id,
        cm_id_priv->pkey = param->primary_path->pkey;
        cm_id_priv->qp_type = param->qp_type;
 
-       ret = cm_alloc_msg(cm_id_priv, &cm_id_priv->msg);
-       if (ret)
-               goto out;
+       spin_lock_irqsave(&cm_id_priv->lock, flags);
+       msg = cm_alloc_priv_msg(cm_id_priv);
+       if (IS_ERR(msg)) {
+               ret = PTR_ERR(msg);
+               goto out_unlock;
+       }
 
-       req_msg = (struct cm_req_msg *) cm_id_priv->msg->mad;
+       req_msg = (struct cm_req_msg *)msg->mad;
        cm_format_req(req_msg, cm_id_priv, param);
        cm_id_priv->tid = req_msg->hdr.tid;
-       cm_id_priv->msg->timeout_ms = cm_id_priv->timeout_ms;
-       cm_id_priv->msg->context[1] = (void *) (unsigned long) IB_CM_REQ_SENT;
+       msg->timeout_ms = cm_id_priv->timeout_ms;
+       msg->context[1] = (void *)(unsigned long)IB_CM_REQ_SENT;
 
        cm_id_priv->local_qpn = cpu_to_be32(IBA_GET(CM_REQ_LOCAL_QPN, req_msg));
        cm_id_priv->rq_psn = cpu_to_be32(IBA_GET(CM_REQ_STARTING_PSN, req_msg));
 
        trace_icm_send_req(&cm_id_priv->id);
-       spin_lock_irqsave(&cm_id_priv->lock, flags);
-       ret = ib_post_send_mad(cm_id_priv->msg, NULL);
-       if (ret) {
-               spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-               goto error2;
-       }
+       ret = ib_post_send_mad(msg, NULL);
+       if (ret)
+               goto out_free;
        BUG_ON(cm_id->state != IB_CM_IDLE);
        cm_id->state = IB_CM_REQ_SENT;
        spin_unlock_irqrestore(&cm_id_priv->lock, flags);
        return 0;
-
-error2:        cm_free_msg(cm_id_priv->msg);
-out:   return ret;
+out_free:
+       cm_free_priv_msg(msg);
+out_unlock:
+       spin_unlock_irqrestore(&cm_id_priv->lock, flags);
+out:
+       return ret;
 }
 EXPORT_SYMBOL(ib_send_cm_req);
 
@@ -1618,7 +1658,7 @@ static int cm_issue_rej(struct cm_port *port,
                IBA_GET(CM_REJ_REMOTE_COMM_ID, rcv_msg));
        ret = ib_post_send_mad(msg, NULL);
        if (ret)
-               cm_free_msg(msg);
+               cm_free_response_msg(msg);
 
        return ret;
 }
@@ -1631,7 +1671,7 @@ static bool cm_req_has_alt_path(struct cm_req_msg *req_msg)
                                               req_msg))));
 }
 
-static void cm_path_set_rec_type(struct ib_device *ib_device, u8 port_num,
+static void cm_path_set_rec_type(struct ib_device *ib_device, u32 port_num,
                                 struct sa_path_rec *path, union ib_gid *gid)
 {
        if (ib_is_opa_gid(gid) && rdma_cap_opa_ah(ib_device, port_num))
@@ -1750,7 +1790,7 @@ static void cm_format_paths_from_req(struct cm_req_msg *req_msg,
 static u16 cm_get_bth_pkey(struct cm_work *work)
 {
        struct ib_device *ib_dev = work->port->cm_dev->ib_device;
-       u8 port_num = work->port->port_num;
+       u32 port_num = work->port->port_num;
        u16 pkey_index = work->mad_recv_wc->wc->pkey_index;
        u16 pkey;
        int ret;
@@ -1778,7 +1818,7 @@ static void cm_opa_to_ib_sgid(struct cm_work *work,
                              struct sa_path_rec *path)
 {
        struct ib_device *dev = work->port->cm_dev->ib_device;
-       u8 port_num = work->port->port_num;
+       u32 port_num = work->port->port_num;
 
        if (rdma_cap_opa_ah(dev, port_num) &&
            (ib_is_opa_gid(&path->sgid))) {
@@ -1974,11 +2014,11 @@ static void cm_dup_req_handler(struct cm_work *work,
        return;
 
 unlock:        spin_unlock_irq(&cm_id_priv->lock);
-free:  cm_free_msg(msg);
+free:  cm_free_response_msg(msg);
 }
 
-static struct cm_id_private * cm_match_req(struct cm_work *work,
-                                          struct cm_id_private *cm_id_priv)
+static struct cm_id_private *cm_match_req(struct cm_work *work,
+                                         struct cm_id_private *cm_id_priv)
 {
        struct cm_id_private *listen_cm_id_priv, *cur_cm_id_priv;
        struct cm_timewait_info *timewait_info;
@@ -2138,20 +2178,17 @@ static int cm_req_handler(struct cm_work *work)
                goto destroy;
        }
 
-       cm_process_routed_req(req_msg, work->mad_recv_wc->wc);
-
        memset(&work->path[0], 0, sizeof(work->path[0]));
        if (cm_req_has_alt_path(req_msg))
                memset(&work->path[1], 0, sizeof(work->path[1]));
        grh = rdma_ah_read_grh(&cm_id_priv->av.ah_attr);
        gid_attr = grh->sgid_attr;
 
-       if (gid_attr &&
-           rdma_protocol_roce(work->port->cm_dev->ib_device,
-                              work->port->port_num)) {
+       if (cm_id_priv->av.ah_attr.type == RDMA_AH_ATTR_TYPE_ROCE) {
                work->path[0].rec_type =
                        sa_conv_gid_to_pathrec_type(gid_attr->gid_type);
        } else {
+               cm_process_routed_req(req_msg, work->mad_recv_wc->wc);
                cm_path_set_rec_type(
                        work->port->cm_dev->ib_device, work->port->port_num,
                        &work->path[0],
@@ -2286,9 +2323,11 @@ int ib_send_cm_rep(struct ib_cm_id *cm_id,
                goto out;
        }
 
-       ret = cm_alloc_msg(cm_id_priv, &msg);
-       if (ret)
+       msg = cm_alloc_priv_msg(cm_id_priv);
+       if (IS_ERR(msg)) {
+               ret = PTR_ERR(msg);
                goto out;
+       }
 
        rep_msg = (struct cm_rep_msg *) msg->mad;
        cm_format_rep(rep_msg, cm_id_priv, param);
@@ -2297,14 +2336,10 @@ int ib_send_cm_rep(struct ib_cm_id *cm_id,
 
        trace_icm_send_rep(cm_id);
        ret = ib_post_send_mad(msg, NULL);
-       if (ret) {
-               spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-               cm_free_msg(msg);
-               return ret;
-       }
+       if (ret)
+               goto out_free;
 
        cm_id->state = IB_CM_REP_SENT;
-       cm_id_priv->msg = msg;
        cm_id_priv->initiator_depth = param->initiator_depth;
        cm_id_priv->responder_resources = param->responder_resources;
        cm_id_priv->rq_psn = cpu_to_be32(IBA_GET(CM_REP_STARTING_PSN, rep_msg));
@@ -2312,8 +2347,13 @@ int ib_send_cm_rep(struct ib_cm_id *cm_id,
                  "IBTA declares QPN to be 24 bits, but it is 0x%X\n",
                  param->qp_num);
        cm_id_priv->local_qpn = cpu_to_be32(param->qp_num & 0xFFFFFF);
+       spin_unlock_irqrestore(&cm_id_priv->lock, flags);
+       return 0;
 
-out:   spin_unlock_irqrestore(&cm_id_priv->lock, flags);
+out_free:
+       cm_free_priv_msg(msg);
+out:
+       spin_unlock_irqrestore(&cm_id_priv->lock, flags);
        return ret;
 }
 EXPORT_SYMBOL(ib_send_cm_rep);
@@ -2360,9 +2400,11 @@ int ib_send_cm_rtu(struct ib_cm_id *cm_id,
                goto error;
        }
 
-       ret = cm_alloc_msg(cm_id_priv, &msg);
-       if (ret)
+       msg = cm_alloc_msg(cm_id_priv);
+       if (IS_ERR(msg)) {
+               ret = PTR_ERR(msg);
                goto error;
+       }
 
        cm_format_rtu((struct cm_rtu_msg *) msg->mad, cm_id_priv,
                      private_data, private_data_len);
@@ -2456,7 +2498,7 @@ static void cm_dup_rep_handler(struct cm_work *work)
        goto deref;
 
 unlock:        spin_unlock_irq(&cm_id_priv->lock);
-free:  cm_free_msg(msg);
+free:  cm_free_response_msg(msg);
 deref: cm_deref_id(cm_id_priv);
 }
 
@@ -2660,10 +2702,10 @@ static int cm_send_dreq_locked(struct cm_id_private *cm_id_priv,
            cm_id_priv->id.lap_state == IB_CM_MRA_LAP_RCVD)
                ib_cancel_mad(cm_id_priv->av.port->mad_agent, cm_id_priv->msg);
 
-       ret = cm_alloc_msg(cm_id_priv, &msg);
-       if (ret) {
+       msg = cm_alloc_priv_msg(cm_id_priv);
+       if (IS_ERR(msg)) {
                cm_enter_timewait(cm_id_priv);
-               return ret;
+               return PTR_ERR(msg);
        }
 
        cm_format_dreq((struct cm_dreq_msg *) msg->mad, cm_id_priv,
@@ -2675,12 +2717,11 @@ static int cm_send_dreq_locked(struct cm_id_private *cm_id_priv,
        ret = ib_post_send_mad(msg, NULL);
        if (ret) {
                cm_enter_timewait(cm_id_priv);
-               cm_free_msg(msg);
+               cm_free_priv_msg(msg);
                return ret;
        }
 
        cm_id_priv->id.state = IB_CM_DREQ_SENT;
-       cm_id_priv->msg = msg;
        return 0;
 }
 
@@ -2735,9 +2776,9 @@ static int cm_send_drep_locked(struct cm_id_private *cm_id_priv,
        cm_set_private_data(cm_id_priv, private_data, private_data_len);
        cm_enter_timewait(cm_id_priv);
 
-       ret = cm_alloc_msg(cm_id_priv, &msg);
-       if (ret)
-               return ret;
+       msg = cm_alloc_msg(cm_id_priv);
+       if (IS_ERR(msg))
+               return PTR_ERR(msg);
 
        cm_format_drep((struct cm_drep_msg *) msg->mad, cm_id_priv,
                       private_data, private_data_len);
@@ -2797,7 +2838,7 @@ static int cm_issue_drep(struct cm_port *port,
                IBA_GET(CM_DREQ_REMOTE_COMM_ID, dreq_msg));
        ret = ib_post_send_mad(msg, NULL);
        if (ret)
-               cm_free_msg(msg);
+               cm_free_response_msg(msg);
 
        return ret;
 }
@@ -2856,7 +2897,7 @@ static int cm_dreq_handler(struct cm_work *work)
 
                if (cm_create_response_msg_ah(work->port, work->mad_recv_wc, msg) ||
                    ib_post_send_mad(msg, NULL))
-                       cm_free_msg(msg);
+                       cm_free_response_msg(msg);
                goto deref;
        case IB_CM_DREQ_RCVD:
                atomic_long_inc(&work->port->counter_group[CM_RECV_DUPLICATES].
@@ -2930,9 +2971,9 @@ static int cm_send_rej_locked(struct cm_id_private *cm_id_priv,
        case IB_CM_REP_RCVD:
        case IB_CM_MRA_REP_SENT:
                cm_reset_to_idle(cm_id_priv);
-               ret = cm_alloc_msg(cm_id_priv, &msg);
-               if (ret)
-                       return ret;
+               msg = cm_alloc_msg(cm_id_priv);
+               if (IS_ERR(msg))
+                       return PTR_ERR(msg);
                cm_format_rej((struct cm_rej_msg *)msg->mad, cm_id_priv, reason,
                              ari, ari_length, private_data, private_data_len,
                              state);
@@ -2940,9 +2981,9 @@ static int cm_send_rej_locked(struct cm_id_private *cm_id_priv,
        case IB_CM_REP_SENT:
        case IB_CM_MRA_REP_RCVD:
                cm_enter_timewait(cm_id_priv);
-               ret = cm_alloc_msg(cm_id_priv, &msg);
-               if (ret)
-                       return ret;
+               msg = cm_alloc_msg(cm_id_priv);
+               if (IS_ERR(msg))
+                       return PTR_ERR(msg);
                cm_format_rej((struct cm_rej_msg *)msg->mad, cm_id_priv, reason,
                              ari, ari_length, private_data, private_data_len,
                              state);
@@ -2993,7 +3034,7 @@ static void cm_format_rej_event(struct cm_work *work)
                IBA_GET_MEM_PTR(CM_REJ_PRIVATE_DATA, rej_msg);
 }
 
-static struct cm_id_private * cm_acquire_rejected_id(struct cm_rej_msg *rej_msg)
+static struct cm_id_private *cm_acquire_rejected_id(struct cm_rej_msg *rej_msg)
 {
        struct cm_id_private *cm_id_priv;
        __be32 remote_id;
@@ -3098,7 +3139,7 @@ int ib_send_cm_mra(struct ib_cm_id *cm_id,
        cm_id_priv = container_of(cm_id, struct cm_id_private, id);
 
        spin_lock_irqsave(&cm_id_priv->lock, flags);
-       switch(cm_id_priv->id.state) {
+       switch (cm_id_priv->id.state) {
        case IB_CM_REQ_RCVD:
                cm_state = IB_CM_MRA_REQ_SENT;
                lap_state = cm_id->lap_state;
@@ -3120,13 +3161,15 @@ int ib_send_cm_mra(struct ib_cm_id *cm_id,
        default:
                trace_icm_send_mra_unknown_err(&cm_id_priv->id);
                ret = -EINVAL;
-               goto error1;
+               goto error_unlock;
        }
 
        if (!(service_timeout & IB_CM_MRA_FLAG_DELAY)) {
-               ret = cm_alloc_msg(cm_id_priv, &msg);
-               if (ret)
-                       goto error1;
+               msg = cm_alloc_msg(cm_id_priv);
+               if (IS_ERR(msg)) {
+                       ret = PTR_ERR(msg);
+                       goto error_unlock;
+               }
 
                cm_format_mra((struct cm_mra_msg *) msg->mad, cm_id_priv,
                              msg_response, service_timeout,
@@ -3134,7 +3177,7 @@ int ib_send_cm_mra(struct ib_cm_id *cm_id,
                trace_icm_send_mra(cm_id);
                ret = ib_post_send_mad(msg, NULL);
                if (ret)
-                       goto error2;
+                       goto error_free_msg;
        }
 
        cm_id->state = cm_state;
@@ -3144,18 +3187,16 @@ int ib_send_cm_mra(struct ib_cm_id *cm_id,
        spin_unlock_irqrestore(&cm_id_priv->lock, flags);
        return 0;
 
-error1:        spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-       kfree(data);
-       return ret;
-
-error2:        spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-       kfree(data);
+error_free_msg:
        cm_free_msg(msg);
+error_unlock:
+       spin_unlock_irqrestore(&cm_id_priv->lock, flags);
+       kfree(data);
        return ret;
 }
 EXPORT_SYMBOL(ib_send_cm_mra);
 
-static struct cm_id_private * cm_acquire_mraed_id(struct cm_mra_msg *mra_msg)
+static struct cm_id_private *cm_acquire_mraed_id(struct cm_mra_msg *mra_msg)
 {
        switch (IBA_GET(CM_MRA_MESSAGE_MRAED, mra_msg)) {
        case CM_MSG_RESPONSE_REQ:
@@ -3346,7 +3387,7 @@ static int cm_lap_handler(struct cm_work *work)
 
                if (cm_create_response_msg_ah(work->port, work->mad_recv_wc, msg) ||
                    ib_post_send_mad(msg, NULL))
-                       cm_free_msg(msg);
+                       cm_free_response_msg(msg);
                goto deref;
        case IB_CM_LAP_RCVD:
                atomic_long_inc(&work->port->counter_group[CM_RECV_DUPLICATES].
@@ -3486,38 +3527,41 @@ int ib_send_cm_sidr_req(struct ib_cm_id *cm_id,
                                 &cm_id_priv->av,
                                 cm_id_priv);
        if (ret)
-               goto out;
+               return ret;
 
        cm_id->service_id = param->service_id;
        cm_id->service_mask = ~cpu_to_be64(0);
        cm_id_priv->timeout_ms = param->timeout_ms;
        cm_id_priv->max_cm_retries = param->max_cm_retries;
-       ret = cm_alloc_msg(cm_id_priv, &msg);
-       if (ret)
-               goto out;
-
-       cm_format_sidr_req((struct cm_sidr_req_msg *) msg->mad, cm_id_priv,
-                          param);
-       msg->timeout_ms = cm_id_priv->timeout_ms;
-       msg->context[1] = (void *) (unsigned long) IB_CM_SIDR_REQ_SENT;
 
        spin_lock_irqsave(&cm_id_priv->lock, flags);
-       if (cm_id->state == IB_CM_IDLE) {
-               trace_icm_send_sidr_req(&cm_id_priv->id);
-               ret = ib_post_send_mad(msg, NULL);
-       } else {
+       if (cm_id->state != IB_CM_IDLE) {
                ret = -EINVAL;
+               goto out_unlock;
        }
 
-       if (ret) {
-               spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-               cm_free_msg(msg);
-               goto out;
+       msg = cm_alloc_priv_msg(cm_id_priv);
+       if (IS_ERR(msg)) {
+               ret = PTR_ERR(msg);
+               goto out_unlock;
        }
+
+       cm_format_sidr_req((struct cm_sidr_req_msg *)msg->mad, cm_id_priv,
+                          param);
+       msg->timeout_ms = cm_id_priv->timeout_ms;
+       msg->context[1] = (void *)(unsigned long)IB_CM_SIDR_REQ_SENT;
+
+       trace_icm_send_sidr_req(&cm_id_priv->id);
+       ret = ib_post_send_mad(msg, NULL);
+       if (ret)
+               goto out_free;
        cm_id->state = IB_CM_SIDR_REQ_SENT;
-       cm_id_priv->msg = msg;
        spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-out:
+       return 0;
+out_free:
+       cm_free_priv_msg(msg);
+out_unlock:
+       spin_unlock_irqrestore(&cm_id_priv->lock, flags);
        return ret;
 }
 EXPORT_SYMBOL(ib_send_cm_sidr_req);
@@ -3664,9 +3708,9 @@ static int cm_send_sidr_rep_locked(struct cm_id_private *cm_id_priv,
        if (cm_id_priv->id.state != IB_CM_SIDR_REQ_RCVD)
                return -EINVAL;
 
-       ret = cm_alloc_msg(cm_id_priv, &msg);
-       if (ret)
-               return ret;
+       msg = cm_alloc_msg(cm_id_priv);
+       if (IS_ERR(msg))
+               return PTR_ERR(msg);
 
        cm_format_sidr_rep((struct cm_sidr_rep_msg *) msg->mad, cm_id_priv,
                           param);
@@ -3917,8 +3961,7 @@ static int cm_establish(struct ib_cm_id *cm_id)
 
        cm_id_priv = container_of(cm_id, struct cm_id_private, id);
        spin_lock_irqsave(&cm_id_priv->lock, flags);
-       switch (cm_id->state)
-       {
+       switch (cm_id->state) {
        case IB_CM_REP_SENT:
        case IB_CM_MRA_REP_RCVD:
                cm_id->state = IB_CM_ESTABLISHED;
@@ -4334,7 +4377,7 @@ static int cm_add_one(struct ib_device *ib_device)
        unsigned long flags;
        int ret;
        int count = 0;
-       unsigned int i;
+       u32 i;
 
        cm_dev = kzalloc(struct_size(cm_dev, port, ib_device->phys_port_cnt),
                         GFP_KERNEL);
@@ -4432,7 +4475,7 @@ static void cm_remove_one(struct ib_device *ib_device, void *client_data)
                .clr_port_cap_mask = IB_PORT_CM_SUP
        };
        unsigned long flags;
-       unsigned int i;
+       u32 i;
 
        write_lock_irqsave(&cm.device_lock, flags);
        list_del(&cm_dev->list);