root/plugin/ipc/ssh/dmtcp_ssh.cpp

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. getport
  2. createStdioFds
  3. openListenSocket
  4. signal_handler
  5. waitForConnection
  6. main

   1 #include <unistd.h>
   2 #include <sys/stat.h>
   3 #include <sys/fcntl.h>
   4 #include <sys/types.h>
   5 #include <sys/wait.h>
   6 #include <sys/errno.h>
   7 #include <sys/socket.h>
   8 #include <linux/limits.h>
   9 #include <netinet/in.h>
  10 #include <arpa/inet.h>
  11 #include <stdio.h>
  12 #include <stdlib.h>
  13 #include <string.h>
  14 #include <assert.h>
  15 #include "ssh.h"
  16 
  17 static int listenSock = -1;
  18 static int noStrictHostKeyChecking = 0;
  19 
  20 extern "C" void dmtcp_get_local_ip_addr(struct in_addr *addr) __attribute((weak));
  21 
  22 //static bool strEndsWith(const char *str, const char *pattern)
  23 //{
  24 //  assert(str != NULL && pattern != NULL);
  25 //  int len1 = strlen(str);
  26 //  int len2 = strlen(pattern);
  27 //  if (len1 >= len2) {
  28 //    size_t idx = len1 - len2;
  29 //    return strncmp(str+idx, pattern, len2) == 0;
  30 //  }
  31 //  return false;
  32 //}
  33 
  34 static int getport(int fd)
  35 {
  36   struct sockaddr_in addr;
  37   socklen_t addrlen = sizeof(addr);
  38   if (getsockname(fd, (struct sockaddr *)&addr, &addrlen) == -1) {
  39     return -1;
  40   }
  41   return (int)ntohs(addr.sin_port);
  42 }
  43 
  44 static void createStdioFds(int *in, int *out, int *err)
  45 {
  46   struct stat buf;
  47   if (fstat(STDIN_FILENO,  &buf)  == -1) {
  48     int fd = open("/dev/null", O_RDWR);
  49     if (fd != STDIN_FILENO) {
  50       dup2(fd, STDIN_FILENO);
  51       close(fd);
  52     }
  53   }
  54   if (fstat(STDOUT_FILENO,  &buf)  == -1) {
  55     int fd = open("/dev/null", O_RDWR);
  56     if (fd != STDOUT_FILENO) {
  57       dup2(fd, STDOUT_FILENO);
  58       close(fd);
  59     }
  60   }
  61   if (fstat(STDERR_FILENO,  &buf)  == -1) {
  62     int fd = open("/dev/null", O_RDWR);
  63     if (fd != STDERR_FILENO) {
  64       dup2(fd, STDERR_FILENO);
  65       close(fd);
  66     }
  67   }
  68 
  69   // Close all open file descriptors
  70   int maxfd = sysconf(_SC_OPEN_MAX);
  71   for (int i = 3; i < maxfd; i++) {
  72     close(i);
  73   }
  74 
  75   if (pipe(in) != 0) {
  76     perror("Error creating pipe: ");
  77   }
  78   if (pipe(out) != 0) {
  79     perror("Error creating pipe: ");
  80   }
  81   if (pipe(err) != 0) {
  82     perror("Error creating pipe: ");
  83   }
  84 }
  85 
  86 static int openListenSocket()
  87 {
  88   struct sockaddr_in saddr;
  89   int sock = socket(AF_INET, SOCK_STREAM, 0);
  90   if (sock == -1) {
  91     perror("Error creating socket: ");
  92   }
  93   memset(&saddr, 0, sizeof(saddr));
  94   saddr.sin_family = AF_INET;
  95   saddr.sin_addr.s_addr = INADDR_ANY;
  96   saddr.sin_port = 0;
  97   if (bind(sock, (struct sockaddr*) &saddr, sizeof saddr) == -1) {
  98     perror("Error binding socket");
  99   }
 100 
 101   if (listen(sock, 1) == -1) {
 102     perror("Error binding socket");
 103   }
 104   return sock;
 105 }
 106 
 107 static void signal_handler(int sig)
 108 {
 109   if (sig == SIGCHLD) {
 110     int status;
 111     wait(&status);
 112     exit(status);
 113   }
 114 }
 115 
 116 static int waitForConnection(int listenSock)
 117 {
 118   int fd = accept(listenSock, NULL, NULL);
 119   if (fd == -1) {
 120     perror("accept failed:");
 121     abort();
 122     exit(0);
 123   }
 124   close(listenSock);
 125   return fd;
 126 }
 127 
 128 int main(int argc, char *argv[], char *envp[])
 129 {
 130   int in[2], out[2], err[2];
 131   int status;
 132   int ssh_stdinfd, ssh_stdoutfd, ssh_stderrfd;
 133 
 134   if (argc < 2) {
 135     printf("***ERROR: This program shouldn't be used directly.\n");
 136     exit(1);
 137   }
 138 
 139   if (strcmp(argv[1], "--noStrictHostKeyChecking") == 0) {
 140     noStrictHostKeyChecking = 1;
 141     argv++;
 142   }
 143 
 144   createStdioFds(in, out, err);
 145   listenSock = openListenSocket();
 146   signal(SIGCHLD, signal_handler);
 147 
 148   pid_t sshChildPid = fork();
 149   if (sshChildPid == 0) {
 150     char buf[PATH_MAX + 80];
 151     char hostname[80];
 152     int port = getport(listenSock);
 153     close(listenSock);
 154 
 155     close(in[1]);
 156     close(out[0]);
 157     close(err[0]);
 158     dup2(in[0], STDIN_FILENO);
 159     dup2(out[1], STDOUT_FILENO);
 160     dup2(err[1], STDERR_FILENO);
 161 
 162     unsetenv("LD_PRELOAD");
 163 
 164     // Replace dmtcp_sshd replace with "dmtcp_sshd --host <host> --port <port>"
 165     struct in_addr saddr;
 166     if (dmtcp_get_local_ip_addr == NULL) {
 167       printf("ERROR: Unable to find dmtcp_get_local_ip_addr.\n");
 168       abort();
 169     }
 170     dmtcp_get_local_ip_addr(&saddr);
 171     char *hostip = inet_ntoa(saddr);
 172     strcpy(hostname, hostip);
 173 
 174     size_t i = 0;
 175     while (argv[i] != NULL) {
 176       // "dmtcp_sshd" may be embedded deep inside the command line.
 177       char *ptr = strstr(argv[i], SSHD_BINARY);
 178       if (ptr != NULL) {
 179         ptr += strlen(SSHD_BINARY);
 180         if (*ptr != '\0') {
 181           *ptr = '\0';
 182           ptr++;
 183         }
 184         sprintf(buf, "%s --host %s --port %d %s",
 185                 argv[i], hostip, port, ptr);
 186         argv[i] = buf;
 187       }
 188       i++;
 189     }
 190     execvp(argv[1], &argv[1]);
 191     printf("%s:%d DMTCP Error detected. Failed to exec.", __FILE__, __LINE__);
 192     abort();
 193   }
 194 
 195   int childSock = waitForConnection(listenSock);
 196 
 197   close(in[0]);
 198   close(out[1]);
 199   close(err[1]);
 200 
 201   ssh_stdinfd = in[1];
 202   ssh_stdoutfd = out[0];
 203   ssh_stderrfd = err[0];
 204 
 205   assert(dmtcp_ssh_register_fds != NULL);
 206   dmtcp_ssh_register_fds(false, ssh_stdinfd, ssh_stdoutfd, ssh_stderrfd,
 207                          childSock, noStrictHostKeyChecking);
 208 
 209   client_loop(ssh_stdinfd, ssh_stdoutfd, ssh_stderrfd, childSock);
 210   wait(&status);
 211   return status;
 212 }

/* [<][>][^][v][top][bottom][index][help] */