GNU Linux-libre 5.10.217-gnu1
[releases.git] / arch / um / kernel / tlb.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2000 - 2007 Jeff Dike (jdike@{addtoit,linux.intel}.com)
4  */
5
6 #include <linux/mm.h>
7 #include <linux/module.h>
8 #include <linux/sched/signal.h>
9
10 #include <asm/tlbflush.h>
11 #include <as-layout.h>
12 #include <mem_user.h>
13 #include <os.h>
14 #include <skas.h>
15 #include <kern_util.h>
16
17 struct host_vm_change {
18         struct host_vm_op {
19                 enum { NONE, MMAP, MUNMAP, MPROTECT } type;
20                 union {
21                         struct {
22                                 unsigned long addr;
23                                 unsigned long len;
24                                 unsigned int prot;
25                                 int fd;
26                                 __u64 offset;
27                         } mmap;
28                         struct {
29                                 unsigned long addr;
30                                 unsigned long len;
31                         } munmap;
32                         struct {
33                                 unsigned long addr;
34                                 unsigned long len;
35                                 unsigned int prot;
36                         } mprotect;
37                 } u;
38         } ops[1];
39         int userspace;
40         int index;
41         struct mm_struct *mm;
42         void *data;
43         int force;
44 };
45
46 #define INIT_HVC(mm, force, userspace) \
47         ((struct host_vm_change) \
48          { .ops         = { { .type = NONE } }, \
49            .mm          = mm, \
50            .data        = NULL, \
51            .userspace   = userspace, \
52            .index       = 0, \
53            .force       = force })
54
55 static void report_enomem(void)
56 {
57         printk(KERN_ERR "UML ran out of memory on the host side! "
58                         "This can happen due to a memory limitation or "
59                         "vm.max_map_count has been reached.\n");
60 }
61
62 static int do_ops(struct host_vm_change *hvc, int end,
63                   int finished)
64 {
65         struct host_vm_op *op;
66         int i, ret = 0;
67
68         for (i = 0; i < end && !ret; i++) {
69                 op = &hvc->ops[i];
70                 switch (op->type) {
71                 case MMAP:
72                         if (hvc->userspace)
73                                 ret = map(&hvc->mm->context.id, op->u.mmap.addr,
74                                           op->u.mmap.len, op->u.mmap.prot,
75                                           op->u.mmap.fd,
76                                           op->u.mmap.offset, finished,
77                                           &hvc->data);
78                         else
79                                 map_memory(op->u.mmap.addr, op->u.mmap.offset,
80                                            op->u.mmap.len, 1, 1, 1);
81                         break;
82                 case MUNMAP:
83                         if (hvc->userspace)
84                                 ret = unmap(&hvc->mm->context.id,
85                                             op->u.munmap.addr,
86                                             op->u.munmap.len, finished,
87                                             &hvc->data);
88                         else
89                                 ret = os_unmap_memory(
90                                         (void *) op->u.munmap.addr,
91                                                       op->u.munmap.len);
92
93                         break;
94                 case MPROTECT:
95                         if (hvc->userspace)
96                                 ret = protect(&hvc->mm->context.id,
97                                               op->u.mprotect.addr,
98                                               op->u.mprotect.len,
99                                               op->u.mprotect.prot,
100                                               finished, &hvc->data);
101                         else
102                                 ret = os_protect_memory(
103                                         (void *) op->u.mprotect.addr,
104                                                         op->u.mprotect.len,
105                                                         1, 1, 1);
106                         break;
107                 default:
108                         printk(KERN_ERR "Unknown op type %d in do_ops\n",
109                                op->type);
110                         BUG();
111                         break;
112                 }
113         }
114
115         if (ret == -ENOMEM)
116                 report_enomem();
117
118         return ret;
119 }
120
121 static int add_mmap(unsigned long virt, unsigned long phys, unsigned long len,
122                     unsigned int prot, struct host_vm_change *hvc)
123 {
124         __u64 offset;
125         struct host_vm_op *last;
126         int fd = -1, ret = 0;
127
128         if (virt + len > STUB_START && virt < STUB_END)
129                 return -EINVAL;
130
131         if (hvc->userspace)
132                 fd = phys_mapping(phys, &offset);
133         else
134                 offset = phys;
135         if (hvc->index != 0) {
136                 last = &hvc->ops[hvc->index - 1];
137                 if ((last->type == MMAP) &&
138                    (last->u.mmap.addr + last->u.mmap.len == virt) &&
139                    (last->u.mmap.prot == prot) && (last->u.mmap.fd == fd) &&
140                    (last->u.mmap.offset + last->u.mmap.len == offset)) {
141                         last->u.mmap.len += len;
142                         return 0;
143                 }
144         }
145
146         if (hvc->index == ARRAY_SIZE(hvc->ops)) {
147                 ret = do_ops(hvc, ARRAY_SIZE(hvc->ops), 0);
148                 hvc->index = 0;
149         }
150
151         hvc->ops[hvc->index++] = ((struct host_vm_op)
152                                   { .type       = MMAP,
153                                     .u = { .mmap = { .addr      = virt,
154                                                      .len       = len,
155                                                      .prot      = prot,
156                                                      .fd        = fd,
157                                                      .offset    = offset }
158                            } });
159         return ret;
160 }
161
162 static int add_munmap(unsigned long addr, unsigned long len,
163                       struct host_vm_change *hvc)
164 {
165         struct host_vm_op *last;
166         int ret = 0;
167
168         if (addr + len > STUB_START && addr < STUB_END)
169                 return -EINVAL;
170
171         if (hvc->index != 0) {
172                 last = &hvc->ops[hvc->index - 1];
173                 if ((last->type == MUNMAP) &&
174                    (last->u.munmap.addr + last->u.mmap.len == addr)) {
175                         last->u.munmap.len += len;
176                         return 0;
177                 }
178         }
179
180         if (hvc->index == ARRAY_SIZE(hvc->ops)) {
181                 ret = do_ops(hvc, ARRAY_SIZE(hvc->ops), 0);
182                 hvc->index = 0;
183         }
184
185         hvc->ops[hvc->index++] = ((struct host_vm_op)
186                                   { .type       = MUNMAP,
187                                     .u = { .munmap = { .addr    = addr,
188                                                        .len     = len } } });
189         return ret;
190 }
191
192 static int add_mprotect(unsigned long addr, unsigned long len,
193                         unsigned int prot, struct host_vm_change *hvc)
194 {
195         struct host_vm_op *last;
196         int ret = 0;
197
198         if (addr + len > STUB_START && addr < STUB_END)
199                 return -EINVAL;
200
201         if (hvc->index != 0) {
202                 last = &hvc->ops[hvc->index - 1];
203                 if ((last->type == MPROTECT) &&
204                    (last->u.mprotect.addr + last->u.mprotect.len == addr) &&
205                    (last->u.mprotect.prot == prot)) {
206                         last->u.mprotect.len += len;
207                         return 0;
208                 }
209         }
210
211         if (hvc->index == ARRAY_SIZE(hvc->ops)) {
212                 ret = do_ops(hvc, ARRAY_SIZE(hvc->ops), 0);
213                 hvc->index = 0;
214         }
215
216         hvc->ops[hvc->index++] = ((struct host_vm_op)
217                                   { .type       = MPROTECT,
218                                     .u = { .mprotect = { .addr  = addr,
219                                                          .len   = len,
220                                                          .prot  = prot } } });
221         return ret;
222 }
223
224 #define ADD_ROUND(n, inc) (((n) + (inc)) & ~((inc) - 1))
225
226 static inline int update_pte_range(pmd_t *pmd, unsigned long addr,
227                                    unsigned long end,
228                                    struct host_vm_change *hvc)
229 {
230         pte_t *pte;
231         int r, w, x, prot, ret = 0;
232
233         pte = pte_offset_kernel(pmd, addr);
234         do {
235                 if ((addr >= STUB_START) && (addr < STUB_END))
236                         continue;
237
238                 r = pte_read(*pte);
239                 w = pte_write(*pte);
240                 x = pte_exec(*pte);
241                 if (!pte_young(*pte)) {
242                         r = 0;
243                         w = 0;
244                 } else if (!pte_dirty(*pte))
245                         w = 0;
246
247                 prot = ((r ? UM_PROT_READ : 0) | (w ? UM_PROT_WRITE : 0) |
248                         (x ? UM_PROT_EXEC : 0));
249                 if (hvc->force || pte_newpage(*pte)) {
250                         if (pte_present(*pte)) {
251                                 if (pte_newpage(*pte))
252                                         ret = add_mmap(addr, pte_val(*pte) & PAGE_MASK,
253                                                        PAGE_SIZE, prot, hvc);
254                         } else
255                                 ret = add_munmap(addr, PAGE_SIZE, hvc);
256                 } else if (pte_newprot(*pte))
257                         ret = add_mprotect(addr, PAGE_SIZE, prot, hvc);
258                 *pte = pte_mkuptodate(*pte);
259         } while (pte++, addr += PAGE_SIZE, ((addr < end) && !ret));
260         return ret;
261 }
262
263 static inline int update_pmd_range(pud_t *pud, unsigned long addr,
264                                    unsigned long end,
265                                    struct host_vm_change *hvc)
266 {
267         pmd_t *pmd;
268         unsigned long next;
269         int ret = 0;
270
271         pmd = pmd_offset(pud, addr);
272         do {
273                 next = pmd_addr_end(addr, end);
274                 if (!pmd_present(*pmd)) {
275                         if (hvc->force || pmd_newpage(*pmd)) {
276                                 ret = add_munmap(addr, next - addr, hvc);
277                                 pmd_mkuptodate(*pmd);
278                         }
279                 }
280                 else ret = update_pte_range(pmd, addr, next, hvc);
281         } while (pmd++, addr = next, ((addr < end) && !ret));
282         return ret;
283 }
284
285 static inline int update_pud_range(p4d_t *p4d, unsigned long addr,
286                                    unsigned long end,
287                                    struct host_vm_change *hvc)
288 {
289         pud_t *pud;
290         unsigned long next;
291         int ret = 0;
292
293         pud = pud_offset(p4d, addr);
294         do {
295                 next = pud_addr_end(addr, end);
296                 if (!pud_present(*pud)) {
297                         if (hvc->force || pud_newpage(*pud)) {
298                                 ret = add_munmap(addr, next - addr, hvc);
299                                 pud_mkuptodate(*pud);
300                         }
301                 }
302                 else ret = update_pmd_range(pud, addr, next, hvc);
303         } while (pud++, addr = next, ((addr < end) && !ret));
304         return ret;
305 }
306
307 static inline int update_p4d_range(pgd_t *pgd, unsigned long addr,
308                                    unsigned long end,
309                                    struct host_vm_change *hvc)
310 {
311         p4d_t *p4d;
312         unsigned long next;
313         int ret = 0;
314
315         p4d = p4d_offset(pgd, addr);
316         do {
317                 next = p4d_addr_end(addr, end);
318                 if (!p4d_present(*p4d)) {
319                         if (hvc->force || p4d_newpage(*p4d)) {
320                                 ret = add_munmap(addr, next - addr, hvc);
321                                 p4d_mkuptodate(*p4d);
322                         }
323                 } else
324                         ret = update_pud_range(p4d, addr, next, hvc);
325         } while (p4d++, addr = next, ((addr < end) && !ret));
326         return ret;
327 }
328
329 void fix_range_common(struct mm_struct *mm, unsigned long start_addr,
330                       unsigned long end_addr, int force)
331 {
332         pgd_t *pgd;
333         struct host_vm_change hvc;
334         unsigned long addr = start_addr, next;
335         int ret = 0, userspace = 1;
336
337         hvc = INIT_HVC(mm, force, userspace);
338         pgd = pgd_offset(mm, addr);
339         do {
340                 next = pgd_addr_end(addr, end_addr);
341                 if (!pgd_present(*pgd)) {
342                         if (force || pgd_newpage(*pgd)) {
343                                 ret = add_munmap(addr, next - addr, &hvc);
344                                 pgd_mkuptodate(*pgd);
345                         }
346                 } else
347                         ret = update_p4d_range(pgd, addr, next, &hvc);
348         } while (pgd++, addr = next, ((addr < end_addr) && !ret));
349
350         if (!ret)
351                 ret = do_ops(&hvc, hvc.index, 1);
352
353         /* This is not an else because ret is modified above */
354         if (ret) {
355                 struct mm_id *mm_idp = &current->mm->context.id;
356
357                 printk(KERN_ERR "fix_range_common: failed, killing current "
358                        "process: %d\n", task_tgid_vnr(current));
359                 mm_idp->kill = 1;
360         }
361 }
362
363 static int flush_tlb_kernel_range_common(unsigned long start, unsigned long end)
364 {
365         struct mm_struct *mm;
366         pgd_t *pgd;
367         p4d_t *p4d;
368         pud_t *pud;
369         pmd_t *pmd;
370         pte_t *pte;
371         unsigned long addr, last;
372         int updated = 0, err = 0, force = 0, userspace = 0;
373         struct host_vm_change hvc;
374
375         mm = &init_mm;
376         hvc = INIT_HVC(mm, force, userspace);
377         for (addr = start; addr < end;) {
378                 pgd = pgd_offset(mm, addr);
379                 if (!pgd_present(*pgd)) {
380                         last = ADD_ROUND(addr, PGDIR_SIZE);
381                         if (last > end)
382                                 last = end;
383                         if (pgd_newpage(*pgd)) {
384                                 updated = 1;
385                                 err = add_munmap(addr, last - addr, &hvc);
386                                 if (err < 0)
387                                         panic("munmap failed, errno = %d\n",
388                                               -err);
389                         }
390                         addr = last;
391                         continue;
392                 }
393
394                 p4d = p4d_offset(pgd, addr);
395                 if (!p4d_present(*p4d)) {
396                         last = ADD_ROUND(addr, P4D_SIZE);
397                         if (last > end)
398                                 last = end;
399                         if (p4d_newpage(*p4d)) {
400                                 updated = 1;
401                                 err = add_munmap(addr, last - addr, &hvc);
402                                 if (err < 0)
403                                         panic("munmap failed, errno = %d\n",
404                                               -err);
405                         }
406                         addr = last;
407                         continue;
408                 }
409
410                 pud = pud_offset(p4d, addr);
411                 if (!pud_present(*pud)) {
412                         last = ADD_ROUND(addr, PUD_SIZE);
413                         if (last > end)
414                                 last = end;
415                         if (pud_newpage(*pud)) {
416                                 updated = 1;
417                                 err = add_munmap(addr, last - addr, &hvc);
418                                 if (err < 0)
419                                         panic("munmap failed, errno = %d\n",
420                                               -err);
421                         }
422                         addr = last;
423                         continue;
424                 }
425
426                 pmd = pmd_offset(pud, addr);
427                 if (!pmd_present(*pmd)) {
428                         last = ADD_ROUND(addr, PMD_SIZE);
429                         if (last > end)
430                                 last = end;
431                         if (pmd_newpage(*pmd)) {
432                                 updated = 1;
433                                 err = add_munmap(addr, last - addr, &hvc);
434                                 if (err < 0)
435                                         panic("munmap failed, errno = %d\n",
436                                               -err);
437                         }
438                         addr = last;
439                         continue;
440                 }
441
442                 pte = pte_offset_kernel(pmd, addr);
443                 if (!pte_present(*pte) || pte_newpage(*pte)) {
444                         updated = 1;
445                         err = add_munmap(addr, PAGE_SIZE, &hvc);
446                         if (err < 0)
447                                 panic("munmap failed, errno = %d\n",
448                                       -err);
449                         if (pte_present(*pte))
450                                 err = add_mmap(addr, pte_val(*pte) & PAGE_MASK,
451                                                PAGE_SIZE, 0, &hvc);
452                 }
453                 else if (pte_newprot(*pte)) {
454                         updated = 1;
455                         err = add_mprotect(addr, PAGE_SIZE, 0, &hvc);
456                 }
457                 addr += PAGE_SIZE;
458         }
459         if (!err)
460                 err = do_ops(&hvc, hvc.index, 1);
461
462         if (err < 0)
463                 panic("flush_tlb_kernel failed, errno = %d\n", err);
464         return updated;
465 }
466
467 void flush_tlb_page(struct vm_area_struct *vma, unsigned long address)
468 {
469         pgd_t *pgd;
470         p4d_t *p4d;
471         pud_t *pud;
472         pmd_t *pmd;
473         pte_t *pte;
474         struct mm_struct *mm = vma->vm_mm;
475         void *flush = NULL;
476         int r, w, x, prot, err = 0;
477         struct mm_id *mm_id;
478
479         address &= PAGE_MASK;
480
481         if (address >= STUB_START && address < STUB_END)
482                 goto kill;
483
484         pgd = pgd_offset(mm, address);
485         if (!pgd_present(*pgd))
486                 goto kill;
487
488         p4d = p4d_offset(pgd, address);
489         if (!p4d_present(*p4d))
490                 goto kill;
491
492         pud = pud_offset(p4d, address);
493         if (!pud_present(*pud))
494                 goto kill;
495
496         pmd = pmd_offset(pud, address);
497         if (!pmd_present(*pmd))
498                 goto kill;
499
500         pte = pte_offset_kernel(pmd, address);
501
502         r = pte_read(*pte);
503         w = pte_write(*pte);
504         x = pte_exec(*pte);
505         if (!pte_young(*pte)) {
506                 r = 0;
507                 w = 0;
508         } else if (!pte_dirty(*pte)) {
509                 w = 0;
510         }
511
512         mm_id = &mm->context.id;
513         prot = ((r ? UM_PROT_READ : 0) | (w ? UM_PROT_WRITE : 0) |
514                 (x ? UM_PROT_EXEC : 0));
515         if (pte_newpage(*pte)) {
516                 if (pte_present(*pte)) {
517                         unsigned long long offset;
518                         int fd;
519
520                         fd = phys_mapping(pte_val(*pte) & PAGE_MASK, &offset);
521                         err = map(mm_id, address, PAGE_SIZE, prot, fd, offset,
522                                   1, &flush);
523                 }
524                 else err = unmap(mm_id, address, PAGE_SIZE, 1, &flush);
525         }
526         else if (pte_newprot(*pte))
527                 err = protect(mm_id, address, PAGE_SIZE, prot, 1, &flush);
528
529         if (err) {
530                 if (err == -ENOMEM)
531                         report_enomem();
532
533                 goto kill;
534         }
535
536         *pte = pte_mkuptodate(*pte);
537
538         return;
539
540 kill:
541         printk(KERN_ERR "Failed to flush page for address 0x%lx\n", address);
542         force_sig(SIGKILL);
543 }
544
545 void flush_tlb_all(void)
546 {
547         /*
548          * Don't bother flushing if this address space is about to be
549          * destroyed.
550          */
551         if (atomic_read(&current->mm->mm_users) == 0)
552                 return;
553
554         flush_tlb_mm(current->mm);
555 }
556
557 void flush_tlb_kernel_range(unsigned long start, unsigned long end)
558 {
559         flush_tlb_kernel_range_common(start, end);
560 }
561
562 void flush_tlb_kernel_vm(void)
563 {
564         flush_tlb_kernel_range_common(start_vm, end_vm);
565 }
566
567 void __flush_tlb_one(unsigned long addr)
568 {
569         flush_tlb_kernel_range_common(addr, addr + PAGE_SIZE);
570 }
571
572 static void fix_range(struct mm_struct *mm, unsigned long start_addr,
573                       unsigned long end_addr, int force)
574 {
575         /*
576          * Don't bother flushing if this address space is about to be
577          * destroyed.
578          */
579         if (atomic_read(&mm->mm_users) == 0)
580                 return;
581
582         fix_range_common(mm, start_addr, end_addr, force);
583 }
584
585 void flush_tlb_range(struct vm_area_struct *vma, unsigned long start,
586                      unsigned long end)
587 {
588         if (vma->vm_mm == NULL)
589                 flush_tlb_kernel_range_common(start, end);
590         else fix_range(vma->vm_mm, start, end, 0);
591 }
592 EXPORT_SYMBOL(flush_tlb_range);
593
594 void flush_tlb_mm_range(struct mm_struct *mm, unsigned long start,
595                         unsigned long end)
596 {
597         fix_range(mm, start, end, 0);
598 }
599
600 void flush_tlb_mm(struct mm_struct *mm)
601 {
602         struct vm_area_struct *vma = mm->mmap;
603
604         while (vma != NULL) {
605                 fix_range(mm, vma->vm_start, vma->vm_end, 0);
606                 vma = vma->vm_next;
607         }
608 }
609
610 void force_flush_all(void)
611 {
612         struct mm_struct *mm = current->mm;
613         struct vm_area_struct *vma = mm->mmap;
614
615         while (vma != NULL) {
616                 fix_range(mm, vma->vm_start, vma->vm_end, 1);
617                 vma = vma->vm_next;
618         }
619 }