GNU Linux-libre 6.8.7-gnu
[releases.git] / tools / testing / vsock / vsock_test.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * vsock_test - vsock.ko test suite
4  *
5  * Copyright (C) 2017 Red Hat, Inc.
6  *
7  * Author: Stefan Hajnoczi <stefanha@redhat.com>
8  */
9
10 #include <getopt.h>
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <string.h>
14 #include <errno.h>
15 #include <unistd.h>
16 #include <linux/kernel.h>
17 #include <sys/types.h>
18 #include <sys/socket.h>
19 #include <time.h>
20 #include <sys/mman.h>
21 #include <poll.h>
22 #include <signal.h>
23
24 #include "vsock_test_zerocopy.h"
25 #include "timeout.h"
26 #include "control.h"
27 #include "util.h"
28
29 static void test_stream_connection_reset(const struct test_opts *opts)
30 {
31         union {
32                 struct sockaddr sa;
33                 struct sockaddr_vm svm;
34         } addr = {
35                 .svm = {
36                         .svm_family = AF_VSOCK,
37                         .svm_port = 1234,
38                         .svm_cid = opts->peer_cid,
39                 },
40         };
41         int ret;
42         int fd;
43
44         fd = socket(AF_VSOCK, SOCK_STREAM, 0);
45
46         timeout_begin(TIMEOUT);
47         do {
48                 ret = connect(fd, &addr.sa, sizeof(addr.svm));
49                 timeout_check("connect");
50         } while (ret < 0 && errno == EINTR);
51         timeout_end();
52
53         if (ret != -1) {
54                 fprintf(stderr, "expected connect(2) failure, got %d\n", ret);
55                 exit(EXIT_FAILURE);
56         }
57         if (errno != ECONNRESET) {
58                 fprintf(stderr, "unexpected connect(2) errno %d\n", errno);
59                 exit(EXIT_FAILURE);
60         }
61
62         close(fd);
63 }
64
65 static void test_stream_bind_only_client(const struct test_opts *opts)
66 {
67         union {
68                 struct sockaddr sa;
69                 struct sockaddr_vm svm;
70         } addr = {
71                 .svm = {
72                         .svm_family = AF_VSOCK,
73                         .svm_port = 1234,
74                         .svm_cid = opts->peer_cid,
75                 },
76         };
77         int ret;
78         int fd;
79
80         /* Wait for the server to be ready */
81         control_expectln("BIND");
82
83         fd = socket(AF_VSOCK, SOCK_STREAM, 0);
84
85         timeout_begin(TIMEOUT);
86         do {
87                 ret = connect(fd, &addr.sa, sizeof(addr.svm));
88                 timeout_check("connect");
89         } while (ret < 0 && errno == EINTR);
90         timeout_end();
91
92         if (ret != -1) {
93                 fprintf(stderr, "expected connect(2) failure, got %d\n", ret);
94                 exit(EXIT_FAILURE);
95         }
96         if (errno != ECONNRESET) {
97                 fprintf(stderr, "unexpected connect(2) errno %d\n", errno);
98                 exit(EXIT_FAILURE);
99         }
100
101         /* Notify the server that the client has finished */
102         control_writeln("DONE");
103
104         close(fd);
105 }
106
107 static void test_stream_bind_only_server(const struct test_opts *opts)
108 {
109         union {
110                 struct sockaddr sa;
111                 struct sockaddr_vm svm;
112         } addr = {
113                 .svm = {
114                         .svm_family = AF_VSOCK,
115                         .svm_port = 1234,
116                         .svm_cid = VMADDR_CID_ANY,
117                 },
118         };
119         int fd;
120
121         fd = socket(AF_VSOCK, SOCK_STREAM, 0);
122
123         if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
124                 perror("bind");
125                 exit(EXIT_FAILURE);
126         }
127
128         /* Notify the client that the server is ready */
129         control_writeln("BIND");
130
131         /* Wait for the client to finish */
132         control_expectln("DONE");
133
134         close(fd);
135 }
136
137 static void test_stream_client_close_client(const struct test_opts *opts)
138 {
139         int fd;
140
141         fd = vsock_stream_connect(opts->peer_cid, 1234);
142         if (fd < 0) {
143                 perror("connect");
144                 exit(EXIT_FAILURE);
145         }
146
147         send_byte(fd, 1, 0);
148         close(fd);
149 }
150
151 static void test_stream_client_close_server(const struct test_opts *opts)
152 {
153         int fd;
154
155         fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
156         if (fd < 0) {
157                 perror("accept");
158                 exit(EXIT_FAILURE);
159         }
160
161         /* Wait for the remote to close the connection, before check
162          * -EPIPE error on send.
163          */
164         vsock_wait_remote_close(fd);
165
166         send_byte(fd, -EPIPE, 0);
167         recv_byte(fd, 1, 0);
168         recv_byte(fd, 0, 0);
169         close(fd);
170 }
171
172 static void test_stream_server_close_client(const struct test_opts *opts)
173 {
174         int fd;
175
176         fd = vsock_stream_connect(opts->peer_cid, 1234);
177         if (fd < 0) {
178                 perror("connect");
179                 exit(EXIT_FAILURE);
180         }
181
182         /* Wait for the remote to close the connection, before check
183          * -EPIPE error on send.
184          */
185         vsock_wait_remote_close(fd);
186
187         send_byte(fd, -EPIPE, 0);
188         recv_byte(fd, 1, 0);
189         recv_byte(fd, 0, 0);
190         close(fd);
191 }
192
193 static void test_stream_server_close_server(const struct test_opts *opts)
194 {
195         int fd;
196
197         fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
198         if (fd < 0) {
199                 perror("accept");
200                 exit(EXIT_FAILURE);
201         }
202
203         send_byte(fd, 1, 0);
204         close(fd);
205 }
206
207 /* With the standard socket sizes, VMCI is able to support about 100
208  * concurrent stream connections.
209  */
210 #define MULTICONN_NFDS 100
211
212 static void test_stream_multiconn_client(const struct test_opts *opts)
213 {
214         int fds[MULTICONN_NFDS];
215         int i;
216
217         for (i = 0; i < MULTICONN_NFDS; i++) {
218                 fds[i] = vsock_stream_connect(opts->peer_cid, 1234);
219                 if (fds[i] < 0) {
220                         perror("connect");
221                         exit(EXIT_FAILURE);
222                 }
223         }
224
225         for (i = 0; i < MULTICONN_NFDS; i++) {
226                 if (i % 2)
227                         recv_byte(fds[i], 1, 0);
228                 else
229                         send_byte(fds[i], 1, 0);
230         }
231
232         for (i = 0; i < MULTICONN_NFDS; i++)
233                 close(fds[i]);
234 }
235
236 static void test_stream_multiconn_server(const struct test_opts *opts)
237 {
238         int fds[MULTICONN_NFDS];
239         int i;
240
241         for (i = 0; i < MULTICONN_NFDS; i++) {
242                 fds[i] = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
243                 if (fds[i] < 0) {
244                         perror("accept");
245                         exit(EXIT_FAILURE);
246                 }
247         }
248
249         for (i = 0; i < MULTICONN_NFDS; i++) {
250                 if (i % 2)
251                         send_byte(fds[i], 1, 0);
252                 else
253                         recv_byte(fds[i], 1, 0);
254         }
255
256         for (i = 0; i < MULTICONN_NFDS; i++)
257                 close(fds[i]);
258 }
259
260 #define MSG_PEEK_BUF_LEN 64
261
262 static void test_msg_peek_client(const struct test_opts *opts,
263                                  bool seqpacket)
264 {
265         unsigned char buf[MSG_PEEK_BUF_LEN];
266         int fd;
267         int i;
268
269         if (seqpacket)
270                 fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
271         else
272                 fd = vsock_stream_connect(opts->peer_cid, 1234);
273
274         if (fd < 0) {
275                 perror("connect");
276                 exit(EXIT_FAILURE);
277         }
278
279         for (i = 0; i < sizeof(buf); i++)
280                 buf[i] = rand() & 0xFF;
281
282         control_expectln("SRVREADY");
283
284         send_buf(fd, buf, sizeof(buf), 0, sizeof(buf));
285
286         close(fd);
287 }
288
289 static void test_msg_peek_server(const struct test_opts *opts,
290                                  bool seqpacket)
291 {
292         unsigned char buf_half[MSG_PEEK_BUF_LEN / 2];
293         unsigned char buf_normal[MSG_PEEK_BUF_LEN];
294         unsigned char buf_peek[MSG_PEEK_BUF_LEN];
295         int fd;
296
297         if (seqpacket)
298                 fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
299         else
300                 fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
301
302         if (fd < 0) {
303                 perror("accept");
304                 exit(EXIT_FAILURE);
305         }
306
307         /* Peek from empty socket. */
308         recv_buf(fd, buf_peek, sizeof(buf_peek), MSG_PEEK | MSG_DONTWAIT,
309                  -EAGAIN);
310
311         control_writeln("SRVREADY");
312
313         /* Peek part of data. */
314         recv_buf(fd, buf_half, sizeof(buf_half), MSG_PEEK, sizeof(buf_half));
315
316         /* Peek whole data. */
317         recv_buf(fd, buf_peek, sizeof(buf_peek), MSG_PEEK, sizeof(buf_peek));
318
319         /* Compare partial and full peek. */
320         if (memcmp(buf_half, buf_peek, sizeof(buf_half))) {
321                 fprintf(stderr, "Partial peek data mismatch\n");
322                 exit(EXIT_FAILURE);
323         }
324
325         if (seqpacket) {
326                 /* This type of socket supports MSG_TRUNC flag,
327                  * so check it with MSG_PEEK. We must get length
328                  * of the message.
329                  */
330                 recv_buf(fd, buf_half, sizeof(buf_half), MSG_PEEK | MSG_TRUNC,
331                          sizeof(buf_peek));
332         }
333
334         recv_buf(fd, buf_normal, sizeof(buf_normal), 0, sizeof(buf_normal));
335
336         /* Compare full peek and normal read. */
337         if (memcmp(buf_peek, buf_normal, sizeof(buf_peek))) {
338                 fprintf(stderr, "Full peek data mismatch\n");
339                 exit(EXIT_FAILURE);
340         }
341
342         close(fd);
343 }
344
345 static void test_stream_msg_peek_client(const struct test_opts *opts)
346 {
347         return test_msg_peek_client(opts, false);
348 }
349
350 static void test_stream_msg_peek_server(const struct test_opts *opts)
351 {
352         return test_msg_peek_server(opts, false);
353 }
354
355 #define SOCK_BUF_SIZE (2 * 1024 * 1024)
356 #define MAX_MSG_PAGES 4
357
358 static void test_seqpacket_msg_bounds_client(const struct test_opts *opts)
359 {
360         unsigned long curr_hash;
361         size_t max_msg_size;
362         int page_size;
363         int msg_count;
364         int fd;
365
366         fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
367         if (fd < 0) {
368                 perror("connect");
369                 exit(EXIT_FAILURE);
370         }
371
372         /* Wait, until receiver sets buffer size. */
373         control_expectln("SRVREADY");
374
375         curr_hash = 0;
376         page_size = getpagesize();
377         max_msg_size = MAX_MSG_PAGES * page_size;
378         msg_count = SOCK_BUF_SIZE / max_msg_size;
379
380         for (int i = 0; i < msg_count; i++) {
381                 size_t buf_size;
382                 int flags;
383                 void *buf;
384
385                 /* Use "small" buffers and "big" buffers. */
386                 if (i & 1)
387                         buf_size = page_size +
388                                         (rand() % (max_msg_size - page_size));
389                 else
390                         buf_size = 1 + (rand() % page_size);
391
392                 buf = malloc(buf_size);
393
394                 if (!buf) {
395                         perror("malloc");
396                         exit(EXIT_FAILURE);
397                 }
398
399                 memset(buf, rand() & 0xff, buf_size);
400                 /* Set at least one MSG_EOR + some random. */
401                 if (i == (msg_count / 2) || (rand() & 1)) {
402                         flags = MSG_EOR;
403                         curr_hash++;
404                 } else {
405                         flags = 0;
406                 }
407
408                 send_buf(fd, buf, buf_size, flags, buf_size);
409
410                 /*
411                  * Hash sum is computed at both client and server in
412                  * the same way:
413                  * H += hash('message data')
414                  * Such hash "controls" both data integrity and message
415                  * bounds. After data exchange, both sums are compared
416                  * using control socket, and if message bounds wasn't
417                  * broken - two values must be equal.
418                  */
419                 curr_hash += hash_djb2(buf, buf_size);
420                 free(buf);
421         }
422
423         control_writeln("SENDDONE");
424         control_writeulong(curr_hash);
425         close(fd);
426 }
427
428 static void test_seqpacket_msg_bounds_server(const struct test_opts *opts)
429 {
430         unsigned long sock_buf_size;
431         unsigned long remote_hash;
432         unsigned long curr_hash;
433         int fd;
434         struct msghdr msg = {0};
435         struct iovec iov = {0};
436
437         fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
438         if (fd < 0) {
439                 perror("accept");
440                 exit(EXIT_FAILURE);
441         }
442
443         sock_buf_size = SOCK_BUF_SIZE;
444
445         if (setsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_MAX_SIZE,
446                        &sock_buf_size, sizeof(sock_buf_size))) {
447                 perror("setsockopt(SO_VM_SOCKETS_BUFFER_MAX_SIZE)");
448                 exit(EXIT_FAILURE);
449         }
450
451         if (setsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_SIZE,
452                        &sock_buf_size, sizeof(sock_buf_size))) {
453                 perror("setsockopt(SO_VM_SOCKETS_BUFFER_SIZE)");
454                 exit(EXIT_FAILURE);
455         }
456
457         /* Ready to receive data. */
458         control_writeln("SRVREADY");
459         /* Wait, until peer sends whole data. */
460         control_expectln("SENDDONE");
461         iov.iov_len = MAX_MSG_PAGES * getpagesize();
462         iov.iov_base = malloc(iov.iov_len);
463         if (!iov.iov_base) {
464                 perror("malloc");
465                 exit(EXIT_FAILURE);
466         }
467
468         msg.msg_iov = &iov;
469         msg.msg_iovlen = 1;
470
471         curr_hash = 0;
472
473         while (1) {
474                 ssize_t recv_size;
475
476                 recv_size = recvmsg(fd, &msg, 0);
477
478                 if (!recv_size)
479                         break;
480
481                 if (recv_size < 0) {
482                         perror("recvmsg");
483                         exit(EXIT_FAILURE);
484                 }
485
486                 if (msg.msg_flags & MSG_EOR)
487                         curr_hash++;
488
489                 curr_hash += hash_djb2(msg.msg_iov[0].iov_base, recv_size);
490         }
491
492         free(iov.iov_base);
493         close(fd);
494         remote_hash = control_readulong();
495
496         if (curr_hash != remote_hash) {
497                 fprintf(stderr, "Message bounds broken\n");
498                 exit(EXIT_FAILURE);
499         }
500 }
501
502 #define MESSAGE_TRUNC_SZ 32
503 static void test_seqpacket_msg_trunc_client(const struct test_opts *opts)
504 {
505         int fd;
506         char buf[MESSAGE_TRUNC_SZ];
507
508         fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
509         if (fd < 0) {
510                 perror("connect");
511                 exit(EXIT_FAILURE);
512         }
513
514         send_buf(fd, buf, sizeof(buf), 0, sizeof(buf));
515
516         control_writeln("SENDDONE");
517         close(fd);
518 }
519
520 static void test_seqpacket_msg_trunc_server(const struct test_opts *opts)
521 {
522         int fd;
523         char buf[MESSAGE_TRUNC_SZ / 2];
524         struct msghdr msg = {0};
525         struct iovec iov = {0};
526
527         fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
528         if (fd < 0) {
529                 perror("accept");
530                 exit(EXIT_FAILURE);
531         }
532
533         control_expectln("SENDDONE");
534         iov.iov_base = buf;
535         iov.iov_len = sizeof(buf);
536         msg.msg_iov = &iov;
537         msg.msg_iovlen = 1;
538
539         ssize_t ret = recvmsg(fd, &msg, MSG_TRUNC);
540
541         if (ret != MESSAGE_TRUNC_SZ) {
542                 printf("%zi\n", ret);
543                 perror("MSG_TRUNC doesn't work");
544                 exit(EXIT_FAILURE);
545         }
546
547         if (!(msg.msg_flags & MSG_TRUNC)) {
548                 fprintf(stderr, "MSG_TRUNC expected\n");
549                 exit(EXIT_FAILURE);
550         }
551
552         close(fd);
553 }
554
555 static time_t current_nsec(void)
556 {
557         struct timespec ts;
558
559         if (clock_gettime(CLOCK_REALTIME, &ts)) {
560                 perror("clock_gettime(3) failed");
561                 exit(EXIT_FAILURE);
562         }
563
564         return (ts.tv_sec * 1000000000ULL) + ts.tv_nsec;
565 }
566
567 #define RCVTIMEO_TIMEOUT_SEC 1
568 #define READ_OVERHEAD_NSEC 250000000 /* 0.25 sec */
569
570 static void test_seqpacket_timeout_client(const struct test_opts *opts)
571 {
572         int fd;
573         struct timeval tv;
574         char dummy;
575         time_t read_enter_ns;
576         time_t read_overhead_ns;
577
578         fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
579         if (fd < 0) {
580                 perror("connect");
581                 exit(EXIT_FAILURE);
582         }
583
584         tv.tv_sec = RCVTIMEO_TIMEOUT_SEC;
585         tv.tv_usec = 0;
586
587         if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, (void *)&tv, sizeof(tv)) == -1) {
588                 perror("setsockopt(SO_RCVTIMEO)");
589                 exit(EXIT_FAILURE);
590         }
591
592         read_enter_ns = current_nsec();
593
594         if (read(fd, &dummy, sizeof(dummy)) != -1) {
595                 fprintf(stderr,
596                         "expected 'dummy' read(2) failure\n");
597                 exit(EXIT_FAILURE);
598         }
599
600         if (errno != EAGAIN) {
601                 perror("EAGAIN expected");
602                 exit(EXIT_FAILURE);
603         }
604
605         read_overhead_ns = current_nsec() - read_enter_ns -
606                         1000000000ULL * RCVTIMEO_TIMEOUT_SEC;
607
608         if (read_overhead_ns > READ_OVERHEAD_NSEC) {
609                 fprintf(stderr,
610                         "too much time in read(2), %lu > %i ns\n",
611                         read_overhead_ns, READ_OVERHEAD_NSEC);
612                 exit(EXIT_FAILURE);
613         }
614
615         control_writeln("WAITDONE");
616         close(fd);
617 }
618
619 static void test_seqpacket_timeout_server(const struct test_opts *opts)
620 {
621         int fd;
622
623         fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
624         if (fd < 0) {
625                 perror("accept");
626                 exit(EXIT_FAILURE);
627         }
628
629         control_expectln("WAITDONE");
630         close(fd);
631 }
632
633 static void test_seqpacket_bigmsg_client(const struct test_opts *opts)
634 {
635         unsigned long sock_buf_size;
636         socklen_t len;
637         void *data;
638         int fd;
639
640         len = sizeof(sock_buf_size);
641
642         fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
643         if (fd < 0) {
644                 perror("connect");
645                 exit(EXIT_FAILURE);
646         }
647
648         if (getsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_SIZE,
649                        &sock_buf_size, &len)) {
650                 perror("getsockopt");
651                 exit(EXIT_FAILURE);
652         }
653
654         sock_buf_size++;
655
656         data = malloc(sock_buf_size);
657         if (!data) {
658                 perror("malloc");
659                 exit(EXIT_FAILURE);
660         }
661
662         send_buf(fd, data, sock_buf_size, 0, -EMSGSIZE);
663
664         control_writeln("CLISENT");
665
666         free(data);
667         close(fd);
668 }
669
670 static void test_seqpacket_bigmsg_server(const struct test_opts *opts)
671 {
672         int fd;
673
674         fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
675         if (fd < 0) {
676                 perror("accept");
677                 exit(EXIT_FAILURE);
678         }
679
680         control_expectln("CLISENT");
681
682         close(fd);
683 }
684
685 #define BUF_PATTERN_1 'a'
686 #define BUF_PATTERN_2 'b'
687
688 static void test_seqpacket_invalid_rec_buffer_client(const struct test_opts *opts)
689 {
690         int fd;
691         unsigned char *buf1;
692         unsigned char *buf2;
693         int buf_size = getpagesize() * 3;
694
695         fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
696         if (fd < 0) {
697                 perror("connect");
698                 exit(EXIT_FAILURE);
699         }
700
701         buf1 = malloc(buf_size);
702         if (!buf1) {
703                 perror("'malloc()' for 'buf1'");
704                 exit(EXIT_FAILURE);
705         }
706
707         buf2 = malloc(buf_size);
708         if (!buf2) {
709                 perror("'malloc()' for 'buf2'");
710                 exit(EXIT_FAILURE);
711         }
712
713         memset(buf1, BUF_PATTERN_1, buf_size);
714         memset(buf2, BUF_PATTERN_2, buf_size);
715
716         send_buf(fd, buf1, buf_size, 0, buf_size);
717
718         send_buf(fd, buf2, buf_size, 0, buf_size);
719
720         close(fd);
721 }
722
723 static void test_seqpacket_invalid_rec_buffer_server(const struct test_opts *opts)
724 {
725         int fd;
726         unsigned char *broken_buf;
727         unsigned char *valid_buf;
728         int page_size = getpagesize();
729         int buf_size = page_size * 3;
730         ssize_t res;
731         int prot = PROT_READ | PROT_WRITE;
732         int flags = MAP_PRIVATE | MAP_ANONYMOUS;
733         int i;
734
735         fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
736         if (fd < 0) {
737                 perror("accept");
738                 exit(EXIT_FAILURE);
739         }
740
741         /* Setup first buffer. */
742         broken_buf = mmap(NULL, buf_size, prot, flags, -1, 0);
743         if (broken_buf == MAP_FAILED) {
744                 perror("mmap for 'broken_buf'");
745                 exit(EXIT_FAILURE);
746         }
747
748         /* Unmap "hole" in buffer. */
749         if (munmap(broken_buf + page_size, page_size)) {
750                 perror("'broken_buf' setup");
751                 exit(EXIT_FAILURE);
752         }
753
754         valid_buf = mmap(NULL, buf_size, prot, flags, -1, 0);
755         if (valid_buf == MAP_FAILED) {
756                 perror("mmap for 'valid_buf'");
757                 exit(EXIT_FAILURE);
758         }
759
760         /* Try to fill buffer with unmapped middle. */
761         res = read(fd, broken_buf, buf_size);
762         if (res != -1) {
763                 fprintf(stderr,
764                         "expected 'broken_buf' read(2) failure, got %zi\n",
765                         res);
766                 exit(EXIT_FAILURE);
767         }
768
769         if (errno != EFAULT) {
770                 perror("unexpected errno of 'broken_buf'");
771                 exit(EXIT_FAILURE);
772         }
773
774         /* Try to fill valid buffer. */
775         res = read(fd, valid_buf, buf_size);
776         if (res < 0) {
777                 perror("unexpected 'valid_buf' read(2) failure");
778                 exit(EXIT_FAILURE);
779         }
780
781         if (res != buf_size) {
782                 fprintf(stderr,
783                         "invalid 'valid_buf' read(2), expected %i, got %zi\n",
784                         buf_size, res);
785                 exit(EXIT_FAILURE);
786         }
787
788         for (i = 0; i < buf_size; i++) {
789                 if (valid_buf[i] != BUF_PATTERN_2) {
790                         fprintf(stderr,
791                                 "invalid pattern for 'valid_buf' at %i, expected %hhX, got %hhX\n",
792                                 i, BUF_PATTERN_2, valid_buf[i]);
793                         exit(EXIT_FAILURE);
794                 }
795         }
796
797         /* Unmap buffers. */
798         munmap(broken_buf, page_size);
799         munmap(broken_buf + page_size * 2, page_size);
800         munmap(valid_buf, buf_size);
801         close(fd);
802 }
803
804 #define RCVLOWAT_BUF_SIZE 128
805
806 static void test_stream_poll_rcvlowat_server(const struct test_opts *opts)
807 {
808         int fd;
809         int i;
810
811         fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
812         if (fd < 0) {
813                 perror("accept");
814                 exit(EXIT_FAILURE);
815         }
816
817         /* Send 1 byte. */
818         send_byte(fd, 1, 0);
819
820         control_writeln("SRVSENT");
821
822         /* Wait until client is ready to receive rest of data. */
823         control_expectln("CLNSENT");
824
825         for (i = 0; i < RCVLOWAT_BUF_SIZE - 1; i++)
826                 send_byte(fd, 1, 0);
827
828         /* Keep socket in active state. */
829         control_expectln("POLLDONE");
830
831         close(fd);
832 }
833
834 static void test_stream_poll_rcvlowat_client(const struct test_opts *opts)
835 {
836         unsigned long lowat_val = RCVLOWAT_BUF_SIZE;
837         char buf[RCVLOWAT_BUF_SIZE];
838         struct pollfd fds;
839         short poll_flags;
840         int fd;
841
842         fd = vsock_stream_connect(opts->peer_cid, 1234);
843         if (fd < 0) {
844                 perror("connect");
845                 exit(EXIT_FAILURE);
846         }
847
848         if (setsockopt(fd, SOL_SOCKET, SO_RCVLOWAT,
849                        &lowat_val, sizeof(lowat_val))) {
850                 perror("setsockopt(SO_RCVLOWAT)");
851                 exit(EXIT_FAILURE);
852         }
853
854         control_expectln("SRVSENT");
855
856         /* At this point, server sent 1 byte. */
857         fds.fd = fd;
858         poll_flags = POLLIN | POLLRDNORM;
859         fds.events = poll_flags;
860
861         /* Try to wait for 1 sec. */
862         if (poll(&fds, 1, 1000) < 0) {
863                 perror("poll");
864                 exit(EXIT_FAILURE);
865         }
866
867         /* poll() must return nothing. */
868         if (fds.revents) {
869                 fprintf(stderr, "Unexpected poll result %hx\n",
870                         fds.revents);
871                 exit(EXIT_FAILURE);
872         }
873
874         /* Tell server to send rest of data. */
875         control_writeln("CLNSENT");
876
877         /* Poll for data. */
878         if (poll(&fds, 1, 10000) < 0) {
879                 perror("poll");
880                 exit(EXIT_FAILURE);
881         }
882
883         /* Only these two bits are expected. */
884         if (fds.revents != poll_flags) {
885                 fprintf(stderr, "Unexpected poll result %hx\n",
886                         fds.revents);
887                 exit(EXIT_FAILURE);
888         }
889
890         /* Use MSG_DONTWAIT, if call is going to wait, EAGAIN
891          * will be returned.
892          */
893         recv_buf(fd, buf, sizeof(buf), MSG_DONTWAIT, RCVLOWAT_BUF_SIZE);
894
895         control_writeln("POLLDONE");
896
897         close(fd);
898 }
899
900 #define INV_BUF_TEST_DATA_LEN 512
901
902 static void test_inv_buf_client(const struct test_opts *opts, bool stream)
903 {
904         unsigned char data[INV_BUF_TEST_DATA_LEN] = {0};
905         ssize_t expected_ret;
906         int fd;
907
908         if (stream)
909                 fd = vsock_stream_connect(opts->peer_cid, 1234);
910         else
911                 fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
912
913         if (fd < 0) {
914                 perror("connect");
915                 exit(EXIT_FAILURE);
916         }
917
918         control_expectln("SENDDONE");
919
920         /* Use invalid buffer here. */
921         recv_buf(fd, NULL, sizeof(data), 0, -EFAULT);
922
923         if (stream) {
924                 /* For SOCK_STREAM we must continue reading. */
925                 expected_ret = sizeof(data);
926         } else {
927                 /* For SOCK_SEQPACKET socket's queue must be empty. */
928                 expected_ret = -EAGAIN;
929         }
930
931         recv_buf(fd, data, sizeof(data), MSG_DONTWAIT, expected_ret);
932
933         control_writeln("DONE");
934
935         close(fd);
936 }
937
938 static void test_inv_buf_server(const struct test_opts *opts, bool stream)
939 {
940         unsigned char data[INV_BUF_TEST_DATA_LEN] = {0};
941         int fd;
942
943         if (stream)
944                 fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
945         else
946                 fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
947
948         if (fd < 0) {
949                 perror("accept");
950                 exit(EXIT_FAILURE);
951         }
952
953         send_buf(fd, data, sizeof(data), 0, sizeof(data));
954
955         control_writeln("SENDDONE");
956
957         control_expectln("DONE");
958
959         close(fd);
960 }
961
962 static void test_stream_inv_buf_client(const struct test_opts *opts)
963 {
964         test_inv_buf_client(opts, true);
965 }
966
967 static void test_stream_inv_buf_server(const struct test_opts *opts)
968 {
969         test_inv_buf_server(opts, true);
970 }
971
972 static void test_seqpacket_inv_buf_client(const struct test_opts *opts)
973 {
974         test_inv_buf_client(opts, false);
975 }
976
977 static void test_seqpacket_inv_buf_server(const struct test_opts *opts)
978 {
979         test_inv_buf_server(opts, false);
980 }
981
982 #define HELLO_STR "HELLO"
983 #define WORLD_STR "WORLD"
984
985 static void test_stream_virtio_skb_merge_client(const struct test_opts *opts)
986 {
987         int fd;
988
989         fd = vsock_stream_connect(opts->peer_cid, 1234);
990         if (fd < 0) {
991                 perror("connect");
992                 exit(EXIT_FAILURE);
993         }
994
995         /* Send first skbuff. */
996         send_buf(fd, HELLO_STR, strlen(HELLO_STR), 0, strlen(HELLO_STR));
997
998         control_writeln("SEND0");
999         /* Peer reads part of first skbuff. */
1000         control_expectln("REPLY0");
1001
1002         /* Send second skbuff, it will be appended to the first. */
1003         send_buf(fd, WORLD_STR, strlen(WORLD_STR), 0, strlen(WORLD_STR));
1004
1005         control_writeln("SEND1");
1006         /* Peer reads merged skbuff packet. */
1007         control_expectln("REPLY1");
1008
1009         close(fd);
1010 }
1011
1012 static void test_stream_virtio_skb_merge_server(const struct test_opts *opts)
1013 {
1014         size_t read = 0, to_read;
1015         unsigned char buf[64];
1016         int fd;
1017
1018         fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
1019         if (fd < 0) {
1020                 perror("accept");
1021                 exit(EXIT_FAILURE);
1022         }
1023
1024         control_expectln("SEND0");
1025
1026         /* Read skbuff partially. */
1027         to_read = 2;
1028         recv_buf(fd, buf + read, to_read, 0, to_read);
1029         read += to_read;
1030
1031         control_writeln("REPLY0");
1032         control_expectln("SEND1");
1033
1034         /* Read the rest of both buffers */
1035         to_read = strlen(HELLO_STR WORLD_STR) - read;
1036         recv_buf(fd, buf + read, to_read, 0, to_read);
1037         read += to_read;
1038
1039         /* No more bytes should be there */
1040         to_read = sizeof(buf) - read;
1041         recv_buf(fd, buf + read, to_read, MSG_DONTWAIT, -EAGAIN);
1042
1043         if (memcmp(buf, HELLO_STR WORLD_STR, strlen(HELLO_STR WORLD_STR))) {
1044                 fprintf(stderr, "pattern mismatch\n");
1045                 exit(EXIT_FAILURE);
1046         }
1047
1048         control_writeln("REPLY1");
1049
1050         close(fd);
1051 }
1052
1053 static void test_seqpacket_msg_peek_client(const struct test_opts *opts)
1054 {
1055         return test_msg_peek_client(opts, true);
1056 }
1057
1058 static void test_seqpacket_msg_peek_server(const struct test_opts *opts)
1059 {
1060         return test_msg_peek_server(opts, true);
1061 }
1062
1063 static sig_atomic_t have_sigpipe;
1064
1065 static void sigpipe(int signo)
1066 {
1067         have_sigpipe = 1;
1068 }
1069
1070 static void test_stream_check_sigpipe(int fd)
1071 {
1072         ssize_t res;
1073
1074         have_sigpipe = 0;
1075
1076         res = send(fd, "A", 1, 0);
1077         if (res != -1) {
1078                 fprintf(stderr, "expected send(2) failure, got %zi\n", res);
1079                 exit(EXIT_FAILURE);
1080         }
1081
1082         if (!have_sigpipe) {
1083                 fprintf(stderr, "SIGPIPE expected\n");
1084                 exit(EXIT_FAILURE);
1085         }
1086
1087         have_sigpipe = 0;
1088
1089         res = send(fd, "A", 1, MSG_NOSIGNAL);
1090         if (res != -1) {
1091                 fprintf(stderr, "expected send(2) failure, got %zi\n", res);
1092                 exit(EXIT_FAILURE);
1093         }
1094
1095         if (have_sigpipe) {
1096                 fprintf(stderr, "SIGPIPE not expected\n");
1097                 exit(EXIT_FAILURE);
1098         }
1099 }
1100
1101 static void test_stream_shutwr_client(const struct test_opts *opts)
1102 {
1103         int fd;
1104
1105         struct sigaction act = {
1106                 .sa_handler = sigpipe,
1107         };
1108
1109         sigaction(SIGPIPE, &act, NULL);
1110
1111         fd = vsock_stream_connect(opts->peer_cid, 1234);
1112         if (fd < 0) {
1113                 perror("connect");
1114                 exit(EXIT_FAILURE);
1115         }
1116
1117         if (shutdown(fd, SHUT_WR)) {
1118                 perror("shutdown");
1119                 exit(EXIT_FAILURE);
1120         }
1121
1122         test_stream_check_sigpipe(fd);
1123
1124         control_writeln("CLIENTDONE");
1125
1126         close(fd);
1127 }
1128
1129 static void test_stream_shutwr_server(const struct test_opts *opts)
1130 {
1131         int fd;
1132
1133         fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
1134         if (fd < 0) {
1135                 perror("accept");
1136                 exit(EXIT_FAILURE);
1137         }
1138
1139         control_expectln("CLIENTDONE");
1140
1141         close(fd);
1142 }
1143
1144 static void test_stream_shutrd_client(const struct test_opts *opts)
1145 {
1146         int fd;
1147
1148         struct sigaction act = {
1149                 .sa_handler = sigpipe,
1150         };
1151
1152         sigaction(SIGPIPE, &act, NULL);
1153
1154         fd = vsock_stream_connect(opts->peer_cid, 1234);
1155         if (fd < 0) {
1156                 perror("connect");
1157                 exit(EXIT_FAILURE);
1158         }
1159
1160         control_expectln("SHUTRDDONE");
1161
1162         test_stream_check_sigpipe(fd);
1163
1164         control_writeln("CLIENTDONE");
1165
1166         close(fd);
1167 }
1168
1169 static void test_stream_shutrd_server(const struct test_opts *opts)
1170 {
1171         int fd;
1172
1173         fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
1174         if (fd < 0) {
1175                 perror("accept");
1176                 exit(EXIT_FAILURE);
1177         }
1178
1179         if (shutdown(fd, SHUT_RD)) {
1180                 perror("shutdown");
1181                 exit(EXIT_FAILURE);
1182         }
1183
1184         control_writeln("SHUTRDDONE");
1185         control_expectln("CLIENTDONE");
1186
1187         close(fd);
1188 }
1189
1190 static void test_double_bind_connect_server(const struct test_opts *opts)
1191 {
1192         int listen_fd, client_fd, i;
1193         struct sockaddr_vm sa_client;
1194         socklen_t socklen_client = sizeof(sa_client);
1195
1196         listen_fd = vsock_stream_listen(VMADDR_CID_ANY, 1234);
1197
1198         for (i = 0; i < 2; i++) {
1199                 control_writeln("LISTENING");
1200
1201                 timeout_begin(TIMEOUT);
1202                 do {
1203                         client_fd = accept(listen_fd, (struct sockaddr *)&sa_client,
1204                                            &socklen_client);
1205                         timeout_check("accept");
1206                 } while (client_fd < 0 && errno == EINTR);
1207                 timeout_end();
1208
1209                 if (client_fd < 0) {
1210                         perror("accept");
1211                         exit(EXIT_FAILURE);
1212                 }
1213
1214                 /* Waiting for remote peer to close connection */
1215                 vsock_wait_remote_close(client_fd);
1216         }
1217
1218         close(listen_fd);
1219 }
1220
1221 static void test_double_bind_connect_client(const struct test_opts *opts)
1222 {
1223         int i, client_fd;
1224
1225         for (i = 0; i < 2; i++) {
1226                 /* Wait until server is ready to accept a new connection */
1227                 control_expectln("LISTENING");
1228
1229                 client_fd = vsock_bind_connect(opts->peer_cid, 1234, 4321, SOCK_STREAM);
1230
1231                 close(client_fd);
1232         }
1233 }
1234
1235 #define RCVLOWAT_CREDIT_UPD_BUF_SIZE    (1024 * 128)
1236 /* This define is the same as in 'include/linux/virtio_vsock.h':
1237  * it is used to decide when to send credit update message during
1238  * reading from rx queue of a socket. Value and its usage in
1239  * kernel is important for this test.
1240  */
1241 #define VIRTIO_VSOCK_MAX_PKT_BUF_SIZE   (1024 * 64)
1242
1243 static void test_stream_rcvlowat_def_cred_upd_client(const struct test_opts *opts)
1244 {
1245         size_t buf_size;
1246         void *buf;
1247         int fd;
1248
1249         fd = vsock_stream_connect(opts->peer_cid, 1234);
1250         if (fd < 0) {
1251                 perror("connect");
1252                 exit(EXIT_FAILURE);
1253         }
1254
1255         /* Send 1 byte more than peer's buffer size. */
1256         buf_size = RCVLOWAT_CREDIT_UPD_BUF_SIZE + 1;
1257
1258         buf = malloc(buf_size);
1259         if (!buf) {
1260                 perror("malloc");
1261                 exit(EXIT_FAILURE);
1262         }
1263
1264         /* Wait until peer sets needed buffer size. */
1265         recv_byte(fd, 1, 0);
1266
1267         if (send(fd, buf, buf_size, 0) != buf_size) {
1268                 perror("send failed");
1269                 exit(EXIT_FAILURE);
1270         }
1271
1272         free(buf);
1273         close(fd);
1274 }
1275
1276 static void test_stream_credit_update_test(const struct test_opts *opts,
1277                                            bool low_rx_bytes_test)
1278 {
1279         size_t recv_buf_size;
1280         struct pollfd fds;
1281         size_t buf_size;
1282         void *buf;
1283         int fd;
1284
1285         fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
1286         if (fd < 0) {
1287                 perror("accept");
1288                 exit(EXIT_FAILURE);
1289         }
1290
1291         buf_size = RCVLOWAT_CREDIT_UPD_BUF_SIZE;
1292
1293         if (setsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_SIZE,
1294                        &buf_size, sizeof(buf_size))) {
1295                 perror("setsockopt(SO_VM_SOCKETS_BUFFER_SIZE)");
1296                 exit(EXIT_FAILURE);
1297         }
1298
1299         if (low_rx_bytes_test) {
1300                 /* Set new SO_RCVLOWAT here. This enables sending credit
1301                  * update when number of bytes if our rx queue become <
1302                  * SO_RCVLOWAT value.
1303                  */
1304                 recv_buf_size = 1 + VIRTIO_VSOCK_MAX_PKT_BUF_SIZE;
1305
1306                 if (setsockopt(fd, SOL_SOCKET, SO_RCVLOWAT,
1307                                &recv_buf_size, sizeof(recv_buf_size))) {
1308                         perror("setsockopt(SO_RCVLOWAT)");
1309                         exit(EXIT_FAILURE);
1310                 }
1311         }
1312
1313         /* Send one dummy byte here, because 'setsockopt()' above also
1314          * sends special packet which tells sender to update our buffer
1315          * size. This 'send_byte()' will serialize such packet with data
1316          * reads in a loop below. Sender starts transmission only when
1317          * it receives this single byte.
1318          */
1319         send_byte(fd, 1, 0);
1320
1321         buf = malloc(buf_size);
1322         if (!buf) {
1323                 perror("malloc");
1324                 exit(EXIT_FAILURE);
1325         }
1326
1327         /* Wait until there will be 128KB of data in rx queue. */
1328         while (1) {
1329                 ssize_t res;
1330
1331                 res = recv(fd, buf, buf_size, MSG_PEEK);
1332                 if (res == buf_size)
1333                         break;
1334
1335                 if (res <= 0) {
1336                         fprintf(stderr, "unexpected 'recv()' return: %zi\n", res);
1337                         exit(EXIT_FAILURE);
1338                 }
1339         }
1340
1341         /* There is 128KB of data in the socket's rx queue, dequeue first
1342          * 64KB, credit update is sent if 'low_rx_bytes_test' == true.
1343          * Otherwise, credit update is sent in 'if (!low_rx_bytes_test)'.
1344          */
1345         recv_buf_size = VIRTIO_VSOCK_MAX_PKT_BUF_SIZE;
1346         recv_buf(fd, buf, recv_buf_size, 0, recv_buf_size);
1347
1348         if (!low_rx_bytes_test) {
1349                 recv_buf_size++;
1350
1351                 /* Updating SO_RCVLOWAT will send credit update. */
1352                 if (setsockopt(fd, SOL_SOCKET, SO_RCVLOWAT,
1353                                &recv_buf_size, sizeof(recv_buf_size))) {
1354                         perror("setsockopt(SO_RCVLOWAT)");
1355                         exit(EXIT_FAILURE);
1356                 }
1357         }
1358
1359         fds.fd = fd;
1360         fds.events = POLLIN | POLLRDNORM | POLLERR |
1361                      POLLRDHUP | POLLHUP;
1362
1363         /* This 'poll()' will return once we receive last byte
1364          * sent by client.
1365          */
1366         if (poll(&fds, 1, -1) < 0) {
1367                 perror("poll");
1368                 exit(EXIT_FAILURE);
1369         }
1370
1371         if (fds.revents & POLLERR) {
1372                 fprintf(stderr, "'poll()' error\n");
1373                 exit(EXIT_FAILURE);
1374         }
1375
1376         if (fds.revents & (POLLIN | POLLRDNORM)) {
1377                 recv_buf(fd, buf, recv_buf_size, MSG_DONTWAIT, recv_buf_size);
1378         } else {
1379                 /* These flags must be set, as there is at
1380                  * least 64KB of data ready to read.
1381                  */
1382                 fprintf(stderr, "POLLIN | POLLRDNORM expected\n");
1383                 exit(EXIT_FAILURE);
1384         }
1385
1386         free(buf);
1387         close(fd);
1388 }
1389
1390 static void test_stream_cred_upd_on_low_rx_bytes(const struct test_opts *opts)
1391 {
1392         test_stream_credit_update_test(opts, true);
1393 }
1394
1395 static void test_stream_cred_upd_on_set_rcvlowat(const struct test_opts *opts)
1396 {
1397         test_stream_credit_update_test(opts, false);
1398 }
1399
1400 static struct test_case test_cases[] = {
1401         {
1402                 .name = "SOCK_STREAM connection reset",
1403                 .run_client = test_stream_connection_reset,
1404         },
1405         {
1406                 .name = "SOCK_STREAM bind only",
1407                 .run_client = test_stream_bind_only_client,
1408                 .run_server = test_stream_bind_only_server,
1409         },
1410         {
1411                 .name = "SOCK_STREAM client close",
1412                 .run_client = test_stream_client_close_client,
1413                 .run_server = test_stream_client_close_server,
1414         },
1415         {
1416                 .name = "SOCK_STREAM server close",
1417                 .run_client = test_stream_server_close_client,
1418                 .run_server = test_stream_server_close_server,
1419         },
1420         {
1421                 .name = "SOCK_STREAM multiple connections",
1422                 .run_client = test_stream_multiconn_client,
1423                 .run_server = test_stream_multiconn_server,
1424         },
1425         {
1426                 .name = "SOCK_STREAM MSG_PEEK",
1427                 .run_client = test_stream_msg_peek_client,
1428                 .run_server = test_stream_msg_peek_server,
1429         },
1430         {
1431                 .name = "SOCK_SEQPACKET msg bounds",
1432                 .run_client = test_seqpacket_msg_bounds_client,
1433                 .run_server = test_seqpacket_msg_bounds_server,
1434         },
1435         {
1436                 .name = "SOCK_SEQPACKET MSG_TRUNC flag",
1437                 .run_client = test_seqpacket_msg_trunc_client,
1438                 .run_server = test_seqpacket_msg_trunc_server,
1439         },
1440         {
1441                 .name = "SOCK_SEQPACKET timeout",
1442                 .run_client = test_seqpacket_timeout_client,
1443                 .run_server = test_seqpacket_timeout_server,
1444         },
1445         {
1446                 .name = "SOCK_SEQPACKET invalid receive buffer",
1447                 .run_client = test_seqpacket_invalid_rec_buffer_client,
1448                 .run_server = test_seqpacket_invalid_rec_buffer_server,
1449         },
1450         {
1451                 .name = "SOCK_STREAM poll() + SO_RCVLOWAT",
1452                 .run_client = test_stream_poll_rcvlowat_client,
1453                 .run_server = test_stream_poll_rcvlowat_server,
1454         },
1455         {
1456                 .name = "SOCK_SEQPACKET big message",
1457                 .run_client = test_seqpacket_bigmsg_client,
1458                 .run_server = test_seqpacket_bigmsg_server,
1459         },
1460         {
1461                 .name = "SOCK_STREAM test invalid buffer",
1462                 .run_client = test_stream_inv_buf_client,
1463                 .run_server = test_stream_inv_buf_server,
1464         },
1465         {
1466                 .name = "SOCK_SEQPACKET test invalid buffer",
1467                 .run_client = test_seqpacket_inv_buf_client,
1468                 .run_server = test_seqpacket_inv_buf_server,
1469         },
1470         {
1471                 .name = "SOCK_STREAM virtio skb merge",
1472                 .run_client = test_stream_virtio_skb_merge_client,
1473                 .run_server = test_stream_virtio_skb_merge_server,
1474         },
1475         {
1476                 .name = "SOCK_SEQPACKET MSG_PEEK",
1477                 .run_client = test_seqpacket_msg_peek_client,
1478                 .run_server = test_seqpacket_msg_peek_server,
1479         },
1480         {
1481                 .name = "SOCK_STREAM SHUT_WR",
1482                 .run_client = test_stream_shutwr_client,
1483                 .run_server = test_stream_shutwr_server,
1484         },
1485         {
1486                 .name = "SOCK_STREAM SHUT_RD",
1487                 .run_client = test_stream_shutrd_client,
1488                 .run_server = test_stream_shutrd_server,
1489         },
1490         {
1491                 .name = "SOCK_STREAM MSG_ZEROCOPY",
1492                 .run_client = test_stream_msgzcopy_client,
1493                 .run_server = test_stream_msgzcopy_server,
1494         },
1495         {
1496                 .name = "SOCK_SEQPACKET MSG_ZEROCOPY",
1497                 .run_client = test_seqpacket_msgzcopy_client,
1498                 .run_server = test_seqpacket_msgzcopy_server,
1499         },
1500         {
1501                 .name = "SOCK_STREAM MSG_ZEROCOPY empty MSG_ERRQUEUE",
1502                 .run_client = test_stream_msgzcopy_empty_errq_client,
1503                 .run_server = test_stream_msgzcopy_empty_errq_server,
1504         },
1505         {
1506                 .name = "SOCK_STREAM double bind connect",
1507                 .run_client = test_double_bind_connect_client,
1508                 .run_server = test_double_bind_connect_server,
1509         },
1510         {
1511                 .name = "SOCK_STREAM virtio credit update + SO_RCVLOWAT",
1512                 .run_client = test_stream_rcvlowat_def_cred_upd_client,
1513                 .run_server = test_stream_cred_upd_on_set_rcvlowat,
1514         },
1515         {
1516                 .name = "SOCK_STREAM virtio credit update + low rx_bytes",
1517                 .run_client = test_stream_rcvlowat_def_cred_upd_client,
1518                 .run_server = test_stream_cred_upd_on_low_rx_bytes,
1519         },
1520         {},
1521 };
1522
1523 static const char optstring[] = "";
1524 static const struct option longopts[] = {
1525         {
1526                 .name = "control-host",
1527                 .has_arg = required_argument,
1528                 .val = 'H',
1529         },
1530         {
1531                 .name = "control-port",
1532                 .has_arg = required_argument,
1533                 .val = 'P',
1534         },
1535         {
1536                 .name = "mode",
1537                 .has_arg = required_argument,
1538                 .val = 'm',
1539         },
1540         {
1541                 .name = "peer-cid",
1542                 .has_arg = required_argument,
1543                 .val = 'p',
1544         },
1545         {
1546                 .name = "list",
1547                 .has_arg = no_argument,
1548                 .val = 'l',
1549         },
1550         {
1551                 .name = "skip",
1552                 .has_arg = required_argument,
1553                 .val = 's',
1554         },
1555         {
1556                 .name = "help",
1557                 .has_arg = no_argument,
1558                 .val = '?',
1559         },
1560         {},
1561 };
1562
1563 static void usage(void)
1564 {
1565         fprintf(stderr, "Usage: vsock_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--list] [--skip=<test_id>]\n"
1566                 "\n"
1567                 "  Server: vsock_test --control-port=1234 --mode=server --peer-cid=3\n"
1568                 "  Client: vsock_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
1569                 "\n"
1570                 "Run vsock.ko tests.  Must be launched in both guest\n"
1571                 "and host.  One side must use --mode=client and\n"
1572                 "the other side must use --mode=server.\n"
1573                 "\n"
1574                 "A TCP control socket connection is used to coordinate tests\n"
1575                 "between the client and the server.  The server requires a\n"
1576                 "listen address and the client requires an address to\n"
1577                 "connect to.\n"
1578                 "\n"
1579                 "The CID of the other side must be given with --peer-cid=<cid>.\n"
1580                 "\n"
1581                 "Options:\n"
1582                 "  --help                 This help message\n"
1583                 "  --control-host <host>  Server IP address to connect to\n"
1584                 "  --control-port <port>  Server port to listen on/connect to\n"
1585                 "  --mode client|server   Server or client mode\n"
1586                 "  --peer-cid <cid>       CID of the other side\n"
1587                 "  --list                 List of tests that will be executed\n"
1588                 "  --skip <test_id>       Test ID to skip;\n"
1589                 "                         use multiple --skip options to skip more tests\n"
1590                 );
1591         exit(EXIT_FAILURE);
1592 }
1593
1594 int main(int argc, char **argv)
1595 {
1596         const char *control_host = NULL;
1597         const char *control_port = NULL;
1598         struct test_opts opts = {
1599                 .mode = TEST_MODE_UNSET,
1600                 .peer_cid = VMADDR_CID_ANY,
1601         };
1602
1603         srand(time(NULL));
1604         init_signals();
1605
1606         for (;;) {
1607                 int opt = getopt_long(argc, argv, optstring, longopts, NULL);
1608
1609                 if (opt == -1)
1610                         break;
1611
1612                 switch (opt) {
1613                 case 'H':
1614                         control_host = optarg;
1615                         break;
1616                 case 'm':
1617                         if (strcmp(optarg, "client") == 0)
1618                                 opts.mode = TEST_MODE_CLIENT;
1619                         else if (strcmp(optarg, "server") == 0)
1620                                 opts.mode = TEST_MODE_SERVER;
1621                         else {
1622                                 fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
1623                                 return EXIT_FAILURE;
1624                         }
1625                         break;
1626                 case 'p':
1627                         opts.peer_cid = parse_cid(optarg);
1628                         break;
1629                 case 'P':
1630                         control_port = optarg;
1631                         break;
1632                 case 'l':
1633                         list_tests(test_cases);
1634                         break;
1635                 case 's':
1636                         skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
1637                                   optarg);
1638                         break;
1639                 case '?':
1640                 default:
1641                         usage();
1642                 }
1643         }
1644
1645         if (!control_port)
1646                 usage();
1647         if (opts.mode == TEST_MODE_UNSET)
1648                 usage();
1649         if (opts.peer_cid == VMADDR_CID_ANY)
1650                 usage();
1651
1652         if (!control_host) {
1653                 if (opts.mode != TEST_MODE_SERVER)
1654                         usage();
1655                 control_host = "0.0.0.0";
1656         }
1657
1658         control_init(control_host, control_port,
1659                      opts.mode == TEST_MODE_SERVER);
1660
1661         run_tests(test_cases, &opts);
1662
1663         control_cleanup();
1664         return EXIT_SUCCESS;
1665 }