source: trunk/toolbox/MLPfit.m @ 35

Last change on this file since 35 was 4, checked in by pinsard, 15 years ago

dealing with CR/LF and encoding

File size: 7.8 KB
Line 
1function  [w1,w2,errtot,rx] = MLPfit(Xv,Yv,Xa,Ya,w1,w2,F1,F2,OptimArg,Visu)
2
3%MLPfit performs iterations of the MLP by the Harris method
4
5%
6
7%       [W1,W2,errtot,Yr] = MLPfit(Xv,Yv,Xa,Ya,W1,W2,F1,F2,OptimArg,Visu)
8
9%
10
11% Xv  input validation data
12
13% Yv  output validation data
14
15% Xa the input learning data
16
17% Ya the output leraning data
18
19% W1 the initial parameter matrix from the input to the hidden layer
20
21% W2 the initial parameter matrix from the hidden layer to the output
22
23% F1 (default 'tah') the activation function of the hidden layer units.
24
25%    Choose in {'tah', 'sig'}
26
27% F2 (default 'lin') the activation function of the output layer units
28
29%    Choose in {'tah', 'sig', 'lin'}
30
31%
32
33% OptimArg (default [1000, 0.000001]) Optimisation parameters: NbIter or
34
35%         [NbIter, Threshold]
36
37% Visu (default [ -, 1]) Visulalisation parameters: DisplayFrequency or
38
39%          [DisplayFrequency, DisplaYdevice]
40
41%
42
43% W1 and W2 Final parameter matrix
44
45% errtot the training error
46
47% Yr the the output of the network at point x
48
49%
50
51% modified on Nov. 15th, 1999 (Evry, France)
52
53
54
55if nargin < 6;
56
57   help MLPfit
58
59   error(sprintf('\n *** MLPfit error: invalid call***\n\n\t[W1,W2,errtot,Yr] =MLPfit(Xa,Ya,W1,W2,F1,F2,OptimArg,Visu);\n\n'));
60
61end;
62
63
64
65nbre_err=10 ;%nbre d'erreur de suites maxi
66i_err=0; %compteur du nbre d'erreur
67
68% Optimization control parameters
69
70
71
72nbitemax = 1000; % default values
73
74seuil = 10^(-6);
75
76
77
78if nargin >= 9
79
80        nbitemax = OptimArg(1);
81
82        lO = length(OptimArg);
83
84        if lO > 1
85
86                seuil = OptimArg(2);
87
88        end;
89
90end;
91
92
93
94%%plot(Xa,MLPval(Xa,w1,w2,F1,F2),'b-',X,Yb,'r-')
95
96%plot(Xa,MLPval(Xa,w1,w2,F1,F2),'b-')
97
98
99
100%pause(0.1)
101
102
103
104
105
106% Optimisation Process monitoring visualisation parameters
107
108
109
110df = nbitemax+2;                % Defaults is NO output
111
112fid = 1;                            % default is standart output               
113
114if nargin == 10
115
116 df = Visu(1);                  % display frequency
117
118 if df == 0; df =  nbitemax+2; end;
119
120 lV = length(Visu);
121
122 if lV > 1
123
124        fid = Visu(2);      % Display Device
125
126 end;
127
128end;
129
130
131
132% Check for F1 and F2
133
134
135
136  if (strcmp(lower(F1),'tah') == 0) & (strcmp(lower(F1),'sig') == 0)
137
138    F1 = 'tah'; disp('Unknown activation function 1. - Set as hyperbolictangent');
139
140  end
141
142
143
144  if (strcmp(lower(F2),'tah') == 0) & (strcmp(lower(F2),'sig') == ...
145                                       0) & (strcmp(lower(F2),'lin') ...
146                                             == 0) & (strcmp(lower(F2),'exp')==0)
147
148    F2 = 'lin'; disp('Unknown output activation function - Set as linear');
149
150  end
151
152
153
154% initialisation
155
156
157
158  ell=size(Xa,1);
159
160  onell = ones(ell,1);
161
162  [t nout] = size(w2);
163
164  errold = 1e16;
165
166  nbite = 0;
167
168  w1p = w1;               % sauvegardes
169
170  w2p = w2;
171
172  w1pp = zeros(size(w1));                 % sauvegardes
173
174  w2pp = zeros(size(w2));
175
176  gradw1p = zeros(size(w1));
177
178  gradw2p = zeros(size(w2));
179
180  descw1 = gradw1p;
181
182  descw2 = gradw2p;
183
184  pas1 = 0.1 * ones(size(w1))*2/ell;
185
186  pas2 = 0.1 * ones(size(w2))*2/ell;
187
188  dim = 0.5;              % diminution du pas en cas d'augmentation de l'erreur
189
190  a = 1.5;
191
192  b = 1/a;
193
194  alpha = .9/(1+.9);                    % momentum
195
196
197
198 lambda = 0.*ones(size(w2));
199
200 lambda(t,:) = zeros(1,nout);
201
202  errtmin = 10000000000000000000000;
203
204
205
206  ovf1 = 1e3 * ones(size(w1));
207
208  ovf2 = 1e3 * ones(size(w2));
209
210
211
212  errtot = [];                          % memoire erreur pour sortie
213
214
215
216% boucle principale
217
218
219
220continumongars=1;
221
222
223
224while (nbite < nbitemax) & continumongars ;  %
225 
226 
227 
228  %Prop
229 
230 
231 
232  a1 = [Xa onell]*w1;
233 
234 
235 
236  if strcmp(lower(F1),'tah')
237   
238    x1 = tanh(a1);
239   
240  elseif strcmp(lower(F1),'sig')
241   
242    x1 = phi(a1);
243   
244  else
245   
246    error('Unknown activation function 1.')
247   
248  end
249 
250 
251 
252  a2 = [x1 onell]*w2;
253 
254 
255 
256  if strcmp(lower(F2),'tah')
257   
258    y = tanh(a2);
259   
260  elseif strcmp(lower(F2),'sig')
261   
262    y = phi(a2);
263   
264  elseif strcmp(lower(F2),'lin')
265   
266    y = a2;
267   
268   
269  elseif strcmp(lower(F2),'exp')
270   
271    y = exp(a2);
272  else
273   
274    error('Unknown output activation function')
275   
276  end
277 
278 
279 
280  err =  (y - Ya);
281 
282  errnew = sum(sum(err'.*err')) + sum(sum(lambda.*w2.^2));
283 
284 
285  if ~isempty(Xv)
286    Yvi=MLPval(Xv,w1,w2,F1,F2);
287    err2=(Yv-Yvi);
288    errval = sum(sum(err2'.*err2'));
289    % keyboard
290  else
291    errval=0;
292  end
293 
294  errtotnew=[errnew,errval];
295 
296  errtot = [errtot ;errtotnew];
297 
298 
299 
300 
301
302  if errnew >= errold ;                 % ca merde => un pas en arriere
303   
304   
305    i_err=i_err+1;
306   
307    if i_err>nbre_err
308      disp('erreur dans l''apprentissage')
309      rx = y;
310      return
311    end
312   
313    w1 = w1p ;                         
314   
315    w2 = w2p ;
316   
317    gradw1 = gradw1p ;
318   
319    gradw2 = gradw2p ;
320   
321    descw1 = gradw1;
322   
323    descw2 = gradw2;
324   
325    pas1 = dim * pas1 ;
326   
327    pas2 = dim * pas2 ;
328   
329    w1 = w1 - pas1 .* gradw1;
330   
331    w2 = w2 - pas2 .* gradw2;
332   
333   
334   
335  else                                  % ca marche => un pas en avant
336   
337   
338    i_err=0;
339    % BP
340   
341   
342   
343    dJdy = 2*err;
344   
345   
346   
347    if strcmp(lower(F2),'tah')
348     
349      dJda2 = dJdy.*(1-y.*y);
350     
351    elseif strcmp(lower(F2),'sig')
352     
353      dJda2 = dJdy.*(y-y.*y);
354     
355    elseif strcmp(lower(F2),'lin')
356     
357      dJda2 = dJdy;
358     
359    elseif strcmp(lower(F2),'exp')
360     
361      dJda2 = dJdy.*y;
362     
363    else
364
365      error('Unknown output activation function')
366     
367    end
368   
369   
370   
371    gradw2 = (([x1 onell]'*dJda2) + (2*lambda.*w2))./ell;
372   
373   
374   
375    if strcmp(lower(F1),'tah')
376     
377      dJdx1 =  (w2(1:t-1,:) * dJda2')' .*(1-x1.*x1);
378     
379    elseif strcmp(lower(F1),'sig')
380     
381      dJdx1 =  (w2(1:t-1,:) * dJda2')' .*(x1-x1.*x1);
382
383    else
384     
385      error('Unknown activation function 1.')
386     
387    end
388   
389   
390
391    gradw1 = [Xa onell]' * dJdx1  ./ell;
392   
393   
394   
395    errold = errnew;                   
396   
397    w1p = w1 ;
398   
399    w2p = w2 ;
400   
401    gradw1p = gradw1 ;
402   
403    gradw2p = gradw2 ;
404   
405    descw1 = (1-alpha) * gradw1 + alpha * descw1;
406   
407    descw2 = (1-alpha) * gradw2 + alpha * descw2;
408   
409   
410   
411    test1 = (gradw1 .* descw1) >= 0;
412   
413    pas1 = ((test1 * a) + ((~test1) * b)) .* pas1;
414   
415    pas1 = (pas1 <= ovf1) .* pas1 + (pas1>ovf1) .* ovf1;
416   
417   
418   
419    test2 = (gradw2 .* descw2) >= 0;
420   
421    pas2 = ((test2 * a) + ((~test2) * b)) .* pas2;
422   
423    pas2 = (pas2 <= ovf2) .* pas2 + (pas2>ovf2) .* ovf2;
424   
425   
426   
427    w1 = w1 - pas1 .* descw1;
428   
429    w2 = w2 - pas2 .* descw2;
430   
431   
432   
433    nbite = nbite + 1;       % il n'y a que les bonnes iterations qui comptent
434   
435   
436   
437    if (rem(nbite,df) == 0),
438     
439      if (rem(nbite,20*df) == df),
440       
441       
442       
443        disp(['|  #epoch | Errorapp|Errorval| Mean Gw1 |  Max Gw1 | Mean Gw2 |  Max Gw2 |Max: ' num2str(nbitemax)])
444       
445       
446       
447       
448       
449      end
450
451     
452     
453      fprintf(fid,'| %7.0f | %7.4f  | %7.4f | %6.6f | %6.6f | %6.6f | %6.6f|\n',[nbite errnew/size(Xa,1)  ...
454                          errval/size(Xv,1),mean(mean(abs(gradw1))) max(max(abs(gradw1))) mean(mean(abs(gradw2))) max(max(abs(gradw2))) ])
455     
456     
457     
458      %figure(2)
459        %hold on
460        %%plotplot(Xa,MLPval(Xa,w1,w2,F1,F2),'b-',X,Yb,'r-')
461        %plot(Xa,MLPval(Xa,w1,w2,F1,F2),'b-')
462        %S=sprintf('en rouge fonction utilisée pour simuler les données en bleu courbe apprise par le MLP');
463        %title(S);
464     
465      %pause(0.1)
466     
467     
468
469      w1min = w1;
470     
471      w2min = w2;
472     
473     
474     
475      bougeti=max([max(max((abs(w1pp-w1p)./max(abs(w1p),1)))) max(max((abs(w2pp-w2p)./max(abs(w2p),1))))]);
476     
477      if bougeti< seuil
478       
479        continumongars=0;
480       
481      end;
482     
483      %        disp([nbite/100 errnew-sum(sum(lambda.*w2.^2)) mean(mean(pas1)) bougeti sum(mean(abs(errtest)))]);
484     
485      if (rem(nbite,1000) == 0),
486       
487        w1pp=w1p;w2pp=w2p;
488       
489        pas1 = mean(mean(pas1)) * ones(size(w1))*2/ell;
490       
491        pas2 = mean(mean(pas2)) * ones(size(w2))*2/ell;
492       
493      end;
494     
495    end;
496   
497   
498    %disp(int2str(nbite))
499   
500   
501  end % fin de l'adaptation des pas.
502 
503 
504 
505 
506end;            % fin de la boucle principale
507
508
509
510%w1min = w1p;w2min = w2p;
511
512
513
514rx = y;
515
516
517
518
519
520
Note: See TracBrowser for help on using the repository browser.