GmCapsule [gsorg-style]

More memory debugging

bf996f25b9bdf6b074465de866a8d19c75f1918f
diff --git a/gmcapsule/gemini.py b/gmcapsule/gemini.py
index a930c73..43b98a9 100644
--- a/gmcapsule/gemini.py
+++ b/gmcapsule/gemini.py
@@ -2,6 +2,7 @@
 # License: BSD 2-Clause
 
 import fnmatch
+import gc
 import hashlib
 import queue
 import socket
@@ -18,6 +19,54 @@ def report_error(stream, code, msg):
     stream.sendall(f'{code} {msg}\r\n'.encode('utf-8'))
 
 
+memtrace_lock = threading.Lock()
+
+
+def display_memtop(snapshot, prev_snapshot, key_type='lineno', limit=10):
+    import tracemalloc
+    import linecache
+    filters = (
+        tracemalloc.Filter(False, ""),
+        tracemalloc.Filter(False, ""),
+        tracemalloc.Filter(False, "*/linecache.py"),
+        tracemalloc.Filter(False, "*/tracemalloc.py")
+    )
+    snapshot = snapshot.filter_traces(filters)
+    if prev_snapshot:
+        prev_snapshot = prev_snapshot.filter_traces(filters)
+        top_stats = snapshot.compare_to(prev_snapshot, key_type)
+        top_type = 'delta'
+        limit = 200
+    else:
+        top_stats = snapshot.statistics(key_type)
+        top_type = 'malloc'
+
+    with memtrace_lock:
+        print("\n\nTop %s %s" % (limit, top_type))
+        for index, stat in enumerate(top_stats[:limit], 1):
+            frame = stat.traceback[0]
+            if prev_snapshot:
+                if stat.size_diff <= 0:
+                    continue
+                print("#%s: %35s:%-5s \x1b[1m%.1f\x1b[0m KiB (%+.1f KiB) count=%d (%+d)"
+                    % (index, frame.filename[-35:], str(frame.lineno) + ':',
+                    stat.size / 1024, stat.size_diff / 1024, stat.count, stat.count_diff))
+            else:
+                print("#%s: %35s:%-5s \x1b[1m%.1f\x1b[0m KiB count=%d"
+                    % (index, frame.filename[-35:], str(frame.lineno) + ':',
+                    stat.size / 1024, stat.count))
+            line = linecache.getline(frame.filename, frame.lineno).strip()
+            if line:
+                print('\x1b[2m    %s\x1b[0m' % line)
+
+        other = top_stats[limit:]
+        if other:
+            size = sum(stat.size for stat in other)
+            print("%s other: %.1f KiB" % (len(other), size / 1024))
+        total = sum(stat.size for stat in top_stats)
+        print("Total size: %.1f KiB\n\n" % (total / 1024))
+
+
 class Identity:
     def __init__(self, cert):
         self.cert = cert
@@ -81,7 +130,13 @@ class Worker(threading.Thread):
         self.jobs = server.work_queue
 
     def run(self):
+        #if self.server.memtrace:
+        #    import tracemalloc
+        #prev_malloc_snapshot = None
         while True:
+            #if self.server.memtrace:
+            #    malloc_snapshot = tracemalloc.take_snapshot()
+
             job = self.jobs.get()
             if job is None:
                 break
@@ -92,20 +147,28 @@ class Worker(threading.Thread):
             except OpenSSL.SSL.SysCallError as error:
                 self.log(f'OpenSSL error: ' + str(error))
             except Exception as error:
-                import traceback
-                traceback.print_exc()
+                #import traceback
+                #traceback.print_exc()
                 try:
                     report_error(stream, 42, str(error))
                 except:
                     pass
-            finally:
-                try:
-                    stream.shutdown()
-                    stream.close()
-                except:
-                    pass
+
+            try:
+                stream.shutdown()
+                stream.close()
+            except:
+                pass
             del from_addr
             del stream
