Use Memory_safeCalloc() to allocate zeroed memory.
[umurmur.git] / src / messages.c
index 109f5eb22af6c8844262678e5519e676ef78f73f..d99548563c042cb04e417930ece49d2a6cc9db59 100644 (file)
@@ -1,5 +1,5 @@
-/* Copyright (C) 2009-2010, Martin Johansson <martin@fatbob.nu>
-   Copyright (C) 2005-2010, Thorvald Natvig <thorvald@natvig.com>
+/* Copyright (C) 2009-2014, Martin Johansson <martin@fatbob.nu>
+   Copyright (C) 2005-2014, Thorvald Natvig <thorvald@natvig.com>
 
    All rights reserved.
 
 #include "client.h"
 #include "pds.h"
 #include "log.h"
+#include "memory.h"
 
 #define PREAMBLE_SIZE 6
 
-static void dumpmsg(uint8_t *data, int size);
 static message_t *Msg_create_nopayload(messageType_t messageType);
 
 static void Msg_addPreamble(uint8_t *buffer, uint16_t type, uint32_t len)
 {
-       type = htons(type);
-       len = htonl(len);
-       
-       buffer[0] = (type) & 0xff;
-       buffer[1] = (type >> 8) & 0xff;
-       
-       buffer[2] = (len) & 0xff;
-       buffer[3] = (len >> 8) & 0xff;
-       buffer[4] = (len >> 16) & 0xff;
-       buffer[5] = (len >> 24) & 0xff; 
+       buffer[1] = (type) & 0xff;
+       buffer[0] = (type >> 8) & 0xff;
+
+       buffer[5] = (len) & 0xff;
+       buffer[4] = (len >> 8) & 0xff;
+       buffer[3] = (len >> 16) & 0xff;
+       buffer[2] = (len >> 24) & 0xff;
 }
 
 static void Msg_getPreamble(uint8_t *buffer, int *type, int *len)
 {
        uint16_t msgType;
        uint32_t msgLen;
-       
-       msgType = buffer[0] | (buffer[1] << 8);
-       msgLen = buffer[2] | (buffer[3] << 8) | (buffer[4] << 16) | (buffer[5] << 24);
-       *type = (int)ntohs(msgType);
-       *len = (int)ntohl(msgLen);
+
+       msgType = buffer[1] | (buffer[0] << 8);
+       msgLen = buffer[5] | (buffer[4] << 8) | (buffer[3] << 16) | (buffer[2] << 24);
+       *type = (int)msgType;
+       *len = (int)msgLen;
 }
 
 #define MAX_MSGSIZE (BUFSIZE - PREAMBLE_SIZE)
@@ -74,7 +71,7 @@ int Msg_messageToNetwork(message_t *msg, uint8_t *buffer)
 {
        int len;
        uint8_t *bufptr = buffer + PREAMBLE_SIZE;
-               
+
        Log_debug("To net: msg type %d", msg->messageType);
        switch (msg->messageType) {
        case Version:
@@ -230,7 +227,37 @@ int Msg_messageToNetwork(message_t *msg, uint8_t *buffer)
                Msg_addPreamble(buffer, msg->messageType, len);
                mumble_proto__channel_remove__pack(msg->payload.channelRemove, bufptr);
                break;
+       case UserStats:
+       {
+               len = mumble_proto__user_stats__get_packed_size(msg->payload.userStats);
+               if (len > MAX_MSGSIZE) {
+                       Log_warn("Too big tx message. Discarding");
+                       break;
+                       }
+               Msg_addPreamble(buffer, msg->messageType, len);
+               mumble_proto__user_stats__pack(msg->payload.userStats, bufptr);
+               break;
+       }
+       case ServerConfig:
+               len = mumble_proto__server_config__get_packed_size(msg->payload.serverConfig);
+               if (len > MAX_MSGSIZE) {
+                       Log_warn("Too big tx message. Discarding");
+                       break;
+                       }
+               Msg_addPreamble(buffer, msg->messageType, len);
+               mumble_proto__server_config__pack(msg->payload.serverConfig, bufptr);
+               break;
 
+       case BanList:
+               len = mumble_proto__ban_list__get_packed_size(msg->payload.banList);
+               if (len > MAX_MSGSIZE) {
+                       Log_warn("Too big tx message. Discarding");
+                       break;
+                       }
+               Msg_addPreamble(buffer, msg->messageType, len);
+               Log_debug("Msg_MessageToNetwork: BanList size %d", len);
+               mumble_proto__ban_list__pack(msg->payload.banList, bufptr);
+               break;
        default:
                Log_warn("Msg_MessageToNetwork: Unsupported message %d", msg->messageType);
                return 0;
@@ -240,11 +267,8 @@ int Msg_messageToNetwork(message_t *msg, uint8_t *buffer)
 
 static message_t *Msg_create_nopayload(messageType_t messageType)
 {
-       message_t *msg = malloc(sizeof(message_t));
+       message_t *msg = Memory_safeCalloc(1, sizeof(message_t));
 
-       if (msg == NULL)
-               Log_fatal("Out of memory");
-       memset(msg, 0, sizeof(message_t));
        msg->refcount = 1;
        msg->messageType = messageType;
        init_list_entry(&msg->node);
@@ -254,76 +278,95 @@ static message_t *Msg_create_nopayload(messageType_t messageType)
 message_t *Msg_create(messageType_t messageType)
 {
        message_t *msg = Msg_create_nopayload(messageType);
-       
+       int i;
+
        switch (messageType) {
        case Version:
-               msg->payload.version = malloc(sizeof(MumbleProto__Version));
+               msg->payload.version = Memory_safeMalloc(1, sizeof(MumbleProto__Version));
                mumble_proto__version__init(msg->payload.version);
                break;
        case UDPTunnel:
-               msg->payload.UDPTunnel = malloc(sizeof(MumbleProto__UDPTunnel));
+               msg->payload.UDPTunnel = Memory_safeMalloc(1, sizeof(MumbleProto__UDPTunnel));
                mumble_proto__udptunnel__init(msg->payload.UDPTunnel);
                break;
        case Authenticate:
-               msg->payload.authenticate = malloc(sizeof(MumbleProto__Authenticate));
+               msg->payload.authenticate = Memory_safeMalloc(1, sizeof(MumbleProto__Authenticate));
                mumble_proto__authenticate__init(msg->payload.authenticate);
                break;
        case Ping:
-               msg->payload.ping = malloc(sizeof(MumbleProto__Ping));
+               msg->payload.ping = Memory_safeMalloc(1, sizeof(MumbleProto__Ping));
                mumble_proto__ping__init(msg->payload.ping);
                break;
        case Reject:
-               msg->payload.reject = malloc(sizeof(MumbleProto__Reject));
+               msg->payload.reject = Memory_safeMalloc(1, sizeof(MumbleProto__Reject));
                mumble_proto__reject__init(msg->payload.reject);
                break;
        case ServerSync:
-               msg->payload.serverSync = malloc(sizeof(MumbleProto__ServerSync));
+               msg->payload.serverSync = Memory_safeMalloc(1, sizeof(MumbleProto__ServerSync));
                mumble_proto__server_sync__init(msg->payload.serverSync);
                break;
        case TextMessage:
-               msg->payload.textMessage = malloc(sizeof(MumbleProto__TextMessage));
+               msg->payload.textMessage = Memory_safeMalloc(1, sizeof(MumbleProto__TextMessage));
                mumble_proto__text_message__init(msg->payload.textMessage);
                break;
        case PermissionDenied:
-               msg->payload.permissionDenied = malloc(sizeof(MumbleProto__PermissionDenied));
+               msg->payload.permissionDenied = Memory_safeMalloc(1, sizeof(MumbleProto__PermissionDenied));
                mumble_proto__permission_denied__init(msg->payload.permissionDenied);
                break;
        case CryptSetup:
-               msg->payload.cryptSetup = malloc(sizeof(MumbleProto__CryptSetup));
+               msg->payload.cryptSetup = Memory_safeMalloc(1, sizeof(MumbleProto__CryptSetup));
                mumble_proto__crypt_setup__init(msg->payload.cryptSetup);
                break;
        case UserList:
-               msg->payload.userList = malloc(sizeof(MumbleProto__UserList));
+               msg->payload.userList = Memory_safeMalloc(1, sizeof(MumbleProto__UserList));
                mumble_proto__user_list__init(msg->payload.userList);
                break;
        case UserState:
-               msg->payload.userState = malloc(sizeof(MumbleProto__UserState));
+               msg->payload.userState = Memory_safeMalloc(1, sizeof(MumbleProto__UserState));
                mumble_proto__user_state__init(msg->payload.userState);
                break;
        case ChannelState:
-               msg->payload.channelState = malloc(sizeof(MumbleProto__ChannelState));
+               msg->payload.channelState = Memory_safeMalloc(1, sizeof(MumbleProto__ChannelState));
                mumble_proto__channel_state__init(msg->payload.channelState);
                break;
        case UserRemove:
-               msg->payload.userRemove = malloc(sizeof(MumbleProto__UserRemove));
+               msg->payload.userRemove = Memory_safeMalloc(1, sizeof(MumbleProto__UserRemove));
                mumble_proto__user_remove__init(msg->payload.userRemove);
                break;
        case VoiceTarget:
-               msg->payload.voiceTarget = malloc(sizeof(MumbleProto__VoiceTarget));
+               msg->payload.voiceTarget = Memory_safeMalloc(1, sizeof(MumbleProto__VoiceTarget));
                mumble_proto__voice_target__init(msg->payload.voiceTarget);
                break;
        case CodecVersion:
-               msg->payload.codecVersion = malloc(sizeof(MumbleProto__CodecVersion));
+               msg->payload.codecVersion = Memory_safeMalloc(1, sizeof(MumbleProto__CodecVersion));
                mumble_proto__codec_version__init(msg->payload.codecVersion);
                break;
        case PermissionQuery:
-               msg->payload.permissionQuery = malloc(sizeof(MumbleProto__PermissionQuery));
+               msg->payload.permissionQuery = Memory_safeMalloc(1, sizeof(MumbleProto__PermissionQuery));
                mumble_proto__permission_query__init(msg->payload.permissionQuery);
                break;
        case ChannelRemove:
-               msg->payload.channelRemove = malloc(sizeof(MumbleProto__ChannelRemove));
+               msg->payload.channelRemove = Memory_safeMalloc(1, sizeof(MumbleProto__ChannelRemove));
                mumble_proto__channel_remove__init(msg->payload.channelRemove);
                break;
+       case UserStats:
+               msg->payload.userStats = Memory_safeMalloc(1, sizeof(MumbleProto__UserStats));
+               mumble_proto__user_stats__init(msg->payload.userStats);
+
+               msg->payload.userStats->from_client = Memory_safeMalloc(1, sizeof(MumbleProto__UserStats__Stats));
+               mumble_proto__user_stats__stats__init(msg->payload.userStats->from_client);
+
+               msg->payload.userStats->from_server = Memory_safeMalloc(1, sizeof(MumbleProto__UserStats__Stats));
+               mumble_proto__user_stats__stats__init(msg->payload.userStats->from_server);
+
+               msg->payload.userStats->version = Memory_safeMalloc(1, sizeof(MumbleProto__Version));
+               mumble_proto__version__init(msg->payload.userStats->version);
+
+               break;
+       case ServerConfig:
+               msg->payload.serverConfig = Memory_safeMalloc(1, sizeof(MumbleProto__ServerConfig));
+               mumble_proto__server_config__init(msg->payload.serverConfig);
+               break;
 
        default:
                Log_warn("Msg_create: Unsupported message %d", msg->messageType);
@@ -333,6 +376,64 @@ message_t *Msg_create(messageType_t messageType)
        return msg;
 }
 
+message_t *Msg_banList_create(int n_bans)
+{
+       message_t *msg = Msg_create_nopayload(BanList);
+       int i;
+
+       msg->payload.banList = Memory_safeCalloc(1, sizeof(MumbleProto__BanList));
+       mumble_proto__ban_list__init(msg->payload.banList);
+       msg->payload.banList->n_bans = n_bans;
+       msg->payload.banList->bans = Memory_safeMalloc(1, sizeof(MumbleProto__BanList__BanEntry *) * n_bans);
+       for (i = 0; i < n_bans; i++) {
+               msg->payload.banList->bans[i] = Memory_safeCalloc(1, sizeof(MumbleProto__BanList__BanEntry));
+               mumble_proto__ban_list__ban_entry__init(msg->payload.banList->bans[i]);
+       }
+       return msg;
+}
+
+void Msg_banList_addEntry(message_t *msg, int index, uint8_t *address, uint32_t mask,
+                          char *name, char *hash, char *reason, char *start, uint32_t duration)
+{
+       MumbleProto__BanList__BanEntry *entry = msg->payload.banList->bans[index];
+
+       entry->address.data = Memory_safeMalloc(1, 16);
+       memcpy(entry->address.data, address, 16);
+       entry->address.len = 16;
+       entry->mask = mask;
+       entry->name = strdup(name);
+       entry->hash = strdup(hash);
+       entry->reason = strdup(reason);
+       entry->start = strdup(start);
+       if (!entry->name || !entry->hash || !entry->reason || !entry->start)
+               Log_fatal("Out of memory");
+
+       if (duration > 0) {
+               entry->duration = duration;
+               entry->has_duration = true;
+       }
+       Log_debug("Msg_banList_addEntry: %s %s %s %s %s",
+               entry->name, entry->hash, entry->address.data, entry->reason, entry->start);
+}
+
+void Msg_banList_getEntry(message_t *msg, int index, uint8_t **address, uint32_t *mask,
+                          char **name, char **hash, char **reason, char **start, uint32_t *duration)
+{
+       MumbleProto__BanList__BanEntry *entry = msg->payload.banList->bans[index];
+
+       *address =  entry->address.data;
+       *mask = entry->mask;
+       *name = entry->name;
+       *hash = entry->hash;
+       *reason = entry->reason;
+       *start = entry->start;
+       if (entry->has_duration)
+               *duration = entry->duration;
+       else
+               *duration = 0;
+}
+
+
 void Msg_inc_ref(message_t *msg)
 {
        msg->refcount++;
@@ -340,6 +441,8 @@ void Msg_inc_ref(message_t *msg)
 
 void Msg_free(message_t *msg)
 {
+       int i;
+
        if (msg->refcount) msg->refcount--;
        if (msg->refcount > 0)
                return;
@@ -399,6 +502,14 @@ void Msg_free(message_t *msg)
                if (msg->unpacked)
                        mumble_proto__text_message__free_unpacked(msg->payload.textMessage, NULL);
                else {
+                       if (msg->payload.textMessage->message)
+                               free(msg->payload.textMessage->message);
+                       if (msg->payload.textMessage->session)
+                               free(msg->payload.textMessage->session);
+                       if (msg->payload.textMessage->channel_id)
+                               free(msg->payload.textMessage->channel_id);
+                       if (msg->payload.textMessage->tree_id)
+                               free(msg->payload.textMessage->tree_id);
                        free(msg->payload.textMessage);
                }
                break;
@@ -436,8 +547,12 @@ void Msg_free(message_t *msg)
                if (msg->unpacked)
                        mumble_proto__channel_state__free_unpacked(msg->payload.channelState, NULL);
                else {
-                       free(msg->payload.channelState->name);
-                       free(msg->payload.channelState->description);
+                       if (msg->payload.channelState->name)
+                               free(msg->payload.channelState->name);
+                       if (msg->payload.channelState->description)
+                               free(msg->payload.channelState->description);
+                       if (msg->payload.channelState->links)
+                               free(msg->payload.channelState->links);
                        free(msg->payload.channelState);
                }
                break;
@@ -476,6 +591,60 @@ void Msg_free(message_t *msg)
                        free(msg->payload.channelRemove);
                }
                break;
+       case UserStats:
+               if (msg->unpacked)
+                       mumble_proto__user_stats__free_unpacked(msg->payload.userStats, NULL);
+               else {
+                       if (msg->payload.userStats->from_client)
+                               free(msg->payload.userStats->from_client);
+                       if (msg->payload.userStats->from_server)
+                               free(msg->payload.userStats->from_server);
+                       if (msg->payload.userStats->version) {
+                               if (msg->payload.userStats->version->release)
+                                       free(msg->payload.userStats->version->release);
+                               if (msg->payload.userStats->version->os)
+                                       free(msg->payload.userStats->version->os);
+                               if (msg->payload.userStats->version->os_version)
+                                       free(msg->payload.userStats->version->os_version);
+
+                               free(msg->payload.userStats->version);
+                       }
+                       if (msg->payload.userStats->celt_versions)
+                               free(msg->payload.userStats->celt_versions);
+                       if (msg->payload.userStats->certificates) {
+                               if (msg->payload.userStats->certificates->data)
+                                       free(msg->payload.userStats->certificates->data);
+                               free(msg->payload.userStats->certificates);
+                       }
+                       if (msg->payload.userStats->address.data)
+                               free(msg->payload.userStats->address.data);
+
+                       free(msg->payload.userStats);
+               }
+               break;
+       case ServerConfig:
+               if (msg->unpacked)
+                       mumble_proto__server_config__free_unpacked(msg->payload.serverConfig, NULL);
+               else {
+                       free(msg->payload.serverConfig);
+               }
+               break;
+       case BanList:
+               if (msg->unpacked)
+                       mumble_proto__ban_list__free_unpacked(msg->payload.banList, NULL);
+               else {
+                       for (i = 0; i < msg->payload.banList->n_bans; i++) {
+                               free(msg->payload.banList->bans[i]->address.data);
+                               free(msg->payload.banList->bans[i]->name);
+                               free(msg->payload.banList->bans[i]->hash);
+                               free(msg->payload.banList->bans[i]->reason);
+                               free(msg->payload.banList->bans[i]->start);
+                               free(msg->payload.banList->bans[i]);
+                       }
+                       free(msg->payload.banList->bans);
+                       free(msg->payload.banList);
+               }
+               break;
 
        default:
                Log_warn("Msg_free: Unsupported message %d", msg->messageType);
@@ -484,35 +653,14 @@ void Msg_free(message_t *msg)
        free(msg);
 }
 
-void dumpmsg(uint8_t *data, int size)
-{
-       int i, r = 0, offset = 0;
-       char buf[512];
-       
-       while (r * 8 + i < size) {
-               for (i = 0; i < 8 && r * 8 + i < size; i++) {
-                       offset += sprintf(buf + offset, "%x ", data[r * 8 + i]);
-               }
-               sprintf(buf + offset, "\n");
-               printf(buf);
-               offset = 0;
-               r++;
-               i = 0;
-       } 
-}
-
 message_t *Msg_CreateVoiceMsg(uint8_t *data, int size)
 {
        message_t *msg = NULL;
-       
+
        msg = Msg_create_nopayload(UDPTunnel);
        msg->unpacked = false;
-       msg->payload.UDPTunnel = malloc(sizeof(struct _MumbleProto__UDPTunnel));
-       if (msg->payload.UDPTunnel == NULL)
-               Log_fatal("Out of memory");
-       msg->payload.UDPTunnel->packet.data = malloc(size);
-       if (msg->payload.UDPTunnel->packet.data == NULL)
-               Log_fatal("Out of memory");
+       msg->payload.UDPTunnel = Memory_safeMalloc(1, sizeof(struct _MumbleProto__UDPTunnel));
+       msg->payload.UDPTunnel->packet.data = Memory_safeMalloc(1, size);
        memcpy(msg->payload.UDPTunnel->packet.data, data, size);
        msg->payload.UDPTunnel->packet.len = size;
        return msg;
@@ -528,7 +676,7 @@ message_t *Msg_networkToMessage(uint8_t *data, int size)
 
        Log_debug("Message type %d size %d", messageType, msgLen);
        //dumpmsg(data, size);
-       
+
        switch (messageType) {
        case Version:
        {
@@ -670,13 +818,40 @@ message_t *Msg_networkToMessage(uint8_t *data, int size)
                        goto err_out;
                break;
        }
+       case UserStats:
+       {
+               msg = Msg_create_nopayload(UserStats);
+               msg->unpacked = true;
+               msg->payload.userStats = mumble_proto__user_stats__unpack(NULL, msgLen, msgData);
+               if (msg->payload.userStats == NULL)
+                       goto err_out;
+               break;
+       }
+       case UserRemove:
+       {
+               msg = Msg_create_nopayload(UserRemove);
+               msg->unpacked = true;
+               msg->payload.userRemove = mumble_proto__user_remove__unpack(NULL, msgLen, msgData);
+               if (msg->payload.userRemove == NULL)
+                       goto err_out;
+               break;
+       }
+       case BanList:
+       {
+               msg = Msg_create_nopayload(BanList);
+               msg->unpacked = true;
+               msg->payload.banList = mumble_proto__ban_list__unpack(NULL, msgLen, msgData);
+               if (msg->payload.banList == NULL)
+                       goto err_out;
+               break;
+       }
 
        default:
-               Log_warn("Unsupported message %d", messageType);
+               Log_warn("Msg_networkToMessage: Unsupported message %d", messageType);
                break;
        }
        return msg;
-       
+
 err_out:
        free(msg);
        return NULL;