Jupyter
Sebastian J. Schlecht, Monday, 31. October 2022
Test learning a clipping function with MLP
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
# define Dataset
class Dataset(torch.utils.data.Dataset):
def __init__(self, num, clip):
self.input = torch.randn(num,1)
self.labels = torch.clamp(self.input, min=-clip, max=clip)
# self.labels = self.input ** 2;
def __len__(self):
return len(self.labels)
def __getitem__(self, index):
# Select sample
y = self.labels[index]
x = self.input[index]
return x, y
batch_size = 50
number_of_epochs = 20
clip = 0.9
num = 1000
training_set = Dataset(num,clip)
training_loader = torch.utils.data.DataLoader(training_set, batch_size=batch_size, shuffle=True)
test_set = Dataset(num,clip)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True)
# define module
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(1, 10)
self.fc2 = nn.Linear(10, 20)
self.fc3 = nn.Linear(20, 10)
self.fc4 = nn.Linear(10, 1)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = self.fc4(x)
return x
def print(self):
for name, param in self.named_parameters():
if param.requires_grad:
print(name, param.data)
net = Net()
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters())
for epoch in range(number_of_epochs): # loop over the dataset multiple times
for i, data in enumerate(training_loader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(float(loss))
print('Finished Training')
net.print()
0.5054024457931519
0.42476707696914673
0.4892147183418274
0.5016235113143921
0.46796905994415283
0.506191074848175
0.3750819265842438
0.43628188967704773
0.41761791706085205
0.3339814841747284
0.3938809335231781
0.3966199457645416
0.46179550886154175
0.27180060744285583
0.4156447649002075
0.3592190444469452
0.42493927478790283
0.24830608069896698
0.32721808552742004
0.33350440859794617
0.41380637884140015
0.3744264543056488
0.3772357702255249
0.339101105928421
0.3356706202030182
0.3446285128593445
0.3767416775226593
0.3104710578918457
0.3535071909427643
0.2872011363506317
0.31624653935432434
0.29231148958206177
0.3035346269607544
0.27156397700309753
0.2991139888763428
0.2952619791030884
0.32593628764152527
0.3445398211479187
0.2786809206008911
0.2983841598033905
0.33857399225234985
0.2698429822921753
0.2899532616138458
0.27511832118034363
0.27700310945510864
0.24798843264579773
0.27581045031547546
0.26187363266944885
0.25144296884536743
0.21957996487617493
0.1566087156534195
0.3100586533546448
0.26605430245399475
0.23568087816238403
0.23832878470420837
0.20436321198940277
0.1831567883491516
0.20349836349487305
0.22161522507667542
0.17290720343589783
0.20474368333816528
0.19474287331104279
0.17118656635284424
0.17147117853164673
0.21577991545200348
0.17923547327518463
0.140830397605896
0.14410808682441711
0.1882449984550476
0.1866573840379715
0.11655132472515106
0.15486755967140198
0.14660781621932983
0.12824159860610962
0.15132743120193481
0.1334061175584793
0.20296142995357513
0.12636318802833557
0.1625654548406601
0.08958902209997177
0.12333649396896362
0.15218965709209442
0.11382555216550827
0.10162606835365295
0.10816758126020432
0.10274434834718704
0.12289168685674667
0.10438963025808334
0.08674982935190201
0.1028674989938736
0.08711696416139603
0.07698816806077957
0.08109548687934875
0.06715775281190872
0.07266759127378464
0.053090039640665054
0.07961154729127884
0.05293424427509308
0.07046812772750854
0.045054301619529724
0.06686286628246307
0.053763989359140396
0.04599469155073166
0.05298565700650215
0.06006673350930214
0.04921303689479828
0.04958082735538483
0.039590805768966675
0.03502899035811424
0.03933865949511528
0.03255784884095192
0.039780888706445694
0.02695368602871895
0.027571601793169975
0.026216330006718636
0.02672029659152031
0.02476891689002514
0.02674967609345913
0.0158847626298666
0.01885199546813965
0.019399970769882202
0.014962169341742992
0.014805333688855171
0.01998368464410305
0.011708715930581093
0.010520897805690765
0.014515080489218235
0.02223592810332775
0.012649240903556347
0.010369312949478626
0.019156325608491898
0.007956527173519135
0.009646212682127953
0.030131597071886063
0.016366995871067047
0.008477347902953625
0.0205403920263052
0.010162584483623505
0.00732888700440526
0.011322029866278172
0.013002569787204266
0.006618652492761612
0.009711711667478085
0.01353226788341999
0.0073477174155414104
0.006182166747748852
0.023682788014411926
0.0071188174188137054
0.008455581963062286
0.010672803036868572
0.008660484105348587
0.012788690626621246
0.008466217666864395
0.012007967568933964
0.0094989612698555
0.01667686551809311
0.005961017683148384
0.010867079719901085
0.006196053232997656
0.008975710719823837
0.006306353956460953
0.011914270929992199
0.004981494043022394
0.01392379030585289
0.008882598020136356
0.008897442370653152
0.006885786075145006
0.007534661330282688
0.006127095781266689
0.018794095143675804
0.008173449896275997
0.009045412763953209
0.014930066652595997
0.005779666360467672
0.008115990087389946
0.008533732034265995
0.0053793941624462605
0.008895311504602432
0.010190223343670368
0.009747040458023548
0.013520555570721626
0.01555438619107008
0.009125531651079655
0.004846183583140373
0.008840315043926239
0.006142988335341215
0.00682963989675045
0.005336494650691748
0.008544986136257648
0.007181923370808363
0.0079987533390522
0.008249332197010517
0.007545026484876871
0.005713548045605421
0.007601512596011162
0.006309641525149345
0.007781594526022673
0.009450632147490978
0.00563865015283227
0.005204386543482542
0.013657323084771633
0.004672298673540354
0.007816526107490063
0.004166348371654749
0.008768604137003422
0.006561242509633303
0.004972578026354313
0.004284904804080725
0.0045523750595748425
0.004305124282836914
0.009094522334635258
0.0031612617895007133
0.002519429661333561
0.005293068941682577
0.011170553974807262
0.010991688817739487
0.0026463111862540245
0.011244051158428192
0.00749776978045702
0.009228956885635853
0.005880804732441902
0.006507831625640392
0.0052168108522892
0.019541669636964798
0.007737858686596155
0.00443919887766242
0.008342987857758999
0.004857320338487625
0.004187632352113724
0.003544635372236371
0.003264597151428461
0.004047788679599762
0.0038327567745000124
0.0033765847329050303
0.004485307727009058
0.004895927384495735
0.003994782455265522
0.010241317562758923
0.0038198409602046013
0.006741153076291084
0.008207053877413273
0.006743066478520632
0.005747954361140728
0.0033938002306967974
0.006985098123550415
0.003965070471167564
0.010252466425299644
0.003549654735252261
0.0037857431452721357
0.006333546247333288
0.00235589942894876
0.0034418809227645397
0.002953439252451062
0.005676734261214733
0.0022003022022545338
0.011668642982840538
0.005117220804095268
0.0028124162927269936
0.0055918567813932896
0.0024021295830607414
0.0036265042144805193
0.0031013020779937506
0.004805361852049828
0.006669086404144764
0.00416877306997776
0.003464368637651205
0.0037992531433701515
0.0038867415860295296
0.0023626270703971386
0.004572764039039612
0.002248072996735573
0.008810912258923054
0.0054762703366577625
0.007094301749020815
0.0025862320326268673
0.004284600727260113
0.0041335029527544975
0.0022353539243340492
0.009093865752220154
0.004370956681668758
0.0031514496076852083
0.0035022764932364225
0.005059209652245045
0.0033674356527626514
0.0030193852726370096
0.0025780329015105963
0.004098314326256514
0.0020134553778916597
0.0028946870006620884
0.0019452879205346107
0.0056954906322062016
0.003126366063952446
0.003821610240265727
0.002129784319549799
0.002042047679424286
0.004874676000326872
0.005857694894075394
0.004085984081029892
0.0028199926018714905
0.009319768287241459
0.0026162774302065372
0.0021293489262461662
0.0027072022203356028
0.0034876144491136074
0.00335272797383368
0.005954155698418617
0.004648131318390369
0.002195154083892703
0.0031238216906785965
0.0033800932578742504
0.004454123787581921
0.0030535110272467136
0.00393053749576211
0.0022039765026420355
0.0031073607970029116
0.0028656134381890297
0.002457177732139826
0.002298101782798767
0.002877561841160059
0.0029112917836755514
0.0018818110693246126
0.0028768032789230347
0.0022500003688037395
0.001954483799636364
0.002482531126588583
0.0011331253917887807
0.0011141860159114003
0.0033472508657723665
0.005585329607129097
0.002410265849903226
0.005820651073008776
0.0017693298868834972
0.0017087000887840986
0.001373580889776349
0.003248336724936962
0.004919007886201143
0.002239574445411563
0.002034867415204644
0.0023711416870355606
0.002297616796568036
0.0013825627975165844
0.0016273170476779342
0.0030522008892148733
0.0013080616481602192
0.0022402415052056313
0.0015440104762092233
0.004494350403547287
0.0017410890432074666
0.0011850498849526048
0.003489218419417739
0.0022658517118543386
0.003485520137473941
0.0012510307133197784
0.0025768885388970375
0.002183937933295965
0.001516107004135847
0.002048900118097663
0.001280587282963097
0.001397680607624352
0.0029074465855956078
0.0016966218827292323
0.0017056555952876806
0.001783750019967556
0.0018331382889300585
0.0010273134103044868
0.0013285300228744745
0.0034452148247510195
0.00437086820602417
0.0009205617243424058
0.001054459484294057
0.0015975538408383727
0.0017761063063517213
0.0010076797334477305
0.0011514350771903992
0.002952246693894267
0.001855501439422369
0.00070126389618963
0.0013156462227925658
0.0021787937730550766
0.0010941127547994256
0.0016162851825356483
0.0009285698761232197
0.0019138300558552146
0.0009639544878154993
0.0013132510939612985
0.0012826721649616957
0.0014462722465395927
0.0009325985447503626
0.0009472963865846395
0.0010517302434891462
0.0008456077193841338
0.0022388482466340065
0.0021026248577982187
0.0030053032096475363
0.0010222006822004914
0.0014408573042601347
0.0012121436884626746
0.0015380540862679482
0.0006259444635361433
0.0006868245545774698
Finished Training
fc1.weight tensor([[ 0.3721],
[-0.3295],
[-0.8585],
[-0.5114],
[-0.3996],
[-0.9245],
[ 0.8753],
[ 0.8704],
[ 0.9115],
[ 0.0135]])
fc1.bias tensor([ 1.1292, -0.2312, -0.3269, -0.4945, 0.3395, 0.7283, -0.2415, 0.9651,
0.4693, -0.9102])
fc2.weight tensor([[ 9.5948e-02, 3.9746e-02, -4.1991e-02, 5.6819e-02, 1.3780e-01,
-1.3759e-01, 2.4883e-01, -2.3085e-01, 3.1752e-01, 2.6596e-01],
[ 1.6440e-01, 1.0791e-01, 2.4116e-01, -3.5249e-01, 2.8502e-03,
-2.1309e-01, 3.4830e-01, -1.2172e-02, 3.1129e-01, 1.9786e-01],
[ 2.0905e-01, -1.5890e-01, -2.8715e-02, -2.8519e-01, 2.0892e-01,
-9.2011e-02, 3.0359e-01, 2.3966e-01, 2.7374e-01, -1.3720e-01],
[-2.5713e-01, -2.0884e-01, 5.9911e-02, -6.8203e-02, 3.8436e-02,
-1.4092e-01, -1.8512e-01, 8.3490e-02, -1.7116e-01, 1.6170e-01],
[ 3.1665e-01, 2.4182e-01, -7.7140e-02, -3.5004e-01, -5.2356e-02,
1.9387e-01, -4.0406e-01, -3.6750e-01, -8.6176e-02, -3.1524e-01],
[-7.7758e-03, 1.9777e-02, -5.0071e-03, 1.4394e-01, 2.6122e-01,
-2.3958e-01, -2.2489e-01, -3.1524e-01, 3.7644e-02, -2.3002e-01],
[-2.2262e-01, -2.2564e-01, -1.7846e-01, -6.8525e-02, 2.9318e-01,
-7.0839e-02, 4.1510e-03, -1.5924e-01, -7.5550e-02, 1.0270e-01],
[ 3.0554e-01, -8.2371e-02, 2.2087e-01, -8.6428e-02, 2.5604e-01,
2.0730e-01, 3.5102e-01, 1.7744e-01, 1.5738e-01, 2.4604e-01],
[ 3.9082e-01, -5.7637e-01, -3.2515e-01, -2.6174e-01, 1.9399e-02,
-2.6776e-01, -2.0496e-01, 2.5031e-01, 9.9010e-02, -3.0642e-01],
[-1.6070e-01, 1.0554e-01, -2.5513e-01, -1.8325e-02, -3.5559e-01,
1.2858e-01, -1.3098e-01, 2.2081e-01, 2.6941e-01, -2.1347e-01],
[ 2.5908e-01, -3.2340e-01, 1.5940e-01, -4.5033e-01, 2.6449e-01,
1.1920e-01, -8.8357e-03, 3.5177e-01, -3.1935e-02, 2.1965e-01],
[ 2.8944e-01, -2.8803e-01, 1.2965e-01, 9.8315e-03, -3.0397e-01,
-1.0221e-01, -1.4344e-01, 7.8904e-02, 2.7641e-01, -2.5588e-01],
[ 2.8773e-01, -3.0796e-01, 7.5191e-02, -2.6509e-01, -5.2185e-02,
-3.6362e-01, 5.1373e-02, 1.0309e-01, 2.5663e-02, 2.2465e-01],
[-2.0094e-01, 5.6251e-02, 1.3748e-01, 2.2335e-01, 4.1196e-01,
-8.6509e-02, -7.5779e-03, 2.1911e-01, 2.1871e-02, 2.8864e-01],
[ 2.4129e-01, -2.5313e-01, 2.6164e-01, -6.0113e-02, -1.7291e-01,
-1.9485e-01, -1.1896e-01, 3.6797e-01, 3.2871e-02, -1.2588e-01],
[-2.9837e-01, -2.0906e-01, -3.7566e-02, -2.7766e-01, -8.1368e-02,
-9.7512e-02, 2.5353e-01, -1.0673e-01, 1.2526e-01, -9.0056e-02],
[-3.0757e-01, -3.1805e-02, 1.8534e-01, -1.2516e-01, -1.8489e-02,
-1.7930e-01, 1.8774e-01, -7.5238e-02, -2.4108e-01, 1.0063e-01],
[ 1.6883e-01, 3.8211e-04, 2.4381e-01, -2.7841e-01, -4.0433e-01,
-7.8925e-02, 3.2364e-03, 3.5408e-01, -6.6366e-02, 1.9972e-01],
[ 1.1284e-01, 2.0313e-01, 2.8658e-01, 7.5850e-02, -5.6530e-02,
1.2204e-01, 1.8304e-01, -1.5863e-01, -2.7456e-01, 1.1286e-01],
[ 9.5115e-02, -2.7967e-01, 3.2540e-01, 6.9701e-02, -1.2832e-01,
2.8414e-01, -3.1242e-01, 1.0254e-01, -2.5878e-01, -1.7183e-01]])
fc2.bias tensor([-0.2677, 0.1064, 0.3526, -0.0374, 0.4561, 0.1586, 0.0320, 0.1695,
0.2299, 0.2226, -0.0681, 0.0921, 0.3121, 0.2696, 0.0507, -0.0712,
-0.2867, 0.1343, 0.3636, 0.1651])
fc3.weight tensor([[ 0.2264, 0.2731, 0.1997, -0.1967, 0.0925, 0.0655, -0.2147, 0.0687,
-0.1583, -0.0250, 0.2710, -0.1579, -0.0012, 0.0079, -0.0343, 0.2083,
-0.2130, -0.1654, 0.0950, -0.0060],
[-0.3148, 0.0697, -0.0594, -0.1206, -0.1136, -0.0563, -0.0753, -0.0802,
0.3156, -0.0267, 0.1237, 0.1390, -0.0907, -0.0419, 0.0921, 0.0035,
-0.0646, 0.2318, -0.0630, -0.1212],
[ 0.0369, -0.0099, -0.0146, -0.2198, 0.0514, 0.1544, 0.0431, -0.0032,
0.0918, 0.2136, 0.0839, 0.2284, 0.0906, -0.0716, -0.0236, 0.0035,
0.1050, 0.0676, -0.0453, 0.0505],
[ 0.2019, 0.1029, -0.1724, 0.2126, -0.0079, -0.1904, -0.0884, 0.1582,
0.0538, -0.0228, 0.0809, 0.0894, -0.1648, 0.1427, 0.0693, 0.0252,
0.0204, 0.0356, 0.2935, 0.0358],
[-0.0382, 0.1262, 0.2815, -0.1656, -0.1745, -0.0436, -0.0300, 0.1551,
0.3153, 0.2605, 0.0904, 0.3011, 0.1789, -0.0059, 0.3071, -0.0023,
0.1951, 0.2652, -0.1214, -0.0429],
[-0.2895, 0.2136, 0.2434, 0.0899, -0.0446, 0.1205, 0.1122, -0.1585,
0.1609, 0.1443, 0.1347, -0.1256, 0.2018, -0.0311, 0.1582, -0.2228,
-0.0495, 0.1253, -0.3098, -0.2082],
[ 0.0753, -0.0625, 0.1167, 0.2068, 0.1717, 0.1977, -0.0191, 0.3729,
-0.1675, -0.1808, 0.2829, -0.2164, -0.1562, 0.0173, -0.0173, 0.1346,
-0.0452, -0.0258, 0.2259, 0.2776],
[ 0.1527, -0.2998, 0.1249, 0.1759, 0.3822, 0.0473, 0.0712, -0.1180,
-0.2744, -0.0193, 0.0439, 0.1360, -0.0755, 0.1174, 0.0792, -0.1337,
-0.2136, 0.0599, 0.2079, 0.0247],
[-0.0572, 0.1760, -0.0109, 0.1384, 0.0633, 0.0511, 0.1249, -0.1667,
0.0084, 0.1752, -0.1058, -0.2064, -0.0048, 0.1084, 0.1272, -0.1980,
-0.1469, -0.2001, 0.1276, -0.1351],
[-0.0523, -0.0670, -0.2114, 0.1301, 0.1551, -0.1630, 0.1385, -0.1210,
0.1904, 0.1850, -0.0549, -0.1425, -0.0239, -0.2057, -0.1383, 0.0558,
0.1950, -0.0719, -0.0827, 0.0355]])
fc3.bias tensor([-0.1290, 0.1935, 0.0259, -0.0953, -0.1686, 0.1954, -0.0455, 0.3036,
0.0078, -0.1845])
fc4.weight tensor([[-0.2992, 0.2344, 0.0157, -0.1589, 0.2604, 0.2133, -0.2932, -0.2656,
-0.2334, 0.0213]])
fc4.bias tensor([-0.1267])
dataiter = iter(test_loader)
in_signal, labels = next(dataiter)
out_signal = net(in_signal).detach().numpy()
print('Loss is:', loss.detach().numpy())
Loss is: 0.00068682455
fig, ax = plt.subplots()
ax.plot(in_signal)
ax.plot(labels,'--')
ax.plot(out_signal,':')
plt.show()