Corrected a bug that causes segfaults if UDP packets are received when no client...
[umurmur.git] / src / client.c
1 /* Copyright (C) 2009, Martin Johansson <martin@fatbob.nu>
2    Copyright (C) 2005-2009, Thorvald Natvig <thorvald@natvig.com>
3
4    All rights reserved.
5
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9
10    - Redistributions of source code must retain the above copyright notice,
11      this list of conditions and the following disclaimer.
12    - Redistributions in binary form must reproduce the above copyright notice,
13      this list of conditions and the following disclaimer in the documentation
14      and/or other materials provided with the distribution.
15    - Neither the name of the Developers nor the names of its contributors may
16      be used to endorse or promote products derived from this software without
17      specific prior written permission.
18
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31 #include <sys/poll.h>
32 #include <sys/socket.h>
33 #include <errno.h>
34 #include "log.h"
35 #include "list.h"
36 #include "client.h"
37 #include "ssl.h"
38 #include "messages.h"
39 #include "messagehandler.h"
40 #include "pds.h"
41 #include "conf.h"
42 #include "channel.h"
43
44
45
46 static int Client_read(client_t *client);
47 static int Client_write(client_t *client);
48 static int Client_voiceMsg(client_t *client, pds_t *pds);
49 static int Client_send_udp(client_t *client, uint8_t *data, int len);
50 static void Client_voiceMsg_tunnel(client_t *client, message_t *msg);
51
52 declare_list(clients);
53 static int clientcount; /* = 0 */
54 static int session = 1;
55 static int maxBandwidth;
56
57 extern int udpsock;
58
59 void Client_init()
60 {
61         maxBandwidth = getIntConf(MAX_BANDWIDTH);
62 }
63
64 int Client_count()
65 {
66         return clientcount;
67 }
68
69 int Client_getfds(struct pollfd *pollfds)
70 {
71         struct dlist *itr;
72         int i = 0;
73         list_iterate(itr, &clients) {
74                 client_t *c;
75                 c = list_get_entry(itr, client_t, node);
76                 pollfds[i].fd = c->tcpfd;
77                 pollfds[i].events = POLLIN | POLLHUP | POLLERR;
78                 if (c->txsize > 0 || c->readBlockedOnWrite) /* Data waiting to be sent? */
79                         pollfds[i].events |= POLLOUT;
80                 i++;
81         }
82         return i;
83 }
84
85 void Client_janitor()
86 {
87         struct dlist *itr;
88         int bwTop = maxBandwidth + maxBandwidth / 4;
89         list_iterate(itr, &clients) {
90                 client_t *c;
91                 c = list_get_entry(itr, client_t, node);
92                 Log_debug("Client %s BW available %d", c->playerName, c->availableBandwidth);
93                 c->availableBandwidth += maxBandwidth;
94                 if (c->availableBandwidth > bwTop)
95                         c->availableBandwidth = bwTop;
96                 
97                 if (Timer_isElapsed(&c->lastActivity, 1000000LL * INACTICITY_TIMEOUT)) {
98                         /* No activity from client - assume it is lost and close. */
99                         Log_info("Session ID %d timeout - closing", c->sessionId);
100                         Client_free(c);
101                 }
102         }
103 }
104
105 int Client_add(int fd, struct sockaddr_in *remote)
106 {
107         client_t *newclient;
108
109         newclient = malloc(sizeof(client_t));
110         if (newclient == NULL)
111                 Log_fatal("Out of memory");
112         memset(newclient, 0, sizeof(client_t));
113
114         newclient->tcpfd = fd;
115         memcpy(&newclient->remote_tcp, remote, sizeof(struct sockaddr_in));
116         newclient->ssl = SSL_newconnection(newclient->tcpfd, &newclient->SSLready);
117         if (newclient->ssl == NULL) {
118                 Log_warn("SSL negotiation failed");
119                 free(newclient);
120                 return -1;
121         }
122         newclient->availableBandwidth = maxBandwidth;
123         Timer_init(&newclient->lastActivity);
124         newclient->sessionId = session++; /* XXX - more elaborate? */
125         
126         init_list_entry(&newclient->txMsgQueue);
127         init_list_entry(&newclient->chan_node);
128         init_list_entry(&newclient->node);
129         
130         list_add_tail(&newclient->node, &clients);
131         clientcount++;
132         return 0;
133 }
134
135 void Client_free(client_t *client)
136 {
137         struct dlist *itr, *save;
138         message_t *sendmsg;
139
140         Log_info("Disconnect client ID %d addr %s port %d", client->sessionId,
141                          inet_ntoa(client->remote_tcp.sin_addr),
142                          ntohs(client->remote_tcp.sin_port));
143
144         if (client->authenticated) {
145                 sendmsg = Msg_create(ServerLeave);
146                 sendmsg->sessionId = client->sessionId;
147                 Client_send_message_except(client, sendmsg);
148         }
149         list_iterate_safe(itr, save, &client->txMsgQueue) {
150                 list_del(&list_get_entry(itr, message_t, node)->node);
151                 Msg_free(list_get_entry(itr, message_t, node));
152         }
153                 
154         list_del(&client->node);
155         list_del(&client->chan_node);
156         if (client->ssl)
157                 SSL_free(client->ssl);
158         close(client->tcpfd);
159         clientcount--;
160         free(client);
161 }
162
163 void Client_close(client_t *client)
164 {
165         SSL_shutdown(client->ssl);
166         client->shutdown_wait = true;
167 }
168
169 void Client_disconnect_all()
170 {
171         struct dlist *itr, *save;
172         
173         list_iterate_safe(itr, save, &clients) {
174                 Client_free(list_get_entry(itr, client_t, node));
175         }
176 }
177
178 int Client_read_fd(int fd)
179 {
180         struct dlist *itr;
181         client_t *client = NULL;
182         
183         list_iterate(itr, &clients) {
184                 if(fd == list_get_entry(itr, client_t, node)->tcpfd) {
185                         client = list_get_entry(itr, client_t, node);
186                         break;
187                 }
188         }
189         if (client == NULL)
190                 Log_fatal("No client found for fd %d", fd);
191         
192         return Client_read(client);
193 }
194
195 int Client_read(client_t *client)
196 {
197         int rc;
198
199         Timer_restart(&client->lastActivity);
200         
201         if (client->writeBlockedOnRead) {
202                 client->writeBlockedOnRead = false;
203                 Log_debug("Client_read: writeBlockedOnRead == true");
204                 return Client_write(client);
205         }
206         
207         if (client->shutdown_wait) {
208                 Client_free(client);
209                 return 0;
210         }
211         if (!client->SSLready) {
212                 int rc;
213                 rc = SSL_nonblockaccept(client->ssl, &client->SSLready);
214                 if (rc < 0) {
215                         Client_free(client);
216                         return -1;
217                 }
218         }
219
220         do {
221                 errno = 0;
222                 if (!client->msgsize) 
223                         rc = SSL_read(client->ssl, client->rxbuf, 3 - client->rxcount);
224                 else if (client->drainleft > 0)
225                         rc = SSL_read(client->ssl, client->rxbuf, client->drainleft > BUFSIZE ? BUFSIZE : client->drainleft);
226                 else
227                         rc = SSL_read(client->ssl, &client->rxbuf[client->rxcount], client->msgsize);
228                 if (rc > 0) {
229                         message_t *msg;
230                         if (client->drainleft > 0)
231                                 client->drainleft -= rc;
232                         else {
233                                 client->rxcount += rc;
234                                 if (!client->msgsize && rc >= 3)
235                                         client->msgsize = ((client->rxbuf[0] & 0xff) << 16) |
236                                                 ((client->rxbuf[1] & 0xff) << 8) |
237                                                 (client->rxbuf[2] & 0xff);
238                                 if (client->msgsize > BUFSIZE - 3 && client->drainleft == 0) {
239                                         Log_warn("Too big message received (%d). Discarding.", client->msgsize);
240                                         client->rxcount = client->msgsize = 0;
241                                         client->drainleft = client->msgsize;
242                                 }
243                                 else if (client->rxcount == client->msgsize + 3) { /* Got all of the message */
244                                         msg = Msg_networkToMessage(&client->rxbuf[3], client->msgsize);
245                                         /* pass messsage to handler */
246                                         if (msg) {
247                                                 if (msg->messageType == Speex) /* Tunneled voice message */
248                                                         Client_voiceMsg_tunnel(client, msg);
249                                                 else 
250                                                         Mh_handle_message(client, msg);
251                                         }
252                                         client->rxcount = client->msgsize = 0;
253                                 }
254                         }
255                 } else /* rc <= 0 */ {
256                         if (SSL_get_error(client->ssl, rc) == SSL_ERROR_WANT_READ) {
257                                 return 0;
258                         }
259                         else if (SSL_get_error(client->ssl, rc) == SSL_ERROR_WANT_WRITE) {
260                                 client->readBlockedOnWrite = true;
261                                 return 0;
262                         }
263                         else if (SSL_get_error(client->ssl, rc) == SSL_ERROR_ZERO_RETURN) {
264                                 Log_warn("Error: Zero return - closing");
265                                 if (!client->shutdown_wait)
266                                         Client_close(client);
267                         }
268                         else {
269                                 if (SSL_get_error(client->ssl, rc) == SSL_ERROR_SYSCALL) {
270                                         /* Hmm. This is where we end up when the client closes its connection.
271                                          * Kind of strange...
272                                          */
273                                         Log_info("Connection closed by peer");
274                                 }
275                                 else {
276                                         Log_warn("SSL error: %d - Closing connection.", SSL_get_error(client->ssl, rc));
277                                 }
278                                 Client_free(client);
279                                 return -1;
280                         }
281                 }
282         } while (SSL_pending(client->ssl));
283         return 0;       
284 }
285
286 int Client_write_fd(int fd)
287 {
288         struct dlist *itr;
289         client_t *client = NULL;
290         
291         list_iterate(itr, &clients) {
292                 if(fd == list_get_entry(itr, client_t, node)->tcpfd) {
293                         client = list_get_entry(itr, client_t, node);
294                         break;
295                 }
296         }
297         if (client == NULL)
298                 Log_fatal("No client found for fd %d", fd);
299         Client_write(client);
300         return 0;
301 }
302
303 int Client_write(client_t *client)
304 {
305         int rc;
306         
307         if (client->readBlockedOnWrite) {
308                 client->readBlockedOnWrite = false;
309                 Log_debug("Client_write: readBlockedOnWrite == true");
310                 return Client_read(client);
311         }
312         rc = SSL_write(client->ssl, &client->txbuf[client->txcount], client->txsize - client->txcount);
313         if (rc > 0) {
314                 client->txcount += rc;
315                 if (client->txcount == client->txsize)
316                         client->txsize = client->txcount = 0;
317         }
318         else if (rc < 0) {
319                 if (SSL_get_error(client->ssl, rc) == SSL_ERROR_WANT_READ) {
320                         client->writeBlockedOnRead = true;
321                         return 0;
322                 }
323                 else if (SSL_get_error(client->ssl, rc) == SSL_ERROR_WANT_WRITE) {
324                         return 0;
325                 }
326                 else {
327                         if (SSL_get_error(client->ssl, rc) == SSL_ERROR_SYSCALL)
328                                 Log_warn("Client_write: Error: %s  - Closing connection", strerror(errno));
329                         else
330                                 Log_warn("Client_write: SSL error: %d - Closing connection.", SSL_get_error(client->ssl, rc));
331                         Client_free(client);
332                         return -1;
333                 }
334         }
335         if (client->txsize == 0 && !list_empty(&client->txMsgQueue)) {
336                 message_t *msg;
337                 msg = list_get_entry(list_get_first(&client->txMsgQueue), message_t, node);
338                 list_del(list_get_first(&client->txMsgQueue));
339                 client->txQueueCount--;
340                 Client_send_message(client, msg);
341         }
342         return 0;
343 }
344
345 int Client_send_message(client_t *client, message_t *msg)
346 {
347         if (!client->authenticated || !client->SSLready) {
348                 Msg_free(msg);
349                 return 0;
350         }
351         if (client->txsize != 0) {
352                 /* Queue message */
353                 if ((client->txQueueCount > 5 &&  msg->messageType == Speex) ||
354                         client->txQueueCount > 30) {
355                         Msg_free(msg);
356                         return -1;
357                 }
358                 client->txQueueCount++;
359                 list_add_tail(&msg->node, &client->txMsgQueue);
360         } else {
361                 int len;
362                 memset(client->txbuf, 0, BUFSIZE);
363                 len = Msg_messageToNetwork(msg, &client->txbuf[3], BUFSIZE - 3);
364                 doAssert(len < BUFSIZE - 3);
365
366                 client->txbuf[0] =  (len >> 16) & 0xff;
367                 client->txbuf[1] =  (len >> 8) & 0xff;
368                 client->txbuf[2] =  len & 0xff;
369                 client->txsize = len + 3;
370                 client->txcount = 0;
371                 Client_write(client);
372                 Msg_free(msg);
373         }
374         return 0;
375 }
376
377 client_t *Client_iterate(client_t **client_itr)
378 {
379         client_t *c = *client_itr;
380
381         if (list_empty(&clients))
382                 return NULL;
383         
384         if (c == NULL) {
385                 c = list_get_entry(list_get_first(&clients), client_t, node);
386         } else {
387                 if (list_get_next(&c->node) == &clients)
388                         c = NULL;
389                 else
390                         c = list_get_entry(list_get_next(&c->node), client_t, node);
391         }
392         *client_itr = c;
393         return c;
394 }
395
396
397 int Client_send_message_except(client_t *client, message_t *msg)
398 {
399         client_t *itr = NULL;
400         int count = 0;
401         
402         Msg_inc_ref(msg); /* Make sure a reference is held during the whole iteration. */
403         while (Client_iterate(&itr) != NULL) {
404                 if (itr != client) {
405                         if (count++ > 0)
406                                 Msg_inc_ref(msg); /* One extra reference for each new copy */
407                         Log_debug("Msg %d to %s refcount %d",  msg->messageType, itr->playerName, msg->refcount);
408                         Client_send_message(itr, msg);
409                 }
410         }
411         Msg_free(msg); /* Free our reference to the message */
412         
413         if (count == 0)
414                 Msg_free(msg); /* If only 1 client is connected then no message is passed
415                                                 * to Client_send_message(). Free it here. */
416                 
417         return 0;
418 }
419
420 static bool_t checkDecrypt(client_t *client, const uint8_t *encrypted, uint8_t *plain, unsigned int len)
421 {
422         if (CryptState_isValid(&client->cryptState) &&
423                 CryptState_decrypt(&client->cryptState, encrypted, plain, len))
424                 return true;
425
426         if (Timer_elapsed(&client->cryptState.tLastGood) > 5000000ULL) {
427                 if (Timer_elapsed(&client->cryptState.tLastRequest) > 5000000ULL) {
428                         message_t *sendmsg;
429                         Timer_restart(&client->cryptState.tLastRequest);
430                         
431                         sendmsg = Msg_create(CryptSync);
432                         sendmsg->sessionId = client->sessionId;
433                         sendmsg->payload.cryptSync.empty = true;
434                         Log_info("Requesting voice channel crypt resync");
435                         Client_send_message(client, sendmsg);
436                 }
437         }
438         return false;
439 }
440
441 int Client_read_udp()
442 {
443         int len;
444         struct sockaddr_in from;
445         socklen_t fromlen = sizeof(struct sockaddr_in);
446         uint64_t key;
447         client_t *itr;
448         int msgType = 0;
449         uint32_t sessionId = 0;
450         pds_t *pds;
451         
452 #if defined(__LP64__)
453         uint8_t encbuff[512 + 8];
454         uint8_t *encrypted = encbuff + 4;
455 #else
456         uint8_t encrypted[512];
457 #endif
458         uint8_t buffer[512];
459         
460         len = recvfrom(udpsock, encrypted, 512, MSG_TRUNC, (struct sockaddr *)&from, &fromlen);
461         if (len == 0) {
462                 return -1;
463         } else if (len < 0) {
464                 return -1;
465         } else if (len < 6) {
466                 // 4 bytes crypt header + type + session
467                 return 0;
468         } else if (len > 512) {
469                 return 0;
470         }
471         
472         key = (((uint64_t)from.sin_addr.s_addr) << 16) ^ from.sin_port;
473         pds = Pds_create(buffer, len - 4);
474         itr = NULL;
475         
476         while (Client_iterate(&itr) != NULL) {
477                 if (itr->key == key) {
478                         if (!checkDecrypt(itr, encrypted, buffer, len))
479                                 goto out;
480                         msgType = Pds_get_numval(pds);
481                         sessionId = Pds_get_numval(pds);
482                         if (itr->sessionId != sessionId)
483                                 goto out;
484                         break;
485                 }
486         }       
487         if (itr == NULL) { /* Unknown peer */
488                 while (Client_iterate(&itr) != NULL) {
489                         pds->offset = 0;
490                         if (itr->remote_tcp.sin_addr.s_addr == from.sin_addr.s_addr) {
491                                 if (checkDecrypt(itr, encrypted, buffer, len)) {
492                                         msgType = Pds_get_numval(pds);
493                                         sessionId = Pds_get_numval(pds);
494                                         if (itr->sessionId == sessionId) { /* Found matching client */
495                                                 itr->key = key;
496                                                 Log_info("New UDP connection from %s port %d sessionId %d", inet_ntoa(from.sin_addr), ntohs(from.sin_port), sessionId);
497                                                 memcpy(&itr->remote_udp, &from, sizeof(struct sockaddr_in));
498                                                 break;
499                                         }
500                                 }
501                                 else Log_warn("Bad cryptstate from peer");
502                         }
503                 } /* while */
504         }
505         if (itr == NULL) {
506                 goto out;
507         }
508         len -= 4;
509         if (msgType != Speex && msgType != Ping)
510                 goto out;
511         
512         if (msgType == Ping) {
513                 Client_send_udp(itr, buffer, len);
514         }
515         else {
516                 Client_voiceMsg(itr, pds);
517         }
518         
519 out:
520         Pds_free(pds);
521         return 0;
522 }
523
524 static void Client_voiceMsg_tunnel(client_t *client, message_t *msg)
525 {
526         uint8_t buf[512];
527         pds_t *pds = Pds_create(buf, 512);
528
529         Pds_add_numval(pds, msg->messageType);
530         Pds_add_numval(pds, msg->sessionId);
531         Pds_add_numval(pds, msg->payload.speex.seq);
532         Pds_append_data_nosize(pds, msg->payload.speex.data, msg->payload.speex.size);
533         
534         Msg_free(msg);
535         
536         if (!pds->bOk)
537                 Log_warn("Large Speex message from TCP"); /* XXX - pds resize? */
538         pds->maxsize = pds->offset;
539         Client_voiceMsg(client, pds);
540         
541         Pds_free(pds);
542 }
543
544 static int Client_voiceMsg(client_t *client, pds_t *pds)
545 {
546         int seq, flags, msgType, sessionId, packetsize;
547         channel_t *ch = (channel_t *)client->channel;
548         struct dlist *itr;
549         
550         if (!client->authenticated || client->mute)
551                 return 0;
552
553         
554         pds->offset = 0;
555         msgType = Pds_get_numval(pds);
556         sessionId = Pds_get_numval(pds);
557         seq = Pds_get_numval(pds);
558         flags = Pds_get_numval(pds);
559
560         packetsize = 20 + 8 + 4 + pds->maxsize - pds->offset;
561         if (client->availableBandwidth - packetsize < 0)
562                 return 0; /* Discard */
563         
564         client->availableBandwidth -= packetsize;
565         
566         pds->offset = 0;
567         
568         if (flags & LoopBack) {
569                 Client_send_udp(client, pds->data, pds->maxsize);
570                 return 0;
571         }
572         if (ch == NULL)
573                 return 0;
574         
575         list_iterate(itr, &ch->clients) {
576                 client_t *c;
577                 c = list_get_entry(itr, client_t, chan_node);
578                 if (c != client && !c->deaf) {
579                         Client_send_udp(c, pds->data, pds->maxsize);
580                 }
581         }
582         return 0;
583 }
584
585
586 static int Client_send_udp(client_t *client, uint8_t *data, int len)
587 {
588         uint8_t *buf, *mbuf;
589         message_t *sendmsg;
590
591         if (client->remote_udp.sin_port != 0 && CryptState_isValid(&client->cryptState)) {
592 #if defined(__LP64__)
593                 buf = mbuf = malloc(len + 4 + 16);
594                 buf += 4;
595 #else
596                 mbuf = buf = malloc(len + 4);
597 #endif
598                 if (mbuf == NULL)
599                         Log_fatal("Out of memory");
600                 
601                 CryptState_encrypt(&client->cryptState, data, buf, len);
602                 
603                 sendto(udpsock, buf, len + 4, 0, (struct sockaddr *)&client->remote_udp, sizeof(struct sockaddr_in));
604                 
605                 free(mbuf);
606         } else {
607                 pds_t *pds = Pds_create(data, len);
608                 
609                 sendmsg = Msg_create(Pds_get_numval(pds));
610                 sendmsg->sessionId = Pds_get_numval(pds);
611                 
612                 if (sendmsg->messageType == Speex || sendmsg->messageType == Ping) {
613                         if (sendmsg->messageType == Speex) {
614                                 sendmsg->payload.speex.seq = Pds_get_numval(pds);
615                                 sendmsg->payload.speex.size = pds->maxsize - pds->offset;
616                                 doAssert(pds->maxsize - pds->offset <= SPEEX_DATA_SIZE);
617                                 memcpy(sendmsg->payload.speex.data, data + pds->offset, pds->maxsize - pds->offset);
618                         } else { /* Ping */
619                                 sendmsg->payload.ping.timestamp = Pds_get_numval(pds);
620                         }
621                         Client_send_message(client, sendmsg);
622                 } else {
623                         Log_warn("TCP fallback: Unsupported message type %d", sendmsg->messageType);
624                         Msg_free(sendmsg);
625                 }
626                 Pds_free(pds);
627         }
628         return 0;
629 }