mehdi_k Ответов: 0

Простая полностью сверточная сеть в matlab не обучается


Я пытаюсь реализовать простую полностью сверточную сеть в matlab с 3 слоями свертки (это просто для изучения и понимания FCN), я использую функции фильтра Matlab (такие как imfilter , filter2 и ...) и градиентный спуск в качестве алгоритма обучения. Я использовал как сигмовидную, так и Релу в качестве функции активации, но проблема заключается в том, что она не учится, не имеет значения, какую функцию или скорость обучения я использую.

Моя реализация до сих пор :

Что я уже пробовал:

clear
clc

%% image read and convert to gray scale
RGBimg0 = imread('0.jpg');
RGBimg1 = imread('1.jpg');
RGBimg00 = imread('00.jpg');
RGBimg11 = imread('11.jpg');
RGBimg000 = imread('000.jpg');
RGBimg111 = imread('111.jpg');

img = {double(rgb2gray(RGBimg0)) double(rgb2gray(RGBimg1)) ...
       double(rgb2gray(RGBimg00)) double(rgb2gray(RGBimg11)) ...
       double(rgb2gray(RGBimg000)) double(rgb2gray(RGBimg111))};

img{1} = imresize(img{1},[256 256]);
img{2} = imresize(img{2},[256 256]);
img{3} = imresize(img{3},[256 256]);
img{4} = imresize(img{4},[256 256]);
img{5} = imresize(img{5},[256 256]);
img{6} = imresize(img{6},[256 256]);



d0 = double(rgb2gray(imread('d0.jpg')));
d00 = double(rgb2gray(imread('d00.jpg')));
d000 = double(rgb2gray(imread('d000.jpg')));
d1 = double(rgb2gray(imread('d1.jpg')));
d11 = double(rgb2gray(imread('d11.jpg')));
d111 = double(rgb2gray(imread('d111.jpg')));

desired = {d0 d1 d00 d11 d000 d111};   

desired{1} = imresize(desired{1},[256 256]);
desired{2} = imresize(desired{2},[256 256]);
desired{3} = imresize(desired{3},[256 256]);
desired{4} = imresize(desired{4},[256 256]);
desired{5} = imresize(desired{5},[256 256]);
desired{6} = imresize(desired{6},[256 256]);


%Normalise image (0 to 1)
for i=1 : 6
    a = img{i}(:);
    range = max(a) - min(a);
    normm = (a - min(a)) / range;
    img{i} = reshape(normm,256,256);
end

%Normalise desired (0 to 1)
for i=1 : 6
    a = desired{i}(:);
    range = max(a) - min(a);
    normm = (a - min(a)) / range;
    desired{i} = reshape(normm,256,256);
end


%% Convolution initialize

%filter definitions (randiom 3x3, range 0 to 1)
f11 = (-1 + (1+1)*rand(3));
f12 = (-1 + (1+1)*rand(3));
f13 = (-1 + (1+1)*rand(3));

f211 = (-1 + (1+1)*rand(3));
f212 = (-1 + (1+1)*rand(3));
f213 = (-1 + (1+1)*rand(3));
f221 = (-1 + (1+1)*rand(3));
f222 = (-1 + (1+1)*rand(3));
f223 = (-1 + (1+1)*rand(3));
f231 = (-1 + (1+1)*rand(3));
f232 = (-1 + (1+1)*rand(3));
f233 = (-1 + (1+1)*rand(3));


f311 = (-1 + (1+1)*rand(3));
f312 = (-1 + (1+1)*rand(3));
f313 = (-1 + (1+1)*rand(3));
f321 = (-1 + (1+1)*rand(3));
f322 = (-1 + (1+1)*rand(3));
f323 = (-1 + (1+1)*rand(3));
f331 = (-1 + (1+1)*rand(3));
f332 = (-1 + (1+1)*rand(3));
f333 = (-1 + (1+1)*rand(3));

f41 = (-1 + (1+1)*rand(3));
f42 = (-1 + (1+1)*rand(3));
f43 = (-1 + (1+1)*rand(3));


imgOut = zeros(256,256,6);

learnRate = 0.00000000001;

reluSlop = 0.9;

maxIt = 2000000;

err = zeros(35,1);
ETotal = zeros(maxIt,1);