+            gc.collect()
+
+            # if self.server.memtrace:
+            #     malloc_snapshot = tracemalloc.take_snapshot()
+            #     if prev_malloc_snapshot:
+            #         display_memtop(malloc_snapshot, prev_malloc_snapshot)
+            #     prev_malloc_snapshot = tracemalloc.take_snapshot()
+
 
     def log(self, *args):
         print(time.strftime('%Y-%m-%d %H:%M:%S'), f'[{self.id}]', '--', *args)
@@ -253,34 +316,6 @@ class Worker(threading.Thread):
             report_error(stream, 50, 'Permanent failure')
 
 
-def display_memtop(snapshot, key_type='lineno', limit=10):
-    import tracemalloc
-    import linecache
-    snapshot = snapshot.filter_traces((
-        tracemalloc.Filter(False, ""),
-        tracemalloc.Filter(False, ""),
-        tracemalloc.Filter(False, "*/linecache.py"),
-        tracemalloc.Filter(False, "*/tracemalloc.py")
-    ))
-    top_stats = snapshot.statistics(key_type)
-
-    print("\n\nTop %s lines" % limit)
-    for index, stat in enumerate(top_stats[:limit], 1):
-        frame = stat.traceback[0]
-        print("#%s: %35s:%-5s \x1b[1m%.1f\x1b[0m KiB"
-              % (index, frame.filename[-35:], str(frame.lineno) + ':', stat.size / 1024))
-        line = linecache.getline(frame.filename, frame.lineno).strip()
-        if line:
-            print('\x1b[2m    %s\x1b[0m' % line)
-
-    other = top_stats[limit:]
-    if other:
-        size = sum(stat.size for stat in other)
-        print("%s other: %.1f KiB" % (len(other), size / 1024))
-    total = sum(stat.size for stat in top_stats)
-    print("Total allocated size: %.1f KiB\n\n" % (total / 1024))
-
-
 class Server:
     def __init__(self, hostname_or_hostnames, cert_path, key_path,
                  address='localhost', port=1965,
@@ -331,6 +366,10 @@ class Server:
             self.add_entrypoint('gemini', hostname, key, value)
 
     def run(self, memtrace=False):
+        self.memtrace = memtrace
+        if self.memtrace:
+            import tracemalloc
+            tracemalloc.start(10)
         attempts = 60
         print(f'Opening port {self.port}...')
         while True:
@@ -353,9 +392,8 @@ class Server:
             worker.start()
         print(len(self.workers), 'worker(s) started')
 
-        if memtrace:
-            import tracemalloc
-            tracemalloc.start()
+        snapshot = None
+
         while True:
             try:
                 stream = None
@@ -368,14 +406,30 @@ class Server:
                     print('\nStopping the server...')
                     break
                 except Exception as ex:
-                    import traceback
-                    traceback.print_exc()
+                    #import traceback
+                    #traceback.print_exc()
                     print(ex)
+                    #del traceback
             except Exception as ex:
                 print(ex)
 
-            if memtrace:
-                display_memtop(tracemalloc.take_snapshot())
+            if self.memtrace:
+                time.sleep(2)
+                old_snapshot = snapshot
+                gc.collect()
+                snapshot = tracemalloc.take_snapshot()
+                filters = (
+                    tracemalloc.Filter(False, ""),
+                    tracemalloc.Filter(False, ""),
+                    tracemalloc.Filter(False, "*/linecache.py"),
+                    tracemalloc.Filter(False, "*/tracemalloc.py"),
+                    tracemalloc.Filter(False, "*/mimetypes.py"),
+                    tracemalloc.Filter(False, "*/fnmatch.py")
+                )
+                snapshot = snapshot.filter_traces(filters)
+                top_stats = snapshot.statistics('lineno')
+                display_memtop(snapshot, old_snapshot)
+
 
         # Close the server socket.
         self.sv_conn = None