GNU Linux-libre 5.4.274-gnu1
[releases.git] / arch / um / kernel / tlb.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2000 - 2007 Jeff Dike (jdike@{addtoit,linux.intel}.com)
4  */
5
6 #include <linux/mm.h>
7 #include <linux/module.h>
8 #include <linux/sched/signal.h>
9
10 #include <asm/pgtable.h>
11 #include <asm/tlbflush.h>
12 #include <as-layout.h>
13 #include <mem_user.h>
14 #include <os.h>
15 #include <skas.h>
16 #include <kern_util.h>
17
18 struct host_vm_change {
19         struct host_vm_op {
20                 enum { NONE, MMAP, MUNMAP, MPROTECT } type;
21                 union {
22                         struct {
23                                 unsigned long addr;
24                                 unsigned long len;
25                                 unsigned int prot;
26                                 int fd;
27                                 __u64 offset;
28                         } mmap;
29                         struct {
30                                 unsigned long addr;
31                                 unsigned long len;
32                         } munmap;
33                         struct {
34                                 unsigned long addr;
35                                 unsigned long len;
36                                 unsigned int prot;
37                         } mprotect;
38                 } u;
39         } ops[1];
40         int userspace;
41         int index;
42         struct mm_struct *mm;
43         void *data;
44         int force;
45 };
46
47 #define INIT_HVC(mm, force, userspace) \
48         ((struct host_vm_change) \
49          { .ops         = { { .type = NONE } }, \
50            .mm          = mm, \
51            .data        = NULL, \
52            .userspace   = userspace, \
53            .index       = 0, \
54            .force       = force })
55
56 static void report_enomem(void)
57 {
58         printk(KERN_ERR "UML ran out of memory on the host side! "
59                         "This can happen due to a memory limitation or "
60                         "vm.max_map_count has been reached.\n");
61 }
62
63 static int do_ops(struct host_vm_change *hvc, int end,
64                   int finished)
65 {
66         struct host_vm_op *op;
67         int i, ret = 0;
68
69         for (i = 0; i < end && !ret; i++) {
70                 op = &hvc->ops[i];
71                 switch (op->type) {
72                 case MMAP:
73                         if (hvc->userspace)
74                                 ret = map(&hvc->mm->context.id, op->u.mmap.addr,
75                                           op->u.mmap.len, op->u.mmap.prot,
76                                           op->u.mmap.fd,
77                                           op->u.mmap.offset, finished,
78                                           &hvc->data);
79                         else
80                                 map_memory(op->u.mmap.addr, op->u.mmap.offset,
81                                            op->u.mmap.len, 1, 1, 1);
82                         break;
83                 case MUNMAP:
84                         if (hvc->userspace)
85                                 ret = unmap(&hvc->mm->context.id,
86                                             op->u.munmap.addr,
87                                             op->u.munmap.len, finished,
88                                             &hvc->data);
89                         else
90                                 ret = os_unmap_memory(
91                                         (void *) op->u.munmap.addr,
92                                                       op->u.munmap.len);
93
94                         break;
95                 case MPROTECT:
96                         if (hvc->userspace)
97                                 ret = protect(&hvc->mm->context.id,
98                                               op->u.mprotect.addr,
99                                               op->u.mprotect.len,
100                                               op->u.mprotect.prot,
101                                               finished, &hvc->data);
102                         else
103                                 ret = os_protect_memory(
104                                         (void *) op->u.mprotect.addr,
105                                                         op->u.mprotect.len,
106                                                         1, 1, 1);
107                         break;
108                 default:
109                         printk(KERN_ERR "Unknown op type %d in do_ops\n",
110                                op->type);
111                         BUG();
112                         break;
113                 }
114         }
115
116         if (ret == -ENOMEM)
117                 report_enomem();
118
119         return ret;
120 }
121
122 static int add_mmap(unsigned long virt, unsigned long phys, unsigned long len,
123                     unsigned int prot, struct host_vm_change *hvc)
124 {
125         __u64 offset;
126         struct host_vm_op *last;
127         int fd = -1, ret = 0;
128
129         if (virt + len > STUB_START && virt < STUB_END)
130                 return -EINVAL;
131
132         if (hvc->userspace)
133                 fd = phys_mapping(phys, &offset);
134         else
135                 offset = phys;
136         if (hvc->index != 0) {
137                 last = &hvc->ops[hvc->index - 1];
138                 if ((last->type == MMAP) &&
139                    (last->u.mmap.addr + last->u.mmap.len == virt) &&
140                    (last->u.mmap.prot == prot) && (last->u.mmap.fd == fd) &&
141                    (last->u.mmap.offset + last->u.mmap.len == offset)) {
142                         last->u.mmap.len += len;
143                         return 0;
144                 }
145         }
146
147         if (hvc->index == ARRAY_SIZE(hvc->ops)) {
148                 ret = do_ops(hvc, ARRAY_SIZE(hvc->ops), 0);
149                 hvc->index = 0;
150         }
151
152         hvc->ops[hvc->index++] = ((struct host_vm_op)
153                                   { .type       = MMAP,
154                                     .u = { .mmap = { .addr      = virt,
155                                                      .len       = len,
156                                                      .prot      = prot,
157                                                      .fd        = fd,
158                                                      .offset    = offset }
159                            } });
160         return ret;
161 }
162
163 static int add_munmap(unsigned long addr, unsigned long len,
164                       struct host_vm_change *hvc)
165 {
166         struct host_vm_op *last;
167         int ret = 0;
168
169         if (addr + len > STUB_START && addr < STUB_END)
170                 return -EINVAL;
171
172         if (hvc->index != 0) {
173                 last = &hvc->ops[hvc->index - 1];
174                 if ((last->type == MUNMAP) &&
175                    (last->u.munmap.addr + last->u.mmap.len == addr)) {
176                         last->u.munmap.len += len;
177                         return 0;
178                 }
179         }
180
181         if (hvc->index == ARRAY_SIZE(hvc->ops)) {
182                 ret = do_ops(hvc, ARRAY_SIZE(hvc->ops), 0);
183                 hvc->index = 0;
184         }
185
186         hvc->ops[hvc->index++] = ((struct host_vm_op)
187                                   { .type       = MUNMAP,
188                                     .u = { .munmap = { .addr    = addr,
189                                                        .len     = len } } });
190         return ret;
191 }
192
193 static int add_mprotect(unsigned long addr, unsigned long len,
194                         unsigned int prot, struct host_vm_change *hvc)
195 {
196         struct host_vm_op *last;
197         int ret = 0;
198
199         if (addr + len > STUB_START && addr < STUB_END)
200                 return -EINVAL;
201
202         if (hvc->index != 0) {
203                 last = &hvc->ops[hvc->index - 1];
204                 if ((last->type == MPROTECT) &&
205                    (last->u.mprotect.addr + last->u.mprotect.len == addr) &&
206                    (last->u.mprotect.prot == prot)) {
207                         last->u.mprotect.len += len;
208                         return 0;
209                 }
210         }
211
212         if (hvc->index == ARRAY_SIZE(hvc->ops)) {
213                 ret = do_ops(hvc, ARRAY_SIZE(hvc->ops), 0);
214                 hvc->index = 0;
215         }
216
217         hvc->ops[hvc->index++] = ((struct host_vm_op)
218                                   { .type       = MPROTECT,
219                                     .u = { .mprotect = { .addr  = addr,
220                                                          .len   = len,
221                                                          .prot  = prot } } });
222         return ret;
223 }
224
225 #define ADD_ROUND(n, inc) (((n) + (inc)) & ~((inc) - 1))
226
227 static inline int update_pte_range(pmd_t *pmd, unsigned long addr,
228                                    unsigned long end,
229                                    struct host_vm_change *hvc)
230 {
231         pte_t *pte;
232         int r, w, x, prot, ret = 0;
233
234         pte = pte_offset_kernel(pmd, addr);
235         do {
236                 if ((addr >= STUB_START) && (addr < STUB_END))
237                         continue;
238
239                 r = pte_read(*pte);
240                 w = pte_write(*pte);
241                 x = pte_exec(*pte);
242                 if (!pte_young(*pte)) {
243                         r = 0;
244                         w = 0;
245                 } else if (!pte_dirty(*pte))
246                         w = 0;
247
248                 prot = ((r ? UM_PROT_READ : 0) | (w ? UM_PROT_WRITE : 0) |
249                         (x ? UM_PROT_EXEC : 0));
250                 if (hvc->force || pte_newpage(*pte)) {
251                         if (pte_present(*pte)) {
252                                 if (pte_newpage(*pte))
253                                         ret = add_mmap(addr, pte_val(*pte) & PAGE_MASK,
254                                                        PAGE_SIZE, prot, hvc);
255                         } else
256                                 ret = add_munmap(addr, PAGE_SIZE, hvc);
257                 } else if (pte_newprot(*pte))
258                         ret = add_mprotect(addr, PAGE_SIZE, prot, hvc);
259                 *pte = pte_mkuptodate(*pte);
260         } while (pte++, addr += PAGE_SIZE, ((addr < end) && !ret));
261         return ret;
262 }
263
264 static inline int update_pmd_range(pud_t *pud, unsigned long addr,
265                                    unsigned long end,
266                                    struct host_vm_change *hvc)
267 {
268         pmd_t *pmd;
269         unsigned long next;
270         int ret = 0;
271
272         pmd = pmd_offset(pud, addr);
273         do {
274                 next = pmd_addr_end(addr, end);
275                 if (!pmd_present(*pmd)) {
276                         if (hvc->force || pmd_newpage(*pmd)) {
277                                 ret = add_munmap(addr, next - addr, hvc);
278                                 pmd_mkuptodate(*pmd);
279                         }
280                 }
281                 else ret = update_pte_range(pmd, addr, next, hvc);
282         } while (pmd++, addr = next, ((addr < end) && !ret));
283         return ret;
284 }
285
286 static inline int update_pud_range(pgd_t *pgd, unsigned long addr,
287                                    unsigned long end,
288                                    struct host_vm_change *hvc)
289 {
290         pud_t *pud;
291         unsigned long next;
292         int ret = 0;
293
294         pud = pud_offset(pgd, addr);
295         do {
296                 next = pud_addr_end(addr, end);
297                 if (!pud_present(*pud)) {
298                         if (hvc->force || pud_newpage(*pud)) {
299                                 ret = add_munmap(addr, next - addr, hvc);
300                                 pud_mkuptodate(*pud);
301                         }
302                 }
303                 else ret = update_pmd_range(pud, addr, next, hvc);
304         } while (pud++, addr = next, ((addr < end) && !ret));
305         return ret;
306 }
307
308 void fix_range_common(struct mm_struct *mm, unsigned long start_addr,
309                       unsigned long end_addr, int force)
310 {
311         pgd_t *pgd;
312         struct host_vm_change hvc;
313         unsigned long addr = start_addr, next;
314         int ret = 0, userspace = 1;
315
316         hvc = INIT_HVC(mm, force, userspace);
317         pgd = pgd_offset(mm, addr);
318         do {
319                 next = pgd_addr_end(addr, end_addr);
320                 if (!pgd_present(*pgd)) {
321                         if (force || pgd_newpage(*pgd)) {
322                                 ret = add_munmap(addr, next - addr, &hvc);
323                                 pgd_mkuptodate(*pgd);
324                         }
325                 }
326                 else ret = update_pud_range(pgd, addr, next, &hvc);
327         } while (pgd++, addr = next, ((addr < end_addr) && !ret));
328
329         if (!ret)
330                 ret = do_ops(&hvc, hvc.index, 1);
331
332         /* This is not an else because ret is modified above */
333         if (ret) {
334                 printk(KERN_ERR "fix_range_common: failed, killing current "
335                        "process: %d\n", task_tgid_vnr(current));
336                 /* We are under mmap_sem, release it such that current can terminate */
337                 up_write(&current->mm->mmap_sem);
338                 force_sig(SIGKILL);
339                 do_signal(&current->thread.regs);
340         }
341 }
342
343 static int flush_tlb_kernel_range_common(unsigned long start, unsigned long end)
344 {
345         struct mm_struct *mm;
346         pgd_t *pgd;
347         pud_t *pud;
348         pmd_t *pmd;
349         pte_t *pte;
350         unsigned long addr, last;
351         int updated = 0, err = 0, force = 0, userspace = 0;
352         struct host_vm_change hvc;
353
354         mm = &init_mm;
355         hvc = INIT_HVC(mm, force, userspace);
356         for (addr = start; addr < end;) {
357                 pgd = pgd_offset(mm, addr);
358                 if (!pgd_present(*pgd)) {
359                         last = ADD_ROUND(addr, PGDIR_SIZE);
360                         if (last > end)
361                                 last = end;
362                         if (pgd_newpage(*pgd)) {
363                                 updated = 1;
364                                 err = add_munmap(addr, last - addr, &hvc);
365                                 if (err < 0)
366                                         panic("munmap failed, errno = %d\n",
367                                               -err);
368                         }
369                         addr = last;
370                         continue;
371                 }
372
373                 pud = pud_offset(pgd, addr);
374                 if (!pud_present(*pud)) {
375                         last = ADD_ROUND(addr, PUD_SIZE);
376                         if (last > end)
377                                 last = end;
378                         if (pud_newpage(*pud)) {
379                                 updated = 1;
380                                 err = add_munmap(addr, last - addr, &hvc);
381                                 if (err < 0)
382                                         panic("munmap failed, errno = %d\n",
383                                               -err);
384                         }
385                         addr = last;
386                         continue;
387                 }
388
389                 pmd = pmd_offset(pud, addr);
390                 if (!pmd_present(*pmd)) {
391                         last = ADD_ROUND(addr, PMD_SIZE);
392                         if (last > end)
393                                 last = end;
394                         if (pmd_newpage(*pmd)) {
395                                 updated = 1;
396                                 err = add_munmap(addr, last - addr, &hvc);
397                                 if (err < 0)
398                                         panic("munmap failed, errno = %d\n",
399                                               -err);
400                         }
401                         addr = last;
402                         continue;
403                 }
404
405                 pte = pte_offset_kernel(pmd, addr);
406                 if (!pte_present(*pte) || pte_newpage(*pte)) {
407                         updated = 1;
408                         err = add_munmap(addr, PAGE_SIZE, &hvc);
409                         if (err < 0)
410                                 panic("munmap failed, errno = %d\n",
411                                       -err);
412                         if (pte_present(*pte))
413                                 err = add_mmap(addr, pte_val(*pte) & PAGE_MASK,
414                                                PAGE_SIZE, 0, &hvc);
415                 }
416                 else if (pte_newprot(*pte)) {
417                         updated = 1;
418                         err = add_mprotect(addr, PAGE_SIZE, 0, &hvc);
419                 }
420                 addr += PAGE_SIZE;
421         }
422         if (!err)
423                 err = do_ops(&hvc, hvc.index, 1);
424
425         if (err < 0)
426                 panic("flush_tlb_kernel failed, errno = %d\n", err);
427         return updated;
428 }
429
430 void flush_tlb_page(struct vm_area_struct *vma, unsigned long address)
431 {
432         pgd_t *pgd;
433         pud_t *pud;
434         pmd_t *pmd;
435         pte_t *pte;
436         struct mm_struct *mm = vma->vm_mm;
437         void *flush = NULL;
438         int r, w, x, prot, err = 0;
439         struct mm_id *mm_id;
440
441         address &= PAGE_MASK;
442
443         if (address >= STUB_START && address < STUB_END)
444                 goto kill;
445
446         pgd = pgd_offset(mm, address);
447         if (!pgd_present(*pgd))
448                 goto kill;
449
450         pud = pud_offset(pgd, address);
451         if (!pud_present(*pud))
452                 goto kill;
453
454         pmd = pmd_offset(pud, address);
455         if (!pmd_present(*pmd))
456                 goto kill;
457
458         pte = pte_offset_kernel(pmd, address);
459
460         r = pte_read(*pte);
461         w = pte_write(*pte);
462         x = pte_exec(*pte);
463         if (!pte_young(*pte)) {
464                 r = 0;
465                 w = 0;
466         } else if (!pte_dirty(*pte)) {
467                 w = 0;
468         }
469
470         mm_id = &mm->context.id;
471         prot = ((r ? UM_PROT_READ : 0) | (w ? UM_PROT_WRITE : 0) |
472                 (x ? UM_PROT_EXEC : 0));
473         if (pte_newpage(*pte)) {
474                 if (pte_present(*pte)) {
475                         unsigned long long offset;
476                         int fd;
477
478                         fd = phys_mapping(pte_val(*pte) & PAGE_MASK, &offset);
479                         err = map(mm_id, address, PAGE_SIZE, prot, fd, offset,
480                                   1, &flush);
481                 }
482                 else err = unmap(mm_id, address, PAGE_SIZE, 1, &flush);
483         }
484         else if (pte_newprot(*pte))
485                 err = protect(mm_id, address, PAGE_SIZE, prot, 1, &flush);
486
487         if (err) {
488                 if (err == -ENOMEM)
489                         report_enomem();
490
491                 goto kill;
492         }
493
494         *pte = pte_mkuptodate(*pte);
495
496         return;
497
498 kill:
499         printk(KERN_ERR "Failed to flush page for address 0x%lx\n", address);
500         force_sig(SIGKILL);
501 }
502
503 pgd_t *pgd_offset_proc(struct mm_struct *mm, unsigned long address)
504 {
505         return pgd_offset(mm, address);
506 }
507
508 pud_t *pud_offset_proc(pgd_t *pgd, unsigned long address)
509 {
510         return pud_offset(pgd, address);
511 }
512
513 pmd_t *pmd_offset_proc(pud_t *pud, unsigned long address)
514 {
515         return pmd_offset(pud, address);
516 }
517
518 pte_t *pte_offset_proc(pmd_t *pmd, unsigned long address)
519 {
520         return pte_offset_kernel(pmd, address);
521 }
522
523 pte_t *addr_pte(struct task_struct *task, unsigned long addr)
524 {
525         pgd_t *pgd = pgd_offset(task->mm, addr);
526         pud_t *pud = pud_offset(pgd, addr);
527         pmd_t *pmd = pmd_offset(pud, addr);
528
529         return pte_offset_map(pmd, addr);
530 }
531
532 void flush_tlb_all(void)
533 {
534         /*
535          * Don't bother flushing if this address space is about to be
536          * destroyed.
537          */
538         if (atomic_read(&current->mm->mm_users) == 0)
539                 return;
540
541         flush_tlb_mm(current->mm);
542 }
543
544 void flush_tlb_kernel_range(unsigned long start, unsigned long end)
545 {
546         flush_tlb_kernel_range_common(start, end);
547 }
548
549 void flush_tlb_kernel_vm(void)
550 {
551         flush_tlb_kernel_range_common(start_vm, end_vm);
552 }
553
554 void __flush_tlb_one(unsigned long addr)
555 {
556         flush_tlb_kernel_range_common(addr, addr + PAGE_SIZE);
557 }
558
559 static void fix_range(struct mm_struct *mm, unsigned long start_addr,
560                       unsigned long end_addr, int force)
561 {
562         /*
563          * Don't bother flushing if this address space is about to be
564          * destroyed.
565          */
566         if (atomic_read(&mm->mm_users) == 0)
567                 return;
568
569         fix_range_common(mm, start_addr, end_addr, force);
570 }
571
572 void flush_tlb_range(struct vm_area_struct *vma, unsigned long start,
573                      unsigned long end)
574 {
575         if (vma->vm_mm == NULL)
576                 flush_tlb_kernel_range_common(start, end);
577         else fix_range(vma->vm_mm, start, end, 0);
578 }
579 EXPORT_SYMBOL(flush_tlb_range);
580
581 void flush_tlb_mm_range(struct mm_struct *mm, unsigned long start,
582                         unsigned long end)
583 {
584         fix_range(mm, start, end, 0);
585 }
586
587 void flush_tlb_mm(struct mm_struct *mm)
588 {
589         struct vm_area_struct *vma = mm->mmap;
590
591         while (vma != NULL) {
592                 fix_range(mm, vma->vm_start, vma->vm_end, 0);
593                 vma = vma->vm_next;
594         }
595 }
596
597 void force_flush_all(void)
598 {
599         struct mm_struct *mm = current->mm;
600         struct vm_area_struct *vma = mm->mmap;
601
602         while (vma != NULL) {
603                 fix_range(mm, vma->vm_start, vma->vm_end, 1);
604                 vma = vma->vm_next;
605         }
606 }