%% Main Loop
for iteration = 1 : maxIt

    for imgCount = 1 : 1 %size(img,2)

        %% Covolution
        % convolve 1
        activeMap11 = imfilter(img{imgCount},f11,0);
        activeMap12 = imfilter(img{imgCount},f12,0);
        activeMap13 = imfilter(img{imgCount},f13,0);

        % relu 1
        activeMap11 = Relu(activeMap11, reluSlop, false);
        activeMap12 = Relu(activeMap12, reluSlop, false);
        activeMap13 = Relu(activeMap13, reluSlop, false);

        % ---------------------------
        % convolve 2
        activeMap21 = imfilter(activeMap11,f211,0) + imfilter(activeMap12,f221,0) + imfilter(activeMap13,f231,0); 
        activeMap22 = imfilter(activeMap11,f212,0) + imfilter(activeMap12,f222,0) + imfilter(activeMap13,f232,0); 
        activeMap23 = imfilter(activeMap11,f213,0) + imfilter(activeMap12,f223,0) + imfilter(activeMap13,f233,0); 

        % relu 2
        activeMap21 = Relu(activeMap21, reluSlop, false);
        activeMap22 = Relu(activeMap22, reluSlop, false);
        activeMap23 = Relu(activeMap23, reluSlop, false);

        % ---------------------------
        % convolve 3
        activeMap31 = imfilter(activeMap21,f311,0) + imfilter(activeMap22,f321,0) + imfilter(activeMap23,f331,0); 
        activeMap32 = imfilter(activeMap21,f312,0) + imfilter(activeMap22,f322,0) + imfilter(activeMap23,f332,0); 
        activeMap33 = imfilter(activeMap21,f313,0) + imfilter(activeMap22,f323,0) + imfilter(activeMap23,f333,0); 

        % relu 3
        activeMap31 = Relu(activeMap31, reluSlop, false);
        activeMap32 = Relu(activeMap32, reluSlop, false);
        activeMap33 = Relu(activeMap33, reluSlop, false);

        % ---------------------------

        % convolve 4
        activeMap4 = imfilter(activeMap31,f41,0) + imfilter(activeMap32,f42,0) + imfilter(activeMap33,f43,0); 

        % relu 4
        activeMap4 = Relu(activeMap4, reluSlop, false);

        imgOut(:,:,imgCount) = activeMap4;


        %% Backpropagation

        errMat = 0.5 .* ((desired{imgCount} - activeMap4).^2);
        err(imgCount) = sum(errMat(:));


        %filter deltas----------------------------------------------
        error = desired{imgCount} - activeMap4;

        delta_activeMap4 = error .* Relu(activeMap4, reluSlop, true);

        delta_activeMap31 = imfilter(delta_activeMap4,f41,0);
        delta_activeMap32 = imfilter(delta_activeMap4,f42,0);
        delta_activeMap33 = imfilter(delta_activeMap4,f43,0);

        delta_activeMap31 = delta_activeMap31 .* Relu(activeMap31, reluSlop, true);
        delta_activeMap32 = delta_activeMap32 .* Relu(activeMap32, reluSlop, true);
        delta_activeMap33 = delta_activeMap33 .* Relu(activeMap33, reluSlop, true);


        delta_activeMap21 = imfilter(delta_activeMap31,f311,0) + imfilter(delta_activeMap32,f312,0) + imfilter(delta_activeMap33,f313,0);
        delta_activeMap22 = imfilter(delta_activeMap31,f321,0) + imfilter(delta_activeMap32,f322,0) + imfilter(delta_activeMap33,f323,0);
        delta_activeMap23 = imfilter(delta_activeMap31,f331,0) + imfilter(delta_activeMap32,f332,0) + imfilter(delta_activeMap33,f333,0);

        delta_activeMap21 = delta_activeMap21 .* Relu(activeMap21, reluSlop, false);
        delta_activeMap22 = delta_activeMap22 .* Relu(activeMap22, reluSlop, false);
        delta_activeMap23 = delta_activeMap23 .* Relu(activeMap23, reluSlop, false);


        delta_activeMap11 = imfilter(delta_activeMap21,f211,0) + imfilter(delta_activeMap22,f212,0) + imfilter(delta_activeMap23,f213,0);
        delta_activeMap12 = imfilter(delta_activeMap21,f221,0) + imfilter(delta_activeMap22,f222,0) + imfilter(delta_activeMap23,f223,0);
        delta_activeMap13 = imfilter(delta_activeMap21,f231,0) + imfilter(delta_activeMap22,f232,0) + imfilter(delta_activeMap23,f233,0);

        delta_activeMap11 = delta_activeMap11 .* Relu(activeMap11, reluSlop, false);
        delta_activeMap12 = delta_activeMap12 .* Relu(activeMap12, reluSlop, false);
        delta_activeMap13 = delta_activeMap13 .* Relu(activeMap13, reluSlop, false);




        activeMap31 = padarray(activeMap31,[1 1]);
        activeMap32 = padarray(activeMap32,[1 1]);
        activeMap33 = padarray(activeMap33,[1 1]);

        activeMap21 = padarray(activeMap21,[1 1]);
        activeMap22 = padarray(activeMap22,[1 1]);
        activeMap23 = padarray(activeMap23,[1 1]);

        activeMap11 = padarray(activeMap11,[1 1]);
        activeMap12 = padarray(activeMap12,[1 1]);
        activeMap13 = padarray(activeMap13,[1 1]);

        paddedImg1 = padarray(img{imgCount},[1 1]);
        paddedImg2 = padarray(img{imgCount},[1 1]);
        paddedImg3 = padarray(img{imgCount},[1 1]);


        %Weight update
        f41 = f41 + (learnRate .* (filterDelta4(activeMap31, delta_activeMap4)));
        f42 = f41 + (learnRate .* (filterDelta4(activeMap32, delta_activeMap4)));
        f43 = f41 + (learnRate .* (filterDelta4(activeMap33, delta_activeMap4)));

        f311 = f311 + (learnRate .* (filterDelta4(activeMap21, delta_activeMap31)));
        f321 = f321 + (learnRate .* (filterDelta4(activeMap22, delta_activeMap31)));
        f331 = f331 + (learnRate .* (filterDelta4(activeMap23, delta_activeMap31)));
        f312 = f312 + (learnRate .* (filterDelta4(activeMap21, delta_activeMap32)));
        f322 = f322 + (learnRate .* (filterDelta4(activeMap22, delta_activeMap32)));
        f332 = f332 + (learnRate .* (filterDelta4(activeMap23, delta_activeMap32)));
        f313 = f313 + (learnRate .* (filterDelta4(activeMap21, delta_activeMap33)));
        f323 = f323 + (learnRate .* (filterDelta4(activeMap22, delta_activeMap33)));
        f333 = f333 + (learnRate .* (filterDelta4(activeMap23, delta_activeMap33)));

        f211 = f211 + (learnRate .* (filterDelta4(activeMap11, delta_activeMap21)));
        f221 = f221 + (learnRate .* (filterDelta4(activeMap12, delta_activeMap21)));
        f231 = f231 + (learnRate .* (filterDelta4(activeMap13, delta_activeMap21)));
        f212 = f212 + (learnRate .* (filterDelta4(activeMap11, delta_activeMap22)));
        f222 = f222 + (learnRate .* (filterDelta4(activeMap12, delta_activeMap22)));
        f232 = f232 + (learnRate .* (filterDelta4(activeMap13, delta_activeMap22)));
        f213 = f213 + (learnRate .* (filterDelta4(activeMap11, delta_activeMap23)));
        f223 = f223 + (learnRate .* (filterDelta4(activeMap12, delta_activeMap23)));
        f233 = f233 + (learnRate .* (filterDelta4(activeMap13, delta_activeMap23)));

        f11 = f11 + (learnRate .* (filterDelta4(img{imgCount}, delta_activeMap11)));
        f12 = f12 + (learnRate .* (filterDelta4(img{imgCount}, delta_activeMap12)));
        f13 = f13 + (learnRate .* (filterDelta4(img{imgCount}, delta_activeMap13)));





        fprintf('[%f]-', err(imgCount));
    end

    fprintf('%i \n', iteration);

    ETotal(iteration) = err(1);

    imshow([imgOut(:,:,1)],[])

end

plot(ETotal);

0 Ответов