This will make it easier to differentiate the options to those commands further in future. Signed-off-by: David Gibson <david(a)gibson.dropbear.id.au> --- test/nstool.c | 102 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 68 insertions(+), 34 deletions(-) diff --git a/test/nstool.c b/test/nstool.c index 7e069b6..9ea7eeb 100644 --- a/test/nstool.c +++ b/test/nstool.c @@ -11,6 +11,7 @@ #include <stdio.h> #include <stdlib.h> #include <string.h> +#include <stdbool.h> #include <errno.h> #include <unistd.h> #include <sys/socket.h> @@ -37,19 +38,55 @@ static void usage(void) " terminate.\n"); } -static void hold(int fd, const struct sockaddr_un *addr) +static int connect_ctl(const char * sockpath, bool wait) { + int fd = socket(AF_UNIX, SOCK_STREAM, PF_UNIX); + struct sockaddr_un addr = { + .sun_family = AF_UNIX, + }; int rc; - rc = bind(fd, (struct sockaddr *)addr, sizeof(*addr)); + if (fd < 0) + die("socket(): %s\n", strerror(errno)); + + strncpy(addr.sun_path, sockpath, UNIX_PATH_MAX); + + do { + rc = connect(fd, (struct sockaddr *)&addr, sizeof(addr)); + if (rc < 0 && + (!wait || (errno != ENOENT && errno != ECONNREFUSED))) + die("connect() to %s: %s\n", sockpath, strerror(errno)); + } while (rc < 0); + + return fd; +} + +static void cmd_hold(int argc, char *argv[]) +{ + int fd = socket(AF_UNIX, SOCK_STREAM, PF_UNIX); + struct sockaddr_un addr = { + .sun_family = AF_UNIX, + }; + const char *sockpath = argv[1]; + int rc; + + if (argc != 2) + usage(); + + if (fd < 0) + die("socket(): %s\n", strerror(errno)); + + strncpy(addr.sun_path, sockpath, UNIX_PATH_MAX); + + rc = bind(fd, (struct sockaddr *)&addr, sizeof(addr)); if (rc < 0) - die("bind(): %s\n", strerror(errno)); + die("bind() to %s: %s\n", sockpath, strerror(errno)); rc = listen(fd, 0); if (rc < 0) - die("listen(): %s\n", strerror(errno)); + die("listen() on %s: %s\n", sockpath, strerror(errno)); - printf("nstool: local PID=%d local UID=%u local GID=%u\n", + printf("nstool hold: local PID=%d local UID=%u local GID=%u\n", getpid(), getuid(), getgid()); do { int afd = accept(fd, NULL, NULL); @@ -63,71 +100,68 @@ static void hold(int fd, const struct sockaddr_un *addr) die("read(): %s\n", strerror(errno)); } while (rc == 0); - unlink(addr->sun_path); + unlink(sockpath); } -static void pid(int fd, const struct sockaddr_un *addr) +static void cmd_pid(int argc, char *argv[]) { - int rc; + const char *sockpath = argv[1]; struct ucred peercred; socklen_t optlen = sizeof(peercred); + int fd, rc; - do { - rc = connect(fd, (struct sockaddr *)addr, sizeof(*addr)); - if (rc < 0 && errno != ENOENT && errno != ECONNREFUSED) - die("connect(): %s\n", strerror(errno)); - } while (rc < 0); + if (argc != 2) + usage(); + + fd = connect_ctl(sockpath, true); rc = getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &peercred, &optlen); if (rc < 0) - die("getsockopet(SO_PEERCRED): %s\n", strerror(errno)); + die("getsockopet(SO_PEERCRED) %s: %s\n", + sockpath, strerror(errno)); close(fd); printf("%d\n", peercred.pid); } -static void stop(int fd, const struct sockaddr_un *addr) +static void cmd_stop(int argc, char *argv[]) { - int rc; + const char *sockpath = argv[1]; + int fd, rc; char buf = 'Q'; - rc = connect(fd, (struct sockaddr *)addr, sizeof(*addr)); - if (rc < 0) - die("connect(): %s\n", strerror(errno)); + if (argc != 2) + usage(); + + fd = connect_ctl(sockpath, false); rc = write(fd, &buf, sizeof(buf)); if (rc < 0) - die("write(): %s\n", strerror(errno)); + die("write() to %s: %s\n", sockpath, strerror(errno)); close(fd); } int main(int argc, char *argv[]) { + const char *subcmd = argv[1]; int fd; - const char *sockname; - struct sockaddr_un sockaddr = { - .sun_family = AF_UNIX, - }; - if (argc != 3) + if (argc < 2) usage(); - sockname = argv[2]; - strncpy(sockaddr.sun_path, sockname, UNIX_PATH_MAX); - fd = socket(AF_UNIX, SOCK_STREAM, PF_UNIX); if (fd < 0) die("socket(): %s\n", strerror(errno)); - if (strcmp(argv[1], "hold") == 0) - hold(fd, &sockaddr); - else if (strcmp(argv[1], "pid") == 0) - pid(fd, &sockaddr); - else if (strcmp(argv[1], "stop") == 0) - stop(fd, &sockaddr); + if (strcmp(subcmd, "hold") == 0) + cmd_hold(argc - 1, argv + 1); + else if (strcmp(subcmd, "pid") == 0) + cmd_pid(argc - 1, argv + 1); + else if (strcmp(subcmd, "stop") == 0) + cmd_stop(argc - 1, argv + 1); else usage(); -- 2.39.2