Простая полностью сверточная сеть в matlab не обучается
Я пытаюсь реализовать простую полностью сверточную сеть в matlab с 3 слоями свертки (это просто для изучения и понимания FCN), я использую функции фильтра Matlab (такие как imfilter , filter2 и ...) и градиентный спуск в качестве алгоритма обучения. Я использовал как сигмовидную, так и Релу в качестве функции активации, но проблема заключается в том, что она не учится, не имеет значения, какую функцию или скорость обучения я использую.
Моя реализация до сих пор :
Что я уже пробовал:
clear clc %% image read and convert to gray scale RGBimg0 = imread('0.jpg'); RGBimg1 = imread('1.jpg'); RGBimg00 = imread('00.jpg'); RGBimg11 = imread('11.jpg'); RGBimg000 = imread('000.jpg'); RGBimg111 = imread('111.jpg'); img = {double(rgb2gray(RGBimg0)) double(rgb2gray(RGBimg1)) ... double(rgb2gray(RGBimg00)) double(rgb2gray(RGBimg11)) ... double(rgb2gray(RGBimg000)) double(rgb2gray(RGBimg111))}; img{1} = imresize(img{1},[256 256]); img{2} = imresize(img{2},[256 256]); img{3} = imresize(img{3},[256 256]); img{4} = imresize(img{4},[256 256]); img{5} = imresize(img{5},[256 256]); img{6} = imresize(img{6},[256 256]); d0 = double(rgb2gray(imread('d0.jpg'))); d00 = double(rgb2gray(imread('d00.jpg'))); d000 = double(rgb2gray(imread('d000.jpg'))); d1 = double(rgb2gray(imread('d1.jpg'))); d11 = double(rgb2gray(imread('d11.jpg'))); d111 = double(rgb2gray(imread('d111.jpg'))); desired = {d0 d1 d00 d11 d000 d111}; desired{1} = imresize(desired{1},[256 256]); desired{2} = imresize(desired{2},[256 256]); desired{3} = imresize(desired{3},[256 256]); desired{4} = imresize(desired{4},[256 256]); desired{5} = imresize(desired{5},[256 256]); desired{6} = imresize(desired{6},[256 256]); %Normalise image (0 to 1) for i=1 : 6 a = img{i}(:); range = max(a) - min(a); normm = (a - min(a)) / range; img{i} = reshape(normm,256,256); end %Normalise desired (0 to 1) for i=1 : 6 a = desired{i}(:); range = max(a) - min(a); normm = (a - min(a)) / range; desired{i} = reshape(normm,256,256); end %% Convolution initialize %filter definitions (randiom 3x3, range 0 to 1) f11 = (-1 + (1+1)*rand(3)); f12 = (-1 + (1+1)*rand(3)); f13 = (-1 + (1+1)*rand(3)); f211 = (-1 + (1+1)*rand(3)); f212 = (-1 + (1+1)*rand(3)); f213 = (-1 + (1+1)*rand(3)); f221 = (-1 + (1+1)*rand(3)); f222 = (-1 + (1+1)*rand(3)); f223 = (-1 + (1+1)*rand(3)); f231 = (-1 + (1+1)*rand(3)); f232 = (-1 + (1+1)*rand(3)); f233 = (-1 + (1+1)*rand(3)); f311 = (-1 + (1+1)*rand(3)); f312 = (-1 + (1+1)*rand(3)); f313 = (-1 + (1+1)*rand(3)); f321 = (-1 + (1+1)*rand(3)); f322 = (-1 + (1+1)*rand(3)); f323 = (-1 + (1+1)*rand(3)); f331 = (-1 + (1+1)*rand(3)); f332 = (-1 + (1+1)*rand(3)); f333 = (-1 + (1+1)*rand(3)); f41 = (-1 + (1+1)*rand(3)); f42 = (-1 + (1+1)*rand(3)); f43 = (-1 + (1+1)*rand(3)); imgOut = zeros(256,256,6); learnRate = 0.00000000001; reluSlop = 0.9; maxIt = 2000000; err = zeros(35,1); ETotal = zeros(maxIt,1); %% Main Loop for iteration = 1 : maxIt for imgCount = 1 : 1 %size(img,2) %% Covolution % convolve 1 activeMap11 = imfilter(img{imgCount},f11,0); activeMap12 = imfilter(img{imgCount},f12,0); activeMap13 = imfilter(img{imgCount},f13,0); % relu 1 activeMap11 = Relu(activeMap11, reluSlop, false); activeMap12 = Relu(activeMap12, reluSlop, false); activeMap13 = Relu(activeMap13, reluSlop, false); % --------------------------- % convolve 2 activeMap21 = imfilter(activeMap11,f211,0) + imfilter(activeMap12,f221,0) + imfilter(activeMap13,f231,0); activeMap22 = imfilter(activeMap11,f212,0) + imfilter(activeMap12,f222,0) + imfilter(activeMap13,f232,0); activeMap23 = imfilter(activeMap11,f213,0) + imfilter(activeMap12,f223,0) + imfilter(activeMap13,f233,0); % relu 2 activeMap21 = Relu(activeMap21, reluSlop, false); activeMap22 = Relu(activeMap22, reluSlop, false); activeMap23 = Relu(activeMap23, reluSlop, false); % --------------------------- % convolve 3 activeMap31 = imfilter(activeMap21,f311,0) + imfilter(activeMap22,f321,0) + imfilter(activeMap23,f331,0); activeMap32 = imfilter(activeMap21,f312,0) + imfilter(activeMap22,f322,0) + imfilter(activeMap23,f332,0); activeMap33 = imfilter(activeMap21,f313,0) + imfilter(activeMap22,f323,0) + imfilter(activeMap23,f333,0); % relu 3 activeMap31 = Relu(activeMap31, reluSlop, false); activeMap32 = Relu(activeMap32, reluSlop, false); activeMap33 = Relu(activeMap33, reluSlop, false); % --------------------------- % convolve 4 activeMap4 = imfilter(activeMap31,f41,0) + imfilter(activeMap32,f42,0) + imfilter(activeMap33,f43,0); % relu 4 activeMap4 = Relu(activeMap4, reluSlop, false); imgOut(:,:,imgCount) = activeMap4; %% Backpropagation errMat = 0.5 .* ((desired{imgCount} - activeMap4).^2); err(imgCount) = sum(errMat(:)); %filter deltas---------------------------------------------- error = desired{imgCount} - activeMap4; delta_activeMap4 = error .* Relu(activeMap4, reluSlop, true); delta_activeMap31 = imfilter(delta_activeMap4,f41,0); delta_activeMap32 = imfilter(delta_activeMap4,f42,0); delta_activeMap33 = imfilter(delta_activeMap4,f43,0); delta_activeMap31 = delta_activeMap31 .* Relu(activeMap31, reluSlop, true); delta_activeMap32 = delta_activeMap32 .* Relu(activeMap32, reluSlop, true); delta_activeMap33 = delta_activeMap33 .* Relu(activeMap33, reluSlop, true); delta_activeMap21 = imfilter(delta_activeMap31,f311,0) + imfilter(delta_activeMap32,f312,0) + imfilter(delta_activeMap33,f313,0); delta_activeMap22 = imfilter(delta_activeMap31,f321,0) + imfilter(delta_activeMap32,f322,0) + imfilter(delta_activeMap33,f323,0); delta_activeMap23 = imfilter(delta_activeMap31,f331,0) + imfilter(delta_activeMap32,f332,0) + imfilter(delta_activeMap33,f333,0); delta_activeMap21 = delta_activeMap21 .* Relu(activeMap21, reluSlop, false); delta_activeMap22 = delta_activeMap22 .* Relu(activeMap22, reluSlop, false); delta_activeMap23 = delta_activeMap23 .* Relu(activeMap23, reluSlop, false); delta_activeMap11 = imfilter(delta_activeMap21,f211,0) + imfilter(delta_activeMap22,f212,0) + imfilter(delta_activeMap23,f213,0); delta_activeMap12 = imfilter(delta_activeMap21,f221,0) + imfilter(delta_activeMap22,f222,0) + imfilter(delta_activeMap23,f223,0); delta_activeMap13 = imfilter(delta_activeMap21,f231,0) + imfilter(delta_activeMap22,f232,0) + imfilter(delta_activeMap23,f233,0); delta_activeMap11 = delta_activeMap11 .* Relu(activeMap11, reluSlop, false); delta_activeMap12 = delta_activeMap12 .* Relu(activeMap12, reluSlop, false); delta_activeMap13 = delta_activeMap13 .* Relu(activeMap13, reluSlop, false); activeMap31 = padarray(activeMap31,[1 1]); activeMap32 = padarray(activeMap32,[1 1]); activeMap33 = padarray(activeMap33,[1 1]); activeMap21 = padarray(activeMap21,[1 1]); activeMap22 = padarray(activeMap22,[1 1]); activeMap23 = padarray(activeMap23,[1 1]); activeMap11 = padarray(activeMap11,[1 1]); activeMap12 = padarray(activeMap12,[1 1]); activeMap13 = padarray(activeMap13,[1 1]); paddedImg1 = padarray(img{imgCount},[1 1]); paddedImg2 = padarray(img{imgCount},[1 1]); paddedImg3 = padarray(img{imgCount},[1 1]); %Weight update f41 = f41 + (learnRate .* (filterDelta4(activeMap31, delta_activeMap4))); f42 = f41 + (learnRate .* (filterDelta4(activeMap32, delta_activeMap4))); f43 = f41 + (learnRate .* (filterDelta4(activeMap33, delta_activeMap4))); f311 = f311 + (learnRate .* (filterDelta4(activeMap21, delta_activeMap31))); f321 = f321 + (learnRate .* (filterDelta4(activeMap22, delta_activeMap31))); f331 = f331 + (learnRate .* (filterDelta4(activeMap23, delta_activeMap31))); f312 = f312 + (learnRate .* (filterDelta4(activeMap21, delta_activeMap32))); f322 = f322 + (learnRate .* (filterDelta4(activeMap22, delta_activeMap32))); f332 = f332 + (learnRate .* (filterDelta4(activeMap23, delta_activeMap32))); f313 = f313 + (learnRate .* (filterDelta4(activeMap21, delta_activeMap33))); f323 = f323 + (learnRate .* (filterDelta4(activeMap22, delta_activeMap33))); f333 = f333 + (learnRate .* (filterDelta4(activeMap23, delta_activeMap33))); f211 = f211 + (learnRate .* (filterDelta4(activeMap11, delta_activeMap21))); f221 = f221 + (learnRate .* (filterDelta4(activeMap12, delta_activeMap21))); f231 = f231 + (learnRate .* (filterDelta4(activeMap13, delta_activeMap21))); f212 = f212 + (learnRate .* (filterDelta4(activeMap11, delta_activeMap22))); f222 = f222 + (learnRate .* (filterDelta4(activeMap12, delta_activeMap22))); f232 = f232 + (learnRate .* (filterDelta4(activeMap13, delta_activeMap22))); f213 = f213 + (learnRate .* (filterDelta4(activeMap11, delta_activeMap23))); f223 = f223 + (learnRate .* (filterDelta4(activeMap12, delta_activeMap23))); f233 = f233 + (learnRate .* (filterDelta4(activeMap13, delta_activeMap23))); f11 = f11 + (learnRate .* (filterDelta4(img{imgCount}, delta_activeMap11))); f12 = f12 + (learnRate .* (filterDelta4(img{imgCount}, delta_activeMap12))); f13 = f13 + (learnRate .* (filterDelta4(img{imgCount}, delta_activeMap13))); fprintf('[%f]-', err(imgCount)); end fprintf('%i \n', iteration); ETotal(iteration) = err(1); imshow([imgOut(:,:,1)],[]) end plot(ETotal);