fix address comparison
[umurmur.git] / src / ban.c
index b02d92116a3518c0da418ab74d50d9fc188a8ca2..667bdc8ecec353114b7e8a65e9db2c2b39f45795 100644 (file)
--- a/src/ban.c
+++ b/src/ban.c
@@ -8,7 +8,7 @@
    are met:
 
    - Redistributions of source code must retain the above copyright notice,
    are met:
 
    - Redistributions of source code must retain the above copyright notice,
-     this list of conditions and the following disclaimer.
+        this list of conditions and the following disclaimer.
    - Redistributions in binary form must reproduce the above copyright notice,
      this list of conditions and the following disclaimer in the documentation
      and/or other materials provided with the distribution.
    - Redistributions in binary form must reproduce the above copyright notice,
      this list of conditions and the following disclaimer in the documentation
      and/or other materials provided with the distribution.
    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-*/
+   */
 
 #include <stdlib.h>
 #include <time.h>
 
 #include <stdlib.h>
 #include <time.h>
+#include <string.h>
 #include "log.h"
 #include "list.h"
 #include "ban.h"
 #include "conf.h"
 #include "ssl.h"
 #include "log.h"
 #include "list.h"
 #include "ban.h"
 #include "conf.h"
 #include "ssl.h"
+#include "util.h"
 
 static void Ban_saveBanFile(void);
 static void Ban_readBanFile(void);
 
 static void Ban_saveBanFile(void);
 static void Ban_readBanFile(void);
@@ -56,10 +58,10 @@ void Ban_init(void)
 
 void Ban_deinit(void)
 {
 
 void Ban_deinit(void)
 {
-       /* Save banlist */      
+       /* Save banlist */
        if (getStrConf(BANFILE) != NULL)
                Ban_saveBanFile();
        if (getStrConf(BANFILE) != NULL)
                Ban_saveBanFile();
-               
+
        Ban_clearBanList();
 }
 
        Ban_clearBanList();
 }
 
@@ -68,14 +70,14 @@ void Ban_UserBan(client_t *client, char *reason)
        ban_t *ban;
        char hexhash[41];
 
        ban_t *ban;
        char hexhash[41];
 
-       ban = malloc(sizeof(ban_t));
+       ban = calloc(1, sizeof(ban_t));
        if (ban == NULL)
                Log_fatal("Out of memory");
        if (ban == NULL)
                Log_fatal("Out of memory");
-       memset(ban, 0, sizeof(ban_t));
-       
+
        memcpy(ban->hash, client->hash, 20);
        memcpy(ban->hash, client->hash, 20);
-       memcpy(&ban->address, &client->remote_tcp.sin_addr, sizeof(in_addr_t));
-       ban->mask = 128;
+
+       ban->address = client->remote_tcp;
+       ban->mask = (ban->address.ss_family == AF_INET) ? 32 : 128;
        ban->reason = strdup(reason);
        ban->name = strdup(client->username);
        ban->time = time(NULL);
        ban->reason = strdup(reason);
        ban->name = strdup(client->username);
        ban->time = time(NULL);
@@ -86,10 +88,11 @@ void Ban_UserBan(client_t *client, char *reason)
        banlist_changed = true;
        if(getBoolConf(SYNC_BANFILE))
                Ban_saveBanFile();
        banlist_changed = true;
        if(getBoolConf(SYNC_BANFILE))
                Ban_saveBanFile();
-       
+
        SSLi_hash2hex(ban->hash, hexhash);
        SSLi_hash2hex(ban->hash, hexhash);
+
        Log_info_client(client, "User kickbanned. Reason: '%s' Hash: %s IP: %s Banned for: %d seconds",
        Log_info_client(client, "User kickbanned. Reason: '%s' Hash: %s IP: %s Banned for: %d seconds",
-                       ban->reason, hexhash, inet_ntoa(*((struct in_addr *)&ban->address)), ban->duration);
+               ban->reason, hexhash, Util_clientAddressToString(client), ban->duration);
 }
 
 
 }
 
 
