root/plugin/ipc/ssh/ssh.cpp

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. dmtcp_SSH_EventHook
  2. drain
  3. refill
  4. receiveFileDescr
  5. sshdReceiveFds
  6. createNewDmtcpSshdProcess
  7. dmtcp_ssh_register_fds
  8. prepareForExec
  9. updateCoordHost
  10. execve
  11. execvp
  12. execvpe

   1 #include <sys/syscall.h>
   2 #include <sys/socket.h>
   3 #include <netinet/in.h>
   4 #include <arpa/inet.h>
   5 #include <sys/un.h>
   6 #include "dmtcp.h"
   7 #include "util.h"
   8 #include "util_ipc.h"
   9 #include "jassert.h"
  10 #include "jfilesystem.h"
  11 #include "ipc.h"
  12 #include "ssh.h"
  13 #include "sshdrainer.h"
  14 #include "shareddata.h"
  15 
  16 using namespace dmtcp;
  17 
  18 #define SSHD_PIPE_FD -1
  19 
  20 static string cmd;
  21 static string prefix;
  22 static string dmtcp_launch_path;
  23 static string dmtcp_ssh_path;
  24 static string dmtcp_sshd_path;
  25 static string dmtcp_nocheckpoint_path;
  26 
  27 static SSHDrainer *theDrainer = NULL;
  28 static int sshStdin = -1;
  29 static int sshStdout = -1;
  30 static int sshStderr = -1;
  31 static int sshSockFd = -1;
  32 static bool isSshdProcess = false;
  33 static int noStrictHostKeyChecking = 0;
  34 
  35 static bool sshPluginEnabled = false;
  36 
  37 extern "C" void process_fd_event(int event, int arg1, int arg2 = -1);
  38 static void drain();
  39 static void refill(bool isRestart);
  40 static void sshdReceiveFds();
  41 static void createNewDmtcpSshdProcess();
  42 
  43 void dmtcp_SSH_EventHook(DmtcpEvent_t event, DmtcpEventData_t *data)
  44 {
  45   if (!sshPluginEnabled) return;
  46   switch (event) {
  47     case DMTCP_EVENT_DRAIN:
  48       drain();
  49       break;
  50 
  51     case DMTCP_EVENT_THREADS_RESUME:
  52       refill(data->refillInfo.isRestart);
  53       break;
  54 
  55     default:
  56       break;
  57   }
  58 }
  59 
  60 static void drain()
  61 {
  62   JASSERT(theDrainer == NULL);
  63   theDrainer = new SSHDrainer();
  64   if (isSshdProcess) { // dmtcp_ssh process
  65     theDrainer->beginDrainOf(STDIN_FILENO, sshStdin);
  66     theDrainer->beginDrainOf(STDOUT_FILENO);
  67     theDrainer->beginDrainOf(STDERR_FILENO);
  68   } else {
  69     theDrainer->beginDrainOf(sshStdin);
  70     theDrainer->beginDrainOf(sshStdout, STDOUT_FILENO);
  71     theDrainer->beginDrainOf(sshStderr, STDERR_FILENO);
  72   }
  73   theDrainer->monitorSockets(DRAINER_CHECK_FREQ);
  74 }
  75 
  76 static void refill(bool isRestart)
  77 {
  78   if (isRestart) {
  79     if (isSshdProcess) { // dmtcp_sshd
  80       sshdReceiveFds();
  81     } else { // dmtcp_ssh
  82       createNewDmtcpSshdProcess();
  83     }
  84   }
  85 
  86   theDrainer->refill();
  87   // Free up the object
  88   delete theDrainer;
  89   theDrainer = NULL;
  90 }
  91 
  92 static void receiveFileDescr(int fd)
  93 {
  94   int data;
  95   int ret = Util::receiveFd(SSHD_RECEIVE_FD, &data, sizeof(data));
  96   if (fd == SSHD_PIPE_FD) {
  97     return;
  98   }
  99   JASSERT(data == fd) (data) (fd);
 100   if (fd != ret) {
 101     _real_close(fd);
 102     _real_dup2(ret, fd);
 103     _real_close(ret);
 104   }
 105 }
 106 
 107 static void sshdReceiveFds()
 108 {
 109   // Add receive-fd data socket.
 110   static struct sockaddr_un fdReceiveAddr;
 111   static socklen_t         fdReceiveAddrLen;
 112 
 113   memset(&fdReceiveAddr, 0, sizeof(fdReceiveAddr));
 114   jalib::JSocket sock(_real_socket(AF_UNIX, SOCK_DGRAM, 0));
 115   JASSERT(sock.isValid());
 116   sock.changeFd(SSHD_RECEIVE_FD);
 117   fdReceiveAddr.sun_family = AF_UNIX;
 118   JASSERT(_real_bind(SSHD_RECEIVE_FD,
 119                      (struct sockaddr*) &fdReceiveAddr,
 120                      sizeof(fdReceiveAddr.sun_family)) == 0) (JASSERT_ERRNO);
 121 
 122   fdReceiveAddrLen = sizeof(fdReceiveAddr);
 123   JASSERT(getsockname(SSHD_RECEIVE_FD,
 124                       (struct sockaddr *)&fdReceiveAddr,
 125                       &fdReceiveAddrLen) == 0);
 126 
 127   // Send this information to dmtcp_ssh process
 128   ssize_t ret = write(sshSockFd, &fdReceiveAddrLen, sizeof(fdReceiveAddrLen));
 129   JASSERT(ret == sizeof(fdReceiveAddrLen)) (sshSockFd) (ret) (JASSERT_ERRNO);
 130   ret = write(sshSockFd, &fdReceiveAddr, fdReceiveAddrLen);
 131   JASSERT(ret == (ssize_t) fdReceiveAddrLen);
 132 
 133   // Now receive fds
 134   receiveFileDescr(STDIN_FILENO);
 135   receiveFileDescr(STDOUT_FILENO);
 136   receiveFileDescr(STDERR_FILENO);
 137   receiveFileDescr(SSHD_PIPE_FD);
 138   _real_close(SSHD_RECEIVE_FD);
 139 }
 140 
 141 static void createNewDmtcpSshdProcess()
 142 {
 143   struct sockaddr_un addr;
 144   socklen_t          addrLen;
 145   static char abstractSockName[20];
 146   int in[2], out[2], err[2];
 147 
 148   ssize_t ret = read(sshSockFd, &addrLen, sizeof(addrLen));
 149   JASSERT(ret == sizeof(addrLen));
 150   memset(&addr, 0, sizeof(addr));
 151   ret = read(sshSockFd, &addr, addrLen);
 152   JASSERT(ret == (ssize_t) addrLen);
 153   JASSERT(strlen(&addr.sun_path[1]) < sizeof(abstractSockName));
 154   strcpy(abstractSockName, &addr.sun_path[1]);
 155 
 156   struct sockaddr_in sshdSockAddr;
 157   socklen_t sshdSockAddrLen = sizeof(sshdSockAddr);
 158   char remoteHost[80];
 159   JASSERT(getpeername(sshSockFd, (struct sockaddr*)&sshdSockAddr,
 160                       &sshdSockAddrLen) == 0);
 161   char *ip = inet_ntoa(sshdSockAddr.sin_addr);
 162   strcpy(remoteHost, ip);
 163 
 164   if (dmtcp_nocheckpoint_path.length() == 0) {
 165     dmtcp_nocheckpoint_path = Util::getPath("dmtcp_nocheckpoint");
 166     dmtcp_sshd_path = Util::getPath("dmtcp_sshd");
 167   }
 168 
 169   JASSERT(pipe(in) == 0) (JASSERT_ERRNO);
 170   JASSERT(pipe(out) == 0) (JASSERT_ERRNO);
 171   JASSERT(pipe(err) == 0) (JASSERT_ERRNO);
 172 
 173   pid_t sshChildPid = fork();
 174   JASSERT(sshChildPid != -1);
 175   if (sshChildPid == 0) {
 176     const int max_args = 16;
 177     char *argv[16];
 178     int idx = 0;
 179 
 180     argv[idx++] = (char*) dmtcp_nocheckpoint_path.c_str();
 181     argv[idx++] = const_cast<char*>("ssh");
 182     if (noStrictHostKeyChecking) {
 183       argv[idx++] = const_cast<char*>("-o");
 184       argv[idx++] = const_cast<char*>("StrictHostKeyChecking=no");
 185     }
 186     argv[idx++] = remoteHost;
 187     argv[idx++] = (char*) dmtcp_sshd_path.c_str();
 188     argv[idx++] = const_cast<char*>("--listenAddr");
 189     argv[idx++] = abstractSockName;
 190     argv[idx++] = NULL;
 191     JASSERT(idx < max_args) (idx);
 192 
 193     process_fd_event(SYS_close, in[1]);
 194     process_fd_event(SYS_close, out[0]);
 195     process_fd_event(SYS_close, err[0]);
 196     dup2(in[0], STDIN_FILENO);
 197     dup2(out[1], STDOUT_FILENO);
 198     dup2(err[1], STDERR_FILENO);
 199 
 200     JTRACE("Launching ") (argv[0]) (argv[1]) (argv[2]) (argv[3]) (argv[4]) (argv[5]);
 201     _real_execvp(argv[0], argv);
 202     JASSERT(false);
 203   }
 204 
 205   dup2(in[1],  500 + sshStdin);
 206   dup2(out[0], 500 + sshStdout);
 207   dup2(err[0], 500 + sshStderr);
 208 
 209   close(in[0]);
 210   close(in[1]);
 211   close(out[0]);
 212   close(out[1]);
 213   close(err[0]);
 214   close(err[1]);
 215 
 216   dup2(500 + sshStdin, sshStdin);
 217   dup2(500 + sshStdout, sshStdout);
 218   dup2(500 + sshStderr, sshStderr);
 219   close(500 + sshStdin);
 220   close(500 + sshStdout);
 221   close(500 + sshStderr);
 222 
 223   process_fd_event(SYS_close, sshStdin);
 224   process_fd_event(SYS_close, sshStdout);
 225   process_fd_event(SYS_close, sshStderr);
 226 }
 227 
 228 extern "C" void dmtcp_ssh_register_fds(int isSshd, int in, int out, int err,
 229                                        int sock, int noStrictChecking)
 230 {
 231   if (isSshd) { // dmtcp_sshd
 232     process_fd_event(SYS_close, STDIN_FILENO);
 233     process_fd_event(SYS_close, STDOUT_FILENO);
 234     process_fd_event(SYS_close, STDERR_FILENO);
 235   } else { // dmtcp_ssh
 236     process_fd_event(SYS_close, in);
 237     process_fd_event(SYS_close, out);
 238     process_fd_event(SYS_close, err);
 239   }
 240   sshStdin = in;
 241   sshStdout = out;
 242   sshStderr = err;
 243   sshSockFd = sock;
 244   isSshdProcess = isSshd;
 245   sshPluginEnabled = true;
 246   noStrictHostKeyChecking = noStrictChecking;
 247 }
 248 
 249 static void prepareForExec(char *const argv[], char ***newArgv)
 250 {
 251   size_t nargs = 0;
 252   bool noStrictChecking = false;
 253   string precmd, postcmd, tempcmd;
 254   while (argv[nargs++] != NULL);
 255 
 256   if (nargs < 3) {
 257     JNOTE("ssh with less than 3 args") (argv[0]) (argv[1]);
 258     *newArgv = (char**) argv;
 259     return;
 260   }
 261 
 262   //find command part
 263   size_t commandStart = 2;
 264   for (size_t i = 1; i < nargs; ++i) {
 265     string s = argv[i];
 266     if (strcmp(argv[i], "-o") == 0) {
 267       if (strcmp(argv[i+1], "StrictHostKeyChecking=no") == 0) {
 268         noStrictChecking = true;
 269       }
 270       i++;
 271       continue;
 272     }
 273 
 274     // The following flags have additional parameters and aren't fully
 275     // supported. We simply forward them to the ssh command.
 276     if (s == "-b" || s == "-c" || s == "-E" || s == "-e" || s == "-F" ||
 277         s == "-I" || s == "-i" || s == "-l" || s == "-O" || s == "-o" ||
 278         s == "-p" || s == "-Q" || s == "-S") {
 279       i++;
 280       continue;
 281     }
 282 
 283     // These options have a higer probability of failure due to binding
 284     // addresses, etc.
 285     if (s == "-b" || s == "-D" || s == "-L" || s == "-m" || s == "-R" ||
 286         s == "-W" || s == "-w") {
 287       JNOTE("The '" + s + "' ssh option isn't fully supported!");
 288       i++;
 289       continue;
 290     }
 291 
 292     if (argv[i][0] != '-') {
 293       commandStart = i + 1;
 294       break;
 295     }
 296   }
 297   JASSERT(commandStart < nargs && argv[commandStart][0] != '-')
 298     (commandStart) (nargs) (argv[commandStart])
 299     .Text("failed to parse ssh command line");
 300 
 301   vector<string> dmtcp_args;
 302   Util::getDmtcpArgs(dmtcp_args);
 303 
 304   dmtcp_launch_path = Util::getPath("dmtcp_launch");
 305   dmtcp_ssh_path = Util::getPath("dmtcp_ssh");
 306   dmtcp_sshd_path = Util::getPath("dmtcp_sshd");
 307   dmtcp_nocheckpoint_path = Util::getPath("dmtcp_nocheckpoint");
 308 
 309   prefix = dmtcp_launch_path + " ";
 310   for(size_t i = 0; i < dmtcp_args.size(); i++){
 311     prefix += dmtcp_args[i] + " ";
 312   }
 313   prefix += dmtcp_sshd_path + " ";
 314   JTRACE("Prefix")(prefix);
 315 
 316   // process command
 317   size_t semipos, pos;
 318   size_t actpos = string::npos;
 319   tempcmd = argv[commandStart];
 320   for(semipos = 0; (pos = tempcmd.find(';',semipos+1)) != string::npos;
 321       semipos = pos, actpos = pos);
 322 
 323   if (actpos > 0 && actpos != string::npos) {
 324     precmd = tempcmd.substr(0, actpos + 1);
 325     postcmd = tempcmd.substr(actpos + 1);
 326     postcmd = postcmd.substr(postcmd.find_first_not_of(" "));
 327   } else {
 328     precmd = "";
 329     postcmd = tempcmd;
 330   }
 331 
 332   cmd = precmd;
 333   // convert "exec cmd" to "exec <dmtcp-prefix> cmd"
 334   if (Util::strStartsWith(postcmd, "exec")) {
 335     cmd += "exec " + prefix + postcmd.substr(strlen("exec"));
 336   } else {
 337     cmd += prefix + postcmd;
 338   }
 339 
 340   //now repack args
 341   char** new_argv = (char**) JALLOC_HELPER_MALLOC(sizeof(char*) * (nargs + 10));
 342   memset(new_argv, 0, sizeof(char*) * (nargs + 10));
 343 
 344   size_t idx = 0;
 345   new_argv[idx++] = (char*) dmtcp_ssh_path.c_str();
 346   if (noStrictChecking) {
 347     new_argv[idx++] = const_cast<char*>("--noStrictHostKeyChecking");
 348   }
 349   new_argv[idx++] = (char*) dmtcp_nocheckpoint_path.c_str();
 350 
 351   string newCommand = string(new_argv[0]) + " " + string(new_argv[1]) + " ";
 352   for (size_t i = 0; i < commandStart; ++i) {
 353     new_argv[idx++] = ( char* ) argv[i];
 354     if (argv[i] != NULL) {
 355       newCommand += argv[i];
 356       newCommand += ' ';
 357     }
 358   }
 359   new_argv[idx++] = (char*) cmd.c_str();
 360   newCommand += cmd + " ";
 361 
 362   for (size_t i = commandStart + 1; i < nargs; ++i) {
 363     new_argv[idx++] = (char*) argv[i];
 364     if (argv[i] != NULL) {
 365       newCommand += argv[i];
 366       newCommand += ' ';
 367     }
 368   }
 369   JNOTE("New ssh command") (newCommand);
 370   *newArgv = new_argv;
 371   return;
 372 }
 373 
 374 // This code is copied from dmtcp_coordinator.cpp:calLocalAddr()
 375 static void updateCoordHost() {
 376   if (SharedData::coordHost() != "127.0.0.1")  return;
 377 
 378   struct in_addr localhostIPAddr;
 379   string cmd;
 380   char hostname[HOST_NAME_MAX];
 381   JASSERT(gethostname(hostname, sizeof hostname) == 0) (JASSERT_ERRNO);
 382   struct addrinfo *result;
 383   struct addrinfo *res;
 384   int error;
 385   struct addrinfo hints;
 386 
 387   memset(&localhostIPAddr, 0, sizeof localhostIPAddr);
 388   memset(&hints, 0, sizeof(struct addrinfo));
 389   hints.ai_family = AF_INET;
 390   hints.ai_socktype = SOCK_STREAM;
 391   hints.ai_flags = AI_PASSIVE;
 392   hints.ai_protocol = 0;
 393   hints.ai_canonname = NULL;
 394   hints.ai_addr = NULL;
 395   hints.ai_next = NULL;
 396 
 397   /* resolve the domain name into a list of addresses */
 398   error = getaddrinfo(hostname, NULL, &hints, &result);
 399   if (error == 0) {
 400     /* loop over all returned results and do inverse lookup */
 401     bool success = false;
 402     for (res = result; res != NULL; res = res->ai_next) {
 403       char name[NI_MAXHOST] = "";
 404       struct sockaddr_in *s = (struct sockaddr_in*) res->ai_addr;
 405 
 406       error = getnameinfo(res->ai_addr, res->ai_addrlen, name, NI_MAXHOST, NULL, 0, 0);
 407       if (error != 0) {
 408         JTRACE("getnameinfo() failed.") (gai_strerror(error));
 409         continue;
 410       }
 411       if (Util::strStartsWith(name, hostname) ||
 412           Util::strStartsWith(hostname, name)) {
 413         JASSERT(sizeof localhostIPAddr == sizeof s->sin_addr);
 414         success = true;
 415         memcpy(&localhostIPAddr, &s->sin_addr, sizeof s->sin_addr);
 416       }
 417     }
 418     if (!success) {
 419       JWARNING("Failed to find coordinator IP address.  DMTCP may fail.") (hostname) ;
 420     }
 421   } else {
 422     if (error == EAI_SYSTEM) {
 423       perror("getaddrinfo");
 424     } else {
 425       JTRACE("Error in getaddrinfo") (gai_strerror(error));
 426     }
 427     inet_aton("127.0.0.1", &localhostIPAddr);
 428   }
 429 
 430   SharedData::setCoordHost(&localhostIPAddr);
 431 }
 432 
 433 extern "C" int execve (const char *filename, char *const argv[],
 434                        char *const envp[])
 435 {
 436   if (jalib::Filesystem::BaseName(filename) != "ssh") {
 437     return _real_execve(filename, argv, envp);
 438   }
 439 
 440   updateCoordHost();
 441 
 442   char **newArgv = NULL;
 443   prepareForExec(argv, &newArgv);
 444   int ret = _real_execve (newArgv[0], newArgv, envp);
 445   JALLOC_HELPER_FREE(newArgv);
 446   return ret;
 447 }
 448 
 449 extern "C" int execvp (const char *filename, char *const argv[])
 450 {
 451   if (jalib::Filesystem::BaseName(filename) != "ssh") {
 452     return _real_execvp(filename, argv);
 453   }
 454 
 455   updateCoordHost();
 456 
 457   char **newArgv;
 458   prepareForExec(argv, &newArgv);
 459   int ret = _real_execvp (newArgv[0], newArgv);
 460   JALLOC_HELPER_FREE(newArgv);
 461   return ret;
 462 }
 463 
 464 // This function first appeared in glibc 2.11
 465 extern "C" int execvpe (const char *filename, char *const argv[],
 466                          char *const envp[])
 467 {
 468   if (jalib::Filesystem::BaseName(filename) != "ssh") {
 469     return _real_execvpe(filename, argv, envp);
 470   }
 471 
 472   updateCoordHost();
 473 
 474   char **newArgv;
 475   prepareForExec(argv, &newArgv);
 476   int ret = _real_execvpe(newArgv[0], newArgv, envp);
 477   JALLOC_HELPER_FREE(newArgv);
 478   return ret;
 479 }

/* [<][>][^][v][top][bottom][index][help] */