]> git.proxmox.com Git - mirror_ubuntu-hirsute-kernel.git/blobdiff - net/tls/tls_sw.c
Merge tag 'batadv-next-for-davem-20190201' of git://git.open-mesh.org/linux-merge
[mirror_ubuntu-hirsute-kernel.git] / net / tls / tls_sw.c
index 11cdc8f7db63c7d84d1a6befbafb7f4f491eb8c6..3f2a6af27e6287756032ebd1908b1f0b2fd9120f 100644 (file)
@@ -124,6 +124,7 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
 {
        struct aead_request *aead_req = (struct aead_request *)req;
        struct scatterlist *sgout = aead_req->dst;
+       struct scatterlist *sgin = aead_req->src;
        struct tls_sw_context_rx *ctx;
        struct tls_context *tls_ctx;
        struct scatterlist *sg;
@@ -134,12 +135,16 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
        skb = (struct sk_buff *)req->data;
        tls_ctx = tls_get_ctx(skb->sk);
        ctx = tls_sw_ctx_rx(tls_ctx);
-       pending = atomic_dec_return(&ctx->decrypt_pending);
 
        /* Propagate if there was an err */
        if (err) {
                ctx->async_wait.err = err;
                tls_err_abort(skb->sk, err);
+       } else {
+               struct strp_msg *rxm = strp_msg(skb);
+
+               rxm->offset += tls_ctx->rx.prepend_size;
+               rxm->full_len -= tls_ctx->rx.overhead_size;
        }
 
        /* After using skb->sk to propagate sk through crypto async callback
@@ -147,18 +152,21 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
         */
        skb->sk = NULL;
 
-       /* Release the skb, pages and memory allocated for crypto req */
-       kfree_skb(skb);
 
-       /* Skip the first S/G entry as it points to AAD */
-       for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
-               if (!sg)
-                       break;
-               put_page(sg_page(sg));
+       /* Free the destination pages if skb was not decrypted inplace */
+       if (sgout != sgin) {
+               /* Skip the first S/G entry as it points to AAD */
+               for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
+                       if (!sg)
+                               break;
+                       put_page(sg_page(sg));
+               }
        }
 
        kfree(aead_req);
 
+       pending = atomic_dec_return(&ctx->decrypt_pending);
+
        if (!pending && READ_ONCE(ctx->async_notify))
                complete(&ctx->async_wait.completion);
 }
@@ -439,6 +447,8 @@ static int tls_do_encryption(struct sock *sk,
        struct scatterlist *sge = sk_msg_elem(msg_en, start);
        int rc;
 
+       memcpy(rec->iv_data, tls_ctx->tx.iv, sizeof(rec->iv_data));
+
        sge->offset += tls_ctx->tx.prepend_size;
        sge->length -= tls_ctx->tx.prepend_size;
 
@@ -448,7 +458,7 @@ static int tls_do_encryption(struct sock *sk,
        aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
        aead_request_set_crypt(aead_req, rec->sg_aead_in,
                               rec->sg_aead_out,
-                              data_len, tls_ctx->tx.iv);
+                              data_len, rec->iv_data);
 
        aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
                                  tls_encrypt_done, sk);
@@ -1020,8 +1030,8 @@ send_end:
        return copied ? copied : ret;
 }
 
-int tls_sw_do_sendpage(struct sock *sk, struct page *page,
-                      int offset, size_t size, int flags)
+static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
+                             int offset, size_t size, int flags)
 {
        long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
        struct tls_context *tls_ctx = tls_get_ctx(sk);
@@ -1143,16 +1153,6 @@ sendpage_end:
        return copied ? copied : ret;
 }
 
-int tls_sw_sendpage_locked(struct sock *sk, struct page *page,
-                          int offset, size_t size, int flags)
-{
-       if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
-                     MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY))
-               return -ENOTSUPP;
-
-       return tls_sw_do_sendpage(sk, page, offset, size, flags);
-}
-
 int tls_sw_sendpage(struct sock *sk, struct page *page,
                    int offset, size_t size, int flags)
 {
@@ -1281,7 +1281,7 @@ out:
 static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
                            struct iov_iter *out_iov,
                            struct scatterlist *out_sg,
-                           int *chunk, bool *zc)
+                           int *chunk, bool *zc, bool async)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
@@ -1381,13 +1381,13 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 fallback_to_reg_recv:
                sgout = sgin;
                pages = 0;
