Skip to main content

Overview

The base module system provides the foundation for all modules in BallonTranslator, including text detectors, OCR engines, inpainters, and translators. It handles module registration, parameter management, model loading, and device management.

Module Registry

BallonTranslator uses a registry pattern to manage different module types. Each module type has its own registry:
from modules.base import init_module_registries

# Initialize all module registries
init_module_registries()

# Initialize specific module types
from modules.base import (
    init_textdetector_registries,
    init_ocr_registries,
    init_inpainter_registries,
    init_translator_registries
)

Module Scripts Configuration

Module discovery is configured through the MODULE_SCRIPTS dictionary:
MODULE_SCRIPTS = {
    'translator': {'module_dir': 'modules/translators', 'module_pattern': r'trans_(.*?).py'},
    'textdetector': {'module_dir': 'modules/textdetector', 'module_pattern': r'detector_(.*?).py'},
    'inpainter': {'module_dir': 'modules/inpaint', 'module_pattern': r'inpaint_(.*?).py'},
    'ocr': {'module_dir': 'modules/ocr', 'module_pattern': r'ocr_(.*?).py'},
}

BaseModule Class

All modules inherit from the BaseModule class, which provides common functionality.

Class Definition

from modules.base import BaseModule

class MyModule(BaseModule):
    params = {
        'device': DEVICE_SELECTOR(),
        'my_param': {
            'type': 'selector',
            'options': [1, 2, 3],
            'value': 1
        }
    }
    
    _load_model_keys = {'model'}  # Model attributes to track
    
    def __init__(self, **params):
        super().__init__(**params)
        self.model = None
    
    def _load_model(self):
        # Load your model here
        pass

Class Attributes

params
Dict
Module parameters configuration. Each parameter can be a simple value or a dictionary with metadata:
params = {
    'simple_param': 1.0,
    'complex_param': {
        'type': 'selector',  # UI widget type
        'options': [1, 2, 3],
        'value': 1,
        'data_type': int,
        'description': 'Parameter description'
    }
}
_load_model_keys
set
Set of attribute names that hold loaded models. Used for automatic model lifecycle management.
_preprocess_hooks
OrderedDict
Preprocessing hooks that run before module execution. Shared across all instances of the class.
_postprocess_hooks
OrderedDict
Postprocessing hooks that run after module execution. Shared across all instances of the class.
download_file_list
List[Dict]
List of files to download for the module. Each entry specifies:
download_file_list = [{
    'url': 'https://example.com/model.ckpt',
    'sha256_pre_calculated': 'abc123...',
    'files': 'data/models/model.ckpt',
    'save_dir': 'data/models',
}]

Methods

Parameter Management

get_param_value
method
Get the current value of a parameter.Parameters:
  • param_key (str): Parameter name
Returns: Parameter value
device = self.get_param_value('device')
set_param_value
method
Set a parameter value with optional type conversion.Parameters:
  • param_key (str): Parameter name
  • param_value: New value
  • convert_dtype (bool): Whether to convert to expected type (default: True)
self.set_param_value('device', 'cuda')
updateParam
method
Update a parameter value. Override this to react to parameter changes.Parameters:
  • param_key (str): Parameter name
  • param_content: New value
def updateParam(self, param_key: str, param_content):
    super().updateParam(param_key, param_content)
    if param_key == 'device':
        self.model.to(self.get_param_value('device'))

Model Lifecycle

load_model
method
Load the module’s model. Acquires a loading lock to prevent concurrent loads.
self.load_model()
_load_model
method
Internal method to load the model. Override this in your module.
def _load_model(self):
    self.model = load_my_model()
unload_model
method
Unload the module’s model to free memory.Parameters:
  • empty_cache (bool): Whether to empty GPU cache after unloading
Returns: bool - Whether any model was deleted
self.unload_model(empty_cache=True)
all_model_loaded
method
Check if all required models are loaded.Returns: bool
if not self.all_model_loaded():
    self.load_model()

Hook Management

register_preprocess_hooks
classmethod
Register preprocessing hooks (shared across all instances).Parameters:
  • callbacks (Union[List, Callable, Dict]): Hook functions
def my_hook(**kwargs):
    print("Preprocessing...")

MyModule.register_preprocess_hooks(my_hook)
register_postprocess_hooks
classmethod
Register postprocessing hooks (shared across all instances).Parameters:
  • callbacks (Union[List, Callable, Dict]): Hook functions
MyModule.register_postprocess_hooks(my_hook)

Properties

low_vram_mode
bool
Whether low VRAM mode is enabled (from params).
debug_mode
bool
Whether debug mode is enabled (from shared config).

Device Detection

