diff --git a/predtuner/torchutil/summary.py b/predtuner/torchutil/summary.py
index 7491fa3bf563274df7b5737c1fdd71101b66f9d7..25d346f98f19f64e61cda3346b0c467cb1252b84 100644
--- a/predtuner/torchutil/summary.py
+++ b/predtuner/torchutil/summary.py
@@ -1,66 +1,98 @@
 from collections import OrderedDict
-from typing import Tuple
+from typing import Iterable, Tuple
 
 import pandas
 import torch
 import torch.nn as nn
-
 from .indexing import ModuleIndexer
 
+_summary_used = False
+
 
 def get_flops(module: nn.Module, input_shape, output_shape):
-    if output_shape is None:
-        return None
-    n_elem = torch.prod(torch.tensor(output_shape)).item()
-    if isinstance(module, nn.Linear):
-        if input_shape is None:
-            return None
-        _, n = input_shape
+    # Partially following impl here:
+    # https://github.com/juliagusak/flopco-pytorch/blob/c9679785d802f4984c9c5e5d47958e3b82044ce9/flopco/compute_layer_flops.py
+    from torchvision.models.detection.transform import GeneralizedRCNNTransform
+
+    def linear_flops():
+        m, n = input_shape
         k, n_ = module.weight.shape
         assert n == n_
-        return n * n * k
-    if isinstance(module, nn.Conv2d):
+        return m * n * k
+
+    def conv2d_flops():
         _, _, h, w = output_shape
         return module.weight.numel() * h * w
-    if isinstance(module, nn.BatchNorm2d):
-        return 6 * n_elem
-    return None
+
+    def pool2d_flops():
+        ksize = module.kernel_size
+        if isinstance(ksize, int):
+            ksize = ksize, ksize
+        k_area = ksize[0] * ksize[1]
+        return k_area * _get_numel(output_shape)
+
+    def ntimes_input_numel(n: int):
+        return lambda: n * _get_numel(input_shape)
+
+    def ntimes_output_numel(n: int):
+        return lambda: n * _get_numel(output_shape)
+
+    type_dispatch = {
+        nn.Linear: linear_flops,
+        nn.Conv2d: conv2d_flops,
+        nn.BatchNorm2d: ntimes_output_numel(6),
+        nn.ReLU: ntimes_output_numel(1),
+        nn.AvgPool2d: pool2d_flops,
+        nn.MaxPool2d: pool2d_flops,
+        # Resize is likely more than 1x input size, but let's go with that.
+        GeneralizedRCNNTransform: ntimes_input_numel(2),
+    }
+    handler = type_dispatch.get(type(module))
+    if not handler:
+        if not list(module.children()):
+            _print_once(f"Leaf module {module} cannot be handled")
+        return None
+    try:
+        return handler()
+    except RuntimeError as e:
+        _print_once(f'Error "{e}" when handling {module}')
+        return None
 
 
 def get_summary(model: nn.Module, model_args: Tuple) -> pandas.DataFrame:
+    from torchvision.ops.feature_pyramid_network import LastLevelMaxPool
+
     include = lambda m: (
-            not isinstance(m, nn.Sequential) and not isinstance(m, nn.ModuleList) and not (m == model)
+        not isinstance(m, nn.Sequential)
+        and not isinstance(m, nn.ModuleList)
+        and not (m == model)
     )
     indexed = ModuleIndexer(model, include, lambda m: True)
     find_by_module = lambda m: indexed.find_by_module(m)[0]
     summary = OrderedDict()
     hooks = []
+    special_ops = {LastLevelMaxPool: last_level_max_pool_io}
 
     def hook(module: nn.Module, inputs, outputs):
         module_name = find_by_module(module)
-
-        try:
-            input_shape = list(inputs[0].size())
-        except AttributeError:
-            input_shape = None
-        try:
-            if isinstance(outputs, (list, tuple)):
-                output_shape = [[-1] + list(o.size())[1:] for o in outputs]
-            else:
-                output_shape = list(outputs.size())
-        except AttributeError:
-            output_shape = None
+        special_handler = special_ops.get(type(module))
+        if special_handler:
+            input_shape, output_shape, flops = special_handler(module, inputs, outputs)
+        else:
+            input_shape, output_shape, flops = default_io(module, inputs, outputs)
 
         n_params = sum(param.numel() for param in module.parameters())
         trainable = any(param.requires_grad for param in module.parameters())
+        is_leaf = not list(module.children())
 
         summary[module_name] = OrderedDict(
             type=module.__class__.__name__,
             input_shape=input_shape,
             output_shape=output_shape,
             params=n_params,
-            flops=get_flops(module, input_shape, output_shape),
-            trainable=trainable
+            flops=flops,
+            trainable=trainable,
+            is_leaf=is_leaf
         )
 
     def register_hook(module: nn.Module):
@@ -74,4 +106,44 @@ def get_summary(model: nn.Module, model_args: Tuple) -> pandas.DataFrame:
     # remove these hooks
     for h in hooks:
         h.remove()
-    return pandas.DataFrame(summary)
+    global _summary_used
+    _summary_used = True  # Prevent further error printing
+    return pandas.DataFrame(summary).T
+
+
+def last_level_max_pool_io(_, inputs, outputs):
+    input_shapes = [list(i.size()) for i in inputs[0]]
+    output_shapes = [list(o.size()) for o in outputs[0]]
+    total_numel = sum([_get_numel(s) for s in input_shapes])
+    return input_shapes, output_shapes, total_numel
+
+
+def default_handle_sizes(value):
+    try:
+        if isinstance(value, torch.Tensor):
+            return list(value.size())
+        if isinstance(value, dict):
+            return {k: list(v.size()) for k, v in value.items()}
+        if isinstance(value, Iterable):
+            return [list(i.size()) for i in value]
+    except AttributeError as e:
+        _print_once(f"Cannot handle {type(value)}: error {e}")
+        return None
+    _print_once(f"Cannot handle {type(value)}")
+    return None
+
+
+def default_io(module: nn.Module, inputs, outputs):
+    input_shape = default_handle_sizes(inputs[0])
+    output_shape = default_handle_sizes(outputs)
+    return input_shape, output_shape, get_flops(module, input_shape, output_shape)
+
+
+def _get_numel(shape):
+    return torch.prod(torch.tensor(shape)).item()
+
+
+def _print_once(*args, **kwargs):
+    if _summary_used:
+        return
+    print(*args, **kwargs)