#include <net/scm.h>
#define Nprintk(a...)
+#define NLGRPSZ(x) (ALIGN(x, sizeof(unsigned long) * 8) / 8)
struct netlink_sock {
/* struct sock has to be the first member of netlink_sock */
struct sock sk;
u32 pid;
- unsigned int groups;
u32 dst_pid;
u32 dst_group;
+ u32 flags;
+ u32 subscriptions;
+ u32 ngroups;
+ unsigned long *groups;
unsigned long state;
wait_queue_head_t wait;
struct netlink_callback *cb;
spinlock_t cb_lock;
void (*data_ready)(struct sock *sk, int bytes);
struct module *module;
- u32 flags;
};
#define NETLINK_KERNEL_SOCKET 0x1
struct nl_pid_hash hash;
struct hlist_head mc_list;
unsigned int nl_nonroot;
+ unsigned int groups;
struct module *module;
int registered;
};
BUG_TRAP(!atomic_read(&sk->sk_rmem_alloc));
BUG_TRAP(!atomic_read(&sk->sk_wmem_alloc));
BUG_TRAP(!nlk_sk(sk)->cb);
+ BUG_TRAP(!nlk_sk(sk)->groups);
}
/* This lock without WQ_FLAG_EXCLUSIVE is good on UP and it is _very_ bad on SMP.
netlink_table_grab();
if (sk_del_node_init(sk))
nl_table[sk->sk_protocol].hash.entries--;
- if (nlk_sk(sk)->groups)
+ if (nlk_sk(sk)->subscriptions)
__sk_del_bind_node(sk);
netlink_table_ungrab();
}
static int netlink_create(struct socket *sock, int protocol)
{
struct module *module = NULL;
+ struct netlink_sock *nlk;
+ unsigned int groups;
int err = 0;
sock->state = SS_UNCONNECTED;
module = nl_table[protocol].module;
else
err = -EPROTONOSUPPORT;
+ groups = nl_table[protocol].groups;
netlink_unlock_table();
- if (err)
- goto out;
+ if (err || (err = __netlink_create(sock, protocol) < 0))
+ goto out_module;
+
+ nlk = nlk_sk(sock->sk);
- if ((err = __netlink_create(sock, protocol) < 0))
+ nlk->groups = kmalloc(NLGRPSZ(groups), GFP_KERNEL);
+ if (nlk->groups == NULL) {
+ err = -ENOMEM;
goto out_module;
+ }
+ memset(nlk->groups, 0, NLGRPSZ(groups));
+ nlk->ngroups = groups;
- nlk_sk(sock->sk)->module = module;
+ nlk->module = module;
out:
return err;
skb_queue_purge(&sk->sk_write_queue);
- if (nlk->pid && !nlk->groups) {
+ if (nlk->pid && !nlk->subscriptions) {
struct netlink_notify n = {
.protocol = sk->sk_protocol,
.pid = nlk->pid,
netlink_table_ungrab();
}
+ kfree(nlk->groups);
+ nlk->groups = NULL;
+
sock_put(sk);
return 0;
}
capable(CAP_NET_ADMIN);
}
+static void
+netlink_update_subscriptions(struct sock *sk, unsigned int subscriptions)
+{
+ struct netlink_sock *nlk = nlk_sk(sk);
+
+ if (nlk->subscriptions && !subscriptions)
+ __sk_del_bind_node(sk);
+ else if (!nlk->subscriptions && subscriptions)
+ sk_add_bind_node(sk, &nl_table[sk->sk_protocol].mc_list);
+ nlk->subscriptions = subscriptions;
+}
+
static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
{
struct sock *sk = sock->sk;
return err;
}
- if (!nladdr->nl_groups && !nlk->groups)
+ if (!nladdr->nl_groups && !(u32)nlk->groups[0])
return 0;
netlink_table_grab();
- if (nlk->groups && !nladdr->nl_groups)
- __sk_del_bind_node(sk);
- else if (!nlk->groups && nladdr->nl_groups)
- sk_add_bind_node(sk, &nl_table[sk->sk_protocol].mc_list);
- nlk->groups = nladdr->nl_groups;
+ netlink_update_subscriptions(sk, nlk->subscriptions +
+ hweight32(nladdr->nl_groups) -
+ hweight32(nlk->groups[0]));
+ nlk->groups[0] = (nlk->groups[0] & ~0xffffffffUL) | nladdr->nl_groups;
netlink_table_ungrab();
return 0;
nladdr->nl_groups = netlink_group_mask(nlk->dst_group);
} else {
nladdr->nl_pid = nlk->pid;
- nladdr->nl_groups = nlk->groups;
+ nladdr->nl_groups = nlk->groups[0];
}
return 0;
}
if (p->exclude_sk == sk)
goto out;
- if (nlk->pid == p->pid || !(nlk->groups & netlink_group_mask(p->group)))
+ if (nlk->pid == p->pid || p->group - 1 >= nlk->ngroups ||
+ !test_bit(p->group - 1, nlk->groups))
goto out;
if (p->failure) {
if (sk == p->exclude_sk)
goto out;
- if (nlk->pid == p->pid || !(nlk->groups & netlink_group_mask(p->group)))
+ if (nlk->pid == p->pid || p->group - 1 >= nlk->ngroups ||
+ !test_bit(p->group - 1, nlk->groups))
goto out;
sk->sk_err = p->code;
nlk->flags |= NETLINK_KERNEL_SOCKET;
netlink_table_grab();
+ nl_table[unit].groups = 32;
nl_table[unit].module = module;
nl_table[unit].registered = 1;
netlink_table_ungrab();
s,
s->sk_protocol,
nlk->pid,
- nlk->groups,
+ nlk->flags & NETLINK_KERNEL_SOCKET ?
+ 0 : (unsigned int)nlk->groups[0],
atomic_read(&s->sk_rmem_alloc),
atomic_read(&s->sk_wmem_alloc),
nlk->cb,