]> git.proxmox.com Git - mirror_ubuntu-hirsute-kernel.git/blobdiff - arch/x86/net/bpf_jit_comp.c
UBUNTU: SAUCE: bpf, x86: Validate computation of branch displacements for x86-64
[mirror_ubuntu-hirsute-kernel.git] / arch / x86 / net / bpf_jit_comp.c
index 796506dcfc42e86484a8eb27b69ce8a14c2776c0..280357cc9d03a8607e5d69177bbd283966cd8c80 100644 (file)
@@ -1476,7 +1476,16 @@ emit_jmp:
                }
 
                if (image) {
-                       if (unlikely(proglen + ilen > oldproglen)) {
+                       /*
+                        * When populating the image, assert that:
+                        *
+                        *  i) We do not write beyond the allocated space, and
+                        * ii) addrs[i] did not change from the prior run, in order
+                        *     to validate assumptions made for computing branch
+                        *     displacements.
+                        */
+                       if (unlikely(proglen + ilen > oldproglen ||
+                                    proglen + ilen != addrs[i])) {
                                pr_err("bpf_jit: fatal error\n");
                                return -EFAULT;
                        }
@@ -1735,7 +1744,7 @@ static int invoke_bpf_mod_ret(const struct btf_func_model *m, u8 **pprog,
  * add rsp, 8                      // skip eth_type_trans's frame
  * ret                             // return to its caller
  */
-int arch_prepare_bpf_trampoline(void *image, void *image_end,
+int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *image_end,
                                const struct btf_func_model *m, u32 flags,
                                struct bpf_tramp_progs *tprogs,
                                void *orig_call)
@@ -1774,6 +1783,15 @@ int arch_prepare_bpf_trampoline(void *image, void *image_end,
 
        save_regs(m, &prog, nr_args, stack_size);
 
+       if (flags & BPF_TRAMP_F_CALL_ORIG) {
+               /* arg1: mov rdi, im */
+               emit_mov_imm64(&prog, BPF_REG_1, (long) im >> 32, (u32) (long) im);
+               if (emit_call(&prog, __bpf_tramp_enter, prog)) {
+                       ret = -EINVAL;
+                       goto cleanup;
+               }
+       }
+
        if (fentry->nr_progs)
                if (invoke_bpf(m, &prog, fentry, stack_size))
                        return -EINVAL;
@@ -1792,8 +1810,7 @@ int arch_prepare_bpf_trampoline(void *image, void *image_end,
        }
 
        if (flags & BPF_TRAMP_F_CALL_ORIG) {
-               if (fentry->nr_progs || fmod_ret->nr_progs)
-                       restore_regs(m, &prog, nr_args, stack_size);
+               restore_regs(m, &prog, nr_args, stack_size);
 
                /* call original function */
                if (emit_call(&prog, orig_call, prog)) {
@@ -1802,6 +1819,9 @@ int arch_prepare_bpf_trampoline(void *image, void *image_end,
                }
                /* remember return value in a stack for bpf prog to access */
                emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8);
+               im->ip_after_call = prog;
+               memcpy(prog, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE);
+               prog += X86_PATCH_SIZE;
        }
 
        if (fmod_ret->nr_progs) {
@@ -1832,9 +1852,17 @@ int arch_prepare_bpf_trampoline(void *image, void *image_end,
         * the return value is only updated on the stack and still needs to be
         * restored to R0.
         */
-       if (flags & BPF_TRAMP_F_CALL_ORIG)
+       if (flags & BPF_TRAMP_F_CALL_ORIG) {
+               im->ip_epilogue = prog;
+               /* arg1: mov rdi, im */
+               emit_mov_imm64(&prog, BPF_REG_1, (long) im >> 32, (u32) (long) im);
+               if (emit_call(&prog, __bpf_tramp_exit, prog)) {
+                       ret = -EINVAL;
+                       goto cleanup;
+               }
                /* restore original return value back into RAX */
                emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, -8);
+       }
 
        EMIT1(0x5B); /* pop rbx */
        EMIT1(0xC9); /* leave */