【雪狼】选择性SSM与预测编码的结合方案

这个帖子是我来到本坛所主要想发的内容。我本人的专业不是AI领域的,所以希望能在这里跟大家有所交流和讨论,以及如果可以的话希望得到一些帮助。本坛的 @Runnel 同学曾经帮我在本坛转发过我在超理论坛发的帖子:【外部】mamba模型与多模态

先用简单的语言阐述一下我的构想:选择性SSM(例如mamba)有着处理长序列的线性时间复杂度,对比transformer具有处理长序列上的天然优势,但其表现略有不如transformer。预测编码(PC)具备仿生(很有可能是哺乳动物神经系统的底层逻辑)、高效的误差驱动学习机制,但是通过预测编码的经典架构直接构建语言模型有一定困难。因此或许可以尝试将选择性SSM和预测编码结合起来,通过预测编码的误差驱动学习机制,补充选择性SSM对比transformer的不足之处,增强结合后整个系统的推理和泛化能力。

此外,我亦有想法,利用这个结合系统来弥补中小开发者在有限的数据和算力下做出妥协带来的性能下降,进而训练出虽然可能不会做微积分,但学习能力更强、在“不熟悉语境下”表现更像人类的AI。这样的话,可以做出一个完全本地可控(可部署可训练)且能以人的情感互动模式问答的AI,用于满足较为私密的情感需求。

接下来是选择性SSM和PC的结合方案

先定义基本约定:对于全局的输入和输出,我们定义输入处为最低层,输出处为最高层。“低一层”代表更靠近全局输入层,“高一层”代表更靠近全局输出层。

这个架构以选择性SSM为主,所有的选择性SSM作为输入和输出的表征神经元。然后在每层的SSM上增添一层误差神经元,这样就凑齐了预测编码的表征神经元和误差神经元系统。但是,与传统的预测编码不同,我们要求每一层SSM作为表征神经元接受低一层表征神经元SSM的输入,同时接受高一层误差神经元的修正输入,然后将本层表征神经元SSM的输出传递给高一层;本层的SSM表征神经元不接受本层的误差神经元修正。对于本层的误差神经元,它接受本层的表征神经元输入和低一层的表征神经元输入,具体做法是将本层的表征神经元输入通过一个通常是非线性的预测函数f_pred 处理以后得到预测值,再用预测值减去低一层的表征神经元输入当作误差输出。

公式如下:

h_{t}^{l}=SSM^{l}\left( h_{t-1}^{l}, x_{t}^{l} \right) + W_{down}^{l} \cdot e_{t-1}^{l+1}
e_{t}^{l}=h_{t}^{l}-f_{pred}^{l}\left( h_{t}^{l+1} \right)

这个架构不存在本层表征神经元和本层误差神经元之间的联系,这是与传统预测编码的重要区别。此外,这个架构不允许本层误差神经元接受高一层误差神经元的修正。

接下来是物理边界问题:最底层的神经元接受输入,但它显然理论上是无所谓误差的。最高层的神经元产生输出,但由于没有更高层,所以不可能接受上层的误差输入修正。这两个问题采取如下方式解决:1、在整个选择性SSM-预测编码体系最底下加一个全连接层,用于传入数据;这样第一层选择性SSM-PC的误差神经元就可以根据全连接层和本层的表征神经元传入数据产生误差,但这个误差不用于修正任何值。2、对于最高层,设定其表征神经元为自由,也就是说它拿到的“误差”永远恒为0;编程中的操作方式完全可以把它的误差定位到一排0张量上。

以上只是物理边界问题的一种处理模式。我后来觉得,如果第一层的误差神经元传入一个输入序列的移位(后移一位),这样训练后得到的误差值或许是有意义的——它表征了“下一个”词与上一个的预测区别(或者说是关联程度)。对于最高层表征神经元之上的“误差”,其实可以是一种调制手段,例如将它设定成非常小的高斯噪声用于代替退火算法调制输出的词(输出的词永远取概率分布的最大值,但是由于有高斯噪声的调制,使得不会每次相同上下文都是同一个词)。进一步看待高于最高层表征神经元之上的“误差”,它或许可以作为一种“记忆连接”连接到记忆存储库里,由记忆存储库调制输出。