-               *chunk = 0;
+               *chunk = data_len;
                *zc = false;
        }
 
        /* Prepare and submit AEAD request */
        err = tls_do_decryption(sk, skb, sgin, sgout, iv,
-                               data_len, aead_req, *zc);
+                               data_len, aead_req, async);
        if (err == -EINPROGRESS)
                return err;
 
@@ -1400,7 +1400,8 @@ fallback_to_reg_recv:
 }
 
 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
-                             struct iov_iter *dest, int *chunk, bool *zc)
+                             struct iov_iter *dest, int *chunk, bool *zc,
+                             bool async)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
@@ -1413,7 +1414,7 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
                return err;
 #endif
        if (!ctx->decrypted) {
-               err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
+               err = decrypt_internal(sk, skb, dest, NULL, chunk, zc, async);
                if (err < 0) {
                        if (err == -EINPROGRESS)
                                tls_advance_record_sn(sk, &tls_ctx->rx);
@@ -1439,7 +1440,7 @@ int decrypt_skb(struct sock *sk, struct sk_buff *skb,
        bool zc = true;
        int chunk;
 
-       return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
+       return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc, false);
 }
 
 static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
@@ -1466,6 +1467,77 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
        return true;
 }
 
+/* This function traverses the rx_list in tls receive context to copies the
+ * decrypted data records into the buffer provided by caller zero copy is not
+ * true. Further, the records are removed from the rx_list if it is not a peek
+ * case and the record has been consumed completely.
+ */
+static int process_rx_list(struct tls_sw_context_rx *ctx,
+                          struct msghdr *msg,
+                          size_t skip,
+                          size_t len,
+                          bool zc,
+                          bool is_peek)
+{
+       struct sk_buff *skb = skb_peek(&ctx->rx_list);
+       ssize_t copied = 0;
+
+       while (skip && skb) {
+               struct strp_msg *rxm = strp_msg(skb);
+
+               if (skip < rxm->full_len)
+                       break;
+
+               skip = skip - rxm->full_len;
+               skb = skb_peek_next(skb, &ctx->rx_list);
+       }
+
+       while (len && skb) {
+               struct sk_buff *next_skb;
+               struct strp_msg *rxm = strp_msg(skb);
+               int chunk = min_t(unsigned int, rxm->full_len - skip, len);
+
+               if (!zc || (rxm->full_len - skip) > len) {
+                       int err = skb_copy_datagram_msg(skb, rxm->offset + skip,
+                                                   msg, chunk);
+                       if (err < 0)
+                               return err;
+               }
+
+               len = len - chunk;
+               copied = copied + chunk;
+
+               /* Consume the data from record if it is non-peek case*/
+               if (!is_peek) {
+                       rxm->offset = rxm->offset + chunk;
+                       rxm->full_len = rxm->full_len - chunk;
+
+                       /* Return if there is unconsumed data in the record */
+                       if (rxm->full_len - skip)
+                               break;
+               }
+
+               /* The remaining skip-bytes must lie in 1st record in rx_list.
+                * So from the 2nd record, 'skip' should be 0.
+                */
+               skip = 0;
+
+               if (msg)
+                       msg->msg_flags |= MSG_EOR;
+
+               next_skb = skb_peek_next(skb, &ctx->rx_list);
+
+               if (!is_peek) {
+                       skb_unlink(skb, &ctx->rx_list);
+                       kfree_skb(skb);
+               }
+
+               skb = next_skb;
+       }
+
+       return copied;
+}
+
 int tls_sw_recvmsg(struct sock *sk,
                   struct msghdr *msg,
                   size_t len,
@@ -1476,7 +1548,8 @@ int tls_sw_recvmsg(struct sock *sk,
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
        struct sk_psock *psock;
-       unsigned char control;
+       unsigned char control = 0;
+       ssize_t decrypted = 0;
        struct strp_msg *rxm;
        struct sk_buff *skb;
        ssize_t copied = 0;
@@ -1484,6 +1557,7 @@ int tls_sw_recvmsg(struct sock *sk,
        int target, err = 0;
        long timeo;
        bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
+       bool is_peek = flags & MSG_PEEK;
        int num_async = 0;
 
        flags |= nonblock;
@@ -1494,11 +1568,28 @@ int tls_sw_recvmsg(struct sock *sk,
        psock = sk_psock_get(sk);
        lock_sock(sk);
 
-       target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
-       timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+       /* Process pending decrypted records. It must be non-zero-copy */
+       err = process_rx_list(ctx, msg, 0, len, false, is_peek);
+       if (err < 0) {
+               tls_err_abort(sk, err);
+               goto end;
+       } else {
+               copied = err;
+       }
+
+       len = len - copied;
+       if (len) {
+               target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
+               timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+       } else {
+               goto recv_end;
+       }
+
        do {
-               bool zc = false;
+               bool retain_skb = false;
                bool async = false;
+               bool zc = false;
+               int to_decrypt;
                int chunk = 0;
 
                skb = tls_wait_data(sk, psock, flags, timeo, &err);
@@ -1508,7 +1599,7 @@ int tls_sw_recvmsg(struct sock *sk,
                                                            msg, len, flags);
 
                                if (ret > 0) {
-                                       copied += ret;
+                                       decrypted += ret;
                                        len -= ret;
                                        continue;
                                }
@@ -1535,70 +1626,70 @@ int tls_sw_recvmsg(struct sock *sk,
                        goto recv_end;
                }
 
-               if (!ctx->decrypted) {
-                       int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
-
-                       if (!is_kvec && to_copy <= len &&
-                           likely(!(flags & MSG_PEEK)))
-                               zc = true;
-
-                       err = decrypt_skb_update(sk, skb, &msg->msg_iter,
-                                                &chunk, &zc);
-                       if (err < 0 && err != -EINPROGRESS) {
-                               tls_err_abort(sk, EBADMSG);
-                               goto recv_end;
-                       }
+               to_decrypt = rxm->full_len - tls_ctx->rx.overhead_size;
 
-                       if (err == -EINPROGRESS) {
-                               async = true;
-                               num_async++;
-                               goto pick_next_record;
-                       }
+               if (to_decrypt <= len && !is_kvec && !is_peek)
+                       zc = true;
 
-                       ctx->decrypted = true;
+               err = decrypt_skb_update(sk, skb, &msg->msg_iter,
+                                        &chunk, &zc, ctx->async_capable);
+               if (err < 0 && err != -EINPROGRESS) {
+                       tls_err_abort(sk, EBADMSG);
+                       goto recv_end;
                }
 
-               if (!zc) {
-                       chunk = min_t(unsigned int, rxm->full_len, len);
+               if (err == -EINPROGRESS) {
+                       async = true;
+                       num_async++;
+                       goto pick_next_record;
+               } else {
+                       if (!zc) {
+                               if (rxm->full_len > len) {
+                                       retain_skb = true;
+                                       chunk = len;
+                               } else {
+                                       chunk = rxm->full_len;
+                               }
 
-                       err = skb_copy_datagram_msg(skb, rxm->offset, msg,
-                                                   chunk);
-                       if (err < 0)
-                               goto recv_end;
+                               err = skb_copy_datagram_msg(skb, rxm->offset,
+                                                           msg, chunk);
+                               if (err < 0)
+                                       goto recv_end;
+
+                               if (!is_peek) {
+                                       rxm->offset = rxm->offset + chunk;
+                                       rxm->full_len = rxm->full_len - chunk;
+                               }
+                       }
                }
 
 pick_next_record:
-               copied += chunk;
+               if (chunk > len)
+                       chunk = len;
+
+               decrypted += chunk;
                len -= chunk;
-               if (likely(!(flags & MSG_PEEK))) {
-                       u8 control = ctx->control;
-
-                       /* For async, drop current skb reference */
-                       if (async)
-                               skb = NULL;
-
-                       if (tls_sw_advance_skb(sk, skb, chunk)) {
-                               /* Return full control message to
-                                * userspace before trying to parse
-                                * another message type
-                                */
-                               msg->msg_flags |= MSG_EOR;
-                               if (control != TLS_RECORD_TYPE_DATA)
-                                       goto recv_end;
-                       } else {
-                               break;
-                       }
-               } else {
-                       /* MSG_PEEK right now cannot look beyond current skb
-                        * from strparser, meaning we cannot advance skb here
-                        * and thus unpause strparser since we'd loose original
-                        * one.
+
+               /* For async or peek case, queue the current skb */
+               if (async || is_peek || retain_skb) {
+                       skb_queue_tail(&ctx->rx_list, skb);
+                       skb = NULL;
+               }
+
+               if (tls_sw_advance_skb(sk, skb, chunk)) {
+                       /* Return full control message to
+                        * userspace before trying to parse
+                        * another message type
                         */
+                       msg->msg_flags |= MSG_EOR;
+                       if (ctx->control != TLS_RECORD_TYPE_DATA)
+                               goto recv_end;
+               } else {
                        break;
                }
 
                /* If we have a new message from strparser, continue now. */
-               if (copied >= target && !ctx->recv_pkt)
+               if (decrypted >= target && !ctx->recv_pkt)
                        break;
        } while (len);
 
@@ -1612,13 +1703,33 @@ recv_end:
                                /* one of async decrypt failed */
                                tls_err_abort(sk, err);
                                copied = 0;
+                               decrypted = 0;
+                               goto end;
                        }
                } else {
                        reinit_completion(&ctx->async_wait.completion);
                }
                WRITE_ONCE(ctx->async_notify, false);
+
+               /* Drain records from the rx_list & copy if required */
+               if (is_peek || is_kvec)
+                       err = process_rx_list(ctx, msg, copied,
+                                             decrypted, false, is_peek);
+               else
+                       err = process_rx_list(ctx, msg, 0,
+                                             decrypted, true, is_peek);
+               if (err < 0) {
+                       tls_err_abort(sk, err);
+                       copied = 0;
+                       goto end;
+               }
+
+               WARN_ON(decrypted != err);
        }
 
+       copied += decrypted;
+
+end:
        release_sock(sk);
        if (psock)
                sk_psock_put(sk, psock);
@@ -1655,7 +1766,7 @@ ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
        }
 
        if (!ctx->decrypted) {
-               err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc);
+               err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
 
                if (err < 0) {
                        tls_err_abort(sk, EBADMSG);
@@ -1792,7 +1903,9 @@ void tls_sw_free_resources_tx(struct sock *sk)
        if (atomic_read(&ctx->encrypt_pending))
                crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
 
+       release_sock(sk);
        cancel_delayed_work_sync(&ctx->tx_work.work);
+       lock_sock(sk);
 
        /* Tx whatever records we can transmit and abandon the rest */
        tls_tx_records(sk, -1);
@@ -1842,6 +1955,7 @@ void tls_sw_release_resources_rx(struct sock *sk)
        if (ctx->aead_recv) {
                kfree_skb(ctx->recv_pkt);
                ctx->recv_pkt = NULL;
+               skb_queue_purge(&ctx->rx_list);
                crypto_free_aead(ctx->aead_recv);
                strp_stop(&ctx->strp);
                write_lock_bh(&sk->sk_callback_lock);
@@ -1891,6 +2005,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
        struct crypto_aead **aead;
        struct strp_callbacks cb;
        u16 nonce_size, tag_size, iv_size, rec_seq_size;
+       struct crypto_tfm *tfm;
        char *iv, *rec_seq;
        int rc = 0;
 
@@ -1937,6 +2052,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
                crypto_init_wait(&sw_ctx_rx->async_wait);
                crypto_info = &ctx->crypto_recv.info;
                cctx = &ctx->rx;
+               skb_queue_head_init(&sw_ctx_rx->rx_list);
                aead = &sw_ctx_rx->aead_recv;
        }
 
@@ -2004,6 +2120,10 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
                goto free_aead;
 
        if (sw_ctx_rx) {
+               tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
+               sw_ctx_rx->async_capable =
+                       tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC;
+
                /* Set up strparser */
                memset(&cb, 0, sizeof(cb));
                cb.rcv_msg = tls_queue;