Genesis: A Neural Network Architecture Inspired by Human Brain Connections

June 28, 2024

Disclaimer

The current performance of Genesis is not yet comparable to the classic multi-head attention mechanism. However, ongoing optimizations and improvements aim to bridge this gap.

Introduction

In this article, I will walk you through one of my latest projects, Genesis, a neural network architecture designed to mimic human neural connections. This project has been both a challenging and rewarding experience, as it required me to delve deep into the intricacies of neural structures and their functions. The goal was to create a model that could potentially offer more efficient and biologically plausible computations.

GenesisConfig: The Backbone of Configuration

One of the first steps in creating Genesis was defining a robust configuration class. The GenesisConfig class allows for flexible and detailed configurations. It includes parameters for the dimensions of the network, the number of sections, neurons, attentive neurons, dendritics, synapses, and more. This configurability is crucial as it lets me experiment with different network structures and their impact on performance.

@dataclass
class GenesisConfig:
    n_dim: int
    n_sections: int
    n_neurons: int
    n_attentive_neurons: int
    n_attentive_neuron_heads: int
    n_dendritics: int
    n_synapses: int
    n_pos_size: int
    n_neighbors: int
    n_seed: int
    n_input: str
    n_output: str
    n_vocab_size: int
    n_cross_attention: bool
    n_outputs: int

The Genesis Class: Combining Sections and Neurons

The Genesis class is where the magic happens. It combines various sections, each containing neurons that can interact and propagate signals. The architecture supports multiple input and output types, making it versatile for different tasks. For instance, you can configure it for causal language modeling or sequence classification.

Initialization

During initialization, the network configures its input, sections, splits for section inputs and outputs, and the final output layer based on the provided configuration.

class Genesis(nn.Module):
    def __init__(self, config: GenesisConfig, tokenizer: AutoTokenizer = None):
        super(Genesis, self).__init__()
        self.config = config
        self.tokenizer = tokenizer
        self.n_dim = self.config.n_dim
        self.n_inner_dim = self.n_dim // self.config.n_sections if self.config.n_section == "split" else self.n_dim
        self.input = self._init_input()
        self.sections = self._init_sections()
        self.a_splits, self.c_splits, self.combiner = self._init_splits()
        self.output = self._init_output()

Forward Pass

The forward pass handles the data flow through the network, managing neuron activation and interaction between sections. If configured for cross-attention, it processes secondary inputs accordingly.

def forward(self, inputs, attention_mask=None, y=None, labels=None):
    total_neuron_activation_count = 0
    hidden_states = self._process_input(inputs)
    hidden_states = self._process_sections(hidden_states, attention_mask, y, total_neuron_activation_count)
    outputs, loss = self._process_output(hidden_states, labels)
    return GenesisOutput(loss=loss, outputs=outputs, neuron_activation_count=total_neuron_activation_count)

Sections and Neurons: The Core Components

Section Class

Each Section contains a set of neurons. The neurons can be either attentive or non-attentive, with different internal structures and functions. The section class also includes methods for forwarding passes and neuron activation management.

class Section(nn.Module):
    def __init__(self, n_dim, n_neurons, n_attentive_neurons, n_attentive_neuron_heads, n_dendritics, n_synapses, n_pos_size, n_neighbors):
        super(Section, self).__init__()
        self.neurons = self._init_neurons()
        self.indexer = Indexer(self.n_dim, self.neurons, self.n_attentive_neurons)

    def forward(self, x, attention_mask=None, y=None):
        neuron = self.indexer(x)
        x = neuron.forward(x, lambda: self.increment_neuron_count(), attention_mask, y)
        self.reset_neurons()
        return x, self.neuron_activation_count

Neuron Class

The Neuron class is the heart of Genesis. It handles dendritic processing, somatic interaction, axonal signal propagation, and synaptic transmission. Neurons can also establish connections with their neighbors, simulating the interconnected nature of biological neurons.

class Neuron(nn.Module):
    def __init__(self, n_dim, attentive_neuron, n_attentive_neuron_heads, n_dendritics, n_synapses, n_pos_size, n_neighbors):
        super(Neuron, self).__init__()
        self.dendritics = nn.ModuleList([Dendritic(self.n_dim) for _ in range(self.n_dendritics)])
        self.soma = Soma(self.n_dim)
        self.axon = Axon(self.n_dim, self.attentive_neuron, self.n_attentive_neuron_heads)
        self.synapses = nn.ModuleList([Synapse(self.n_dim) for _ in range(self.n_synapses)])
        self.ln = nn.LayerNorm(self.n_dim, bias=False)
        self.neighbors = []

    def forward(self, x, increment_neuron_count_func, attention_mask=None, y=None):
        if self.activated:
            return x
        self.activated = True
        increment_neuron_count_func()
        x = self.process(x, attention_mask, y)
        x = self.propagate(x)
        x = self.connect(x, increment_neuron_count_func)
        return x

Input & Output Layers

Reading & Bridge

The Reading and Bridge classes handle different types of input and output processing. The Reading class is used for text input embedding, while the Bridge class can be used for intermediary processing or as an output layer.

class Reading(nn.Module):
    def __init__(self, n_embed_size, n_vocab_size, n_pos_size):
        super(Reading, self).__init__()
        self.wte = nn.Embedding(self.n_vocab_size, self.n_embed_size)
        self.wpe = nn.Embedding(self.n_pos_size, self.n_embed_size)
        self.fc_out = nn.Linear(self.n_embed_size, self.n_embed_size)
        self.ln = nn.LayerNorm(self.n_embed_size)
        self.act = nn.SiLU()

    def forward(self, input_ids):
        position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.long, device=input_ids.device)
        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids.unsqueeze(0))
        hidden_states = inputs_embeds + position_embeds
        hidden_states = self.fc_out(hidden_states)
        hidden_states = self.ln(hidden_states)
        hidden_states = self.act(hidden_states)
        return hidden_states

CausalLM & SequenceClassifier

These classes define the final output processing, whether it's for language modeling or sequence classification.

class CausalLM(nn.Module):
    def __init__(self, n_dim, n_vocab_size):
        super(CausalLM, self).__init__()
        self.lm_head = nn.Linear(self.n_dim, self.n_vocab_size)

    def forward(self, x):
        return self.lm_head(x)

class SequenceClassifier(nn.Module):
    def __init__(self, n_dim, n_outputs):
        super(SequenceClassifier, self).__init__()
        self.fc_out = nn.Linear(self.n_dim, self.n_outputs)

    def forward(self, x):
        x = torch.mean(x, dim=1)
        x = self.fc_out(x)
        x = torch.softmax(x, dim=-1) if self.n_outputs > 1 else torch.sigmoid(x)
        return x

Conclusion

Genesis represents a step forward in neural network architecture by drawing inspiration from the human brain. Its modular design, configurability, and biologically plausible computations make it a promising approach for various AI tasks. Developing this model has been a fascinating journey, combining theoretical research with practical implementation challenges. I look forward to further optimizing and expanding its capabilities.

This project is open-source, and I welcome collaboration from anyone interested in exploring the intersections of artificial and biological intelligence. Let's push the boundaries of what's possible together!