GNU Linux-libre 6.8.7-gnu
[releases.git] / tools / testing / vsock / msg_zerocopy_common.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* Some common code for MSG_ZEROCOPY logic
3  *
4  * Copyright (C) 2023 SberDevices.
5  *
6  * Author: Arseniy Krasnov <avkrasnov@salutedevices.com>
7  */
8
9 #include <stdio.h>
10 #include <stdlib.h>
11 #include <sys/types.h>
12 #include <sys/socket.h>
13 #include <linux/errqueue.h>
14
15 #include "msg_zerocopy_common.h"
16
17 void enable_so_zerocopy(int fd)
18 {
19         int val = 1;
20
21         if (setsockopt(fd, SOL_SOCKET, SO_ZEROCOPY, &val, sizeof(val))) {
22                 perror("setsockopt");
23                 exit(EXIT_FAILURE);
24         }
25 }
26
27 void vsock_recv_completion(int fd, const bool *zerocopied)
28 {
29         struct sock_extended_err *serr;
30         struct msghdr msg = { 0 };
31         char cmsg_data[128];
32         struct cmsghdr *cm;
33         ssize_t res;
34
35         msg.msg_control = cmsg_data;
36         msg.msg_controllen = sizeof(cmsg_data);
37
38         res = recvmsg(fd, &msg, MSG_ERRQUEUE);
39         if (res) {
40                 fprintf(stderr, "failed to read error queue: %zi\n", res);
41                 exit(EXIT_FAILURE);
42         }
43
44         cm = CMSG_FIRSTHDR(&msg);
45         if (!cm) {
46                 fprintf(stderr, "cmsg: no cmsg\n");
47                 exit(EXIT_FAILURE);
48         }
49
50         if (cm->cmsg_level != SOL_VSOCK) {
51                 fprintf(stderr, "cmsg: unexpected 'cmsg_level'\n");
52                 exit(EXIT_FAILURE);
53         }
54
55         if (cm->cmsg_type != VSOCK_RECVERR) {
56                 fprintf(stderr, "cmsg: unexpected 'cmsg_type'\n");
57                 exit(EXIT_FAILURE);
58         }
59
60         serr = (void *)CMSG_DATA(cm);
61         if (serr->ee_origin != SO_EE_ORIGIN_ZEROCOPY) {
62                 fprintf(stderr, "serr: wrong origin: %u\n", serr->ee_origin);
63                 exit(EXIT_FAILURE);
64         }
65
66         if (serr->ee_errno) {
67                 fprintf(stderr, "serr: wrong error code: %u\n", serr->ee_errno);
68                 exit(EXIT_FAILURE);
69         }
70
71         /* This flag is used for tests, to check that transmission was
72          * performed as expected: zerocopy or fallback to copy. If NULL
73          * - don't care.
74          */
75         if (!zerocopied)
76                 return;
77
78         if (*zerocopied && (serr->ee_code & SO_EE_CODE_ZEROCOPY_COPIED)) {
79                 fprintf(stderr, "serr: was copy instead of zerocopy\n");
80                 exit(EXIT_FAILURE);
81         }
82
83         if (!*zerocopied && !(serr->ee_code & SO_EE_CODE_ZEROCOPY_COPIED)) {
84                 fprintf(stderr, "serr: was zerocopy instead of copy\n");
85                 exit(EXIT_FAILURE);
86         }
87 }