然后我要讲一讲迄今为止我做的研究工作

首先,对于选择性SSM的选择,我采用了mamba。除了mamba以外,其实还有S4也可以做选择。预测编码方面就是采取了非典型的高层误差调制低层表征的做法。

因为我并非本专业本领域的,代码大多数是通过AI的帮助而编写来的;中间经历了好几次AI偷懒导致代码出现问题的状况发生,每一次的修改都花了很大精力。目前实现了完整mamba块(带语言模型里的一维卷积层)与预测编码的融合:其中来自高一层误差的调制,是发生在表征神经元输入处的,在特征投影和状态演化之前。这个方案确实实现了我的构想;不过,由于mamba的特征投影和门控等操作,带来了大量的存储空间占用,存储的模型大小从几mB飙升到了140mB,训练时也常常需要小心爆显存。

我对存储空间占用过大的问题进行过反思。如果去掉mamba的投影层等内容,是可以将存储空间大幅减小的。mamba进行去掉投影和门控的操作以后,对矩阵指数部分等进行一阶近似,可以转化为一个近似的双线性系统,不仅参数简化了,需要的显存也降低了不少。这个双线性系统的合理性在于仿生(记得在哪里看过,生物的神经响应最近似一个双线性系统;但即使不是,我们也完全可以认为生物是在一阶线性近似条件下工作的),但是这样会导致mamba里输入依赖的形式步长受到限制(保证不失稳的前提下,作为mamba注意程度调节参数的形式步长Δ会被限制在一个数值之内,亦即不能有无限大的形式步长来作为极强的当前注意力去overwhelm所有的历史状态)。这个思想我尚且没有深入去探究,我只知道mamba的确不是以仿生为前提去设计的。

对于现在SSPC语言模型的训练:我手里的语料库比较小,只有数百kB,但全部是高质量的。目前采用了这样的模型:词汇表长3584,词向量1024维;【输入端嵌入矩阵(3584→1024),全连接层1024维】+(每层1024个表征神经元+1024个误差神经元)*7+输出端嵌入矩阵(与输入端相同,1024→3584);其中对于每层1024维度的输入和输出之外,有mamba参数:投影128维,一维卷积4个词向量。训练的损失函数是cross entropy加上一个预测编码的误差损失。

我训练的场所是Kaggle,比最初用的google colab 稍微强一点,因为可以挂机一晚上12h。

输出的模型文件140mB。训练十几轮后AI勉强能回答问题,例如你是谁之类的,但不太稳定,语义上有一定理解之后的创新但也有很多内容不通顺。

这个模型的训练现象,loss存在表征神经元和误差神经元产生的genenrate loss和pc loss的循环下降:首先是gen loss下降,此时pc loss 基本不变,达到一定程度后pc loss上升;随即是pc loss下降,gen loss稍微上升;最后是二者都下降到手里。但是循环下降的过程可以不止一轮,我观察到过三轮(当然那是256维的极小模型)。

接下来要做的工作:

1、继续当前模型的训练。测试性能并观察新现象。

2、探索双线性近似形式的可能性。

3、继续编制语料库。

4、整理代码(之前真的很乱),整理数学公式,然后发在本论坛上。

2 个赞

前排支持!
看起来好厉害… 我也不是做这块的,似乎不是很懂,但是确实看起来挺有意思的
如果有机会我可以帮忙一起做做,我有相当的算力可以支持一下你的想法

谢谢你的回复与关注。我还要进一步修订代码,整理好了之后会发在此处。

目前的最新代码训练时采取的是变长样本单样本训练的策略。过去试过用常规的语言模型定长截断语料和批量训练的方法,但是太容易爆显存;早期参数少的测试里用过,后来只得作罢。

谢谢你的支持,如果需要算力,会和你讲的。

