pthread_rwlock_unlock(&server.lock);
}
-static int get_fd(const char *mad, int *fd, __be64 *gid_ifid)
+static int get_fd(const char *mad, int umad_len, int *fd, __be64 *gid_ifid)
{
struct umad_hdr *hdr = (struct umad_hdr *)mad;
char *data = (char *)hdr + sizeof(*hdr);
uint16_t attr_id = be16toh(hdr->attr_id);
int rc = 0;
+ if (umad_len <= sizeof(*hdr)) {
+ rc = -EINVAL;
+ syslog(LOG_DEBUG, "Ignoring MAD packets with header only\n");
+ goto out;
+ }
+
switch (attr_id) {
case UMAD_CM_ATTR_REQ:
+ if (unlikely(umad_len < sizeof(*hdr) + CM_REQ_DGID_POS +
+ sizeof(*gid_ifid))) {
+ rc = -EINVAL;
+ syslog(LOG_WARNING,
+ "Invalid MAD packet size (%d) for attr_id 0x%x\n", umad_len,
+ attr_id);
+ goto out;
+ }
memcpy(gid_ifid, data + CM_REQ_DGID_POS, sizeof(*gid_ifid));
rc = hash_tbl_search_fd_by_ifid(fd, gid_ifid);
break;
case UMAD_CM_ATTR_SIDR_REQ:
+ if (unlikely(umad_len < sizeof(*hdr) + CM_SIDR_REQ_DGID_POS +
+ sizeof(*gid_ifid))) {
+ rc = -EINVAL;
+ syslog(LOG_WARNING,
+ "Invalid MAD packet size (%d) for attr_id 0x%x\n", umad_len,
+ attr_id);
+ goto out;
+ }
memcpy(gid_ifid, data + CM_SIDR_REQ_DGID_POS, sizeof(*gid_ifid));
rc = hash_tbl_search_fd_by_ifid(fd, gid_ifid);
break;
data += sizeof(comm_id);
/* Fall through */
case UMAD_CM_ATTR_SIDR_REP:
+ if (unlikely(umad_len < sizeof(*hdr) + sizeof(comm_id))) {
+ rc = -EINVAL;
+ syslog(LOG_WARNING,
+ "Invalid MAD packet size (%d) for attr_id 0x%x\n", umad_len,
+ attr_id);
+ goto out;
+ }
memcpy(&comm_id, data, sizeof(comm_id));
if (comm_id) {
rc = hash_tbl_search_fd_by_comm_id(comm_id, fd, gid_ifid);
syslog(LOG_DEBUG, "mad_to_vm: %d 0x%x 0x%x\n", *fd, attr_id, comm_id);
+out:
return rc;
}
} while (rc && server.run);
if (server.run) {
- rc = get_fd(msg.umad.mad, &fd, &msg.hdr.sgid.global.interface_id);
+ rc = get_fd(msg.umad.mad, msg.umad_len, &fd,
+ &msg.hdr.sgid.global.interface_id);
if (rc) {
continue;
}