#include #include #include #include #include #include #include #include #include #include #include #include #include #include #ifdef USE_SHA1 # include #else # define SHA_DIGEST_LENGTH 20 unsigned char *SHA1 (const unsigned char *d, unsigned long n, unsigned char *md) { d = d; n = n; memset(md, 0, SHA_DIGEST_LENGTH); return md; } #endif #define PROTO_VERSION 1 struct helo { uint32_t version; uint32_t counter; uint32_t hosts; char hostname[64]; unsigned char sha1[20]; } __attribute__ ((packed)); #if defined(__linux__) || defined(__APPLE__) int adjust_carp (uint32_t ho, uint32_t *oho, long long ms, long long *oms) { ho = ho; oho = oho; ms = ms; oms = oms; return 0; } #define setproctitle(fmt, args ...) do { \ int _spt_len; \ char **_spt_c = (envp ? envp : argv); \ for (; *_spt_c; ++_spt_c) if (!*(_spt_c + 1)) break; \ while (**_spt_c) ++*_spt_c; \ _spt_len = *_spt_c - *argv; \ memset(*argv, 0, _spt_len); \ snprintf(*argv, _spt_len - 1, "mhelod: " fmt, ## args); \ } while (0) #else int adjust_carp (uint32_t ho, uint32_t *oho, long long ms, long long *oms) { static const char *carp_hosts = "net.inet.carp.iplb_hosts", *carp_host = "net.inet.carp.iplb_host"; if (ho != *oho) { syslog(LOG_NOTICE, "adjusting hosts %u -> %u via sysctl\n", *oho, ho); *oho = ho; if (sysctlbyname(carp_hosts, NULL, 0, &ho, sizeof(ho)) < 0) return -1; } if (ms != *oms) { syslog(LOG_NOTICE, "adjusting myself %lli -> %lli via sysctl\n", *oms, ms); *oms = ms; if (sysctlbyname(carp_host, NULL, 0, &ms, sizeof(ms)) < 0) return -1; } return 0; } #endif /* opens udp socket, sets mcast loop + ttl */ int mcast_socket () { int s, opt; if ((s = socket(AF_INET, SOCK_DGRAM, 0)) < 0) return -1; opt = 1; if (setsockopt(s, IPPROTO_IP, IP_MULTICAST_LOOP, &opt, sizeof(opt))) return -1; opt = 1; if (setsockopt(s, IPPROTO_IP, IP_MULTICAST_TTL, &opt, sizeof(opt))) return -1; return s; } /* fills h, computes SHA1 with pass as hidden salt */ void make_helo (struct helo *h, uint32_t counter, char *hostname, uint32_t hosts, char *pass) { struct { struct helo h; char pass[64]; } sec; unsigned char md[SHA_DIGEST_LENGTH]; memset(h, 0, sizeof(*h)); h->counter = counter; strncpy(h->hostname, hostname, sizeof(h->hostname)); h->hosts = hosts; h->version = htonl(PROTO_VERSION); memset(&sec, 0, sizeof(sec)); memcpy(&sec.h, h, sizeof(*h)); strncpy(sec.pass, pass, sizeof(sec.pass)); SHA1((unsigned char*) &sec, sizeof(sec), md); memcpy(h->sha1, md, sizeof(md)); } int running = 1; void kill_handler (int signum) { signum = signum; syslog(LOG_INFO, "%s", "received sigkill, going down..."); running = 0; } int main (int argc, char **argv, char **envp) { int ch, helo_interval = 1, dead_interval = 3, stats_interval = 600; /* sec */ char hostname[64], pass[64] = "test"; uint32_t counter = 0, hosts = 0, old_hosts = 0; long long myself, old_myself = -1; struct helo h, h2; struct list { struct list *next; time_t lasthelo; struct helo h; } *peers = NULL, *walk, *prev, *add_pos_prev, *add_pos_next; int client_socket, server_socket, found; time_t lasthelosend = 0, now, laststatsprint, starttime; struct sockaddr_in sin, recv_sin; struct ip_mreq ipm; ssize_t ret; struct timeval tv; fd_set rfd; pid_t pid; struct sigaction sa; envp = envp; /* mcast address */ memset(&sin, 0, sizeof(sin)); sin.sin_addr.s_addr = inet_addr("224.66.66.66"); sin.sin_port = htons(6666); sin.sin_family = PF_INET; /* parse command line */ while ((ch = getopt(argc, argv, "a:p:x:e:d:s:h")) != -1) { switch (ch) { case 'a': sin.sin_addr.s_addr = inet_addr(optarg); break; case 'p': sin.sin_port = htons(atoi(optarg)); break; case 'x': strncpy(pass, optarg, sizeof(pass)); break; case 'e': helo_interval = atoi(optarg); break; case 'd': dead_interval = atoi(optarg); break; case 's': stats_interval = atoi(optarg); break; default: fputs("usage: mhelod [-a
] [-p ] [-x ]\n" " [-e ] [-d ] [-s ]\n" " [-h]\n\n", stderr); return -1; } } /* overwrite commandline */ setproctitle("%s:%i", inet_ntoa(sin.sin_addr), sin.sin_port); /* be honest, save descriptors */ close(0); close(1); /* server */ memset(&ipm, 0, sizeof(ipm)); ipm.imr_multiaddr.s_addr = sin.sin_addr.s_addr; ipm.imr_interface.s_addr = htonl(INADDR_ANY); if ((server_socket = mcast_socket()) < 0) { perror("server_socket"); return -1; } if (bind(server_socket, (struct sockaddr*) &sin, sizeof(sin))) { perror("bind()"); return -1; } if (setsockopt(server_socket, IPPROTO_IP, IP_ADD_MEMBERSHIP, &ipm, sizeof(ipm))) { perror("cannot join multicast group"); return -1; } /* client */ if ((client_socket = mcast_socket()) < 0) { perror("client_socket"); return -1; } if (gethostname(hostname, sizeof(hostname)) < 0) { perror("gethostname()"); return -1; } hostname[sizeof(hostname) - 1] = '\0'; /* daemonize */ pid = fork(); if (pid < 0) { perror("fork()"); return -1; } if (pid > 0) return 0; // TODO: pidfile memset(&sa, 0, sizeof(sa)); sa.sa_handler = kill_handler; if (sigaction(SIGTERM, &sa, NULL) < 0) { perror("sigaction()"); return -1; } if (chdir("/") < 0) { perror("chdir()"); return -1; } if (setsid() < 0) { perror("setsid()"); return -1; } close(2); openlog("mhelod", LOG_PID, LOG_DAEMON); syslog(LOG_INFO, "%s", "starting..."); starttime = laststatsprint = time(NULL); while (running) { now = time(NULL); if (now >= lasthelosend + helo_interval) { /* send helo */ lasthelosend = now; make_helo(&h, htonl(counter), hostname, htonl(hosts), pass); ret = sendto(client_socket, &h, sizeof(h), 0, (struct sockaddr*) &sin, sizeof(sin)); if (ret < 0) { syslog(LOG_ERR, "%s(): %s\n", "sendto", strerror(errno)); break; } if ((unsigned) ret != sizeof(h)) { syslog(LOG_ERR, "sendto() sent %i of %i bytes\n", ret, sizeof(h)); break; } counter++; } /* listen for multicast */ FD_ZERO(&rfd); FD_SET(server_socket, &rfd); tv.tv_sec = helo_interval; tv.tv_usec = 0; ret = select(server_socket + 1, &rfd, NULL, NULL, &tv); if (ret < 0) { syslog(LOG_ERR, "%s(): %s\n", "select", strerror(errno)); break; } if (ret > 1) { syslog(LOG_ERR, "%s\n", "select() madness"); break; } if (ret != 1) continue; /* receive packet */ socklen_t len = sizeof(recv_sin); ret = recvfrom(server_socket, &h, sizeof(h), 0, (struct sockaddr*) &recv_sin, &len); if (ret < 0) { syslog(LOG_ERR, "%s(): %s\n", "recvfrom", strerror(errno)); break; } if ((unsigned) ret > sizeof(h)) { syslog(LOG_ERR, "%s\n", "recvfrom() madness"); break; } if ((unsigned) ret < sizeof(h)) { syslog(LOG_NOTICE, "recv %i, not %i bytes from %s\n", ret, sizeof(h), inet_ntoa(recv_sin.sin_addr)); continue; } /* check protocol version */ if (ntohl(h.version) != PROTO_VERSION) { syslog(LOG_NOTICE, "proto mismatch: %s is using version %i, not %i\n", inet_ntoa(recv_sin.sin_addr), ntohl(h.version), PROTO_VERSION); continue; } /* process packet */ make_helo(&h2, h.counter, h.hostname, h.hosts, pass); if (memcmp(&h, &h2, sizeof(h))) { syslog(LOG_NOTICE, "auth mismatch from %s\n", inet_ntoa(recv_sin.sin_addr)); continue; } h.counter = ntohl(h.counter); h.hosts = ntohl(h.hosts); h.hostname[sizeof(h.hostname) - 1] = '\0'; if (h.hosts != hosts) syslog(LOG_WARNING, "%s (%s) sees %i instead of %i hosts", inet_ntoa(recv_sin.sin_addr), h.hostname, h.hosts, hosts); /* check peers */ walk = peers; prev = NULL; add_pos_next = peers; add_pos_prev = NULL; found = 0; hosts = 0; myself = -1; while (walk) { if (walk->lasthelo + dead_interval < now) { /* delete peer if no action for 3 seconds */ syslog(LOG_NOTICE, "peer %s: timeout\n", walk->h.hostname); if (prev == NULL) { peers = walk->next; free(walk); walk = peers; } else { prev->next = walk->next; free(walk); walk = prev->next; } continue; } ret = strcmp(walk->h.hostname, h.hostname); if (!ret) { /* already known peer */ if (walk->h.counter + 1 != h.counter) syslog(LOG_NOTICE, "peer %s: expected sequence %i, got %i\n", h.hostname, walk->h.counter + 1, h.counter); walk->lasthelo = now; memcpy(&walk->h, &h, sizeof(h)); found = 1; } /* mcast loop is enabled by default, so i get my own stuff, too */ if (!strcmp(walk->h.hostname, hostname)) { if (myself < 0) myself = hosts; else syslog(LOG_WARNING, "some host is using my hostname (%s)", hostname); } hosts++; prev = walk; walk = walk->next; if (ret < 0) { add_pos_prev = prev; add_pos_next = walk; } } if (!found) { /* received packet contains unknown peer, add it to the list */ struct list *n = malloc(sizeof(struct list)); if (!n) { syslog(LOG_ERR, "%s", "out of memory!"); break; } n->next = add_pos_next; n->lasthelo = now; memcpy(&n->h, &h, sizeof(h)); if (add_pos_prev) add_pos_prev->next = n; else peers = n; if (!strcmp(n->h.hostname, hostname)) myself = hosts; hosts++; syslog(LOG_NOTICE, "new peer: %s\n", h.hostname); } /* sanity checks */ if (myself < 0) syslog(LOG_WARNING, "%s", "don't know myself"); if (hosts <= 0) syslog(LOG_WARNING, "%s", "don't know any hosts"); /* sysctl */ if (adjust_carp(hosts, &old_hosts, myself, &old_myself) < 0) { syslog(LOG_ERR, "%s(): %s", "adjust_carp", strerror(errno)); break; } if (now >= laststatsprint + stats_interval) { /* print some stats to syslog */ laststatsprint = now; syslog(LOG_INFO, "running for %li seconds, %u hosts, i am host #%lli", (long)(now - starttime), hosts, myself); } } /* got here by break or signal */ if (adjust_carp(0, &old_hosts, -1, &old_myself) < 0) syslog(LOG_WARNING, "%s(): %s", "adjust_carp", strerror(errno)); if (setsockopt(server_socket, IPPROTO_IP, IP_DROP_MEMBERSHIP, &ipm, sizeof(ipm)) < 0) syslog(LOG_WARNING, "cannot leave multicast group: %s", strerror(errno)); close(server_socket); close(client_socket); syslog(LOG_INFO, "%s", "exiting..."); return 0; }