no classification
no tag
no datas
posted on 2024-11-07 20:03 read(335) comment(0) like(13) collect(2)
The following is a simple linear regression/ML code that I have modified. For all initial weight and bias (i.e. weight = np.array([0.03, 0.04, 0.02]), bias = 0.01), the training will blow up (It just won't converge).
Wonder if there is a bug in the code or how to choose good initial values (weight and bias) so it will converge.
#Adopted from
import numpy as np
from numpy import genfromtxt
def predict(X, weight, bias):
return, weight) + bias
def cost_function(X, Y, weight, bias):
companies = X.shape[0]
return np.sum((predict(X, weight, bias) - Y) **2) / companies
def update_weights(X, Y, weight, bias, learning_rate):
companies = X.shape[0]
dW = 2 *, predict(X, weight, bias) - Y)
db = 2 * np.sum(predict(X, weight, bias) - Y)
for i in range(companies):
# Calculate partial derivatives
# -2x(y - (mx + b))
dw += -2*X[i] * (sales[i] - (weight*X[i] + bias))
# -2(y - (mx + b))
db += -2*(sales[i] - (weight*X[i] + bias))
#print(dW, db)
# We subtract because the derivatives point in direction of steepest ascent
#weight -= (dW / companies) * learning_rate
#bias -= (db / companies) * learning_rate
return weight - (dW / companies) * learning_rate, bias - (db / companies) * learning_rate
def train(X, Y, weight, bias, learning_rate, iters):
cost_history = []
for i in range(iters):
weight,bias = update_weights(X, Y, weight, bias, learning_rate)
#Calculate cost for auditing purposes
cost = cost_function(X, Y, weight, bias)
# Log Progress
if i % 100 == 0:
print ("iter: "+str(i) + " cost: "+str(cost) + "\n")
return weight, bias, cost_history
#the Advertising.csv is from
if __name__ == "__main__":
my_data = genfromtxt('Advertising.csv', delimiter=',')
X = my_data[1:, 1:4:1]
Y = my_data[1:, 4]; #the sales
a,b, _ = train(X, Y, np.array([0.03, 0.04, 0.02]), 0.01, 0.001, 1000)
The problem is, whatever value I use as initial weight and bias (i.e. weight = np.array([0.03, 0.04, 0.02]), bias = 0.01) will cause the value to blow up.
It just won't converge.
train(X, Y, weight, bias, 0.001, 1000)
When I ran the above code snippet, I got
$ python
iter: 0 cost: 212337.75728564826
/Users/joe/anaconda3/lib/python3.6/site-packages/numpy/core/ RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims) RuntimeWarning: overflow encountered in square
return np.sum((predict(X, weight, bias) - Y) **2) / companies
iter: 100 cost: inf RuntimeWarning: invalid value encountered in subtract
return weight - dW * learning_rate / companies , bias - db * learning_rate / companies
iter: 200 cost: nan
iter: 300 cost: nan
iter: 400 cost: nan
iter: 500 cost: nan
iter: 600 cost: nan
iter: 700 cost: nan
iter: 800 cost: nan
iter: 900 cost: nan
Figured out the cause of the problem! The learning rate in this case 0.001
is too high.
Change it to be 0.00001
works. i,e, change the last line in original snippet to be the following makes it work.
a,b, _ = train(X, Y, np.array([0.03, 0.04, 0.02]), 0.01, 0.00001, 1000)
Here is the output:
iter: 0 cost: 23.07411798374272
iter: 100 cost: 6.479930413738248
iter: 200 cost: 5.097751463999494
iter: 300 cost: 4.528064099014893
iter: 400 cost: 4.263917598438141
iter: 500 cost: 4.1398851132621655
iter: 600 cost: 4.081383875535448
iter: 700 cost: 4.053584811192947
iter: 800 cost: 4.040172367398533
iter: 900 cost: 4.033501506011401
source:python black hole net
Please indicate the source for any form of reprinting. If any infringement is discovered, it will be held legally responsible.
Comment content: (supports up to 255 characters)
Copyright © 2018-2021 python black hole network All Rights Reserved All rights reserved, and all rights reserved.京ICP备18063182号-7
For complaints and reports, and advertising cooperation, please contact or QQ3083709327
Disclaimer: All articles on the website are uploaded by users and are only for readers' learning and communication use, and commercial use is prohibited. If the article involves pornography, reactionary, infringement and other illegal information, please report it to us and we will delete it immediately after verification!