马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。
您需要 登录 才可以下载或查看,没有账号?我要加入
x
- clc
- clear
- %step 1=========================
- %定义输入样本;
- t=0:0.01:1.5;
- x=-sin(2*pi*t);
- targ=[0 0 1 1 0 0 ];
- eta=0.02;aerfa=0.935;
- %初始化连接权wjh(输出层和隐层的连接权);whi(隐层和输出层的连接权);
- %假设小波函数节点数为:H个;样本数为P;
- %输出节点数为:J个;输入节点数为:I个;
- H=15;P=2;
- I=length(t);
- J=length(targ);
- %初始化小波参数
- b=rand(H,1);
- a=rand(H,1);
- %初始化权系数;
- whi=rand(I,H);
- wjh=rand(H,J);
- %阈值初始化;
- b1=rand(H,1);
- b2=rand(J,1);
- p=0;
- %保存的误差;
- Err_NetOut=[];
- flag=1;count=0;
- while flag>0
- flag=0;
- count=count+1;
- %step 2=================================
- xhp1=0;
- for h=1:H
- for i=1:I
- xhp1=xhp1+whi(i,h)*x(i);
- end
- ixhp(h)=xhp1+b1(h);
- xhp1=0;
- end
- for h=1:H
- oxhp(h)=fai((ixhp(h)-b(h))/a(h));
- end
- %step 3====================================
- ixjp1=0;
- for j=1:J
- for h=1:H
- ixjp1=ixjp1+wjh(h,j)*oxhp(h);
- end
- ixjp(j)=ixjp1+b2(j);
- ixjp1=0;
- end
- for i=1:J
- oxjp(i)=fnn(ixjp(i));
- end
- %step 6==保存每次误差=====
- wuchayy=1/2*sumsqr(oxjp-targ);
- %E_x=1/2*sumsqr(x);
- Err_NetOut=[Err_NetOut wuchayy];%保存每次的误差;
- %Err_rate=Err_NetOut/E_x;
- %Err_rate
- %oxjp
- %求detaj ,detab2==================================
- for j=1:J
- detaj(j)=-(oxjp(j)-targ(j))*oxjp(j)*(1-oxjp(j));
- end
- for j=1:J
- for h=1:H
- detawjh(h,j)=eta*detaj(j)*oxhp(h);
- end
- end
- detab2=eta*detaj;
- %求detah, detawhi detab1 detab detaa;========================
- sum=0;
- for h=1:H
- for j=1:J
- sum=detaj(j)*wjh(h,j)*diffai((ixhp(h)-b(h))/a(h))/a(h)+sum;
- end
- detah(h)=sum;
- sum=0;
- end
- for h=1:H
- for i=1:I
- detawhi(i,h)=eta*detah(h)*x(i);
- end
- end
- detab1=eta*detah;
- detab=-eta*detah;
- for h=1:H
- detaa(h)=-eta*detah(h)*((ixhp(h)-b(h))/a(h));
- end
- %引入动量因子aerfa,修正各个系数==========================================
- wjh=wjh+(1+aerfa)*detawjh;
- whi=whi+(1+aerfa)*detawhi;
- a=a+(1+aerfa)*detaa';
- b=b+(1+aerfa)*detab';
- b1=b1+(1+aerfa)*detab1';
- b2=b2+(1+aerfa)*detab2';
- %======================================================
- %引入修正算法!!
- %判断所有的样本是否计算完==================================
- p=p+1;
- if p~=P
- flag=flag+1;
- else
- if Err_NetOut(end)>0.05
- flag=flag+1;
- else
- figure;
- plot(Err_NetOut);
- title('误差曲线');
- disp('目标达到');
- %disp(oxjp);
- end
- end
- if count>2000
- figure;
- plot(Err_NetOut);
- title('误差曲线');
- disp('目标未达到');
- disp(oxjp);
- break;
- end
- end
复制代码
这里还需要定义三个函数
1. diffai.m
- function y3=diffai(x);
- y3=-0.75*sin(1.75*x)*exp(-x.^2/2)-cos(1.75*x)*exp(-x.^2/2)*x;
复制代码
2. fai.m
- function yl=fai(x)
- yl=cos(0.75.*x)+exp(-x.^2/2);
复制代码
3. fnn.m
- function y2=fnn(x)
- y2=1/(1+exp(-x));
复制代码 |