/* [<][>][^][v][top][bottom][index][help] */
DEFINITIONS
This source file includes following definitions.
- getport
- createStdioFds
- openListenSocket
- signal_handler
- waitForConnection
- main
1 #include <unistd.h>
2 #include <sys/stat.h>
3 #include <sys/fcntl.h>
4 #include <sys/types.h>
5 #include <sys/wait.h>
6 #include <sys/errno.h>
7 #include <sys/socket.h>
8 #include <linux/limits.h>
9 #include <netinet/in.h>
10 #include <arpa/inet.h>
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <string.h>
14 #include <assert.h>
15 #include "ssh.h"
16
17 static int listenSock = -1;
18 static int noStrictHostKeyChecking = 0;
19
20 extern "C" void dmtcp_get_local_ip_addr(struct in_addr *addr) __attribute((weak));
21
22 //static bool strEndsWith(const char *str, const char *pattern)
23 //{
24 // assert(str != NULL && pattern != NULL);
25 // int len1 = strlen(str);
26 // int len2 = strlen(pattern);
27 // if (len1 >= len2) {
28 // size_t idx = len1 - len2;
29 // return strncmp(str+idx, pattern, len2) == 0;
30 // }
31 // return false;
32 //}
33
34 static int getport(int fd)
35 {
36 struct sockaddr_in addr;
37 socklen_t addrlen = sizeof(addr);
38 if (getsockname(fd, (struct sockaddr *)&addr, &addrlen) == -1) {
39 return -1;
40 }
41 return (int)ntohs(addr.sin_port);
42 }
43
44 static void createStdioFds(int *in, int *out, int *err)
45 {
46 struct stat buf;
47 if (fstat(STDIN_FILENO, &buf) == -1) {
48 int fd = open("/dev/null", O_RDWR);
49 if (fd != STDIN_FILENO) {
50 dup2(fd, STDIN_FILENO);
51 close(fd);
52 }
53 }
54 if (fstat(STDOUT_FILENO, &buf) == -1) {
55 int fd = open("/dev/null", O_RDWR);
56 if (fd != STDOUT_FILENO) {
57 dup2(fd, STDOUT_FILENO);
58 close(fd);
59 }
60 }
61 if (fstat(STDERR_FILENO, &buf) == -1) {
62 int fd = open("/dev/null", O_RDWR);
63 if (fd != STDERR_FILENO) {
64 dup2(fd, STDERR_FILENO);
65 close(fd);
66 }
67 }
68
69 // Close all open file descriptors
70 int maxfd = sysconf(_SC_OPEN_MAX);
71 for (int i = 3; i < maxfd; i++) {
72 close(i);
73 }
74
75 if (pipe(in) != 0) {
76 perror("Error creating pipe: ");
77 }
78 if (pipe(out) != 0) {
79 perror("Error creating pipe: ");
80 }
81 if (pipe(err) != 0) {
82 perror("Error creating pipe: ");
83 }
84 }
85
86 static int openListenSocket()
87 {
88 struct sockaddr_in saddr;
89 int sock = socket(AF_INET, SOCK_STREAM, 0);
90 if (sock == -1) {
91 perror("Error creating socket: ");
92 }
93 memset(&saddr, 0, sizeof(saddr));
94 saddr.sin_family = AF_INET;
95 saddr.sin_addr.s_addr = INADDR_ANY;
96 saddr.sin_port = 0;
97 if (bind(sock, (struct sockaddr*) &saddr, sizeof saddr) == -1) {
98 perror("Error binding socket");
99 }
100
101 if (listen(sock, 1) == -1) {
102 perror("Error binding socket");
103 }
104 return sock;
105 }
106
107 static void signal_handler(int sig)
108 {
109 if (sig == SIGCHLD) {
110 int status;
111 wait(&status);
112 exit(status);
113 }
114 }
115
116 static int waitForConnection(int listenSock)
117 {
118 int fd = accept(listenSock, NULL, NULL);
119 if (fd == -1) {
120 perror("accept failed:");
121 abort();
122 exit(0);
123 }
124 close(listenSock);
125 return fd;
126 }
127
128 int main(int argc, char *argv[], char *envp[])
129 {
130 int in[2], out[2], err[2];
131 int status;
132 int ssh_stdinfd, ssh_stdoutfd, ssh_stderrfd;
133
134 if (argc < 2) {
135 printf("***ERROR: This program shouldn't be used directly.\n");
136 exit(1);
137 }
138
139 if (strcmp(argv[1], "--noStrictHostKeyChecking") == 0) {
140 noStrictHostKeyChecking = 1;
141 argv++;
142 }
143
144 createStdioFds(in, out, err);
145 listenSock = openListenSocket();
146 signal(SIGCHLD, signal_handler);
147
148 pid_t sshChildPid = fork();
149 if (sshChildPid == 0) {
150 char buf[PATH_MAX + 80];
151 char hostname[80];
152 int port = getport(listenSock);
153 close(listenSock);
154
155 close(in[1]);
156 close(out[0]);
157 close(err[0]);
158 dup2(in[0], STDIN_FILENO);
159 dup2(out[1], STDOUT_FILENO);
160 dup2(err[1], STDERR_FILENO);
161
162 unsetenv("LD_PRELOAD");
163
164 // Replace dmtcp_sshd replace with "dmtcp_sshd --host <host> --port <port>"
165 struct in_addr saddr;
166 if (dmtcp_get_local_ip_addr == NULL) {
167 printf("ERROR: Unable to find dmtcp_get_local_ip_addr.\n");
168 abort();
169 }
170 dmtcp_get_local_ip_addr(&saddr);
171 char *hostip = inet_ntoa(saddr);
172 strcpy(hostname, hostip);
173
174 size_t i = 0;
175 while (argv[i] != NULL) {
176 // "dmtcp_sshd" may be embedded deep inside the command line.
177 char *ptr = strstr(argv[i], SSHD_BINARY);
178 if (ptr != NULL) {
179 ptr += strlen(SSHD_BINARY);
180 if (*ptr != '\0') {
181 *ptr = '\0';
182 ptr++;
183 }
184 sprintf(buf, "%s --host %s --port %d %s",
185 argv[i], hostip, port, ptr);
186 argv[i] = buf;
187 }
188 i++;
189 }
190 execvp(argv[1], &argv[1]);
191 printf("%s:%d DMTCP Error detected. Failed to exec.", __FILE__, __LINE__);
192 abort();
193 }
194
195 int childSock = waitForConnection(listenSock);
196
197 close(in[0]);
198 close(out[1]);
199 close(err[1]);
200
201 ssh_stdinfd = in[1];
202 ssh_stdoutfd = out[0];
203 ssh_stderrfd = err[0];
204
205 assert(dmtcp_ssh_register_fds != NULL);
206 dmtcp_ssh_register_fds(false, ssh_stdinfd, ssh_stdoutfd, ssh_stderrfd,
207 childSock, noStrictHostKeyChecking);
208
209 client_loop(ssh_stdinfd, ssh_stdoutfd, ssh_stderrfd, childSock);
210 wait(&status);
211 return status;
212 }