常见的深度学习框架基本都支持自动微分机制,能够在无需手动推导公式的情况下求函数各个自变量的一阶导数。最近在做智能控制的大作业时决定搞一个自动微分,想着融进大作业里面混点分。一开始的思路非常简单,就是利用基本的求导方法和递归,但是随着不断深入扩充功能后才发现把问题想得太简单,写了一半就放弃了,这里简单记录一下。

我一共写了两个类。第一个类叫agnum,该类在初始化中使用一个长度为3的列表记录操作符号和生成此类的两个父节点,例如A=B*C,那么A的两个父节点就是B和C,操作符号为*。反向求导的时候,只需要从某个函数中不断溯源,和根据求导规则进行递归计算即可,只需要注意判断一下求导目标是否是当前节点,以及某个父节点是否包含求导变量即可。如d(a^b)/dx,需要先判断一下a和b是否是以x为变量,不同情况求导规则不一样。
但是这样的机制相当于会对每个数值进行记录,我希望以后能加入对矩阵的支持。

第二个类是参考这篇自动微分(Auto differentiation)机制详解改写的,补充了不少运算功能,但是写到索引支持时,还是有很多问题,例如A[:, 1] = B[:, 2] + B[:, 5],这里会把B的梯度传递两次。但是写到这里时脑阔太累干不动了,这里先把这个坑挖好,日后有空有机会再填充吧...

下面是代码
agnum:

import numpy as np

class agnum():
    def __init__(self, val, dtype=None) -> None:
        self.val = np.array(val, dtype)
        self.ops = [0, 0, 0]
        self.dtype = dtype

    def __num_check(self, num):
        if isinstance(num, int):
            return agnum(num, 'int32')
        if isinstance(num, float):
            return agnum(num, 'float32')
        if isinstance(num, np.ndarray):
            return agnum(num, num.dtype)
        return num

    def __contains__(self, operand): # operand in self
        if self is operand: return True
        if self.ops[0] == 0: return False
        _, a, b = self.ops
        return (operand in a) or (operand in b)

    def __mul__(self, operand):
        operand = self.__num_check(operand)
        t = self.val * operand.val
        r = agnum(t)
        r.ops = ['mul', self, operand]
        return r

    def __add__(self, operand):
        if operand == 0: return self
        operand = self.__num_check(operand)
        t = self.val + operand.val
        r = agnum(t)
        r.ops = ['add', self, operand]
        return r

    def __sub__(self, operand):
        if operand == 0: return self
        operand = self.__num_check(operand)
        t = self.val - operand.val
        r = agnum(t)
        r.ops = ['sub', self, operand]
        return r
    
    def __pow__(self, operand):
        operand = self.__num_check(operand)
        t = self.val ** operand.val
        r = agnum(t)
        r.ops = ['pow', self, operand]
        return r

    def dot(self, operand):
        operand = self.__num_check(operand)
        t = np.dot(self.val, operand.val)
        r = agnum(t)
        r.ops = ['dot', self, operand]
        return r

    def T(self):
        t = self.val.T
        r = agnum(t)
        r.ops = ['T', self, agnum(0, self.dtype)]
        return r

    def diff(self, target):
        if self is target: # 自导
            return agnum(1, self.dtype)
        if self.ops[0] == 0: # 纯数
            return agnum(0, self.dtype)
        op, a, b = self.ops
        ahasVar = target in a # 判断操作数是否含有求导部分
        bhasVar = target in b
        if not (ahasVar or bhasVar): # a、b都不包含target
            return agnum(0, self.dtype)
        # 求导规则
        if op == 'add':
            return a.diff(target) - b.diff(target)
        if op == 'sub':
            return a.diff(target) - b.diff(target)
        if op == 'mul':
            return a.diff(target) * b + a * b.diff(target)
        if op == 'dot':
            return a.diff(target).dot(b) + a.dot(b.diff(target))
        if op == 'pow':
            if ahasVar:
                return b*a**(b-1) * a.diff(target)
            elif bhasVar:
                return a**b*np.log(a.val) * b.diff(target)
            else: # a、b都包含target (上面已经排除过都不包含的部分)
                return np.e ** (b * np.log(a)) * a.diff(target) * b.diff(target)
        if op == 'T': # 转置操作中 操作数a为转置前
            return a.diff(target)
    
    def copy(self):
        r = agnum(self.val, self.dtype)
        r.ops = self.ops.copy()
        return r


if __name__ == "__main__":
    a = agnum([3, 5])
    b = agnum([1.2, 1])
    c = a ** b + b
    print(c.diff(a).val)