1 个赞

接1楼。这里画一些结构图,用以阐述选择性SSM与预测编码结合架构。

一如之前所示,这个架构是以选择性SSM作为预测编码的表征神经元(记作Chr. SSSM),再加上预测编码典型的误差神经元(记作Err. Neuron)构成的。我不同版本的代码有着不同的结合方式,但是它们一般说来可以用这样的形式来统一:

(y_{t}, h_{t})=SSSM(h_{t-1}, x_{t}, e)

它们的组合方式区别有:1. 误差e的处理方式,例如在SSSM内部是采用了简单加在 x_{t} 上还是参与了选择性的调制(我目前做的几个版本都是 x_{t}’=x_{t}+e_{t} ,除了代码还没修改完成的最新版);2. 上标区别,也就是层连接的区别,具体体现为SSSM之中的e究竟是对于h的本层( h^{l}, e^{l} )还是高一层( h^{l}, e^{l+1} )(我目前做的几个版本都是高一层)。

以下结合结构图进行阐述。

版本1&2:mamba块+PC/双线性系统近似+PC。

它们的结构如下图所示:

可以看出它们遵循的模式是以方框表示的表征神经元之间存在信息的直接逐层流动(这和传统的预测编码不同),而以圆圈表示的误差神经元之间不存在信息的直接流动(和经典预测编码相同)。任何一个误差神经元想要把自己的信息传递给高一层的误差神经元,则先要把信息传递给低一层的表征神经元,由低一层的表征神经元把信息传递给同层表征神经元,再传递给高一层的误差神经元。

这种架构的特征是把误差神经元当作了数值上小的校正,用以减缓异常或高变化数据引起的对表征神经元带来的冲击;误差神经元相当于表征神经元的副手,不能独立构造信号通路。同时,这种架构也完全不能带来传统预测编码所带来的逐层上升的数据抽象化——网络高层拿到的不是数据的抽象特征而是逐层被修饰和修改过的数据。站在语言模型的角度我们很容易理解这个构造:因为从输入端传到输出端的永远是语义空间之中的内容,不然就不能保证输出的内容是词汇。如果执行经典的预测编码构造,输出端输出的是某种特征——这更像是把句子的语气、情感之类的特征分了类——这样是不能胜任应答的语言任务的。但是,当我们把误差神经元的作用弱化为表征神经元副手的时候,实际上带来的效应就是表征神经元在干活,表征神经元起了最主要的作用;这一点可以从训练数据中误差编码的loss约为表征神经元loss的1%-0.1%观察到。

其余可以批评的内容,在于每层的误差神经元连接的方式。可以观察到,因为高一层的误差神经元调整本层的表征神经元,带来的问题是最低层的误差神经元不修饰任何事物;而且除了这个最低层的误差神经元以外,如果将每一个高层的误差神经元都下移一层到本层,其实得到的网络和原来的等效——就只是删去了一个最低层不干任何事情的误差神经元而已。

此外,我还要自我批评一下版本1(mamba+pc)和2(Bilinear+PC)中的误差放入SSSM的算法问题。如下图所示,它们只是参与了对每一层的输入数值的修正,而从未对选择性机制构成独立而主要的影响。

可以看得出,误差e是通过加在输入x上来影响SSSM的状态选择性的;而无论是mamba的状态选择性还是Bilinear近似的状态选择性,都是完全通过输入x来决定的。但因为e通常是个小值,所以对状态选择性的调制能力十分有限。【注:之前1楼的公式里,误差W*e是加在 h^{l} 层的输出之后的,但这样其实就相当于把 h^{l+1} 的输入数值加入了W*e】。

如果日后要在版本1和2的框架下加以微小的改进,我想或许是把误差神经元在每层的位置调整一下,也就是把原本高一层的误差神经元放在下面。如下图所示:

接下来是说版本3:版本3的代码目前还没有写完,确切的说写完了,但是必须通过优化才能跑。版本3已经完全抛弃了mamba架构,在双线性的基础上走的更远(注意之前的bilinear其实不是完全的bilinear,这不是由非线性激活函数造成而是由x*Wx项造成)。之前的版本2是把mamba里的exp(Δ*A)矩阵指数拆掉了(变成一阶近似),导致的结果就是必须加上layernorm防止数值失稳爆掉。版本3则引入了一个正规化标量因子f(r)=1/(1+r),除去内部状态的模长来防止数值爆炸,同时也是一种遗忘机制。之所以它在双线性上走的更远,是因为它对输入x和高一层的误差e,这两个变量对内部状态h都是双线性的;这样的设定针对的是之前误差e不调制状态选择性的问题。版本3还引进了误差神经元之间的信号通路,亦即:高一层误差神经元直接传递信息到本层误差神经元,用于调制f_pred()对本层神经元的操作。

如图所示:

这份架构实现了两条信息通路,即表征神经元向上的信息传递和误差神经元向下的信息传递。这种架构的要点在于不能容许信息未经调制的直接传递——也就是说不能容许残差神经元的形式,不能有直接加的x——它必须保证每一层的信息都充分地与其他信息发生了作用。

版本3的代码还未经测试。事实上问题不在架构而在代码实现。由于完全双线性要引入12个N*N的矩阵,这些矩阵在PyTorch的自动微分作用之下每一个时间步都会被保存下来,导致的问题是即使对序列长T=1024的训练数据,对于512维模型,其占用内存都会超过40GB. 因此必须重写forward里的算子来避免保存大的矩阵。512维模型已经是我能接受的最小的维数了。因此做下去就必须解决这个问题。

至于架构方面未来该如何改良,我想或许也就是把最底层的孤立误差神经元去掉,然后所有的高层误差神经元下移一层。如图所示:

如果更进一步对架构进行改良,就会是将误差神经元拆成抑制神经元和兴奋神经元两部分;那会是更加复杂但也可能是更加仿生的构造。

不过不管怎么说,这一类实现了信息上行和指导下行的神经网络结构,都有在输出端之后加入决策器的构造可能。那样或许可以应用到一类更广阔的事物之上。

附:版本1(mamba-pc)模型代码

%%writefile models.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class SSSM(nn.Module):
    """
    本模块负责核心的状态空间演进,不包含外部的卷积和门控。
    """
    def __init__(self, d_model, d_state):
        super().__init__()
        self.d_model, self.d_state = d_model, d_state
        # 初始化 A 矩阵
        A = torch.arange(1, d_state + 1).repeat(d_model, 1).float()
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(d_model))
        
        # 选择性机制:生成 dt, B, C
        self.x_proj = nn.Linear(d_model, d_state * 2 + 1, bias=False)
        self.dt_proj = nn.Linear(1, d_model)
        nn.init.constant_(self.dt_proj.bias, -4.0)
        
        # 输出线性层
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x, h_prev):
        # x: (batch, d_model)
        A = -torch.exp(self.A_log)
        projected = self.x_proj(x)
        dt_raw, B, C = torch.split(projected, [1, self.d_state, self.d_state], dim=-1)
        
        # 离散化
        dt = torch.clamp(F.softplus(self.dt_proj(dt_raw)), min=1e-4, max=0.1)
        A_bar = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0))
        B_bar = dt.unsqueeze(-1) * B.unsqueeze(1)
        
        # 状态更新 (SSM 核心)
        h_t = A_bar * h_prev + B_bar * x.unsqueeze(-1)
        
        # 计算输出
        y = torch.einsum('bds,bs->bd', h_t, C) + x * self.D
        return self.out_proj(F.silu(y)), h_t