@@ -97,17 +100,16 @@ void Ban_pruneBanned()
 {
        struct dlist *itr;
        ban_t *ban;
 {
        struct dlist *itr;
        ban_t *ban;
-       char hexhash[41];
        uint64_t bantime_long;
        uint64_t bantime_long;
-               
+
        list_iterate(itr, &banlist) {
                ban = list_get_entry(itr, ban_t, node);
                bantime_long = ban->duration * 1000000LL;
 #ifdef DEBUG
                SSLi_hash2hex(ban->hash, hexhash);
                Log_debug("BL: User %s Reason: '%s' Hash: %s IP: %s Time left: %d",
        list_iterate(itr, &banlist) {
                ban = list_get_entry(itr, ban_t, node);
                bantime_long = ban->duration * 1000000LL;
 #ifdef DEBUG
                SSLi_hash2hex(ban->hash, hexhash);
                Log_debug("BL: User %s Reason: '%s' Hash: %s IP: %s Time left: %d",
-                         ban->name, ban->reason, hexhash, inet_ntoa(*((struct in_addr *)&ban->address)),
-                         bantime_long / 1000000LL - Timer_elapsed(&ban->startTime) / 1000000LL);
+                       ban->name, ban->reason, hexhash, Util_addressToString(&ban->address)),
+                       bantime_long / 1000000LL - Timer_elapsed(&ban->startTime) / 1000000LL);
 #endif
                /* Duration of 0 = forever */
                if (ban->duration != 0 && Timer_isElapsed(&ban->startTime, bantime_long)) {
 #endif
                /* Duration of 0 = forever */
                if (ban->duration != 0 && Timer_isElapsed(&ban->startTime, bantime_long)) {
@@ -129,32 +131,54 @@ bool_t Ban_isBanned(client_t *client)
        ban_t *ban;
        list_iterate(itr, &banlist) {
                ban = list_get_entry(itr, ban_t, node);
        ban_t *ban;
        list_iterate(itr, &banlist) {
                ban = list_get_entry(itr, ban_t, node);
-               if (memcmp(ban->hash, client->hash, 20) == 0) 
+               if (memcmp(ban->hash, client->hash, 20) == 0)
                        return true;
        }
        return false;
                        return true;
        }
        return false;
-       
+
 }
 
 }
 
-bool_t Ban_isBannedAddr(in_addr_t *addr)
+bool_t Ban_isBannedAddr(struct sockaddr_storage *address)
 {
        struct dlist *itr;
        ban_t *ban;
 {
        struct dlist *itr;
        ban_t *ban;
-       int mask;
-       in_addr_t tempaddr1, tempaddr2;
-       
+       uint64_t clientAddressBytes[2] = {0};
+       uint64_t banAddressBytes[2] = {0};
+       uint64_t banMaskBits[2] = {UINT64_MAX};
+
+       if (address->ss_family == AF_INET) {
+               memcpy(clientAddressBytes, &((struct sockaddr_in *)address)->sin_addr, 4);
+       } else {
+               memcpy(clientAddressBytes, &((struct sockaddr_in6 *)address)->sin6_addr, 16);
+       }
+
        list_iterate(itr, &banlist) {
                ban = list_get_entry(itr, ban_t, node);
        list_iterate(itr, &banlist) {
                ban = list_get_entry(itr, ban_t, node);
-               mask = ban->mask - 96;
-               if (mask < 32) { /* XXX - only ipv4 support */
-                       memcpy(&tempaddr1, addr, sizeof(in_addr_t));
-                       memcpy(&tempaddr2, &ban->address, sizeof(in_addr_t));
-                       tempaddr1 &= (2 ^ mask) - 1;
-                       tempaddr2 &= (2 ^ mask) - 1;
+
+               if(address->ss_len == ban->address.ss_family) {
+                       if (ban->address.ss_family == AF_INET) {
+                               memcpy(banAddressBytes, &((struct sockaddr_in *)&ban->address)->sin_addr, 4);
+                       } else {
+                               memcpy(banAddressBytes, &((struct sockaddr_in6 *)&ban->address)->sin6_addr, 16);
+                       }
+
+                       banMaskBits[0] <<= (ban->mask >= 64) ? 0 : 64 - ban->mask;
+                       banMaskBits[1] <<= (ban->mask > 64) ? 128 - ban->mask : 64;
+
+                       clientAddressBytes[0] &= banMaskBits[0];
+                       clientAddressBytes[1] &= banMaskBits[1];
+
+                       banAddressBytes[0] &= banMaskBits[0];
+                       banAddressBytes[1] &= banMaskBits[1];
+
+                       if (memcmp(clientAddressBytes, banAddressBytes, 16) == 0) {
+                               return true;
+                       }
+
                }
                }
-               if (memcmp(&tempaddr1, &tempaddr2, sizeof(in_addr_t)) == 0) 
-                       return true;
+
        }
        }
+
        return false;
 }
 
        return false;
 }
 
@@ -173,19 +197,24 @@ message_t *Ban_getBanList(void)
        char timestr[32];
        char hexhash[41];
        uint8_t address[16];
        char timestr[32];
        char hexhash[41];
        uint8_t address[16];
-       
+
        msg = Msg_banList_create(bancount);
        list_iterate(itr, &banlist) {
                ban = list_get_entry(itr, ban_t, node);
                gmtime_r(&ban->time, &timespec);
                strftime(timestr, 32, "%Y-%m-%dT%H:%M:%S", &timespec);
                SSLi_hash2hex(ban->hash, hexhash);
        msg = Msg_banList_create(bancount);
        list_iterate(itr, &banlist) {
                ban = list_get_entry(itr, ban_t, node);
                gmtime_r(&ban->time, &timespec);
                strftime(timestr, 32, "%Y-%m-%dT%H:%M:%S", &timespec);
                SSLi_hash2hex(ban->hash, hexhash);
-               /* ipv4 representation as ipv6 address. */
                memset(address, 0, 16);
                memset(address, 0, 16);
-               memcpy(&address[12], &ban->address, 4);
-               memset(&address[10], 0xff, 2); /* IPv4 */
-               Msg_banList_addEntry(msg, i++, address, ban->mask, ban->name,
-                                    hexhash, ban->reason, timestr, ban->duration);
+
+               if(ban->address.ss_family == AF_INET) {
+                       memcpy(&address[12], &((struct sockaddr_in *)&ban->address)->sin_addr, 4);
+                       memset(&address[10], 0xff, 2);
+                       Msg_banList_addEntry(msg, i++, address, ban->mask + 96, ban->name, hexhash, ban->reason, timestr, ban->duration);
+               } else {
+                       memcpy(&address, &((struct sockaddr_in6 *)&ban->address)->sin6_addr, 16);
+                       Msg_banList_addEntry(msg, i++, address, ban->mask, ban->name, hexhash, ban->reason, timestr, ban->duration);
+               }
+
        }
        return msg;
 }
        }
        return msg;
 }
@@ -212,15 +241,26 @@ void Ban_putBanList(message_t *msg, int n_bans)
        char *hexhash, *name, *reason, *start;
        uint32_t duration, mask;
        uint8_t *address;
        char *hexhash, *name, *reason, *start;
        uint32_t duration, mask;
        uint8_t *address;
-       
+       char mappedBytes[12] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff};
+
        for (i = 0; i < n_bans; i++) {
                Msg_banList_getEntry(msg, i, &address, &mask, &name, &hexhash, &reason, &start, &duration);
                ban = malloc(sizeof(ban_t));
                if (ban == NULL)
                        Log_fatal("Out of memory");
        for (i = 0; i < n_bans; i++) {
                Msg_banList_getEntry(msg, i, &address, &mask, &name, &hexhash, &reason, &start, &duration);
                ban = malloc(sizeof(ban_t));
                if (ban == NULL)
                        Log_fatal("Out of memory");
-               memset(ban, 0, sizeof(ban_t));
                SSLi_hex2hash(hexhash, ban->hash);
                SSLi_hex2hash(hexhash, ban->hash);
-               memcpy(&ban->address, &address[12], 4);
+
+               if(memcmp(address, mappedBytes, 12) == 0) {
+                       memcpy(&((struct sockaddr_in *)&ban->address)->sin_addr, &address[12], 4);
+                       ban->address.ss_family = AF_INET;
+                       if (mask > 32) {
+                               mask = 32;
+                       }
+               } else {
+                       memcpy(&((struct sockaddr_in6 *)&ban->address)->sin6_addr, address, 16);
+                       ban->address.ss_family = AF_INET6;
+               }
+
                ban->mask = mask;
                ban->reason = strdup(reason);
                ban->name = strdup(name);
                ban->mask = mask;
                ban->reason = strdup(reason);
                ban->name = strdup(name);
@@ -253,8 +293,8 @@ static void Ban_saveBanFile(void)
        list_iterate(itr, &banlist) {
                ban = list_get_entry(itr, ban_t, node);
                SSLi_hash2hex(ban->hash, hexhash);
        list_iterate(itr, &banlist) {
                ban = list_get_entry(itr, ban_t, node);
                SSLi_hash2hex(ban->hash, hexhash);
-               fprintf(file, "%s,%s,%d,%ld,%d,%s,%s\n", hexhash, inet_ntoa(*((struct in_addr *)&ban->address)),
-                       ban->mask, (long int)ban->time, ban->duration, ban->name, ban->reason);
+
+               fprintf(file, "%s,%s,%d,%ld,%d,%s,%s\n", hexhash, Util_addressToString(&ban->address),ban->mask, (long int)ban->time, ban->duration, ban->name, ban->reason);
        }
        fclose(file);
        banlist_changed = false;
        }
        fclose(file);
        banlist_changed = false;
@@ -263,7 +303,6 @@ static void Ban_saveBanFile(void)
 
 static void Ban_readBanFile(void)
 {
 
 static void Ban_readBanFile(void)
 {
-       struct dlist *itr;
        ban_t *ban;
        char line[1024], *hexhash, *address, *name, *reason;
        uint32_t mask, duration;
        ban_t *ban;
        char line[1024], *hexhash, *address, *name, *reason;
        uint32_t mask, duration;
@@ -297,13 +336,21 @@ static void Ban_readBanFile(void)
                p = strtok(NULL, "\n");
                if (p == NULL) break;
                reason = p;
                p = strtok(NULL, "\n");
                if (p == NULL) break;
                reason = p;
-               
+
                ban = malloc(sizeof(ban_t));
                if (ban == NULL)
                        Log_fatal("Out of memory");
                memset(ban, 0, sizeof(ban_t));
                SSLi_hex2hash(hexhash, ban->hash);
                ban = malloc(sizeof(ban_t));
                if (ban == NULL)
                        Log_fatal("Out of memory");
                memset(ban, 0, sizeof(ban_t));
                SSLi_hex2hash(hexhash, ban->hash);
-               inet_aton(address, (struct in_addr *)&ban->address);
+               if (inet_pton(AF_INET, address, &ban->address) == 0) {
+                       if (inet_pton(AF_INET6, address, &ban->address) == 0) {
+                               Log_warn("Address \"%s\" is illegal!", address);
+                       } else {
+                               ban->address.ss_family = AF_INET6;
+                       }
+               } else {
+                       ban->address.ss_family = AF_INET;
+               }
                ban->name = strdup(name);
                ban->reason = strdup(reason);
                if (ban->name == NULL || ban->reason == NULL)
                ban->name = strdup(name);
                ban->reason = strdup(reason);
                if (ban->name == NULL || ban->reason == NULL)