Add the Client_find_by_session() function.
[umurmur.git] / src / client.c
index b3b12b75030d3f76821bcaaeb2ef1995156d41e8..b69da172cc51324bf87701bbe698d035fa321b55 100644 (file)
@@ -36,6 +36,7 @@
 #include <stdlib.h>
 #include <string.h>
 #include "log.h"
+#include "memory.h"
 #include "list.h"
 #include "client.h"
 #include "ssl.h"
@@ -53,6 +54,7 @@ extern char system_string[], version_string[];
 static int Client_read(client_t *client);
 static int Client_write(client_t *client);
 static int Client_send_udp(client_t *client, uint8_t *data, int len);
+static client_t *Client_find_by_fd(int fd);
 void Client_free(client_t *client);
 
 declare_list(clients);
@@ -115,9 +117,7 @@ void Client_janitor()
 
 void Client_codec_add(client_t *client, int codec)
 {
-       codec_t *cd = malloc(sizeof(codec_t));
-       if (cd == NULL)
-               Log_fatal("Out of memory");
+       codec_t *cd = Memory_safeMalloc(1, sizeof(codec_t));
        init_list_entry(&cd->node);
        cd->codec = codec;
        list_add_tail(&cd->node, &client->codecs);
@@ -157,9 +157,7 @@ void Client_token_add(client_t *client, char *token_string)
 
        if (client->tokencount >= MAX_TOKENS)
                return;
-       token = malloc(sizeof(token_t));
-       if (token == NULL)
-               Log_fatal("Out of memory");
+       token = Memory_safeMalloc(1, sizeof(token_t));
        init_list_entry(&token->node);
        token->token = strdup(token_string);
        if (token->token == NULL)
@@ -227,10 +225,7 @@ void recheckCodecVersions(client_t *connectingClient)
                                }
                        }
                        if (!found) {
-                               cd = malloc(sizeof(codec_t));
-                               if (!cd)
-                                       Log_fatal("Out of memory");
-                               memset(cd, 0, sizeof(codec_t));
+                               cd = Memory_safeCalloc(1, sizeof(codec_t));
                                init_list_entry(&cd->node);
                                cd->codec = codec_itr->codec;
                                cd->count = 1;
@@ -323,20 +318,24 @@ int Client_add(int fd, struct sockaddr_storage *remote)
 {
        client_t* newclient;
        message_t *sendmsg;
+       char* addressString = NULL;
 
        if (Ban_isBannedAddr(remote)) {
-               Log_info("Address %s banned. Disconnecting", Util_addressToString(remote));
+               addressString = Util_addressToString(remote);
+               Log_info("Address %s banned. Disconnecting", addressString);
+               free(addressString);
                return -1;
        }
 
-       if ((newclient = calloc(1, sizeof(client_t))) == NULL)
-               Log_fatal("(%s:%s): Out of memory while allocating %d bytes.", __FILE__, __LINE__, sizeof(client_t));
+       newclient = Memory_safeCalloc(1, sizeof(client_t));
 
        newclient->tcpfd = fd;
        memcpy(&newclient->remote_tcp, remote, sizeof(struct sockaddr_storage));
        newclient->ssl = SSLi_newconnection(&newclient->tcpfd, &newclient->SSLready);
        if (newclient->ssl == NULL) {
-               Log_warn("SSL negotiation failed with %s on port %d", Util_addressToString(remote), Util_addressToPort(remote));
+               addressString = Util_addressToString(remote);
+               Log_warn("SSL negotiation failed with %s on port %d", addressString, Util_addressToPort(remote));
+               free(addressString);
                free(newclient);
                return -1;
        }
@@ -401,16 +400,12 @@ void Client_free(client_t *client)
                SSLi_free(client->ssl);
        close(client->tcpfd);
        clientcount--;
-       if (client->release)
-               free(client->release);
-       if (client->os)
-               free(client->os);
-       if (client->os_version)
-               free(client->os_version);
-       if (client->username)
-               free(client->username);
-       if (client->context)
-               free(client->context);
+
+       free(client->release);
+       free(client->os);
+       free(client->os_version);
+       free(client->username);
+       free(client->context);
        free(client);
 
        if (authenticatedLeft)
@@ -432,17 +427,42 @@ void Client_disconnect_all()
        }
 }
 