is_cpu_intensive
method
Check if module runs on CPU.Returns: bool
is_gpu_intensive
method
Check if module runs on GPU (cuda/mps/xpu/privateuseone).Returns: bool
is_computational_intensive
method
Check if module has device parameter (implies computational load).Returns: bool

Device Management

Available Devices

from modules.base import DEFAULT_DEVICE, AVAILABLE_DEVICES, DEVICE_SELECTOR

print(f"Default device: {DEFAULT_DEVICE}")  # e.g., 'cuda'
print(f"Available: {AVAILABLE_DEVICES}")     # e.g., ['cpu', 'cuda']

# Use in module params
params = {
    'device': DEVICE_SELECTOR()  # Creates device selector with available devices
}

Device Constants

DEFAULT_DEVICE
str
Auto-detected default device (cuda, mps, xpu, or cpu).
AVAILABLE_DEVICES
List[str]
List of all available devices on the system.
BF16_SUPPORTED
bool
Whether bfloat16 is supported on the default device.
GPUINTENSIVE_SET
set
Set of GPU device types: .

Utility Functions

soft_empty_cache
function
Empty GPU cache and run garbage collection.
from modules.base import soft_empty_cache

soft_empty_cache()  # Frees GPU memory
is_nvidia
function
Check if running on NVIDIA CUDA.Returns: bool
is_intel
function
Check if running on Intel XPU.Returns: bool

Torch Utilities

Data Type Mapping

from modules.base import TORCH_DTYPE_MAP

# Convert string to torch dtype
precision = TORCH_DTYPE_MAP['fp16']  # torch.float16

# Available mappings:
# 'fp32' -> torch.float32
# 'fp16' -> torch.float16
# 'bf16' -> torch.bfloat16

Parameter Utilities

Helper Functions

standardize_module_params
function
Convert simple params to standardized dict format.
from modules.base import standardize_module_params

params = {'device': 'cuda', 'size': 1024}
standardize_module_params(params)
# Now: {'device': {'value': 'cuda', 'data_type': str}, ...}
patch_module_params
function
Merge saved config params with module default params.Parameters:
  • cfg_param (dict): Config parameters
  • module_params (dict): Module default parameters
  • module_name (str): Module name for logging
Returns: Patched parameters
merge_config_module_params
function
Merge config with multiple modules’ parameters.Parameters:
  • config_params (Dict): Configuration
  • module_keys (List): Module names
  • get_module (Callable): Function to get module by key
Returns: Merged config

Example: Custom Module

from modules.base import BaseModule, DEVICE_SELECTOR, DEFAULT_DEVICE
import torch

class MyCustomModule(BaseModule):
    """Example custom module."""
    
    params = {
        'device': DEVICE_SELECTOR(),
        'threshold': {
            'type': 'selector',
            'options': [0.5, 0.7, 0.9],
            'value': 0.7
        },
        'description': 'My custom module'
    }
    
    _load_model_keys = {'model', 'processor'}
    
    def __init__(self, **params):
        super().__init__(**params)
        self.device = self.get_param_value('device')
        self.threshold = self.get_param_value('threshold')
        self.model = None
        self.processor = None
    
    def _load_model(self):
        """Load the model."""
        self.model = torch.nn.Linear(10, 10).to(self.device)
        self.processor = lambda x: x
        self.logger.info(f"Model loaded on {self.device}")
    
    def process(self, data):
        """Process data."""
        if not self.all_model_loaded():
            self.load_model()
        
        # Your processing logic
        return self.model(data)
    
    def updateParam(self, param_key: str, param_content):
        """React to parameter changes."""
        super().updateParam(param_key, param_content)
        
        if param_key == 'device':
            new_device = self.get_param_value('device')
            if self.model is not None:
                self.model.to(new_device)
            self.device = new_device
        
        elif param_key == 'threshold':
            self.threshold = self.get_param_value('threshold')

# Usage
module = MyCustomModule()
result = module.process(torch.randn(10))

Registry Class

The Registry class manages module registration and retrieval.
from utils.registry import Registry

# Create a registry
MY_MODULES = Registry('my_modules')
register_my_module = MY_MODULES.register_module

# Register a module
@register_my_module('my_module')
class MyModule:
    pass

# Get a registered module
ModuleClass = MY_MODULES['my_module']
instance = ModuleClass()

Registry Methods

register_module
method
Register a module class.Parameters:
  • name (str | None): Registration name (defaults to class name)
  • force (bool): Override existing registration
  • module (type): Module class
# As decorator
@REGISTRY.register_module('my_name')
class MyClass:
    pass

# As function
REGISTRY.register_module(name='my_name', module=MyClass)
get
method
Get a registered module class by name.Parameters:
  • key (str): Module name
Returns: Module class or None
module_dict
dict
Dictionary of all registered modules.