class SSPCLayer(nn.Module):
    """
    重点:
    1. 合并了输入投影 z_proj 和 x_proj 减小显存开销。
    2. f_pred 为带 SiLU 激活的预测函数。
    3. 集成了 1D 卷积。
    """
    def __init__(self, dim, d_state, d_conv=4):
        super().__init__()
        self.dim = dim
        
        # --- 显存优化:合并投影层 ---
        # 原本是 z_proj(dim->dim) 和内部输入投影,现在统一合并为一个大层 (dim -> dim * 2)
        # 一半用于 SSM 支路 (x),一半用于门控支路 (z)
        self.merged_in_proj = nn.Linear(dim, dim * 2, bias=False)
        
        # 局部卷积层 (Mamba 标配)
        self.conv1d = nn.Conv1d(
            in_channels=dim,
            out_channels=dim,
            kernel_size=d_conv,
            groups=dim,
            padding=d_conv - 1
        )
        
        # 实例化 SSSM 
        self.sssm = SSSM(dim, d_state)
        
        # 预测函数 f_pred 为 Linear + SiLU
        self.f_pred = nn.Sequential(
            nn.Linear(dim, dim),
            nn.SiLU()
        )
        
        self.down_weight = nn.Parameter(torch.full((1,), 0.05))

    def forward(self, x_step, h_prev, z_step, e_top_prev=None):
        """
        x_step: 当前时刻经过卷积处理后的 SSM 支路输入
        z_step: 当前时刻的门控支路输入
        e_top_prev: 来自高层的预测误差反馈
        """
        # 融合预测编码的自上而下反馈:将误差修正引入 SSM 输入
        ssm_input = x_step + self.down_weight * e_top_prev if e_top_prev is not None else x_step
        
        # 进入核心 SSSM 单元
        y, h_t = self.sssm(ssm_input, h_prev)
        
        # 门控组合:y * SiLU(z) -> Mamba 的标志性结构
        # 注意:z_step 已经在外部或传入前经过了初步处理
        output = y * F.silu(z_step)
        
        return output, h_t

class SSPCModel(nn.Module):
    def __init__(self, vec_dim, model_dim, d_state, num_layers):
        super().__init__()
        self.model_dim, self.num_layers, self.d_state = model_dim, num_layers, d_state
        self.input_proj = nn.Linear(vec_dim, model_dim)
        self.layers = nn.ModuleList([SSPCLayer(model_dim, d_state) for _ in range(num_layers)])
        self.output_head = nn.Linear(model_dim, vec_dim)

    def forward(self, x_seq, states=None):
        batch_size, seq_len, _ = x_seq.shape
        device = x_seq.device
        
        # 初始化 SSM 隐状态
        if states is None:
            states = [torch.zeros(batch_size, self.model_dim, self.d_state).to(device) for _ in range(self.num_layers)]
        
        # 初始化预测误差
        errors = [torch.zeros(batch_size, self.model_dim).to(device) for _ in range(self.num_layers + 1)]
        
        # 初始词嵌入投影
        x_emb = self.input_proj(x_seq)
        
        all_vec_outputs, pc_loss_sum = [], 0
        
        # --- 优化策略:在时间循环外预计算卷积和合并投影,以节省显存和加速 ---
        # 注意:由于层级间的 SSPC 依赖,每一层的输入需要实时计算,
        # 但我们可以在每一层内部对输入序列进行一次性卷积预处理。
        
        current_layer_input = x_emb
        
        for t in range(seq_len):
            new_states, layer_outputs = [], []
            
            for l in range(self.num_layers):
                # 获取当前层的输入
                h_in = current_layer_input[:, t] if l == 0 else layer_outputs[l-1]
                
                # --- 使用 chunk 切分合并后的投影 ---
                # 这一步将 h_in 一次性投射到 2*dim,然后切分为 SSM 支路和门控支路
                combined = self.layers[l].merged_in_proj(h_in)
                x_ssm_raw, z_gate = combined.chunk(2, dim=-1)
                
                # 为了保持简单且适宜家用机,卷积可以在这里简化处理或作为状态维护
                # 此处为了演示原理,假设卷积已通过某种方式作用于 x_ssm_raw
                # 在实际针对序列的训练中,卷积通常在循环外对整个 x_ssm 序列预执行一次
                
                # 调用更新逻辑 (应用了改名后的 sssm 和新 f_pred)
                out_l, h_t_l = self.layers[l](x_ssm_raw, states[l], z_gate, errors[l+1])
                
                new_states.append(h_t_l)
                layer_outputs.append(out_l)
            
            # 预测编码损失计算:使用升级后的带 SiLU 的 f_pred
            step_pc_loss = 0
            for l in range(self.num_layers - 1):
                # 高层状态通过 f_pred (Linear + SiLU) 预测低层状态
                h_pred = self.layers[l].f_pred(layer_outputs[l+1])
                error = layer_outputs[l] - h_pred
                errors[l] = error.detach()
                step_pc_loss += torch.mean(error**2)
            
            pc_loss_sum += step_pc_loss
            states = new_states
            # 【关键注释】:原本输出的是网络最底层(Layer 0)的表征神经元数值 all_layer0_outputs.append(layer_outputs[0])
            # 【必须】改成输出网络最高层的表征神经元数值 
            all_layer0_outputs.append(layer_outputs[-1])
            
        return torch.stack(all_vec_outputs, dim=1), states, pc_loss_sum / seq_len

