Fix TCP mode memory leak
[umurmur.git] / src / client.c
1 /* Copyright (C) 2009, Martin Johansson <martin@fatbob.nu>
2    Copyright (C) 2005-2009, Thorvald Natvig <thorvald@natvig.com>
3
4    All rights reserved.
5
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9
10    - Redistributions of source code must retain the above copyright notice,
11      this list of conditions and the following disclaimer.
12    - Redistributions in binary form must reproduce the above copyright notice,
13      this list of conditions and the following disclaimer in the documentation
14      and/or other materials provided with the distribution.
15    - Neither the name of the Developers nor the names of its contributors may
16      be used to endorse or promote products derived from this software without
17      specific prior written permission.
18
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31 #include <sys/poll.h>
32 #include <sys/socket.h>
33 #include <errno.h>
34 #include "log.h"
35 #include "list.h"
36 #include "client.h"
37 #include "ssl.h"
38 #include "messages.h"
39 #include "messagehandler.h"
40 #include "pds.h"
41 #include "conf.h"
42 #include "channel.h"
43
44
45
46 static int Client_read(client_t *client);
47 static int Client_write(client_t *client);
48 static int Client_voiceMsg(client_t *client, pds_t *pds);
49 static int Client_send_udp(client_t *client, uint8_t *data, int len);
50 static void Client_voiceMsg_tunnel(client_t *client, message_t *msg);
51
52 declare_list(clients);
53 static int clientcount; /* = 0 */
54 static int session = 1;
55 static int maxBandwidth;
56
57 extern int udpsock;
58
59 void Client_init()
60 {
61         maxBandwidth = getIntConf(MAX_BANDWIDTH);
62 }
63
64 int Client_count()
65 {
66         return clientcount;
67 }
68
69 int Client_getfds(struct pollfd *pollfds)
70 {
71         struct dlist *itr;
72         int i = 0;
73         list_iterate(itr, &clients) {
74                 client_t *c;
75                 c = list_get_entry(itr, client_t, node);
76                 pollfds[i].fd = c->tcpfd;
77                 pollfds[i].events = POLLIN | POLLHUP | POLLERR;
78                 if (c->txsize > 0 || c->readBlockedOnWrite) /* Data waiting to be sent? */
79                         pollfds[i].events |= POLLOUT;
80                 i++;
81         }
82         return i;
83 }
84
85 void Client_janitor()
86 {
87         struct dlist *itr;
88         int bwTop = maxBandwidth + maxBandwidth / 4;
89         list_iterate(itr, &clients) {
90                 client_t *c;
91                 c = list_get_entry(itr, client_t, node);
92                 Log_debug("Client %s BW available %d", c->playerName, c->availableBandwidth);
93                 c->availableBandwidth += maxBandwidth;
94                 if (c->availableBandwidth > bwTop)
95                         c->availableBandwidth = bwTop;
96                 
97                 if (Timer_isElapsed(&c->lastActivity, 1000000LL * INACTICITY_TIMEOUT)) {
98                         /* No activity from client - assume it is lost and close. */
99                         Log_info("Session ID %d timeout - closing", c->sessionId);
100                         Client_free(c);
101                 }
102         }
103 }
104
105 int Client_add(int fd, struct sockaddr_in *remote)
106 {
107         client_t *newclient;
108
109         newclient = malloc(sizeof(client_t));
110         if (newclient == NULL)
111                 Log_fatal("Out of memory");
112         memset(newclient, 0, sizeof(client_t));
113
114         newclient->tcpfd = fd;
115         memcpy(&newclient->remote_tcp, remote, sizeof(struct sockaddr_in));
116         newclient->ssl = SSL_newconnection(newclient->tcpfd, &newclient->SSLready);
117         if (newclient->ssl == NULL) {
118                 Log_warn("SSL negotiation failed");
119                 free(newclient);
120                 return -1;
121         }
122         newclient->availableBandwidth = maxBandwidth;
123         Timer_init(&newclient->lastActivity);
124         newclient->sessionId = session++; /* XXX - more elaborate? */
125         
126         init_list_entry(&newclient->txMsgQueue);
127         init_list_entry(&newclient->chan_node);
128         init_list_entry(&newclient->node);
129         
130         list_add_tail(&newclient->node, &clients);
131         clientcount++;
132         return 0;
133 }
134
135 void Client_free(client_t *client)
136 {
137         struct dlist *itr, *save;
138         message_t *sendmsg;
139
140         Log_info("Disconnect client ID %d addr %s port %d", client->sessionId,
141                          inet_ntoa(client->remote_tcp.sin_addr),
142                          ntohs(client->remote_tcp.sin_port));
143
144         if (client->authenticated) {
145                 sendmsg = Msg_create(ServerLeave);
146                 sendmsg->sessionId = client->sessionId;
147                 Client_send_message_except(client, sendmsg);
148         }
149         list_iterate_safe(itr, save, &client->txMsgQueue) {
150                 list_del(&list_get_entry(itr, message_t, node)->node);
151                 Msg_free(list_get_entry(itr, message_t, node));
152         }
153                 
154         list_del(&client->node);
155         list_del(&client->chan_node);
156         if (client->ssl)
157                 SSL_free(client->ssl);
158         close(client->tcpfd);
159         clientcount--;
160         free(client);
161 }
162
163 void Client_close(client_t *client)
164 {
165         SSL_shutdown(client->ssl);
166         client->shutdown_wait = true;
167 }
168
169 void Client_disconnect_all()
170 {
171         struct dlist *itr, *save;
172         
173         list_iterate_safe(itr, save, &clients) {
174                 Client_free(list_get_entry(itr, client_t, node));
175         }
176 }
177
178 int Client_read_fd(int fd)
179 {
180         struct dlist *itr;
181         client_t *client = NULL;
182         
183         list_iterate(itr, &clients) {
184                 if(fd == list_get_entry(itr, client_t, node)->tcpfd) {
185                         client = list_get_entry(itr, client_t, node);
186                         break;
187                 }
188         }
189         if (client == NULL)
190                 Log_fatal("No client found for fd %d", fd);
191         
192         return Client_read(client);
193 }
194
195 int Client_read(client_t *client)
196 {
197         int rc;
198
199         Timer_restart(&client->lastActivity);
200         
201         if (client->writeBlockedOnRead) {
202                 client->writeBlockedOnRead = false;
203                 Log_debug("Client_read: writeBlockedOnRead == true");
204                 return Client_write(client);
205         }
206         
207         if (client->shutdown_wait) {
208                 Client_free(client);
209                 return 0;
210         }
211         if (!client->SSLready) {
212                 int rc;
213                 rc = SSL_nonblockaccept(client->ssl, &client->SSLready);
214                 if (rc < 0) {
215                         Client_free(client);
216                         return -1;
217                 }
218         }
219
220         do {
221                 errno = 0;
222                 if (!client->msgsize) 
223                         rc = SSL_read(client->ssl, client->rxbuf, 3 - client->rxcount);
224                 else if (client->drainleft > 0)
225                         rc = SSL_read(client->ssl, client->rxbuf, client->drainleft > BUFSIZE ? BUFSIZE : client->drainleft);
226                 else
227                         rc = SSL_read(client->ssl, &client->rxbuf[client->rxcount], client->msgsize);
228                 if (rc > 0) {
229                         message_t *msg;
230                         if (client->drainleft > 0)
231                                 client->drainleft -= rc;
232                         else {
233                                 client->rxcount += rc;
234                                 if (!client->msgsize && rc >= 3)
235                                         client->msgsize = ((client->rxbuf[0] & 0xff) << 16) |
236                                                 ((client->rxbuf[1] & 0xff) << 8) |
237                                                 (client->rxbuf[2] & 0xff);
238                                 if (client->msgsize > BUFSIZE - 3 && client->drainleft == 0) {
239                                         Log_warn("Too big message received (%d). Discarding.", client->msgsize);
240                                         client->rxcount = client->msgsize = 0;
241                                         client->drainleft = client->msgsize;
242                                 }
243                                 else if (client->rxcount == client->msgsize + 3) { /* Got all of the message */
244                                         msg = Msg_networkToMessage(&client->rxbuf[3], client->msgsize);
245                                         /* pass messsage to handler */
246                                         if (msg) {
247                                                 if (msg->messageType == Speex) /* Tunneled voice message */
248                                                         Client_voiceMsg_tunnel(client, msg);
249                                                 else 
250                                                         Mh_handle_message(client, msg);
251                                         }
252                                         client->rxcount = client->msgsize = 0;
253                                 }
254                         }
255                 } else /* rc <= 0 */ {
256                         if (SSL_get_error(client->ssl, rc) == SSL_ERROR_WANT_READ) {
257                                 return 0;
258                         }
259                         else if (SSL_get_error(client->ssl, rc) == SSL_ERROR_WANT_WRITE) {
260                                 client->readBlockedOnWrite = true;
261                                 return 0;
262                         }
263                         else if (SSL_get_error(client->ssl, rc) == SSL_ERROR_ZERO_RETURN) {
264                                 Log_warn("Error: Zero return - closing");
265                                 if (!client->shutdown_wait)
266                                         Client_close(client);
267                         }
268                         else {
269                                 if (SSL_get_error(client->ssl, rc) == SSL_ERROR_SYSCALL) {
270                                         /* Hmm. This is where we end up when the client closes its connection.
271                                          * Kind of strange...
272                                          */
273                                         Log_info("Connection closed by peer");
274                                 }
275                                 else {
276                                         Log_warn("SSL error: %d - Closing connection.", SSL_get_error(client->ssl, rc));
277                                 }
278                                 Client_free(client);
279                                 return -1;
280                         }
281                 }
282         } while (SSL_pending(client->ssl));
283         return 0;       
284 }
285
286 int Client_write_fd(int fd)
287 {
288         struct dlist *itr;
289         client_t *client = NULL;
290         
291         list_iterate(itr, &clients) {
292                 if(fd == list_get_entry(itr, client_t, node)->tcpfd) {
293                         client = list_get_entry(itr, client_t, node);
294                         break;
295                 }
296         }
297         if (client == NULL)
298                 Log_fatal("No client found for fd %d", fd);
299         Client_write(client);
300         return 0;
301 }
302
303 int Client_write(client_t *client)
304 {
305         int rc;
306         
307         if (client->readBlockedOnWrite) {
308                 client->readBlockedOnWrite = false;
309                 Log_debug("Client_write: readBlockedOnWrite == true");
310                 return Client_read(client);
311         }
312         rc = SSL_write(client->ssl, &client->txbuf[client->txcount], client->txsize - client->txcount);
313         if (rc > 0) {
314                 client->txcount += rc;
315                 if (client->txcount == client->txsize)
316                         client->txsize = client->txcount = 0;
317         }
318         else if (rc < 0) {
319                 if (SSL_get_error(client->ssl, rc) == SSL_ERROR_WANT_READ) {
320                         client->writeBlockedOnRead = true;
321                         return 0;
322                 }
323                 else if (SSL_get_error(client->ssl, rc) == SSL_ERROR_WANT_WRITE) {
324                         return 0;
325                 }
326                 else {
327                         if (SSL_get_error(client->ssl, rc) == SSL_ERROR_SYSCALL)
328                                 Log_warn("Client_write: Error: %s  - Closing connection", strerror(errno));
329                         else
330                                 Log_warn("Client_write: SSL error: %d - Closing connection.", SSL_get_error(client->ssl, rc));
331                         Client_free(client);
332                         return -1;
333                 }
334         }
335         if (client->txsize == 0 && !list_empty(&client->txMsgQueue)) {
336                 message_t *msg;
337                 msg = list_get_entry(list_get_first(&client->txMsgQueue), message_t, node);
338                 list_del(list_get_first(&client->txMsgQueue));
339                 client->txQueueCount--;
340                 Client_send_message(client, msg);
341         }
342         return 0;
343 }
344
345 int Client_send_message(client_t *client, message_t *msg)
346 {
347         if (!client->authenticated || !client->SSLready) {
348                 Msg_free(msg);
349                 return 0;
350         }
351         if (client->txsize != 0) {
352                 /* Queue message */
353                 if ((client->txQueueCount > 5 &&  msg->messageType == Speex) ||
354                         client->txQueueCount > 30) {
355                         Msg_free(msg);
356                         return -1;
357                 }
358                 client->txQueueCount++;
359                 list_add_tail(&msg->node, &client->txMsgQueue);
360         } else {
361                 int len;
362                 memset(client->txbuf, 0, BUFSIZE);
363                 len = Msg_messageToNetwork(msg, &client->txbuf[3], BUFSIZE - 3);
364                 doAssert(len < BUFSIZE - 3);
365
366                 client->txbuf[0] =  (len >> 16) & 0xff;
367                 client->txbuf[1] =  (len >> 8) & 0xff;
368                 client->txbuf[2] =  len & 0xff;
369                 client->txsize = len + 3;
370                 client->txcount = 0;
371                 Client_write(client);
372                 Msg_free(msg);
373         }
374         return 0;
375 }
376
377 client_t *Client_iterate(client_t **client_itr)
378 {
379         client_t *c = *client_itr;
380         
381         if (c == NULL && !list_empty(&clients)) {
382                 c = list_get_entry(list_get_first(&clients), client_t, node);
383         } else {
384                 if (list_get_next(&c->node) == &clients)
385                         c = NULL;
386                 else
387                         c = list_get_entry(list_get_next(&c->node), client_t, node);
388         }
389         *client_itr = c;
390         return c;
391 }
392
393
394 int Client_send_message_except(client_t *client, message_t *msg)
395 {
396         client_t *itr = NULL;
397         int count = 0;
398         
399         Msg_inc_ref(msg); /* Make sure a reference is held during the whole iteration. */
400         while (Client_iterate(&itr) != NULL) {
401                 if (itr != client) {
402                         if (count++ > 0)
403                                 Msg_inc_ref(msg); /* One extra reference for each new copy */
404                         Log_debug("Msg %d to %s refcount %d",  msg->messageType, itr->playerName, msg->refcount);
405                         Client_send_message(itr, msg);
406                 }
407         }
408         Msg_free(msg); /* Free our reference to the message */
409         
410         if (count == 0)
411                 Msg_free(msg); /* If only 1 client is connected then no message is passed
412                                                 * to Client_send_message(). Free it here. */
413                 
414         return 0;
415 }
416
417 static bool_t checkDecrypt(client_t *client, const uint8_t *encrypted, uint8_t *plain, unsigned int len)
418 {
419         if (CryptState_isValid(&client->cryptState) &&
420                 CryptState_decrypt(&client->cryptState, encrypted, plain, len))
421                 return true;
422
423         if (Timer_elapsed(&client->cryptState.tLastGood) > 5000000ULL) {
424                 if (Timer_elapsed(&client->cryptState.tLastRequest) > 5000000ULL) {
425                         message_t *sendmsg;
426                         Timer_restart(&client->cryptState.tLastRequest);
427                         
428                         sendmsg = Msg_create(CryptSync);
429                         sendmsg->sessionId = client->sessionId;
430                         sendmsg->payload.cryptSync.empty = true;
431                         Log_info("Requesting voice channel crypt resync");
432                         Client_send_message(client, sendmsg);
433                 }
434         }
435         return false;
436 }
437
438 int Client_read_udp()
439 {
440         int len;
441         struct sockaddr_in from;
442         socklen_t fromlen = sizeof(struct sockaddr_in);
443         uint64_t key;
444         client_t *itr;
445         int msgType = 0;
446         uint32_t sessionId = 0;
447         pds_t *pds;
448         
449 #if defined(__LP64__)
450         uint8_t encbuff[512 + 8];
451         uint8_t *encrypted = encbuff + 4;
452 #else
453         uint8_t encrypted[512];
454 #endif
455         uint8_t buffer[512];
456         
457         len = recvfrom(udpsock, encrypted, 512, MSG_TRUNC, (struct sockaddr *)&from, &fromlen);
458         if (len == 0) {
459                 return -1;
460         } else if (len < 0) {
461                 return -1;
462         } else if (len < 6) {
463                 // 4 bytes crypt header + type + session
464                 return 0;
465         } else if (len > 512) {
466                 return 0;
467         }
468         
469         key = (((uint64_t)from.sin_addr.s_addr) << 16) ^ from.sin_port;
470         pds = Pds_create(buffer, len - 4);
471         itr = NULL;
472         
473         while (Client_iterate(&itr) != NULL) {
474                 if (itr->key == key) {
475                         if (!checkDecrypt(itr, encrypted, buffer, len))
476                                 goto out;
477                         msgType = Pds_get_numval(pds);
478                         sessionId = Pds_get_numval(pds);
479                         if (itr->sessionId != sessionId)
480                                 goto out;
481                         break;
482                 }
483         }       
484         if (itr == NULL) { /* Unknown peer */
485                 while (Client_iterate(&itr) != NULL) {
486                         pds->offset = 0;
487                         if (itr->remote_tcp.sin_addr.s_addr == from.sin_addr.s_addr) {
488                                 if (checkDecrypt(itr, encrypted, buffer, len)) {
489                                         msgType = Pds_get_numval(pds);
490                                         sessionId = Pds_get_numval(pds);
491                                         if (itr->sessionId == sessionId) { /* Found matching client */
492                                                 itr->key = key;
493                                                 Log_info("New UDP connection from %s port %d sessionId %d", inet_ntoa(from.sin_addr), ntohs(from.sin_port), sessionId);
494                                                 memcpy(&itr->remote_udp, &from, sizeof(struct sockaddr_in));
495                                                 break;
496                                         }
497                                 }
498                                 else Log_warn("Bad cryptstate from peer");
499                         }
500                 } /* while */
501         }
502         if (itr == NULL) {
503                 goto out;
504         }
505         len -= 4;
506         if (msgType != Speex && msgType != Ping)
507                 goto out;
508         
509         if (msgType == Ping) {
510                 Client_send_udp(itr, buffer, len);
511         }
512         else {
513                 Client_voiceMsg(itr, pds);
514         }
515         
516 out:
517         Pds_free(pds);
518         return 0;
519 }
520
521 static void Client_voiceMsg_tunnel(client_t *client, message_t *msg)
522 {
523         uint8_t buf[512];
524         pds_t *pds = Pds_create(buf, 512);
525
526         Pds_add_numval(pds, msg->messageType);
527         Pds_add_numval(pds, msg->sessionId);
528         Pds_add_numval(pds, msg->payload.speex.seq);
529         Pds_append_data_nosize(pds, msg->payload.speex.data, msg->payload.speex.size);
530         
531         Msg_free(msg);
532         
533         if (!pds->bOk)
534                 Log_warn("Large Speex message from TCP"); /* XXX - pds resize? */
535         pds->maxsize = pds->offset;
536         Client_voiceMsg(client, pds);
537         
538         Pds_free(pds);
539 }
540
541 static int Client_voiceMsg(client_t *client, pds_t *pds)
542 {
543         int seq, flags, msgType, sessionId, packetsize;
544         channel_t *ch = (channel_t *)client->channel;
545         struct dlist *itr;
546         
547         if (!client->authenticated || client->mute)
548                 return 0;
549
550         
551         pds->offset = 0;
552         msgType = Pds_get_numval(pds);
553         sessionId = Pds_get_numval(pds);
554         seq = Pds_get_numval(pds);
555         flags = Pds_get_numval(pds);
556
557         packetsize = 20 + 8 + 4 + pds->maxsize - pds->offset;
558         if (client->availableBandwidth - packetsize < 0)
559                 return 0; /* Discard */
560         
561         client->availableBandwidth -= packetsize;
562         
563         pds->offset = 0;
564         
565         if (flags & LoopBack) {
566                 Client_send_udp(client, pds->data, pds->maxsize);
567                 return 0;
568         }
569         if (ch == NULL)
570                 return 0;
571         
572         list_iterate(itr, &ch->clients) {
573                 client_t *c;
574                 c = list_get_entry(itr, client_t, chan_node);
575                 if (c != client && !c->deaf) {
576                         Client_send_udp(c, pds->data, pds->maxsize);
577                 }
578         }
579         return 0;
580 }
581
582
583 static int Client_send_udp(client_t *client, uint8_t *data, int len)
584 {
585         uint8_t *buf, *mbuf;
586         message_t *sendmsg;
587
588         if (client->remote_udp.sin_port != 0 && CryptState_isValid(&client->cryptState)) {
589 #if defined(__LP64__)
590                 buf = mbuf = malloc(len + 4 + 16);
591                 buf += 4;
592 #else
593                 mbuf = buf = malloc(len + 4);
594 #endif
595                 if (mbuf == NULL)
596                         Log_fatal("Out of memory");
597                 
598                 CryptState_encrypt(&client->cryptState, data, buf, len);
599                 
600                 sendto(udpsock, buf, len + 4, 0, (struct sockaddr *)&client->remote_udp, sizeof(struct sockaddr_in));
601                 
602                 free(mbuf);
603         } else {
604                 pds_t *pds = Pds_create(data, len);
605                 
606                 sendmsg = Msg_create(Pds_get_numval(pds));
607                 sendmsg->sessionId = Pds_get_numval(pds);
608                 
609                 if (sendmsg->messageType == Speex || sendmsg->messageType == Ping) {
610                         if (sendmsg->messageType == Speex) {
611                                 sendmsg->payload.speex.seq = Pds_get_numval(pds);
612                                 sendmsg->payload.speex.size = pds->maxsize - pds->offset;
613                                 doAssert(pds->maxsize - pds->offset <= SPEEX_DATA_SIZE);
614                                 memcpy(sendmsg->payload.speex.data, data + pds->offset, pds->maxsize - pds->offset);
615                         } else { /* Ping */
616                                 sendmsg->payload.ping.timestamp = Pds_get_numval(pds);
617                         }
618                         Client_send_message(client, sendmsg);
619                 } else {
620                         Log_warn("TCP fallback: Unsupported message type %d", sendmsg->messageType);
621                         Msg_free(sendmsg);
622                 }
623                 Pds_free(pds);
624         }
625         return 0;
626 }