/* [<][>][^][v][top][bottom][index][help] */
DEFINITIONS
This source file includes following definitions.
- dmtcp_SSH_EventHook
- drain
- refill
- receiveFileDescr
- sshdReceiveFds
- createNewDmtcpSshdProcess
- dmtcp_ssh_register_fds
- prepareForExec
- updateCoordHost
- execve
- execvp
- execvpe
1 #include <sys/syscall.h>
2 #include <sys/socket.h>
3 #include <netinet/in.h>
4 #include <arpa/inet.h>
5 #include <sys/un.h>
6 #include "dmtcp.h"
7 #include "util.h"
8 #include "util_ipc.h"
9 #include "jassert.h"
10 #include "jfilesystem.h"
11 #include "ipc.h"
12 #include "ssh.h"
13 #include "sshdrainer.h"
14 #include "shareddata.h"
15
16 using namespace dmtcp;
17
18 #define SSHD_PIPE_FD -1
19
20 static string cmd;
21 static string prefix;
22 static string dmtcp_launch_path;
23 static string dmtcp_ssh_path;
24 static string dmtcp_sshd_path;
25 static string dmtcp_nocheckpoint_path;
26
27 static SSHDrainer *theDrainer = NULL;
28 static int sshStdin = -1;
29 static int sshStdout = -1;
30 static int sshStderr = -1;
31 static int sshSockFd = -1;
32 static bool isSshdProcess = false;
33 static int noStrictHostKeyChecking = 0;
34
35 static bool sshPluginEnabled = false;
36
37 extern "C" void process_fd_event(int event, int arg1, int arg2 = -1);
38 static void drain();
39 static void refill(bool isRestart);
40 static void sshdReceiveFds();
41 static void createNewDmtcpSshdProcess();
42
43 void dmtcp_SSH_EventHook(DmtcpEvent_t event, DmtcpEventData_t *data)
44 {
45 if (!sshPluginEnabled) return;
46 switch (event) {
47 case DMTCP_EVENT_DRAIN:
48 drain();
49 break;
50
51 case DMTCP_EVENT_THREADS_RESUME:
52 refill(data->refillInfo.isRestart);
53 break;
54
55 default:
56 break;
57 }
58 }
59
60 static void drain()
61 {
62 JASSERT(theDrainer == NULL);
63 theDrainer = new SSHDrainer();
64 if (isSshdProcess) { // dmtcp_ssh process
65 theDrainer->beginDrainOf(STDIN_FILENO, sshStdin);
66 theDrainer->beginDrainOf(STDOUT_FILENO);
67 theDrainer->beginDrainOf(STDERR_FILENO);
68 } else {
69 theDrainer->beginDrainOf(sshStdin);
70 theDrainer->beginDrainOf(sshStdout, STDOUT_FILENO);
71 theDrainer->beginDrainOf(sshStderr, STDERR_FILENO);
72 }
73 theDrainer->monitorSockets(DRAINER_CHECK_FREQ);
74 }
75
76 static void refill(bool isRestart)
77 {
78 if (isRestart) {
79 if (isSshdProcess) { // dmtcp_sshd
80 sshdReceiveFds();
81 } else { // dmtcp_ssh
82 createNewDmtcpSshdProcess();
83 }
84 }
85
86 theDrainer->refill();
87 // Free up the object
88 delete theDrainer;
89 theDrainer = NULL;
90 }
91
92 static void receiveFileDescr(int fd)
93 {
94 int data;
95 int ret = Util::receiveFd(SSHD_RECEIVE_FD, &data, sizeof(data));
96 if (fd == SSHD_PIPE_FD) {
97 return;
98 }
99 JASSERT(data == fd) (data) (fd);
100 if (fd != ret) {
101 _real_close(fd);
102 _real_dup2(ret, fd);
103 _real_close(ret);
104 }
105 }
106
107 static void sshdReceiveFds()
108 {
109 // Add receive-fd data socket.
110 static struct sockaddr_un fdReceiveAddr;
111 static socklen_t fdReceiveAddrLen;
112
113 memset(&fdReceiveAddr, 0, sizeof(fdReceiveAddr));
114 jalib::JSocket sock(_real_socket(AF_UNIX, SOCK_DGRAM, 0));
115 JASSERT(sock.isValid());
116 sock.changeFd(SSHD_RECEIVE_FD);
117 fdReceiveAddr.sun_family = AF_UNIX;
118 JASSERT(_real_bind(SSHD_RECEIVE_FD,
119 (struct sockaddr*) &fdReceiveAddr,
120 sizeof(fdReceiveAddr.sun_family)) == 0) (JASSERT_ERRNO);
121
122 fdReceiveAddrLen = sizeof(fdReceiveAddr);
123 JASSERT(getsockname(SSHD_RECEIVE_FD,
124 (struct sockaddr *)&fdReceiveAddr,
125 &fdReceiveAddrLen) == 0);
126
127 // Send this information to dmtcp_ssh process
128 ssize_t ret = write(sshSockFd, &fdReceiveAddrLen, sizeof(fdReceiveAddrLen));
129 JASSERT(ret == sizeof(fdReceiveAddrLen)) (sshSockFd) (ret) (JASSERT_ERRNO);
130 ret = write(sshSockFd, &fdReceiveAddr, fdReceiveAddrLen);
131 JASSERT(ret == (ssize_t) fdReceiveAddrLen);
132
133 // Now receive fds
134 receiveFileDescr(STDIN_FILENO);
135 receiveFileDescr(STDOUT_FILENO);
136 receiveFileDescr(STDERR_FILENO);
137 receiveFileDescr(SSHD_PIPE_FD);
138 _real_close(SSHD_RECEIVE_FD);
139 }
140
141 static void createNewDmtcpSshdProcess()
142 {
143 struct sockaddr_un addr;
144 socklen_t addrLen;
145 static char abstractSockName[20];
146 int in[2], out[2], err[2];
147
148 ssize_t ret = read(sshSockFd, &addrLen, sizeof(addrLen));
149 JASSERT(ret == sizeof(addrLen));
150 memset(&addr, 0, sizeof(addr));
151 ret = read(sshSockFd, &addr, addrLen);
152 JASSERT(ret == (ssize_t) addrLen);
153 JASSERT(strlen(&addr.sun_path[1]) < sizeof(abstractSockName));
154 strcpy(abstractSockName, &addr.sun_path[1]);
155
156 struct sockaddr_in sshdSockAddr;
157 socklen_t sshdSockAddrLen = sizeof(sshdSockAddr);
158 char remoteHost[80];
159 JASSERT(getpeername(sshSockFd, (struct sockaddr*)&sshdSockAddr,
160 &sshdSockAddrLen) == 0);
161 char *ip = inet_ntoa(sshdSockAddr.sin_addr);
162 strcpy(remoteHost, ip);
163
164 if (dmtcp_nocheckpoint_path.length() == 0) {
165 dmtcp_nocheckpoint_path = Util::getPath("dmtcp_nocheckpoint");
166 dmtcp_sshd_path = Util::getPath("dmtcp_sshd");
167 }
168
169 JASSERT(pipe(in) == 0) (JASSERT_ERRNO);
170 JASSERT(pipe(out) == 0) (JASSERT_ERRNO);
171 JASSERT(pipe(err) == 0) (JASSERT_ERRNO);
172
173 pid_t sshChildPid = fork();
174 JASSERT(sshChildPid != -1);
175 if (sshChildPid == 0) {
176 const int max_args = 16;
177 char *argv[16];
178 int idx = 0;
179
180 argv[idx++] = (char*) dmtcp_nocheckpoint_path.c_str();
181 argv[idx++] = const_cast<char*>("ssh");
182 if (noStrictHostKeyChecking) {
183 argv[idx++] = const_cast<char*>("-o");
184 argv[idx++] = const_cast<char*>("StrictHostKeyChecking=no");
185 }
186 argv[idx++] = remoteHost;
187 argv[idx++] = (char*) dmtcp_sshd_path.c_str();
188 argv[idx++] = const_cast<char*>("--listenAddr");
189 argv[idx++] = abstractSockName;
190 argv[idx++] = NULL;
191 JASSERT(idx < max_args) (idx);
192
193 process_fd_event(SYS_close, in[1]);
194 process_fd_event(SYS_close, out[0]);
195 process_fd_event(SYS_close, err[0]);
196 dup2(in[0], STDIN_FILENO);
197 dup2(out[1], STDOUT_FILENO);
198 dup2(err[1], STDERR_FILENO);
199
200 JTRACE("Launching ") (argv[0]) (argv[1]) (argv[2]) (argv[3]) (argv[4]) (argv[5]);
201 _real_execvp(argv[0], argv);
202 JASSERT(false);
203 }
204
205 dup2(in[1], 500 + sshStdin);
206 dup2(out[0], 500 + sshStdout);
207 dup2(err[0], 500 + sshStderr);
208
209 close(in[0]);
210 close(in[1]);
211 close(out[0]);
212 close(out[1]);
213 close(err[0]);
214 close(err[1]);
215
216 dup2(500 + sshStdin, sshStdin);
217 dup2(500 + sshStdout, sshStdout);
218 dup2(500 + sshStderr, sshStderr);
219 close(500 + sshStdin);
220 close(500 + sshStdout);
221 close(500 + sshStderr);
222
223 process_fd_event(SYS_close, sshStdin);
224 process_fd_event(SYS_close, sshStdout);
225 process_fd_event(SYS_close, sshStderr);
226 }
227
228 extern "C" void dmtcp_ssh_register_fds(int isSshd, int in, int out, int err,
229 int sock, int noStrictChecking)
230 {
231 if (isSshd) { // dmtcp_sshd
232 process_fd_event(SYS_close, STDIN_FILENO);
233 process_fd_event(SYS_close, STDOUT_FILENO);
234 process_fd_event(SYS_close, STDERR_FILENO);
235 } else { // dmtcp_ssh
236 process_fd_event(SYS_close, in);
237 process_fd_event(SYS_close, out);
238 process_fd_event(SYS_close, err);
239 }
240 sshStdin = in;
241 sshStdout = out;
242 sshStderr = err;
243 sshSockFd = sock;
244 isSshdProcess = isSshd;
245 sshPluginEnabled = true;
246 noStrictHostKeyChecking = noStrictChecking;
247 }
248
249 static void prepareForExec(char *const argv[], char ***newArgv)
250 {
251 size_t nargs = 0;
252 bool noStrictChecking = false;
253 string precmd, postcmd, tempcmd;
254 while (argv[nargs++] != NULL);
255
256 if (nargs < 3) {
257 JNOTE("ssh with less than 3 args") (argv[0]) (argv[1]);
258 *newArgv = (char**) argv;
259 return;
260 }
261
262 //find command part
263 size_t commandStart = 2;
264 for (size_t i = 1; i < nargs; ++i) {
265 string s = argv[i];
266 if (strcmp(argv[i], "-o") == 0) {
267 if (strcmp(argv[i+1], "StrictHostKeyChecking=no") == 0) {
268 noStrictChecking = true;
269 }
270 i++;
271 continue;
272 }
273
274 // The following flags have additional parameters and aren't fully
275 // supported. We simply forward them to the ssh command.
276 if (s == "-b" || s == "-c" || s == "-E" || s == "-e" || s == "-F" ||
277 s == "-I" || s == "-i" || s == "-l" || s == "-O" || s == "-o" ||
278 s == "-p" || s == "-Q" || s == "-S") {
279 i++;
280 continue;
281 }
282
283 // These options have a higer probability of failure due to binding
284 // addresses, etc.
285 if (s == "-b" || s == "-D" || s == "-L" || s == "-m" || s == "-R" ||
286 s == "-W" || s == "-w") {
287 JNOTE("The '" + s + "' ssh option isn't fully supported!");
288 i++;
289 continue;
290 }
291
292 if (argv[i][0] != '-') {
293 commandStart = i + 1;
294 break;
295 }
296 }
297 JASSERT(commandStart < nargs && argv[commandStart][0] != '-')
298 (commandStart) (nargs) (argv[commandStart])
299 .Text("failed to parse ssh command line");
300
301 vector<string> dmtcp_args;
302 Util::getDmtcpArgs(dmtcp_args);
303
304 dmtcp_launch_path = Util::getPath("dmtcp_launch");
305 dmtcp_ssh_path = Util::getPath("dmtcp_ssh");
306 dmtcp_sshd_path = Util::getPath("dmtcp_sshd");
307 dmtcp_nocheckpoint_path = Util::getPath("dmtcp_nocheckpoint");
308
309 prefix = dmtcp_launch_path + " ";
310 for(size_t i = 0; i < dmtcp_args.size(); i++){
311 prefix += dmtcp_args[i] + " ";
312 }
313 prefix += dmtcp_sshd_path + " ";
314 JTRACE("Prefix")(prefix);
315
316 // process command
317 size_t semipos, pos;
318 size_t actpos = string::npos;
319 tempcmd = argv[commandStart];
320 for(semipos = 0; (pos = tempcmd.find(';',semipos+1)) != string::npos;
321 semipos = pos, actpos = pos);
322
323 if (actpos > 0 && actpos != string::npos) {
324 precmd = tempcmd.substr(0, actpos + 1);
325 postcmd = tempcmd.substr(actpos + 1);
326 postcmd = postcmd.substr(postcmd.find_first_not_of(" "));
327 } else {
328 precmd = "";
329 postcmd = tempcmd;
330 }
331
332 cmd = precmd;
333 // convert "exec cmd" to "exec <dmtcp-prefix> cmd"
334 if (Util::strStartsWith(postcmd, "exec")) {
335 cmd += "exec " + prefix + postcmd.substr(strlen("exec"));
336 } else {
337 cmd += prefix + postcmd;
338 }
339
340 //now repack args
341 char** new_argv = (char**) JALLOC_HELPER_MALLOC(sizeof(char*) * (nargs + 10));
342 memset(new_argv, 0, sizeof(char*) * (nargs + 10));
343
344 size_t idx = 0;
345 new_argv[idx++] = (char*) dmtcp_ssh_path.c_str();
346 if (noStrictChecking) {
347 new_argv[idx++] = const_cast<char*>("--noStrictHostKeyChecking");
348 }
349 new_argv[idx++] = (char*) dmtcp_nocheckpoint_path.c_str();
350
351 string newCommand = string(new_argv[0]) + " " + string(new_argv[1]) + " ";
352 for (size_t i = 0; i < commandStart; ++i) {
353 new_argv[idx++] = ( char* ) argv[i];
354 if (argv[i] != NULL) {
355 newCommand += argv[i];
356 newCommand += ' ';
357 }
358 }
359 new_argv[idx++] = (char*) cmd.c_str();
360 newCommand += cmd + " ";
361
362 for (size_t i = commandStart + 1; i < nargs; ++i) {
363 new_argv[idx++] = (char*) argv[i];
364 if (argv[i] != NULL) {
365 newCommand += argv[i];
366 newCommand += ' ';
367 }
368 }
369 JNOTE("New ssh command") (newCommand);
370 *newArgv = new_argv;
371 return;
372 }
373
374 // This code is copied from dmtcp_coordinator.cpp:calLocalAddr()
375 static void updateCoordHost() {
376 if (SharedData::coordHost() != "127.0.0.1") return;
377
378 struct in_addr localhostIPAddr;
379 string cmd;
380 char hostname[HOST_NAME_MAX];
381 JASSERT(gethostname(hostname, sizeof hostname) == 0) (JASSERT_ERRNO);
382 struct addrinfo *result;
383 struct addrinfo *res;
384 int error;
385 struct addrinfo hints;
386
387 memset(&localhostIPAddr, 0, sizeof localhostIPAddr);
388 memset(&hints, 0, sizeof(struct addrinfo));
389 hints.ai_family = AF_INET;
390 hints.ai_socktype = SOCK_STREAM;
391 hints.ai_flags = AI_PASSIVE;
392 hints.ai_protocol = 0;
393 hints.ai_canonname = NULL;
394 hints.ai_addr = NULL;
395 hints.ai_next = NULL;
396
397 /* resolve the domain name into a list of addresses */
398 error = getaddrinfo(hostname, NULL, &hints, &result);
399 if (error == 0) {
400 /* loop over all returned results and do inverse lookup */
401 bool success = false;
402 for (res = result; res != NULL; res = res->ai_next) {
403 char name[NI_MAXHOST] = "";
404 struct sockaddr_in *s = (struct sockaddr_in*) res->ai_addr;
405
406 error = getnameinfo(res->ai_addr, res->ai_addrlen, name, NI_MAXHOST, NULL, 0, 0);
407 if (error != 0) {
408 JTRACE("getnameinfo() failed.") (gai_strerror(error));
409 continue;
410 }
411 if (Util::strStartsWith(name, hostname) ||
412 Util::strStartsWith(hostname, name)) {
413 JASSERT(sizeof localhostIPAddr == sizeof s->sin_addr);
414 success = true;
415 memcpy(&localhostIPAddr, &s->sin_addr, sizeof s->sin_addr);
416 }
417 }
418 if (!success) {
419 JWARNING("Failed to find coordinator IP address. DMTCP may fail.") (hostname) ;
420 }
421 } else {
422 if (error == EAI_SYSTEM) {
423 perror("getaddrinfo");
424 } else {
425 JTRACE("Error in getaddrinfo") (gai_strerror(error));
426 }
427 inet_aton("127.0.0.1", &localhostIPAddr);
428 }
429
430 SharedData::setCoordHost(&localhostIPAddr);
431 }
432
433 extern "C" int execve (const char *filename, char *const argv[],
434 char *const envp[])
435 {
436 if (jalib::Filesystem::BaseName(filename) != "ssh") {
437 return _real_execve(filename, argv, envp);
438 }
439
440 updateCoordHost();
441
442 char **newArgv = NULL;
443 prepareForExec(argv, &newArgv);
444 int ret = _real_execve (newArgv[0], newArgv, envp);
445 JALLOC_HELPER_FREE(newArgv);
446 return ret;
447 }
448
449 extern "C" int execvp (const char *filename, char *const argv[])
450 {
451 if (jalib::Filesystem::BaseName(filename) != "ssh") {
452 return _real_execvp(filename, argv);
453 }
454
455 updateCoordHost();
456
457 char **newArgv;
458 prepareForExec(argv, &newArgv);
459 int ret = _real_execvp (newArgv[0], newArgv);
460 JALLOC_HELPER_FREE(newArgv);
461 return ret;
462 }
463
464 // This function first appeared in glibc 2.11
465 extern "C" int execvpe (const char *filename, char *const argv[],
466 char *const envp[])
467 {
468 if (jalib::Filesystem::BaseName(filename) != "ssh") {
469 return _real_execvpe(filename, argv, envp);
470 }
471
472 updateCoordHost();
473
474 char **newArgv;
475 prepareForExec(argv, &newArgv);
476 int ret = _real_execvpe(newArgv[0], newArgv, envp);
477 JALLOC_HELPER_FREE(newArgv);
478 return ret;
479 }