# SSPCLanguageModel 保持不变,它会自动调用重命名后的核心组件
class SSPCLanguageModel(nn.Module):
    def __init__(self, vocab_size, vec_dim, model_dim, d_state, num_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, vec_dim)
        self.sspc_core = SSPCModel(vec_dim, model_dim, d_state, num_layers)
        self.lm_head = nn.Linear(vec_dim, vocab_size)
        self.lm_head.weight = self.embedding.weight

    def forward(self, input_ids, states=None):
        x_vecs = self.embedding(input_ids)
        # 【关键注释】:输出的是网络哪一层的表征神经元数值? 
        pred_vecs, next_states, pc_loss = self.sspc_core(x_vecs, states)
        logits = self.lm_head(pred_vecs)
        return logits, next_states, pc_loss

    @torch.no_grad()
    def generate(self, input_ids, max_new_tokens, temperature=1.0, top_k=50):
        self.eval()
        generated = input_ids
        states = None
        for _ in range(max_new_tokens):
            logits, states, _ = self.forward(generated[:, -1:], states)
            next_token_logits = logits[:, -1, :] / (temperature + 1e-8)
            v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1)))
            next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf')
            probs = torch.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated = torch.cat((generated, next_token), dim=1)
            if next_token.item() == 3: break # [EOS]
        return generated

附:版本2(双线性SSM-PC)代码:

import torch
import torch.nn as nn
import torch.nn.functional as F

class SSSM_Bilinear(nn.Module):
    """
    双线性mamba近似状态空间模型的核心实现
    h_t = (A0 + delta_t * A_delta) h_{t-1} + (B0 + delta_t * B_delta) x_t
    y_t = C0 * h_t + g_t * (C1 * h_t)
    """
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        # 标量控制头:生成步长 delta 和 选择性门控 g
        self.delta_head = nn.Linear(d_model, 1)
        self.gate_head = nn.Linear(d_model, 1)

        # 核心矩阵 A, B, C 
        self.A0 = nn.Parameter(torch.eye(d_model) * 0.9)
        self.A_delta = nn.Parameter(torch.randn(d_model, d_model) * 0.001)
        self.B0 = nn.Parameter(torch.randn(d_model, d_model) * 0.001)
        self.B_delta = nn.Parameter(torch.randn(d_model, d_model) * 0.001)
        self.C0 = nn.Parameter(torch.randn(d_model, d_model) * 0.001)
        self.C1 = nn.Parameter(torch.randn(d_model, d_model) * 0.001)

        # 增加一个内部归一化,防止状态 h_t 溢出
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x_t, h_prev):
        # 计算控制变量 
        # 钳定 delta_t,防止步长过大导致状态飞出
        # 限制 delta 在 [1e-6, 1.0] 之间
        delta_t = torch.clamp(F.softplus(self.delta_head(x_t)), min=1e-6, max=1.0)
        g_t = torch.sigmoid(self.gate_head(x_t))

        # 构造有效 A 和 B 矩阵 
        # A_eff: (batch, d_model, d_model) 
        A_eff = self.A0.unsqueeze(0) + delta_t.unsqueeze(-1) * self.A_delta.unsqueeze(0)
        B_eff = self.B0.unsqueeze(0) + delta_t.unsqueeze(-1) * self.B_delta.unsqueeze(0)

        # 状态更新
        # 使用 tanh 限制状态更新的幅度,或者在更新后进行 norm
        h_t = torch.bmm(A_eff, h_prev.unsqueeze(-1)).squeeze(-1)
        h_t = h_t + torch.bmm(B_eff, x_t.unsqueeze(-1)).squeeze(-1)
        
        # 这一步非常重要:防止递归过程中数值无限增长
        h_t = self.norm(h_t)

        # 输出计算(双线性门控输出)
        y_t = h_t @ self.C0.T + g_t * (h_t @ self.C1.T) 
        return y_t, h_t