-int Client_read_fd(int fd)
+client_t *Client_find_by_session(int session_id)
 {
        struct dlist *itr;
-       client_t *client = NULL;
 
        list_iterate(itr, &clients) {
-               if (fd == list_get_entry(itr, client_t, node)->tcpfd) {
-                       client = list_get_entry(itr, client_t, node);
-                       break;
+               client_t *client = list_get_entry(itr, client_t, node);
+
+               if (client->sessionId == session_id) {
+                       return client;
                }
        }
+
+       return NULL;
+}
+
+client_t *Client_find_by_fd(int fd)
+{
+       struct dlist *itr;
+
+       list_iterate(itr, &clients) {
+               client_t *client = list_get_entry(itr, client_t, node);
+
+               if (client->tcpfd == fd) {
+                       return client;
+               }
+       }
+
+       return NULL;
+}
+
+int Client_read_fd(int fd)
+{
+       client_t *client;
+
+       client = Client_find_by_fd(fd);
+
        if (client != NULL)
                return Client_read(client);
        else
@@ -544,15 +564,10 @@ int Client_read(client_t *client)
 
 int Client_write_fd(int fd)
 {
-       struct dlist *itr;
-       client_t *client = NULL;
+       client_t *client;
+
+       client = Client_find_by_fd(fd);
 
-       list_iterate(itr, &clients) {
-               if(fd == list_get_entry(itr, client_t, node)->tcpfd) {
-                       client = list_get_entry(itr, client_t, node);
-                       break;
-               }
-       }
        if (client != NULL)
                return Client_write(client);
        else
@@ -664,24 +679,24 @@ client_t *Client_iterate(client_t **client_itr)
        return c;
 }
 
-void Client_textmessage(client_t *client, char *text)
+void Client_textmessage(client_t *client, const char *text)
 {
        char *message;
        uint32_t *tree_id;
        message_t *sendmsg = NULL;
 
-       message = malloc(strlen(text) + 1);
-       if (!message)
-               Log_fatal("Out of memory");
-       tree_id = malloc(sizeof(uint32_t));
-       if (!tree_id)
+       message = strdup(text);
+
+       if (message == NULL)
                Log_fatal("Out of memory");
+
+       tree_id = Memory_safeMalloc(1, sizeof(uint32_t));
        *tree_id = 0;
        sendmsg = Msg_create(TextMessage);
        sendmsg->payload.textMessage->message = message;
        sendmsg->payload.textMessage->n_tree_id = 1;
        sendmsg->payload.textMessage->tree_id = tree_id;
-       strcpy(message, text);
+
        Client_send_message(client, sendmsg);
 }
 
@@ -689,22 +704,15 @@ void Client_textmessage(client_t *client, char *text)
 int Client_send_message_except(client_t *client, message_t *msg)
 {
        client_t *itr = NULL;
-       int count = 0;
 
-       Msg_inc_ref(msg); /* Make sure a reference is held during the whole iteration. */
        while (Client_iterate(&itr) != NULL) {
                if (itr != client) {
-                       if (count++ > 0)
-                               Msg_inc_ref(msg); /* One extra reference for each new copy */
+                       Msg_inc_ref(msg); /* One extra reference for each new copy */
                        Log_debug("Msg %d to %s refcount %d",  msg->messageType, itr->username, msg->refcount);
                        Client_send_message(itr, msg);
                }
        }
-       Msg_free(msg); /* Free our reference to the message */
-
-       if (count == 0)
-               Msg_free(msg); /* If only 1 client is connected then no message is passed
-                                               * to Client_send_message(). Free it here. */
+       Msg_free(msg); /* Consume caller's reference. */
 
        return 0;
 }
@@ -712,22 +720,15 @@ int Client_send_message_except(client_t *client, message_t *msg)
 int Client_send_message_except_ver(client_t *client, message_t *msg, uint32_t version)
 {
        client_t *itr = NULL;
-       int count = 0;
 
-       Msg_inc_ref(msg); /* Make sure a reference is held during the whole iteration. */
        while (Client_iterate(&itr) != NULL) {
                if (itr != client) {
-                       if (count++ > 0)
-                               Msg_inc_ref(msg); /* One extra reference for each new copy */
+                       Msg_inc_ref(msg); /* One extra reference for each new copy */
                        Log_debug("Msg %d to %s refcount %d",  msg->messageType, itr->username, msg->refcount);
                        Client_send_message_ver(itr, msg, version);
                }
        }
-       Msg_free(msg); /* Free our reference to the message */
-
-       if (count == 0)
-               Msg_free(msg); /* If only 1 client is connected then no message is passed
-                                               * to Client_send_message(). Free it here. */
+       Msg_free(msg); /* Consume caller's reference. */
 
        return 0;
 }
@@ -901,7 +902,7 @@ int Client_voiceMsg(client_t *client, uint8_t *data, int len)
        int offset, packetsize;
        voicetarget_t *vt;
 
-       channel_t *ch = (channel_t *)client->channel;
+       channel_t *ch = client->channel;
        struct dlist *itr;
 
        if (!client->authenticated || client->mute || client->self_mute || ch->silent)
@@ -997,7 +998,7 @@ int Client_voiceMsg(client_t *client, uint8_t *data, int len)
                }
                /* Sessions */
                for (i = 0; i < TARGET_MAX_SESSIONS && vt->sessions[i] != -1; i++) {
-                       client_t *c;
+                       client_t *c = NULL;
                        buffer[0] = (uint8_t) (type | 2);
                        Log_debug("Whisper session %d", vt->sessions[i]);
                        while (Client_iterate(&c) != NULL) {
@@ -1024,14 +1025,11 @@ static int Client_send_udp(client_t *client, uint8_t *data, int len)
        if (Util_clientAddressToPortUDP(client) != 0 && CryptState_isValid(&client->cryptState) &&
                client->bUDP) {
 #if defined(__LP64__)
-               buf = mbuf = malloc(len + 4 + 16);
+               buf = mbuf = Memory_safeMalloc(1, len + 4 + 16);
                buf += 4;
 #else
-               mbuf = buf = malloc(len + 4);
+               mbuf = buf = Memory_safeMalloc(1, len + 4);
 #endif
-               if (mbuf == NULL)
-                       Log_fatal("Out of memory");
-
                CryptState_encrypt(&client->cryptState, data, buf, len);
 
 #if defined(__NetBSD__) || defined(__FreeBSD__) || defined(__OpenBSD__) || defined(__APPLE__)