GNU Linux-libre 5.15.137-gnu
[releases.git] / net / 9p / protocol.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * net/9p/protocol.c
4  *
5  * 9P Protocol Support Code
6  *
7  *  Copyright (C) 2008 by Eric Van Hensbergen <ericvh@gmail.com>
8  *
9  *  Base on code from Anthony Liguori <aliguori@us.ibm.com>
10  *  Copyright (C) 2008 by IBM, Corp.
11  */
12
13 #include <linux/module.h>
14 #include <linux/errno.h>
15 #include <linux/kernel.h>
16 #include <linux/uaccess.h>
17 #include <linux/slab.h>
18 #include <linux/sched.h>
19 #include <linux/stddef.h>
20 #include <linux/types.h>
21 #include <linux/uio.h>
22 #include <net/9p/9p.h>
23 #include <net/9p/client.h>
24 #include "protocol.h"
25
26 #include <trace/events/9p.h>
27
28 static int
29 p9pdu_writef(struct p9_fcall *pdu, int proto_version, const char *fmt, ...);
30
31 void p9stat_free(struct p9_wstat *stbuf)
32 {
33         kfree(stbuf->name);
34         stbuf->name = NULL;
35         kfree(stbuf->uid);
36         stbuf->uid = NULL;
37         kfree(stbuf->gid);
38         stbuf->gid = NULL;
39         kfree(stbuf->muid);
40         stbuf->muid = NULL;
41         kfree(stbuf->extension);
42         stbuf->extension = NULL;
43 }
44 EXPORT_SYMBOL(p9stat_free);
45
46 size_t pdu_read(struct p9_fcall *pdu, void *data, size_t size)
47 {
48         size_t len = min(pdu->size - pdu->offset, size);
49
50         memcpy(data, &pdu->sdata[pdu->offset], len);
51         pdu->offset += len;
52         return size - len;
53 }
54
55 static size_t pdu_write(struct p9_fcall *pdu, const void *data, size_t size)
56 {
57         size_t len = min(pdu->capacity - pdu->size, size);
58
59         memcpy(&pdu->sdata[pdu->size], data, len);
60         pdu->size += len;
61         return size - len;
62 }
63
64 static size_t
65 pdu_write_u(struct p9_fcall *pdu, struct iov_iter *from, size_t size)
66 {
67         size_t len = min(pdu->capacity - pdu->size, size);
68         struct iov_iter i = *from;
69
70         if (!copy_from_iter_full(&pdu->sdata[pdu->size], len, &i))
71                 len = 0;
72
73         pdu->size += len;
74         return size - len;
75 }
76
77 /*      b - int8_t
78  *      w - int16_t
79  *      d - int32_t
80  *      q - int64_t
81  *      s - string
82  *      u - numeric uid
83  *      g - numeric gid
84  *      S - stat
85  *      Q - qid
86  *      D - data blob (int32_t size followed by void *, results are not freed)
87  *      T - array of strings (int16_t count, followed by strings)
88  *      R - array of qids (int16_t count, followed by qids)
89  *      A - stat for 9p2000.L (p9_stat_dotl)
90  *      ? - if optional = 1, continue parsing
91  */
92
93 static int
94 p9pdu_vreadf(struct p9_fcall *pdu, int proto_version, const char *fmt,
95              va_list ap)
96 {
97         const char *ptr;
98         int errcode = 0;
99
100         for (ptr = fmt; *ptr; ptr++) {
101                 switch (*ptr) {
102                 case 'b':{
103                                 int8_t *val = va_arg(ap, int8_t *);
104                                 if (pdu_read(pdu, val, sizeof(*val))) {
105                                         errcode = -EFAULT;
106                                         break;
107                                 }
108                         }
109                         break;
110                 case 'w':{
111                                 int16_t *val = va_arg(ap, int16_t *);
112                                 __le16 le_val;
113                                 if (pdu_read(pdu, &le_val, sizeof(le_val))) {
114                                         errcode = -EFAULT;
115                                         break;
116                                 }
117                                 *val = le16_to_cpu(le_val);
118                         }
119                         break;
120                 case 'd':{
121                                 int32_t *val = va_arg(ap, int32_t *);
122                                 __le32 le_val;
123                                 if (pdu_read(pdu, &le_val, sizeof(le_val))) {
124                                         errcode = -EFAULT;
125                                         break;
126                                 }
127                                 *val = le32_to_cpu(le_val);
128                         }
129                         break;
130                 case 'q':{
131                                 int64_t *val = va_arg(ap, int64_t *);
132                                 __le64 le_val;
133                                 if (pdu_read(pdu, &le_val, sizeof(le_val))) {
134                                         errcode = -EFAULT;
135                                         break;
136                                 }
137                                 *val = le64_to_cpu(le_val);
138                         }
139                         break;
140                 case 's':{
141                                 char **sptr = va_arg(ap, char **);
142                                 uint16_t len;
143
144                                 errcode = p9pdu_readf(pdu, proto_version,
145                                                                 "w", &len);
146                                 if (errcode)
147                                         break;
148
149                                 *sptr = kmalloc(len + 1, GFP_NOFS);
150                                 if (*sptr == NULL) {
151                                         errcode = -ENOMEM;
152                                         break;
153                                 }
154                                 if (pdu_read(pdu, *sptr, len)) {
155                                         errcode = -EFAULT;
156                                         kfree(*sptr);
157                                         *sptr = NULL;
158                                 } else
159                                         (*sptr)[len] = 0;
160                         }
161                         break;
162                 case 'u': {
163                                 kuid_t *uid = va_arg(ap, kuid_t *);
164                                 __le32 le_val;
165                                 if (pdu_read(pdu, &le_val, sizeof(le_val))) {
166                                         errcode = -EFAULT;
167                                         break;
168                                 }
169                                 *uid = make_kuid(&init_user_ns,
170                                                  le32_to_cpu(le_val));
171                         } break;
172                 case 'g': {
173                                 kgid_t *gid = va_arg(ap, kgid_t *);
174                                 __le32 le_val;
175                                 if (pdu_read(pdu, &le_val, sizeof(le_val))) {
176                                         errcode = -EFAULT;
177                                         break;
178                                 }
179                                 *gid = make_kgid(&init_user_ns,
180                                                  le32_to_cpu(le_val));
181                         } break;
182                 case 'Q':{
183                                 struct p9_qid *qid =
184                                     va_arg(ap, struct p9_qid *);
185
186                                 errcode = p9pdu_readf(pdu, proto_version, "bdq",
187                                                       &qid->type, &qid->version,
188                                                       &qid->path);
189                         }
190                         break;
191                 case 'S':{
192                                 struct p9_wstat *stbuf =
193                                     va_arg(ap, struct p9_wstat *);
194
195                                 memset(stbuf, 0, sizeof(struct p9_wstat));
196                                 stbuf->n_uid = stbuf->n_muid = INVALID_UID;
197                                 stbuf->n_gid = INVALID_GID;
198
199                                 errcode =
200                                     p9pdu_readf(pdu, proto_version,
201                                                 "wwdQdddqssss?sugu",
202                                                 &stbuf->size, &stbuf->type,
203                                                 &stbuf->dev, &stbuf->qid,
204                                                 &stbuf->mode, &stbuf->atime,
205                                                 &stbuf->mtime, &stbuf->length,
206                                                 &stbuf->name, &stbuf->uid,
207                                                 &stbuf->gid, &stbuf->muid,
208                                                 &stbuf->extension,
209                                                 &stbuf->n_uid, &stbuf->n_gid,
210                                                 &stbuf->n_muid);
211                                 if (errcode)
212                                         p9stat_free(stbuf);
213                         }
214                         break;
215                 case 'D':{
216                                 uint32_t *count = va_arg(ap, uint32_t *);
217                                 void **data = va_arg(ap, void **);
218
219                                 errcode =
220                                     p9pdu_readf(pdu, proto_version, "d", count);
221                                 if (!errcode) {
222                                         *count =
223                                             min_t(uint32_t, *count,
224                                                   pdu->size - pdu->offset);
225                                         *data = &pdu->sdata[pdu->offset];
226                                 }
227                         }
228                         break;
229                 case 'T':{
230                                 uint16_t *nwname = va_arg(ap, uint16_t *);
231                                 char ***wnames = va_arg(ap, char ***);
232
233                                 errcode = p9pdu_readf(pdu, proto_version,
234                                                                 "w", nwname);
235                                 if (!errcode) {
236                                         *wnames =
237                                             kmalloc_array(*nwname,
238                                                           sizeof(char *),
239                                                           GFP_NOFS);
240                                         if (!*wnames)
241                                                 errcode = -ENOMEM;
242                                 }
243
244                                 if (!errcode) {
245                                         int i;
246
247                                         for (i = 0; i < *nwname; i++) {
248                                                 errcode =
249                                                     p9pdu_readf(pdu,
250                                                                 proto_version,
251                                                                 "s",
252                                                                 &(*wnames)[i]);
253                                                 if (errcode)
254                                                         break;
255                                         }
256                                 }
257
258                                 if (errcode) {
259                                         if (*wnames) {
260                                                 int i;
261
262                                                 for (i = 0; i < *nwname; i++)
263                                                         kfree((*wnames)[i]);
264                                         }
265                                         kfree(*wnames);
266                                         *wnames = NULL;
267                                 }
268                         }
269                         break;
270                 case 'R':{
271                                 uint16_t *nwqid = va_arg(ap, uint16_t *);
272                                 struct p9_qid **wqids =
273                                     va_arg(ap, struct p9_qid **);
274
275                                 *wqids = NULL;
276
277                                 errcode =
278                                     p9pdu_readf(pdu, proto_version, "w", nwqid);
279                                 if (!errcode) {
280                                         *wqids =
281                                             kmalloc_array(*nwqid,
282                                                           sizeof(struct p9_qid),
283                                                           GFP_NOFS);
284                                         if (*wqids == NULL)
285                                                 errcode = -ENOMEM;
286                                 }
287
288                                 if (!errcode) {
289                                         int i;
290
291                                         for (i = 0; i < *nwqid; i++) {
292                                                 errcode =
293                                                     p9pdu_readf(pdu,
294                                                                 proto_version,
295                                                                 "Q",
296                                                                 &(*wqids)[i]);
297                                                 if (errcode)
298                                                         break;
299                                         }
300                                 }
301
302                                 if (errcode) {
303                                         kfree(*wqids);
304                                         *wqids = NULL;
305                                 }
306                         }
307                         break;
308                 case 'A': {
309                                 struct p9_stat_dotl *stbuf =
310                                     va_arg(ap, struct p9_stat_dotl *);
311
312                                 memset(stbuf, 0, sizeof(struct p9_stat_dotl));
313                                 errcode =
314                                     p9pdu_readf(pdu, proto_version,
315                                         "qQdugqqqqqqqqqqqqqqq",
316                                         &stbuf->st_result_mask,
317                                         &stbuf->qid,
318                                         &stbuf->st_mode,
319                                         &stbuf->st_uid, &stbuf->st_gid,
320                                         &stbuf->st_nlink,
321                                         &stbuf->st_rdev, &stbuf->st_size,
322                                         &stbuf->st_blksize, &stbuf->st_blocks,
323                                         &stbuf->st_atime_sec,
324                                         &stbuf->st_atime_nsec,
325                                         &stbuf->st_mtime_sec,
326                                         &stbuf->st_mtime_nsec,
327                                         &stbuf->st_ctime_sec,
328                                         &stbuf->st_ctime_nsec,
329                                         &stbuf->st_btime_sec,
330                                         &stbuf->st_btime_nsec,
331                                         &stbuf->st_gen,
332                                         &stbuf->st_data_version);
333                         }
334                         break;
335                 case '?':
336                         if ((proto_version != p9_proto_2000u) &&
337                                 (proto_version != p9_proto_2000L))
338                                 return 0;
339                         break;
340                 default:
341                         BUG();
342                         break;
343                 }
344
345                 if (errcode)
346                         break;
347         }
348
349         return errcode;
350 }
351
352 int
353 p9pdu_vwritef(struct p9_fcall *pdu, int proto_version, const char *fmt,
354         va_list ap)
355 {
356         const char *ptr;
357         int errcode = 0;
358
359         for (ptr = fmt; *ptr; ptr++) {
360                 switch (*ptr) {
361                 case 'b':{
362                                 int8_t val = va_arg(ap, int);
363                                 if (pdu_write(pdu, &val, sizeof(val)))
364                                         errcode = -EFAULT;
365                         }
366                         break;
367                 case 'w':{
368                                 __le16 val = cpu_to_le16(va_arg(ap, int));
369                                 if (pdu_write(pdu, &val, sizeof(val)))
370                                         errcode = -EFAULT;
371                         }
372                         break;
373                 case 'd':{
374                                 __le32 val = cpu_to_le32(va_arg(ap, int32_t));
375                                 if (pdu_write(pdu, &val, sizeof(val)))
376                                         errcode = -EFAULT;
377                         }
378                         break;
379                 case 'q':{
380                                 __le64 val = cpu_to_le64(va_arg(ap, int64_t));
381                                 if (pdu_write(pdu, &val, sizeof(val)))
382                                         errcode = -EFAULT;
383                         }
384                         break;
385                 case 's':{
386                                 const char *sptr = va_arg(ap, const char *);
387                                 uint16_t len = 0;
388                                 if (sptr)
389                                         len = min_t(size_t, strlen(sptr),
390                                                                 USHRT_MAX);
391
392                                 errcode = p9pdu_writef(pdu, proto_version,
393                                                                 "w", len);
394                                 if (!errcode && pdu_write(pdu, sptr, len))
395                                         errcode = -EFAULT;
396                         }
397                         break;
398                 case 'u': {
399                                 kuid_t uid = va_arg(ap, kuid_t);
400                                 __le32 val = cpu_to_le32(
401                                                 from_kuid(&init_user_ns, uid));
402                                 if (pdu_write(pdu, &val, sizeof(val)))
403                                         errcode = -EFAULT;
404                         } break;
405                 case 'g': {
406                                 kgid_t gid = va_arg(ap, kgid_t);
407                                 __le32 val = cpu_to_le32(
408                                                 from_kgid(&init_user_ns, gid));
409                                 if (pdu_write(pdu, &val, sizeof(val)))
410                                         errcode = -EFAULT;
411                         } break;
412                 case 'Q':{
413                                 const struct p9_qid *qid =
414                                     va_arg(ap, const struct p9_qid *);
415                                 errcode =
416                                     p9pdu_writef(pdu, proto_version, "bdq",
417                                                  qid->type, qid->version,
418                                                  qid->path);
419                         } break;
420                 case 'S':{
421                                 const struct p9_wstat *stbuf =
422                                     va_arg(ap, const struct p9_wstat *);
423                                 errcode =
424                                     p9pdu_writef(pdu, proto_version,
425                                                  "wwdQdddqssss?sugu",
426                                                  stbuf->size, stbuf->type,
427                                                  stbuf->dev, &stbuf->qid,
428                                                  stbuf->mode, stbuf->atime,
429                                                  stbuf->mtime, stbuf->length,
430                                                  stbuf->name, stbuf->uid,
431                                                  stbuf->gid, stbuf->muid,
432                                                  stbuf->extension, stbuf->n_uid,
433                                                  stbuf->n_gid, stbuf->n_muid);
434                         } break;
435                 case 'V':{
436                                 uint32_t count = va_arg(ap, uint32_t);
437                                 struct iov_iter *from =
438                                                 va_arg(ap, struct iov_iter *);
439                                 errcode = p9pdu_writef(pdu, proto_version, "d",
440                                                                         count);
441                                 if (!errcode && pdu_write_u(pdu, from, count))
442                                         errcode = -EFAULT;
443                         }
444                         break;
445                 case 'T':{
446                                 uint16_t nwname = va_arg(ap, int);
447                                 const char **wnames = va_arg(ap, const char **);
448
449                                 errcode = p9pdu_writef(pdu, proto_version, "w",
450                                                                         nwname);
451                                 if (!errcode) {
452                                         int i;
453
454                                         for (i = 0; i < nwname; i++) {
455                                                 errcode =
456                                                     p9pdu_writef(pdu,
457                                                                 proto_version,
458                                                                  "s",
459                                                                  wnames[i]);
460                                                 if (errcode)
461                                                         break;
462                                         }
463                                 }
464                         }
465                         break;
466                 case 'R':{
467                                 uint16_t nwqid = va_arg(ap, int);
468                                 struct p9_qid *wqids =
469                                     va_arg(ap, struct p9_qid *);
470
471                                 errcode = p9pdu_writef(pdu, proto_version, "w",
472                                                                         nwqid);
473                                 if (!errcode) {
474                                         int i;
475
476                                         for (i = 0; i < nwqid; i++) {
477                                                 errcode =
478                                                     p9pdu_writef(pdu,
479                                                                 proto_version,
480                                                                  "Q",
481                                                                  &wqids[i]);
482                                                 if (errcode)
483                                                         break;
484                                         }
485                                 }
486                         }
487                         break;
488                 case 'I':{
489                                 struct p9_iattr_dotl *p9attr = va_arg(ap,
490                                                         struct p9_iattr_dotl *);
491
492                                 errcode = p9pdu_writef(pdu, proto_version,
493                                                         "ddugqqqqq",
494                                                         p9attr->valid,
495                                                         p9attr->mode,
496                                                         p9attr->uid,
497                                                         p9attr->gid,
498                                                         p9attr->size,
499                                                         p9attr->atime_sec,
500                                                         p9attr->atime_nsec,
501                                                         p9attr->mtime_sec,
502                                                         p9attr->mtime_nsec);
503                         }
504                         break;
505                 case '?':
506                         if ((proto_version != p9_proto_2000u) &&
507                                 (proto_version != p9_proto_2000L))
508                                 return 0;
509                         break;
510                 default:
511                         BUG();
512                         break;
513                 }
514
515                 if (errcode)
516                         break;
517         }
518
519         return errcode;
520 }
521
522 int p9pdu_readf(struct p9_fcall *pdu, int proto_version, const char *fmt, ...)
523 {
524         va_list ap;
525         int ret;
526
527         va_start(ap, fmt);
528         ret = p9pdu_vreadf(pdu, proto_version, fmt, ap);
529         va_end(ap);
530
531         return ret;
532 }
533
534 static int
535 p9pdu_writef(struct p9_fcall *pdu, int proto_version, const char *fmt, ...)
536 {
537         va_list ap;
538         int ret;
539
540         va_start(ap, fmt);
541         ret = p9pdu_vwritef(pdu, proto_version, fmt, ap);
542         va_end(ap);
543
544         return ret;
545 }
546
547 int p9stat_read(struct p9_client *clnt, char *buf, int len, struct p9_wstat *st)
548 {
549         struct p9_fcall fake_pdu;
550         int ret;
551
552         fake_pdu.size = len;
553         fake_pdu.capacity = len;
554         fake_pdu.sdata = buf;
555         fake_pdu.offset = 0;
556
557         ret = p9pdu_readf(&fake_pdu, clnt->proto_version, "S", st);
558         if (ret) {
559                 p9_debug(P9_DEBUG_9P, "<<< p9stat_read failed: %d\n", ret);
560                 trace_9p_protocol_dump(clnt, &fake_pdu);
561                 return ret;
562         }
563
564         return fake_pdu.offset;
565 }
566 EXPORT_SYMBOL(p9stat_read);
567
568 int p9pdu_prepare(struct p9_fcall *pdu, int16_t tag, int8_t type)
569 {
570         pdu->id = type;
571         return p9pdu_writef(pdu, 0, "dbw", 0, type, tag);
572 }
573
574 int p9pdu_finalize(struct p9_client *clnt, struct p9_fcall *pdu)
575 {
576         int size = pdu->size;
577         int err;
578
579         pdu->size = 0;
580         err = p9pdu_writef(pdu, 0, "d", size);
581         pdu->size = size;
582
583         trace_9p_protocol_dump(clnt, pdu);
584         p9_debug(P9_DEBUG_9P, ">>> size=%d type: %d tag: %d\n",
585                  pdu->size, pdu->id, pdu->tag);
586
587         return err;
588 }
589
590 void p9pdu_reset(struct p9_fcall *pdu)
591 {
592         pdu->offset = 0;
593         pdu->size = 0;
594 }
595
596 int p9dirent_read(struct p9_client *clnt, char *buf, int len,
597                   struct p9_dirent *dirent)
598 {
599         struct p9_fcall fake_pdu;
600         int ret;
601         char *nameptr;
602
603         fake_pdu.size = len;
604         fake_pdu.capacity = len;
605         fake_pdu.sdata = buf;
606         fake_pdu.offset = 0;
607
608         ret = p9pdu_readf(&fake_pdu, clnt->proto_version, "Qqbs", &dirent->qid,
609                           &dirent->d_off, &dirent->d_type, &nameptr);
610         if (ret) {
611                 p9_debug(P9_DEBUG_9P, "<<< p9dirent_read failed: %d\n", ret);
612                 trace_9p_protocol_dump(clnt, &fake_pdu);
613                 return ret;
614         }
615
616         ret = strscpy(dirent->d_name, nameptr, sizeof(dirent->d_name));
617         if (ret < 0) {
618                 p9_debug(P9_DEBUG_ERROR,
619                          "On the wire dirent name too long: %s\n",
620                          nameptr);
621                 kfree(nameptr);
622                 return ret;
623         }
624         kfree(nameptr);
625
626         return fake_pdu.offset;
627 }
628 EXPORT_SYMBOL(p9dirent_read);