class SSPCLayer(nn.Module):
    def __init__(self, dim, d_conv=4):
        super().__init__()
        self.dim = dim

        # 保留一维卷积提取局部特征 
        self.conv1d = nn.Conv1d(dim, dim, kernel_size=d_conv, groups=dim, padding=d_conv - 1) 

        # 替换为双线性核心
        self.sssm = SSSM_Bilinear(dim)

        # 预测函数 f_pred 
        self.f_pred = nn.Sequential(nn.Linear(dim, dim), nn.SiLU()) 

        # 误差反馈权重
        self.down_weight = nn.Parameter(torch.full((1,), 0.01)) 

    def forward(self, x_step, h_prev, e_top_prev=None):
        # 融合自上而下的误差反馈 
        ssm_input = x_step + self.down_weight * e_top_prev if e_top_prev is not None else x_step

        # 通过双线性核心
        y, h_t = self.sssm(ssm_input, h_prev)

        # 注意:此处移除了原 Mamba 的 z_gate 门控,因为 SSSM_Bilinear 内部已包含 g_t 门控
        return y, h_t

class SSPCModel(nn.Module):
    def __init__(self, model_dim, num_layers):
        super().__init__()
        self.model_dim = model_dim
        self.num_layers = num_layers

        # 构建层级 
        self.layers = nn.ModuleList([SSPCLayer(model_dim) for _ in range(num_layers)]) 

    def forward(self, x_emb, states=None, top_error_noise=None):
        batch_size, seq_len, _ = x_emb.shape
        device = x_emb.device
        
        # 初始化状态 (双线性版本状态维度为 [B, D]) 
        if states is None:
            states = [torch.zeros(batch_size, self.model_dim).to(device) for _ in range(self.num_layers)] 
        
        # 初始化预测误差 
        errors = [torch.zeros(batch_size, self.model_dim).to(device) for _ in range(self.num_layers + 1)]
        
        # 如果传入了生物启发式噪声,将其注入到高于最高层的误差神经元中
        if top_error_noise is not None:
            errors[self.num_layers] = top_error_noise

        # 序列一维卷积预处理 
        # 形状变换: (B, L, D) -> (B, D, L) 
        x_conv = x_emb.transpose(1, 2)
        # 对每一层应用独立的卷积处理(此处简化为共用,若需每层独立可移动到循环内) 
        x_conv = self.layers[0].conv1d(x_conv)[:, :, :seq_len].transpose(1, 2)
        
        all_layer0_outputs = []
        pc_loss_sum = 0
        
        for t in range(seq_len):
            new_states = []
            current_input = x_conv[:, t] 

            # 自底向上逐层传递
            layer_outputs = []
            for l in range(self.num_layers): 

                # 输入来自于上一层输出或嵌入层
                h_in = current_input if l == 0 else layer_outputs[l-1]
                # 传入上一时刻隐状态和来自高层的误差
                out_l, h_t_l = self.layers[l](h_in, states[l], errors[l+1])
                new_states.append(h_t_l)
                layer_outputs.append(out_l)
            
            # 计算预测编码损失 (自上而下) 
            step_pc_loss = 0
            for l in range(self.num_layers - 1):
                h_pred = self.layers[l].f_pred(layer_outputs[l+1])
                error = layer_outputs[l] - h_pred
                errors[l] = error.detach()
                step_pc_loss += torch.mean(error**2)
            
            pc_loss_sum += step_pc_loss
            states = new_states
            # 【关键注释】:原本输出的是网络最底层(Layer 0)的表征神经元数值 all_layer0_outputs.append(layer_outputs[0])
            # 【必须】改成输出网络最高层的表征神经元数值 
            all_layer0_outputs.append(layer_outputs[-1])
            
        return torch.stack(all_layer0_outputs, dim=1), states, pc_loss_sum / seq_len

