Jupyter

Sebastian J. Schlecht, Monday, 31. October 2022

Test learning a clipping function with MLP

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
# define Dataset
class Dataset(torch.utils.data.Dataset):
    def __init__(self, num, clip):
        self.input = torch.randn(num,1)
        self.labels = torch.clamp(self.input, min=-clip, max=clip)
        # self.labels = self.input ** 2;

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        # Select sample
        y = self.labels[index]
        x = self.input[index]

        return x, y

batch_size = 50
number_of_epochs = 20
clip = 0.9
num = 1000
training_set = Dataset(num,clip)
training_loader = torch.utils.data.DataLoader(training_set, batch_size=batch_size, shuffle=True)

test_set = Dataset(num,clip)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True)
# define module
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(1, 10)
        self.fc2 = nn.Linear(10, 20)
        self.fc3 = nn.Linear(20, 10)
        self.fc4 = nn.Linear(10, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

    def print(self):
        for name, param in self.named_parameters():
            if param.requires_grad:
                print(name, param.data)

net = Net()
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters())
for epoch in range(number_of_epochs):  # loop over the dataset multiple times
    for i, data in enumerate(training_loader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        print(float(loss))

print('Finished Training')
net.print()
0.5054024457931519
0.42476707696914673
0.4892147183418274
0.5016235113143921
0.46796905994415283
0.506191074848175
0.3750819265842438
0.43628188967704773
0.41761791706085205
0.3339814841747284
0.3938809335231781
0.3966199457645416
0.46179550886154175
0.27180060744285583
0.4156447649002075
0.3592190444469452
0.42493927478790283
0.24830608069896698
0.32721808552742004
0.33350440859794617
0.41380637884140015
0.3744264543056488
0.3772357702255249
0.339101105928421
0.3356706202030182
0.3446285128593445
0.3767416775226593
0.3104710578918457
0.3535071909427643
0.2872011363506317
0.31624653935432434
0.29231148958206177
0.3035346269607544
0.27156397700309753
0.2991139888763428
0.2952619791030884
0.32593628764152527
0.3445398211479187
0.2786809206008911
0.2983841598033905
0.33857399225234985
0.2698429822921753
0.2899532616138458
0.27511832118034363
0.27700310945510864
0.24798843264579773
0.27581045031547546
0.26187363266944885
0.25144296884536743
0.21957996487617493
0.1566087156534195
0.3100586533546448
0.26605430245399475
0.23568087816238403
0.23832878470420837
0.20436321198940277
0.1831567883491516
0.20349836349487305
0.22161522507667542
0.17290720343589783
0.20474368333816528
0.19474287331104279
0.17118656635284424
0.17147117853164673
0.21577991545200348
0.17923547327518463
0.140830397605896
0.14410808682441711
0.1882449984550476
0.1866573840379715
0.11655132472515106
0.15486755967140198
0.14660781621932983
0.12824159860610962
0.15132743120193481
0.1334061175584793
0.20296142995357513
0.12636318802833557
0.1625654548406601
0.08958902209997177
0.12333649396896362
0.15218965709209442
0.11382555216550827
0.10162606835365295
0.10816758126020432
0.10274434834718704
0.12289168685674667
0.10438963025808334
0.08674982935190201
0.1028674989938736
0.08711696416139603
0.07698816806077957
0.08109548687934875
0.06715775281190872
0.07266759127378464
0.053090039640665054
0.07961154729127884
0.05293424427509308
0.07046812772750854
0.045054301619529724
0.06686286628246307
0.053763989359140396
0.04599469155073166
0.05298565700650215
0.06006673350930214
0.04921303689479828
0.04958082735538483
0.039590805768966675
0.03502899035811424
0.03933865949511528
0.03255784884095192
0.039780888706445694
0.02695368602871895
0.027571601793169975
0.026216330006718636
0.02672029659152031
0.02476891689002514
0.02674967609345913
0.0158847626298666
0.01885199546813965
0.019399970769882202
0.014962169341742992
0.014805333688855171
0.01998368464410305
0.011708715930581093
0.010520897805690765
0.014515080489218235
0.02223592810332775
0.012649240903556347
0.010369312949478626
0.019156325608491898
0.007956527173519135
0.009646212682127953
0.030131597071886063
0.016366995871067047
0.008477347902953625
0.0205403920263052
0.010162584483623505
0.00732888700440526
0.011322029866278172
0.013002569787204266
0.006618652492761612
0.009711711667478085
0.01353226788341999
0.0073477174155414104
0.006182166747748852
0.023682788014411926
0.0071188174188137054
0.008455581963062286
0.010672803036868572
0.008660484105348587
0.012788690626621246
0.008466217666864395
0.012007967568933964
0.0094989612698555
0.01667686551809311
0.005961017683148384
0.010867079719901085
0.006196053232997656
0.008975710719823837
0.006306353956460953
0.011914270929992199
0.004981494043022394
0.01392379030585289
0.008882598020136356
0.008897442370653152
0.006885786075145006
0.007534661330282688
0.006127095781266689
0.018794095143675804
0.008173449896275997
0.009045412763953209
0.014930066652595997
0.005779666360467672
0.008115990087389946
0.008533732034265995
0.0053793941624462605
0.008895311504602432
0.010190223343670368
0.009747040458023548
0.013520555570721626
0.01555438619107008
0.009125531651079655
0.004846183583140373
0.008840315043926239
0.006142988335341215
0.00682963989675045
0.005336494650691748
0.008544986136257648
0.007181923370808363
0.0079987533390522
0.008249332197010517
0.007545026484876871
0.005713548045605421
0.007601512596011162
0.006309641525149345
0.007781594526022673
0.009450632147490978
0.00563865015283227
0.005204386543482542
0.013657323084771633
0.004672298673540354
0.007816526107490063
0.004166348371654749
0.008768604137003422
0.006561242509633303
0.004972578026354313
0.004284904804080725
0.0045523750595748425
0.004305124282836914
0.009094522334635258
0.0031612617895007133
0.002519429661333561
0.005293068941682577
0.011170553974807262
0.010991688817739487
0.0026463111862540245
0.011244051158428192
0.00749776978045702
0.009228956885635853
0.005880804732441902
0.006507831625640392
0.0052168108522892
0.019541669636964798
0.007737858686596155
0.00443919887766242
0.008342987857758999
0.004857320338487625
0.004187632352113724
0.003544635372236371
0.003264597151428461
0.004047788679599762
0.0038327567745000124
0.0033765847329050303
0.004485307727009058
0.004895927384495735
0.003994782455265522
0.010241317562758923
0.0038198409602046013
0.006741153076291084
0.008207053877413273
0.006743066478520632
0.005747954361140728
0.0033938002306967974
0.006985098123550415
0.003965070471167564
0.010252466425299644
0.003549654735252261
0.0037857431452721357
0.006333546247333288
0.00235589942894876
0.0034418809227645397
0.002953439252451062
0.005676734261214733
0.0022003022022545338
0.011668642982840538
0.005117220804095268
0.0028124162927269936
0.0055918567813932896
0.0024021295830607414
0.0036265042144805193
0.0031013020779937506
0.004805361852049828
0.006669086404144764
0.00416877306997776
0.003464368637651205
0.0037992531433701515
0.0038867415860295296
0.0023626270703971386
0.004572764039039612
0.002248072996735573
0.008810912258923054
0.0054762703366577625
0.007094301749020815
0.0025862320326268673
0.004284600727260113
0.0041335029527544975
0.0022353539243340492
0.009093865752220154
0.004370956681668758
0.0031514496076852083
0.0035022764932364225
0.005059209652245045
0.0033674356527626514
0.0030193852726370096
0.0025780329015105963
0.004098314326256514
0.0020134553778916597
0.0028946870006620884
0.0019452879205346107
0.0056954906322062016
0.003126366063952446
0.003821610240265727
0.002129784319549799
0.002042047679424286
0.004874676000326872
0.005857694894075394
0.004085984081029892
0.0028199926018714905
0.009319768287241459
0.0026162774302065372
0.0021293489262461662
0.0027072022203356028
0.0034876144491136074
0.00335272797383368
0.005954155698418617
0.004648131318390369
0.002195154083892703
0.0031238216906785965
0.0033800932578742504
0.004454123787581921
0.0030535110272467136
0.00393053749576211
0.0022039765026420355
0.0031073607970029116
0.0028656134381890297
0.002457177732139826
0.002298101782798767
0.002877561841160059
0.0029112917836755514
0.0018818110693246126
0.0028768032789230347
0.0022500003688037395
0.001954483799636364
0.002482531126588583
0.0011331253917887807
0.0011141860159114003
0.0033472508657723665
0.005585329607129097
0.002410265849903226
0.005820651073008776
0.0017693298868834972
0.0017087000887840986
0.001373580889776349
0.003248336724936962
0.004919007886201143
0.002239574445411563
0.002034867415204644
0.0023711416870355606
0.002297616796568036
0.0013825627975165844
0.0016273170476779342
0.0030522008892148733
0.0013080616481602192
0.0022402415052056313
0.0015440104762092233
0.004494350403547287
0.0017410890432074666
0.0011850498849526048
0.003489218419417739
0.0022658517118543386
0.003485520137473941
0.0012510307133197784
0.0025768885388970375
0.002183937933295965
0.001516107004135847
0.002048900118097663
0.001280587282963097
0.001397680607624352
0.0029074465855956078
0.0016966218827292323
0.0017056555952876806
0.001783750019967556
0.0018331382889300585
0.0010273134103044868
0.0013285300228744745
0.0034452148247510195
0.00437086820602417
0.0009205617243424058
0.001054459484294057
0.0015975538408383727
0.0017761063063517213
0.0010076797334477305
0.0011514350771903992
0.002952246693894267
0.001855501439422369
0.00070126389618963
0.0013156462227925658
0.0021787937730550766
0.0010941127547994256
0.0016162851825356483
0.0009285698761232197
0.0019138300558552146
0.0009639544878154993
0.0013132510939612985
0.0012826721649616957
0.0014462722465395927
0.0009325985447503626
0.0009472963865846395
0.0010517302434891462
0.0008456077193841338
0.0022388482466340065
0.0021026248577982187
0.0030053032096475363
0.0010222006822004914
0.0014408573042601347
0.0012121436884626746
0.0015380540862679482
0.0006259444635361433
0.0006868245545774698
Finished Training
fc1.weight tensor([[ 0.3721],
        [-0.3295],
        [-0.8585],
        [-0.5114],
        [-0.3996],
        [-0.9245],
        [ 0.8753],
        [ 0.8704],
        [ 0.9115],
        [ 0.0135]])
fc1.bias tensor([ 1.1292, -0.2312, -0.3269, -0.4945,  0.3395,  0.7283, -0.2415,  0.9651,
         0.4693, -0.9102])
fc2.weight tensor([[ 9.5948e-02,  3.9746e-02, -4.1991e-02,  5.6819e-02,  1.3780e-01,
         -1.3759e-01,  2.4883e-01, -2.3085e-01,  3.1752e-01,  2.6596e-01],
        [ 1.6440e-01,  1.0791e-01,  2.4116e-01, -3.5249e-01,  2.8502e-03,
         -2.1309e-01,  3.4830e-01, -1.2172e-02,  3.1129e-01,  1.9786e-01],
        [ 2.0905e-01, -1.5890e-01, -2.8715e-02, -2.8519e-01,  2.0892e-01,
         -9.2011e-02,  3.0359e-01,  2.3966e-01,  2.7374e-01, -1.3720e-01],
        [-2.5713e-01, -2.0884e-01,  5.9911e-02, -6.8203e-02,  3.8436e-02,
         -1.4092e-01, -1.8512e-01,  8.3490e-02, -1.7116e-01,  1.6170e-01],
        [ 3.1665e-01,  2.4182e-01, -7.7140e-02, -3.5004e-01, -5.2356e-02,
          1.9387e-01, -4.0406e-01, -3.6750e-01, -8.6176e-02, -3.1524e-01],
        [-7.7758e-03,  1.9777e-02, -5.0071e-03,  1.4394e-01,  2.6122e-01,
         -2.3958e-01, -2.2489e-01, -3.1524e-01,  3.7644e-02, -2.3002e-01],
        [-2.2262e-01, -2.2564e-01, -1.7846e-01, -6.8525e-02,  2.9318e-01,
         -7.0839e-02,  4.1510e-03, -1.5924e-01, -7.5550e-02,  1.0270e-01],
        [ 3.0554e-01, -8.2371e-02,  2.2087e-01, -8.6428e-02,  2.5604e-01,
          2.0730e-01,  3.5102e-01,  1.7744e-01,  1.5738e-01,  2.4604e-01],
        [ 3.9082e-01, -5.7637e-01, -3.2515e-01, -2.6174e-01,  1.9399e-02,
         -2.6776e-01, -2.0496e-01,  2.5031e-01,  9.9010e-02, -3.0642e-01],
        [-1.6070e-01,  1.0554e-01, -2.5513e-01, -1.8325e-02, -3.5559e-01,
          1.2858e-01, -1.3098e-01,  2.2081e-01,  2.6941e-01, -2.1347e-01],
        [ 2.5908e-01, -3.2340e-01,  1.5940e-01, -4.5033e-01,  2.6449e-01,
          1.1920e-01, -8.8357e-03,  3.5177e-01, -3.1935e-02,  2.1965e-01],
        [ 2.8944e-01, -2.8803e-01,  1.2965e-01,  9.8315e-03, -3.0397e-01,
         -1.0221e-01, -1.4344e-01,  7.8904e-02,  2.7641e-01, -2.5588e-01],
        [ 2.8773e-01, -3.0796e-01,  7.5191e-02, -2.6509e-01, -5.2185e-02,
         -3.6362e-01,  5.1373e-02,  1.0309e-01,  2.5663e-02,  2.2465e-01],
        [-2.0094e-01,  5.6251e-02,  1.3748e-01,  2.2335e-01,  4.1196e-01,
         -8.6509e-02, -7.5779e-03,  2.1911e-01,  2.1871e-02,  2.8864e-01],
        [ 2.4129e-01, -2.5313e-01,  2.6164e-01, -6.0113e-02, -1.7291e-01,
         -1.9485e-01, -1.1896e-01,  3.6797e-01,  3.2871e-02, -1.2588e-01],
        [-2.9837e-01, -2.0906e-01, -3.7566e-02, -2.7766e-01, -8.1368e-02,
         -9.7512e-02,  2.5353e-01, -1.0673e-01,  1.2526e-01, -9.0056e-02],
        [-3.0757e-01, -3.1805e-02,  1.8534e-01, -1.2516e-01, -1.8489e-02,
         -1.7930e-01,  1.8774e-01, -7.5238e-02, -2.4108e-01,  1.0063e-01],
        [ 1.6883e-01,  3.8211e-04,  2.4381e-01, -2.7841e-01, -4.0433e-01,
         -7.8925e-02,  3.2364e-03,  3.5408e-01, -6.6366e-02,  1.9972e-01],
        [ 1.1284e-01,  2.0313e-01,  2.8658e-01,  7.5850e-02, -5.6530e-02,
          1.2204e-01,  1.8304e-01, -1.5863e-01, -2.7456e-01,  1.1286e-01],
        [ 9.5115e-02, -2.7967e-01,  3.2540e-01,  6.9701e-02, -1.2832e-01,
          2.8414e-01, -3.1242e-01,  1.0254e-01, -2.5878e-01, -1.7183e-01]])
fc2.bias tensor([-0.2677,  0.1064,  0.3526, -0.0374,  0.4561,  0.1586,  0.0320,  0.1695,
         0.2299,  0.2226, -0.0681,  0.0921,  0.3121,  0.2696,  0.0507, -0.0712,
        -0.2867,  0.1343,  0.3636,  0.1651])
fc3.weight tensor([[ 0.2264,  0.2731,  0.1997, -0.1967,  0.0925,  0.0655, -0.2147,  0.0687,
         -0.1583, -0.0250,  0.2710, -0.1579, -0.0012,  0.0079, -0.0343,  0.2083,
         -0.2130, -0.1654,  0.0950, -0.0060],
        [-0.3148,  0.0697, -0.0594, -0.1206, -0.1136, -0.0563, -0.0753, -0.0802,
          0.3156, -0.0267,  0.1237,  0.1390, -0.0907, -0.0419,  0.0921,  0.0035,
         -0.0646,  0.2318, -0.0630, -0.1212],
        [ 0.0369, -0.0099, -0.0146, -0.2198,  0.0514,  0.1544,  0.0431, -0.0032,
          0.0918,  0.2136,  0.0839,  0.2284,  0.0906, -0.0716, -0.0236,  0.0035,
          0.1050,  0.0676, -0.0453,  0.0505],
        [ 0.2019,  0.1029, -0.1724,  0.2126, -0.0079, -0.1904, -0.0884,  0.1582,
          0.0538, -0.0228,  0.0809,  0.0894, -0.1648,  0.1427,  0.0693,  0.0252,
          0.0204,  0.0356,  0.2935,  0.0358],
        [-0.0382,  0.1262,  0.2815, -0.1656, -0.1745, -0.0436, -0.0300,  0.1551,
          0.3153,  0.2605,  0.0904,  0.3011,  0.1789, -0.0059,  0.3071, -0.0023,
          0.1951,  0.2652, -0.1214, -0.0429],
        [-0.2895,  0.2136,  0.2434,  0.0899, -0.0446,  0.1205,  0.1122, -0.1585,
          0.1609,  0.1443,  0.1347, -0.1256,  0.2018, -0.0311,  0.1582, -0.2228,
         -0.0495,  0.1253, -0.3098, -0.2082],
        [ 0.0753, -0.0625,  0.1167,  0.2068,  0.1717,  0.1977, -0.0191,  0.3729,
         -0.1675, -0.1808,  0.2829, -0.2164, -0.1562,  0.0173, -0.0173,  0.1346,
         -0.0452, -0.0258,  0.2259,  0.2776],
        [ 0.1527, -0.2998,  0.1249,  0.1759,  0.3822,  0.0473,  0.0712, -0.1180,
         -0.2744, -0.0193,  0.0439,  0.1360, -0.0755,  0.1174,  0.0792, -0.1337,
         -0.2136,  0.0599,  0.2079,  0.0247],
        [-0.0572,  0.1760, -0.0109,  0.1384,  0.0633,  0.0511,  0.1249, -0.1667,
          0.0084,  0.1752, -0.1058, -0.2064, -0.0048,  0.1084,  0.1272, -0.1980,
         -0.1469, -0.2001,  0.1276, -0.1351],
        [-0.0523, -0.0670, -0.2114,  0.1301,  0.1551, -0.1630,  0.1385, -0.1210,
          0.1904,  0.1850, -0.0549, -0.1425, -0.0239, -0.2057, -0.1383,  0.0558,
          0.1950, -0.0719, -0.0827,  0.0355]])
fc3.bias tensor([-0.1290,  0.1935,  0.0259, -0.0953, -0.1686,  0.1954, -0.0455,  0.3036,
         0.0078, -0.1845])
fc4.weight tensor([[-0.2992,  0.2344,  0.0157, -0.1589,  0.2604,  0.2133, -0.2932, -0.2656,
         -0.2334,  0.0213]])
fc4.bias tensor([-0.1267])
dataiter = iter(test_loader)
in_signal, labels = next(dataiter)
out_signal = net(in_signal).detach().numpy()
print('Loss is:', loss.detach().numpy())
Loss is: 0.00068682455
fig, ax = plt.subplots()
ax.plot(in_signal)
ax.plot(labels,'--')
ax.plot(out_signal,':')
plt.show()

png

Associate Professor for Signal Processing

My research interests include the virtual- and psychoacoustics, physical modeling and the design of virtual worlds.