X-Git-Url: https://git.proxmox.com/?a=blobdiff_plain;f=drivers%2Finfiniband%2Fcore%2Fcm.c;h=94613275edccb4e6965d067a95a85cb9829bde5f;hb=4b4e586ebe37c8c7e2a4bf46dc4b742756fd788d;hp=0ead0d223154011532402d16eb7a8e7b443e1cf6;hpb=7ac86b3dca1b00f5391d346fdea3ac010d230667;p=mirror_ubuntu-jammy-kernel.git diff --git a/drivers/infiniband/core/cm.c b/drivers/infiniband/core/cm.c index 0ead0d223154..94613275edcc 100644 --- a/drivers/infiniband/core/cm.c +++ b/drivers/infiniband/core/cm.c @@ -305,8 +305,7 @@ static inline void cm_deref_id(struct cm_id_private *cm_id_priv) complete(&cm_id_priv->comp); } -static int cm_alloc_msg(struct cm_id_private *cm_id_priv, - struct ib_mad_send_buf **msg) +static struct ib_mad_send_buf *cm_alloc_msg(struct cm_id_private *cm_id_priv) { struct ib_mad_agent *mad_agent; struct ib_mad_send_buf *m; @@ -359,12 +358,42 @@ static int cm_alloc_msg(struct cm_id_private *cm_id_priv, m->retries = cm_id_priv->max_cm_retries; refcount_inc(&cm_id_priv->refcount); + spin_unlock_irqrestore(&cm.state_lock, flags2); m->context[0] = cm_id_priv; - *msg = m; + return m; out: spin_unlock_irqrestore(&cm.state_lock, flags2); - return ret; + return ERR_PTR(ret); +} + +static struct ib_mad_send_buf * +cm_alloc_priv_msg(struct cm_id_private *cm_id_priv) +{ + struct ib_mad_send_buf *msg; + + lockdep_assert_held(&cm_id_priv->lock); + + msg = cm_alloc_msg(cm_id_priv); + if (IS_ERR(msg)) + return msg; + cm_id_priv->msg = msg; + return msg; +} + +static void cm_free_priv_msg(struct ib_mad_send_buf *msg) +{ + struct cm_id_private *cm_id_priv = msg->context[0]; + + lockdep_assert_held(&cm_id_priv->lock); + + if (!WARN_ON(cm_id_priv->msg != msg)) + cm_id_priv->msg = NULL; + + if (msg->ah) + rdma_destroy_ah(msg->ah, 0); + cm_deref_id(cm_id_priv); + ib_free_send_mad(msg); } static struct ib_mad_send_buf *cm_alloc_response_msg_no_ah(struct cm_port *port, @@ -413,7 +442,7 @@ static int cm_alloc_response_msg(struct cm_port *port, ret = cm_create_response_msg_ah(port, mad_recv_wc, m); if (ret) { - cm_free_msg(m); + ib_free_send_mad(m); return ret; } @@ -421,6 +450,13 @@ static int cm_alloc_response_msg(struct cm_port *port, return 0; } +static void cm_free_response_msg(struct ib_mad_send_buf *msg) +{ + if (msg->ah) + rdma_destroy_ah(msg->ah, 0); + ib_free_send_mad(msg); +} + static void *cm_copy_private_data(const void *private_data, u8 private_data_len) { void *data; @@ -1501,6 +1537,7 @@ int ib_send_cm_req(struct ib_cm_id *cm_id, struct ib_cm_req_param *param) { struct cm_id_private *cm_id_priv; + struct ib_mad_send_buf *msg; struct cm_req_msg *req_msg; unsigned long flags; int ret; @@ -1552,33 +1589,36 @@ int ib_send_cm_req(struct ib_cm_id *cm_id, cm_id_priv->pkey = param->primary_path->pkey; cm_id_priv->qp_type = param->qp_type; - ret = cm_alloc_msg(cm_id_priv, &cm_id_priv->msg); - if (ret) - goto out; + spin_lock_irqsave(&cm_id_priv->lock, flags); + msg = cm_alloc_priv_msg(cm_id_priv); + if (IS_ERR(msg)) { + ret = PTR_ERR(msg); + goto out_unlock; + } - req_msg = (struct cm_req_msg *) cm_id_priv->msg->mad; + req_msg = (struct cm_req_msg *)msg->mad; cm_format_req(req_msg, cm_id_priv, param); cm_id_priv->tid = req_msg->hdr.tid; - cm_id_priv->msg->timeout_ms = cm_id_priv->timeout_ms; - cm_id_priv->msg->context[1] = (void *) (unsigned long) IB_CM_REQ_SENT; + msg->timeout_ms = cm_id_priv->timeout_ms; + msg->context[1] = (void *)(unsigned long)IB_CM_REQ_SENT; cm_id_priv->local_qpn = cpu_to_be32(IBA_GET(CM_REQ_LOCAL_QPN, req_msg)); cm_id_priv->rq_psn = cpu_to_be32(IBA_GET(CM_REQ_STARTING_PSN, req_msg)); trace_icm_send_req(&cm_id_priv->id); - spin_lock_irqsave(&cm_id_priv->lock, flags); - ret = ib_post_send_mad(cm_id_priv->msg, NULL); - if (ret) { - spin_unlock_irqrestore(&cm_id_priv->lock, flags); - goto error2; - } + ret = ib_post_send_mad(msg, NULL); + if (ret) + goto out_free; BUG_ON(cm_id->state != IB_CM_IDLE); cm_id->state = IB_CM_REQ_SENT; spin_unlock_irqrestore(&cm_id_priv->lock, flags); return 0; - -error2: cm_free_msg(cm_id_priv->msg); -out: return ret; +out_free: + cm_free_priv_msg(msg); +out_unlock: + spin_unlock_irqrestore(&cm_id_priv->lock, flags); +out: + return ret; } EXPORT_SYMBOL(ib_send_cm_req); @@ -1618,7 +1658,7 @@ static int cm_issue_rej(struct cm_port *port, IBA_GET(CM_REJ_REMOTE_COMM_ID, rcv_msg)); ret = ib_post_send_mad(msg, NULL); if (ret) - cm_free_msg(msg); + cm_free_response_msg(msg); return ret; } @@ -1974,7 +2014,7 @@ static void cm_dup_req_handler(struct cm_work *work, return; unlock: spin_unlock_irq(&cm_id_priv->lock); -free: cm_free_msg(msg); +free: cm_free_response_msg(msg); } static struct cm_id_private *cm_match_req(struct cm_work *work, @@ -2283,9 +2323,11 @@ int ib_send_cm_rep(struct ib_cm_id *cm_id, goto out; } - ret = cm_alloc_msg(cm_id_priv, &msg); - if (ret) + msg = cm_alloc_priv_msg(cm_id_priv); + if (IS_ERR(msg)) { + ret = PTR_ERR(msg); goto out; + } rep_msg = (struct cm_rep_msg *) msg->mad; cm_format_rep(rep_msg, cm_id_priv, param); @@ -2294,14 +2336,10 @@ int ib_send_cm_rep(struct ib_cm_id *cm_id, trace_icm_send_rep(cm_id); ret = ib_post_send_mad(msg, NULL); - if (ret) { - spin_unlock_irqrestore(&cm_id_priv->lock, flags); - cm_free_msg(msg); - return ret; - } + if (ret) + goto out_free; cm_id->state = IB_CM_REP_SENT; - cm_id_priv->msg = msg; cm_id_priv->initiator_depth = param->initiator_depth; cm_id_priv->responder_resources = param->responder_resources; cm_id_priv->rq_psn = cpu_to_be32(IBA_GET(CM_REP_STARTING_PSN, rep_msg)); @@ -2309,8 +2347,13 @@ int ib_send_cm_rep(struct ib_cm_id *cm_id, "IBTA declares QPN to be 24 bits, but it is 0x%X\n", param->qp_num); cm_id_priv->local_qpn = cpu_to_be32(param->qp_num & 0xFFFFFF); + spin_unlock_irqrestore(&cm_id_priv->lock, flags); + return 0; -out: spin_unlock_irqrestore(&cm_id_priv->lock, flags); +out_free: + cm_free_priv_msg(msg); +out: + spin_unlock_irqrestore(&cm_id_priv->lock, flags); return ret; } EXPORT_SYMBOL(ib_send_cm_rep); @@ -2357,9 +2400,11 @@ int ib_send_cm_rtu(struct ib_cm_id *cm_id, goto error; } - ret = cm_alloc_msg(cm_id_priv, &msg); - if (ret) + msg = cm_alloc_msg(cm_id_priv); + if (IS_ERR(msg)) { + ret = PTR_ERR(msg); goto error; + } cm_format_rtu((struct cm_rtu_msg *) msg->mad, cm_id_priv, private_data, private_data_len); @@ -2453,7 +2498,7 @@ static void cm_dup_rep_handler(struct cm_work *work) goto deref; unlock: spin_unlock_irq(&cm_id_priv->lock); -free: cm_free_msg(msg); +free: cm_free_response_msg(msg); deref: cm_deref_id(cm_id_priv); } @@ -2657,10 +2702,10 @@ static int cm_send_dreq_locked(struct cm_id_private *cm_id_priv, cm_id_priv->id.lap_state == IB_CM_MRA_LAP_RCVD) ib_cancel_mad(cm_id_priv->av.port->mad_agent, cm_id_priv->msg); - ret = cm_alloc_msg(cm_id_priv, &msg); - if (ret) { + msg = cm_alloc_priv_msg(cm_id_priv); + if (IS_ERR(msg)) { cm_enter_timewait(cm_id_priv); - return ret; + return PTR_ERR(msg); } cm_format_dreq((struct cm_dreq_msg *) msg->mad, cm_id_priv, @@ -2672,12 +2717,11 @@ static int cm_send_dreq_locked(struct cm_id_private *cm_id_priv, ret = ib_post_send_mad(msg, NULL); if (ret) { cm_enter_timewait(cm_id_priv); - cm_free_msg(msg); + cm_free_priv_msg(msg); return ret; } cm_id_priv->id.state = IB_CM_DREQ_SENT; - cm_id_priv->msg = msg; return 0; } @@ -2732,9 +2776,9 @@ static int cm_send_drep_locked(struct cm_id_private *cm_id_priv, cm_set_private_data(cm_id_priv, private_data, private_data_len); cm_enter_timewait(cm_id_priv); - ret = cm_alloc_msg(cm_id_priv, &msg); - if (ret) - return ret; + msg = cm_alloc_msg(cm_id_priv); + if (IS_ERR(msg)) + return PTR_ERR(msg); cm_format_drep((struct cm_drep_msg *) msg->mad, cm_id_priv, private_data, private_data_len); @@ -2794,7 +2838,7 @@ static int cm_issue_drep(struct cm_port *port, IBA_GET(CM_DREQ_REMOTE_COMM_ID, dreq_msg)); ret = ib_post_send_mad(msg, NULL); if (ret) - cm_free_msg(msg); + cm_free_response_msg(msg); return ret; } @@ -2853,7 +2897,7 @@ static int cm_dreq_handler(struct cm_work *work) if (cm_create_response_msg_ah(work->port, work->mad_recv_wc, msg) || ib_post_send_mad(msg, NULL)) - cm_free_msg(msg); + cm_free_response_msg(msg); goto deref; case IB_CM_DREQ_RCVD: atomic_long_inc(&work->port->counter_group[CM_RECV_DUPLICATES]. @@ -2927,9 +2971,9 @@ static int cm_send_rej_locked(struct cm_id_private *cm_id_priv, case IB_CM_REP_RCVD: case IB_CM_MRA_REP_SENT: cm_reset_to_idle(cm_id_priv); - ret = cm_alloc_msg(cm_id_priv, &msg); - if (ret) - return ret; + msg = cm_alloc_msg(cm_id_priv); + if (IS_ERR(msg)) + return PTR_ERR(msg); cm_format_rej((struct cm_rej_msg *)msg->mad, cm_id_priv, reason, ari, ari_length, private_data, private_data_len, state); @@ -2937,9 +2981,9 @@ static int cm_send_rej_locked(struct cm_id_private *cm_id_priv, case IB_CM_REP_SENT: case IB_CM_MRA_REP_RCVD: cm_enter_timewait(cm_id_priv); - ret = cm_alloc_msg(cm_id_priv, &msg); - if (ret) - return ret; + msg = cm_alloc_msg(cm_id_priv); + if (IS_ERR(msg)) + return PTR_ERR(msg); cm_format_rej((struct cm_rej_msg *)msg->mad, cm_id_priv, reason, ari, ari_length, private_data, private_data_len, state); @@ -3117,13 +3161,15 @@ int ib_send_cm_mra(struct ib_cm_id *cm_id, default: trace_icm_send_mra_unknown_err(&cm_id_priv->id); ret = -EINVAL; - goto error1; + goto error_unlock; } if (!(service_timeout & IB_CM_MRA_FLAG_DELAY)) { - ret = cm_alloc_msg(cm_id_priv, &msg); - if (ret) - goto error1; + msg = cm_alloc_msg(cm_id_priv); + if (IS_ERR(msg)) { + ret = PTR_ERR(msg); + goto error_unlock; + } cm_format_mra((struct cm_mra_msg *) msg->mad, cm_id_priv, msg_response, service_timeout, @@ -3131,7 +3177,7 @@ int ib_send_cm_mra(struct ib_cm_id *cm_id, trace_icm_send_mra(cm_id); ret = ib_post_send_mad(msg, NULL); if (ret) - goto error2; + goto error_free_msg; } cm_id->state = cm_state; @@ -3141,13 +3187,11 @@ int ib_send_cm_mra(struct ib_cm_id *cm_id, spin_unlock_irqrestore(&cm_id_priv->lock, flags); return 0; -error1: spin_unlock_irqrestore(&cm_id_priv->lock, flags); - kfree(data); - return ret; - -error2: spin_unlock_irqrestore(&cm_id_priv->lock, flags); - kfree(data); +error_free_msg: cm_free_msg(msg); +error_unlock: + spin_unlock_irqrestore(&cm_id_priv->lock, flags); + kfree(data); return ret; } EXPORT_SYMBOL(ib_send_cm_mra); @@ -3343,7 +3387,7 @@ static int cm_lap_handler(struct cm_work *work) if (cm_create_response_msg_ah(work->port, work->mad_recv_wc, msg) || ib_post_send_mad(msg, NULL)) - cm_free_msg(msg); + cm_free_response_msg(msg); goto deref; case IB_CM_LAP_RCVD: atomic_long_inc(&work->port->counter_group[CM_RECV_DUPLICATES]. @@ -3483,38 +3527,41 @@ int ib_send_cm_sidr_req(struct ib_cm_id *cm_id, &cm_id_priv->av, cm_id_priv); if (ret) - goto out; + return ret; cm_id->service_id = param->service_id; cm_id->service_mask = ~cpu_to_be64(0); cm_id_priv->timeout_ms = param->timeout_ms; cm_id_priv->max_cm_retries = param->max_cm_retries; - ret = cm_alloc_msg(cm_id_priv, &msg); - if (ret) - goto out; - - cm_format_sidr_req((struct cm_sidr_req_msg *) msg->mad, cm_id_priv, - param); - msg->timeout_ms = cm_id_priv->timeout_ms; - msg->context[1] = (void *) (unsigned long) IB_CM_SIDR_REQ_SENT; spin_lock_irqsave(&cm_id_priv->lock, flags); - if (cm_id->state == IB_CM_IDLE) { - trace_icm_send_sidr_req(&cm_id_priv->id); - ret = ib_post_send_mad(msg, NULL); - } else { + if (cm_id->state != IB_CM_IDLE) { ret = -EINVAL; + goto out_unlock; } - if (ret) { - spin_unlock_irqrestore(&cm_id_priv->lock, flags); - cm_free_msg(msg); - goto out; + msg = cm_alloc_priv_msg(cm_id_priv); + if (IS_ERR(msg)) { + ret = PTR_ERR(msg); + goto out_unlock; } + + cm_format_sidr_req((struct cm_sidr_req_msg *)msg->mad, cm_id_priv, + param); + msg->timeout_ms = cm_id_priv->timeout_ms; + msg->context[1] = (void *)(unsigned long)IB_CM_SIDR_REQ_SENT; + + trace_icm_send_sidr_req(&cm_id_priv->id); + ret = ib_post_send_mad(msg, NULL); + if (ret) + goto out_free; cm_id->state = IB_CM_SIDR_REQ_SENT; - cm_id_priv->msg = msg; spin_unlock_irqrestore(&cm_id_priv->lock, flags); -out: + return 0; +out_free: + cm_free_priv_msg(msg); +out_unlock: + spin_unlock_irqrestore(&cm_id_priv->lock, flags); return ret; } EXPORT_SYMBOL(ib_send_cm_sidr_req); @@ -3661,9 +3708,9 @@ static int cm_send_sidr_rep_locked(struct cm_id_private *cm_id_priv, if (cm_id_priv->id.state != IB_CM_SIDR_REQ_RCVD) return -EINVAL; - ret = cm_alloc_msg(cm_id_priv, &msg); - if (ret) - return ret; + msg = cm_alloc_msg(cm_id_priv); + if (IS_ERR(msg)) + return PTR_ERR(msg); cm_format_sidr_rep((struct cm_sidr_rep_msg *) msg->mad, cm_id_priv, param);