class SSPCLanguageModel(nn.Module):
    def __init__(self, vocab_size, model_dim, num_layers):
        super().__init__()
        # 1. 词嵌入:维度直接设为 model_dim,不再需要中间的全连接层
        self.embedding = nn.Embedding(vocab_size, model_dim) 

        # 全连接层:确保输入与模型维度匹配并提供初始变换
        # self.input_fc = nn.Linear(model_dim, model_dim)

        self.sspc_core = SSPCModel(model_dim, num_layers)

        # 词向量到词汇表的投影,共享权重以减少参数 
        self.lm_head = nn.Linear(model_dim, vocab_size)
        self.lm_head.weight = self.embedding.weight # Weight Tying 

    def forward(self, input_ids, states=None, top_error_noise=None):
        # 直接连接进 SSPC 层
        x_emb = self.embedding(input_ids)
        # 【关键注释】:注意核心计算后,返回的是基于哪一层表征的预测? 
        pred_vecs, next_states, pc_loss = self.sspc_core(x_emb, states, top_error_noise)
        logits = self.lm_head(pred_vecs)
        return logits, next_states, pc_loss

    @torch.no_grad()
    def generate(self, input_ids, max_new_tokens, temperature=1.0, top_k=50):
        """标准 T 生成算法"""
        self.eval()
        generated = input_ids
        states = None
        for _ in range(max_new_tokens):
            logits, states, _ = self.forward(generated[:, -1:], states)
            next_token_logits = logits[:, -1, :] / (temperature + 1e-8)
            v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1))) 
            next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf')
            probs = torch.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated = torch.cat((generated, next_token), dim=1)
            if next_token.item() == 3: break # [EOS]
        return generated

    @torch.no_grad()
    def generate_bio(self, input_ids, max_new_tokens, noise_intensity=0.01, temperature=1.0):
        """
        生物启发式生成算法:
        在高于最高层的误差神经元 (errors[num_layers]) 注入噪声来指导生成。
        """
        self.eval()
        generated = input_ids
        states = None
        device = input_ids.device
        
        for _ in range(max_new_tokens):
            # 生成一束针对最高层误差神经元的微弱噪声
            # 这个噪声通过 down_weight 逐层向下渗透,影响最底层的输出
            noise = torch.randn(input_ids.size(0), self.sspc_core.model_dim).to(device) * noise_intensity
            
            logits, states, _ = self.forward(generated[:, -1:], states, top_error_noise=noise)
            
            # 此处仍需基础采样以获得具体的 token,但分布已被顶层噪声“扰动”
            next_token_logits = logits[:, -1, :] / (temperature + 1e-8)
            probs = torch.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            generated = torch.cat((generated, next_token), dim=1)
            if next_token.item() == 3: break
        return generated
1 个赞

好有深度的帖子hhh

btw楼主这里是能插入数学符号的:

a^2 +b^2=c^2

($$ a^2 +b^2=c^2$$)
或者inline: a^2 +b^2=c^2 ( $ a^2 +b^2=c^2$ )
希望有助于您的帖子排版

看看有没有懂mamba的uu来讨论讨论