agc类:

import numpy as np
import copy
 
class agc:
    '''自动微分'''
    def __init__(self, value, nodes=[], opt=None, cut=None):
        if isinstance(value, list):
            value = np.array(value)
        self.value = value      # 该节点的值
        self.nodes = nodes      # 生成该节点的两个子节点
        self.opt = opt          # 两个子节点的运算方式
        self.grad = 0           # 函数对该节点的梯度
        self.cut = cut          # 切片时候用

    def __inpChecker(self, num):
        if type(num) in [float, int]:
            return agc(num)
        if isinstance(num, np.ndarray):
            return agc(num)
        return num

    def __add__(self, other):
        other = self.__inpChecker(other)
        return agc(value=self.value+other.value, nodes=[self, other], opt='+')
    def __sub__(self, other):
        other = self.__inpChecker(other)
        return agc(value=self.value-other.value, nodes=[self, other], opt='-')
    def __mul__(self, other):
        other = self.__inpChecker(other)
        return agc(value=self.value*other.value, nodes=[self, other], opt='*')
    def __truediv__(self, other):
        other = self.__inpChecker(other)
        return agc(value=self.value/other.value, nodes=[self, other], opt='/')
    def __pow__(self, other):
        other = self.__inpChecker(other)
        return agc(value=self.value**other.value, nodes=[self, other], opt='**')
    def dot(self,other):
        other = self.__inpChecker(other)
        return agc(value=np.dot(self.value, other.value), nodes=[self, other], opt='dot')
    def sigmoid(self):
        return agc(value=1/(1+np.exp(-self.value)), nodes=[self, 0], opt='sigmoid')
    def T(self):
        return agc(value=self.value.T, nodes=[self, 0], opt='T')

    def __getitem__(self, index):
        # print(index)
        return agc(value=self.value[index], nodes=[self, 0], opt='cut', cut=index)
    def __setitem__(self, index, operand):
        if isinstance(operand, agc):
            value = operand.value
        else:
            value = operand
        self.value[index] = value

    def backward(self, backward_grad=None):
        if backward_grad is None:
            if type(self.value) in [int, float]:
                self.grad = 1
            else: # ndarray
                self.grad = np.ones(self.value.shape)
        else: # backward_grad exists
            self.grad += backward_grad

        if self.opt is None:
            return

        a, b = self.nodes

        if self.opt == 'cut': # 阻隔切片以外的梯度传递
            gd = np.zeros(a.value.shape)
            gd[self.cut] = self.grad
            a.backward(gd)

        elif self.opt == '+': # a + b
            a.backward(self.grad)
            b.backward(self.grad) 

        elif self.opt == '-': # a - b
            a.backward(self.grad)
            b.backward(-self.grad)

        elif self.opt == '*': # a * b
            gd = self.grad * a.value
            b.backward(gd) # ax
            gd = self.grad * b.value
            a.backward(gd) # xb

        elif self.opt == '/': # a / b
            gd = self.grad * (1/(b.value+1e-15))
            a.backward(gd) # x/b
            gd = self.grad * (-a.value*b.value**(-2))
            b.backward(gd) # a/x

        elif self.opt == '**': # a^b
            gd = self.grad * (a.value**b.value*np.log(a.value+1e-15))
            b.backward(gd) # a^x
            gd = self.grad * (b.value*a.value**(b.value-1))
            a.backward(gd) # x^b

        elif self.opt == 'dot':
            gd = np.dot(self.grad, a.value.T)
            b.backward(gd)
            gd = np.dot(b.value.T, self.grad)
            a.backward(gd)

        elif self.opt == 'sigmoid':
            gd = self.grad*(1/(1+np.exp(-a.value)))*(1-1/(1+np.exp(-a.value)))
            a.backward(gd)

        elif self.opt == 'T':
            a.backward(self.grad)

    def zero_grad(self):
        self.grad = 0
        for n in self.nodes:
            if isinstance(n, agc):
                n.zero_grad()
                n.nodes = []
                n.opt = None

    def copy(self):
        return copy.deepcopy(self)

    @property
    def shape(self):
        return self.value.shape

if __name__=='__main__':
    ao = np.random.rand(5, 5)
    bo = np.random.rand(5, 5)
    co = np.random.rand(5, 5)
    a = agc(ao)
    b = agc(bo)
    c = agc(co)
    f1 = a*b
    f2 = f1.dot(c)
    f3 = f2[0][0] + f2[0][1] # 有bug 有两个重复导致梯度被计算两遍
    f3.backward()
    print(a.grad, b.grad, c.grad, sep='\n-------\n')