]> git.proxmox.com Git - mirror_ubuntu-bionic-kernel.git/blobdiff - arch/s390/crypto/aes_s390.c
s390/crypto: fix gcm-aes-s390 selftest failures
[mirror_ubuntu-bionic-kernel.git] / arch / s390 / crypto / aes_s390.c
index ad47abd086308d3040c12807c54174f84e2e35d6..39c850b39582a1c3fe65bead997873234033232c 100644 (file)
@@ -826,19 +826,45 @@ static int gcm_aes_setauthsize(struct crypto_aead *tfm, unsigned int authsize)
        return 0;
 }
 
-static void gcm_sg_walk_start(struct gcm_sg_walk *gw, struct scatterlist *sg,
-                             unsigned int len)
+static void gcm_walk_start(struct gcm_sg_walk *gw, struct scatterlist *sg,
+                          unsigned int len)
 {
        memset(gw, 0, sizeof(*gw));
        gw->walk_bytes_remain = len;
        scatterwalk_start(&gw->walk, sg);
 }
 
-static int gcm_sg_walk_go(struct gcm_sg_walk *gw, unsigned int minbytesneeded)
+static inline unsigned int _gcm_sg_clamp_and_map(struct gcm_sg_walk *gw)
+{
+       struct scatterlist *nextsg;
+
+       gw->walk_bytes = scatterwalk_clamp(&gw->walk, gw->walk_bytes_remain);
+       while (!gw->walk_bytes) {
+               nextsg = sg_next(gw->walk.sg);
+               if (!nextsg)
+                       return 0;
+               scatterwalk_start(&gw->walk, nextsg);
+               gw->walk_bytes = scatterwalk_clamp(&gw->walk,
+                                                  gw->walk_bytes_remain);
+       }
+       gw->walk_ptr = scatterwalk_map(&gw->walk);
+       return gw->walk_bytes;
+}
+
+static inline void _gcm_sg_unmap_and_advance(struct gcm_sg_walk *gw,
+                                            unsigned int nbytes)
+{
+       gw->walk_bytes_remain -= nbytes;
+       scatterwalk_unmap(&gw->walk);
+       scatterwalk_advance(&gw->walk, nbytes);
+       scatterwalk_done(&gw->walk, 0, gw->walk_bytes_remain);
+       gw->walk_ptr = NULL;
+}
+
+static int gcm_in_walk_go(struct gcm_sg_walk *gw, unsigned int minbytesneeded)
 {
        int n;
 
-       /* minbytesneeded <= AES_BLOCK_SIZE */
        if (gw->buf_bytes && gw->buf_bytes >= minbytesneeded) {
                gw->ptr = gw->buf;
                gw->nbytes = gw->buf_bytes;
@@ -851,13 +877,11 @@ static int gcm_sg_walk_go(struct gcm_sg_walk *gw, unsigned int minbytesneeded)
                goto out;
        }
 
-       gw->walk_bytes = scatterwalk_clamp(&gw->walk, gw->walk_bytes_remain);
-       if (!gw->walk_bytes) {
-               scatterwalk_start(&gw->walk, sg_next(gw->walk.sg));
-               gw->walk_bytes = scatterwalk_clamp(&gw->walk,
-                                                  gw->walk_bytes_remain);
+       if (!_gcm_sg_clamp_and_map(gw)) {
+               gw->ptr = NULL;
+               gw->nbytes = 0;
+               goto out;
        }
-       gw->walk_ptr = scatterwalk_map(&gw->walk);
 
        if (!gw->buf_bytes && gw->walk_bytes >= minbytesneeded) {
                gw->ptr = gw->walk_ptr;
@@ -869,51 +893,90 @@ static int gcm_sg_walk_go(struct gcm_sg_walk *gw, unsigned int minbytesneeded)
                n = min(gw->walk_bytes, AES_BLOCK_SIZE - gw->buf_bytes);
                memcpy(gw->buf + gw->buf_bytes, gw->walk_ptr, n);
                gw->buf_bytes += n;
-               gw->walk_bytes_remain -= n;
-               scatterwalk_unmap(&gw->walk);
-               scatterwalk_advance(&gw->walk, n);
-               scatterwalk_done(&gw->walk, 0, gw->walk_bytes_remain);
-
+               _gcm_sg_unmap_and_advance(gw, n);
                if (gw->buf_bytes >= minbytesneeded) {
                        gw->ptr = gw->buf;
                        gw->nbytes = gw->buf_bytes;
                        goto out;
                }
-
-               gw->walk_bytes = scatterwalk_clamp(&gw->walk,
-                                                  gw->walk_bytes_remain);
-               if (!gw->walk_bytes) {
-                       scatterwalk_start(&gw->walk, sg_next(gw->walk.sg));
-                       gw->walk_bytes = scatterwalk_clamp(&gw->walk,
-                                                       gw->walk_bytes_remain);
+               if (!_gcm_sg_clamp_and_map(gw)) {
+                       gw->ptr = NULL;
+                       gw->nbytes = 0;
+                       goto out;
                }
-               gw->walk_ptr = scatterwalk_map(&gw->walk);
        }
 
 out:
        return gw->nbytes;
 }
 
-static void gcm_sg_walk_done(struct gcm_sg_walk *gw, unsigned int bytesdone)
+static int gcm_out_walk_go(struct gcm_sg_walk *gw, unsigned int minbytesneeded)
 {
-       int n;
+       if (gw->walk_bytes_remain == 0) {
+               gw->ptr = NULL;
+               gw->nbytes = 0;
+               goto out;
+       }
 
+       if (!_gcm_sg_clamp_and_map(gw)) {
+               gw->ptr = NULL;
+               gw->nbytes = 0;
+               goto out;
+       }
+
+       if (gw->walk_bytes >= minbytesneeded) {
+               gw->ptr = gw->walk_ptr;
+               gw->nbytes = gw->walk_bytes;
+               goto out;
+       }
+
+       scatterwalk_unmap(&gw->walk);
+       gw->walk_ptr = NULL;
+
+       gw->ptr = gw->buf;
+       gw->nbytes = sizeof(gw->buf);
+
+out:
+       return gw->nbytes;
+}
+
+static int gcm_in_walk_done(struct gcm_sg_walk *gw, unsigned int bytesdone)
+{
        if (gw->ptr == NULL)
-               return;
+               return 0;
 
        if (gw->ptr == gw->buf) {
-               n = gw->buf_bytes - bytesdone;
+               int n = gw->buf_bytes - bytesdone;
                if (n > 0) {
                        memmove(gw->buf, gw->buf + bytesdone, n);
-                       gw->buf_bytes -= n;
+                       gw->buf_bytes = n;
                } else
                        gw->buf_bytes = 0;
-       } else {
-               gw->walk_bytes_remain -= bytesdone;
-               scatterwalk_unmap(&gw->walk);
-               scatterwalk_advance(&gw->walk, bytesdone);
-               scatterwalk_done(&gw->walk, 0, gw->walk_bytes_remain);
-       }
+       } else
+               _gcm_sg_unmap_and_advance(gw, bytesdone);
+
+       return bytesdone;
+}
+
+static int gcm_out_walk_done(struct gcm_sg_walk *gw, unsigned int bytesdone)
+{
+       int i, n;
+
+       if (gw->ptr == NULL)
+               return 0;
+
+       if (gw->ptr == gw->buf) {
+               for (i = 0; i < bytesdone; i += n) {
+                       if (!_gcm_sg_clamp_and_map(gw))
+                               return i;
+                       n = min(gw->walk_bytes, bytesdone - i);
+                       memcpy(gw->walk_ptr, gw->buf + i, n);
+                       _gcm_sg_unmap_and_advance(gw, n);
+               }
+       } else
+               _gcm_sg_unmap_and_advance(gw, bytesdone);
+
+       return bytesdone;
 }
 
 static int gcm_aes_crypt(struct aead_request *req, unsigned int flags)
@@ -926,7 +989,7 @@ static int gcm_aes_crypt(struct aead_request *req, unsigned int flags)
        unsigned int pclen = req->cryptlen;
        int ret = 0;
 
-       unsigned int len, in_bytes, out_bytes,
+       unsigned int n, len, in_bytes, out_bytes,
                     min_bytes, bytes, aad_bytes, pc_bytes;
        struct gcm_sg_walk gw_in, gw_out;
        u8 tag[GHASH_DIGEST_SIZE];
@@ -963,14 +1026,14 @@ static int gcm_aes_crypt(struct aead_request *req, unsigned int flags)
        *(u32 *)(param.j0 + ivsize) = 1;
        memcpy(param.k, ctx->key, ctx->key_len);
 
-       gcm_sg_walk_start(&gw_in, req->src, len);
-       gcm_sg_walk_start(&gw_out, req->dst, len);
+       gcm_walk_start(&gw_in, req->src, len);
+       gcm_walk_start(&gw_out, req->dst, len);
 
        do {
                min_bytes = min_t(unsigned int,
                                  aadlen > 0 ? aadlen : pclen, AES_BLOCK_SIZE);
-               in_bytes = gcm_sg_walk_go(&gw_in, min_bytes);
-               out_bytes = gcm_sg_walk_go(&gw_out, min_bytes);
+               in_bytes = gcm_in_walk_go(&gw_in, min_bytes);
+               out_bytes = gcm_out_walk_go(&gw_out, min_bytes);
                bytes = min(in_bytes, out_bytes);
 
                if (aadlen + pclen <= bytes) {
@@ -997,8 +1060,11 @@ static int gcm_aes_crypt(struct aead_request *req, unsigned int flags)
                          gw_in.ptr + aad_bytes, pc_bytes,
                          gw_in.ptr, aad_bytes);
 
-               gcm_sg_walk_done(&gw_in, aad_bytes + pc_bytes);
-               gcm_sg_walk_done(&gw_out, aad_bytes + pc_bytes);
+               n = aad_bytes + pc_bytes;
+               if (gcm_in_walk_done(&gw_in, n) != n)
+                       return -ENOMEM;
+               if (gcm_out_walk_done(&gw_out, n) != n)
+                       return -ENOMEM;
                aadlen -= aad_bytes;
                pclen -= pc_bytes;
        } while (aadlen + pclen > 0);